A Note on BPTT for LSTM LM
-
Upload
tomonari-masada -
Category
Engineering
-
view
601 -
download
2
Transcript of A Note on BPTT for LSTM LM
A Note on BPTT for LSTM LM
Tomonari MASADA @ Nagasaki University
January 15, 2015
1 Forward pass
K is the vocabulary size. N is the number of hidden layers. Dn is the number of hidden units at the nthlayer. The input sequence is X ≡ (x1, . . . ,xT ), where each xt ≡ (xt1, . . . , xtK)⊤ is a one-hot vector.
• ι : input gates
• ϕ : forget gates
• η : cells
• ω : output gates
• h : cell outputs
f and g are sigmoid functions.
ιnt = f(W nxιxt +W n
h−ιhn−1t +W n
hιhnt−1 +wn
ηι ⊙ ηnt−1 + bnι ) ≡ f(an
ι,t) (1)
ϕnt = f(W n
xϕxt +W nh−ϕh
n−1t +W n
hϕhnt−1 +wn
ηϕ ⊙ ηnt−1 + bnϕ) ≡ f(an
ϕ,t) (2)
ηnt = ϕn
t ⊙ ηnt−1 + ιnt ⊙ g(W n
xηxt +W nh−ηh
n−1t +W n
hηhnt−1 + bnη ) ≡ ϕn
t ⊙ ηnt−1 + ιnt ⊙ g(an
η,t) (3)
ωnt = f(W n
xωxt +W nh−ωh
n−1t +W n
hωhnt−1 +wn
ηω ⊙ ηnt + bnω) ≡ f(an
ω,t) (4)
hnt = ωn
t ⊙ g(ηnt ) , (5)
where ⊙ is the element-wise product. The superscript n means the n-th layer.These can be rewritten for d = 1, . . . , Dn separately:
ιnt,d = f(wn⊤xι,dxt +wn⊤
h−ι,dhn−1t +wn⊤
hι,dhnt−1 + wn
ηι,dηnt−1,d + bnι,d) ≡ f(anι,t,d) (6)
ϕnt,d = f(wn⊤
xϕ,dxt +wn⊤h−ϕ,dh
n−1t +wn⊤
hϕ,dhnt−1 + wn
ηϕ,dηnt−1,d + bnϕ,d) ≡ f(anϕ,t,d) (7)
ηnt,d = ϕnt,dη
nt−1,d + ιnt,dg(w
n⊤xη,dxt +wn⊤
h−η,dhn−1t +wn⊤
hη,dhnt−1 + bnη,d) ≡ ϕn
t,dηnt−1,d + ιnt,dg(a
nη,t,d) (8)
ωnt,d = f(wn⊤
xω,dxt +wn⊤h−ω,dh
n−1t +wn⊤
hω,dhnt−1 + wn
ηω,dηnt,d + bnω,d) ≡ f(anω,t,d) (9)
hnt,d = ωn
t,dg(ηnt,d) . (10)
Word probabilities for each t are obtained as follows:
yt = by +
N∑n=1
W nhyh
nt , yt ≡
exp(yt,k)∑Kk′=1 exp(yt,k′)
. (11)
These can be rewritten for k = 1, . . . ,K separately:
yt,k = by,k +
N∑n=1
Dn∑d=1
wnhy,kdh
nt,d , yt,k ≡ exp(yt,k)∑K
k′=1 exp(yt,k′). (12)
1
2 Negative log likelihood
L(X) = −T∑
t=1
log yt,xt+1 = −T∑
t=1
logexp(yt,xt+1)∑Kk=1 exp(yt,k)
= −T∑
t=1
yt,xt+1 +T∑
t=1
log{ K∑
k=1
exp(yt,k)}
(13)
3 Backward pass
When f(a) = 11+exp(−a) , f
′(a) = f(a)(1− f(a)). When f(a) = tanh(a), f ′(a) = 1− f(a)2.
3.1
∂L(X)
∂yt,k= −
∂yt,xt+1
∂yt,k+
∂
∂yt,klog
{ K∑k=1
exp(yt,k)}= −
∂yt,xt+1
∂yt,k+
∑Kk=1 exp(yt,k)
∂yt,k
∂yt,k∑Kk=1 exp(yt,k)
= −{δ(xt+1 = k)− yt,k
}(14)
For wnhy,kd,
∂yt,k∂wn
hy,kd
=∂by,k
∂wnhy,kd
+
N∑n=1
Dn∑d=1
(hnt,d
∂wnhy,kd
∂wnhy,kd
+ wnhy,kd
∂hnt,d
∂wnhy,kd
)= hn
t,d (15)
∴ ∂L(X)
∂wnhy,kd
=
T∑t=1
∂L(X)
∂yt,k
∂yt,k∂wn
hy,kd
= −T∑
t=1
hnt,d
{δ(xt+1 = k)− yt,k
}(16)
3.2 Output errors
∂L(X)
∂hnt,d
=K∑
k=1
∂L(X)
∂yt,k
∂yt,k∂hn
t,d
+
Dn∑d=1
∂L(X)
∂anω,t+1,d
∂anω,t+1,d
∂hnt,d
+
Dn∑d=1
∂L(X)
∂anη,t+1,d
∂anη,t+1,d
∂hnt,d
+
Dn∑d=1
∂L(X)
∂anϕ,t+1,d
∂anϕ,t+1,d
∂hnt,d
+
Dn∑d=1
∂L(X)
∂anι,t+1,d
∂anι,t+1,d
∂hnt,d
+
Dn+1∑d=1
∂L(X)
∂an+1ω,t,d
∂an+1ω,t,d
∂hnt,d
+
Dn+1∑d=1
∂L(X)
∂an+1η,t,d
∂an+1η,t,d
∂hnt,d
+
Dn+1∑d=1
∂L(X)
∂an+1ϕ,t,d
∂an+1ϕ,t,d
∂hnt,d
+
Dn+1∑d=1
∂L(X)
∂an+1ι,t,d
∂an+1ι,t,d
∂hnt,d
=
K∑k=1
wnhy,kd
∂L(X)
∂yt,k+
Dn∑d=1
wnhω,dd
∂L(X)
∂anω,t+1,d
+
Dn∑d=1
wnhη,dd
∂L(X)
∂anη,t+1,d
+
Dn∑d=1
wnhϕ,dd
∂L(X)
∂anϕ,t+1,d
+
Dn∑d=1
wnhι,dd
∂L(X)
∂anι,t+1,d
+
Dn+1∑d=1
wn+1h−ω,dd
∂L(X)
∂an+1ω,t,d
+
Dn+1∑d=1
wn+1h−η,dd
∂L(X)
∂an+1η,t,d
+
Dn+1∑d=1
wn+1h−ϕ,dd
∂L(X)
∂an+1ϕ,t,d
+
Dn+1∑d=1
wn+1h−ι,dd
∂L(X)
∂an+1ι,t,d
(17)
2
3.3 Output gates
∂L(X)
∂anω,t,d
=∂L(X)
∂hnt,d
∂hnt,d
∂anω,t,d
(18)
∂hnt,d
∂anω,t,d
= g(ηnt,d)f′(anω,t,d) (19)
3.4 States
∂L(X)
∂ηnt,d=
∂L(X)
∂hnt,d
∂hnt,d
∂ηnt,d+
∂L(X)
∂ηnt+1,d
∂ηnt+1,d
∂ηnt,d+
∂L(X)
∂anϕ,t+1,d
∂anϕ,t+1,d
∂ηnt,d+
∂L(X)
∂anι,t+1,d
∂anι,t+1,d
∂ηnt,d
=∂L(X)
∂hnt,d
∂hnt,d
∂ηnt,d+ ϕn
t+1,d
∂L(X)
∂ηnt+1,d
+ wηϕ,d∂L(X)
∂anϕ,t+1,d
+ wηι,d∂L(X)
∂anι,t+1,d
(20)
∂hnt,d
∂ηnt,d= g(ηnt,d)
∂ωnt,d
∂ηnt,d+ ωn
t,dg′(ηnt,d)
= wnηω,df
′(anω,t,d)g(ηnt,d) + ωn
t,dg′(ηnt,d)
= wnηω,d
∂hnt,d
∂anω,t,d
+ ωnt,dg
′(ηnt,d) (21)
∴ ∂L(X)
∂ηnt,d=
∂L(X)
∂hnt,d
{wn
ηω,d
∂hnt,d
∂anω,t,d
+ ωnt,dg
′(ηnt,d)
}+ ϕn
t+1,d
∂L(X)
∂ηnt+1,d
+ wηϕ,d∂L(X)
∂anϕ,t+1,d
+ wηι,d∂L(X)
∂anι,t+1,d
= wnηω,d
∂L(X)
∂anω,t,d
+ ωnt,dg
′(ηnt,d)∂L(X)
∂hnt,d
+ ϕnt+1,d
∂L(X)
∂ηnt+1,d
+ wηϕ,d∂L(X)
∂anϕ,t+1,d
+ wηι,d∂L(X)
∂anι,t+1,d
(22)
3.5 Cells, forget gates, and input gates
∂L(X)
∂anη,t,d=
∂ηnt,d∂anη,t,d
∂L(X)
∂ηnt,d= ιnt,dg
′(anη,t,d)∂L(X)
∂ηnt,d(23)
∂L(X)
∂anϕ,t,d=
∂ηnt,d∂anϕ,t,d
∂L(X)
∂ηnt,d= f ′(anϕ,t,d)η
nt−1,d
∂L(X)
∂ηnt,d(24)
∂L(X)
∂anι,t,d=
∂ηnt,d∂anι,t,d
∂L(X)
∂ηnt,d= f ′(anι,t,d)g(η
nt,d)
∂L(X)
∂ηnt,d(25)
References
[1] Thomas Breuel, Volkmar Frinken, and Marcus Liwicki. Long Short-Term Memory Managing Long-Term Dependencies within Sequences. nntutorial-lstm.pdf at http://lstm.iupr.com/files
[2] Alex Graves and J urgen Schmidhuber. Framewise phoneme classification with bidirectional LSTM andother neural network architectures. Neural Networks 18 (2005) 602–610.
[3] Alex Graves. Generating Sequences With Recurrent Neural Networks. arXiv preprint arXiv:/1308.0850,2013.
3