A Note on BPTT for LSTM LM

3
A Note on BPTT for LSTM LM Tomonari MASADA @ Nagasaki University January 15, 2015 1 Forward pass K is the vocabulary size. N is the number of hidden layers. D n is the number of hidden units at the nth layer. The input sequence is X (x 1 ,..., x T ), where each x t (x t1 ,...,x tK ) is a one-hot vector. ι : input gates ϕ : forget gates η : cells ω : output gates h : cell outputs f and g are sigmoid functions. ι n t = f (W n x t + W n h - ι h n-1 t + W n h n t-1 + w n ηι η n t-1 + b n ι ) f (a n ι,t ) (1) ϕ n t = f (W n x t + W n h - ϕ h n-1 t + W n h n t-1 + w n ηϕ η n t-1 + b n ϕ ) f (a n ϕ,t ) (2) η n t = ϕ n t η n t-1 + ι n t g(W n x t + W n h - η h n-1 t + W n h n t-1 + b n η ) ϕ n t η n t-1 + ι n t g(a n η,t ) (3) ω n t = f (W n x t + W n h - ω h n-1 t + W n h n t-1 + w n ηω η n t + b n ω ) f (a n ω,t ) (4) h n t = ω n t g(η n t ) , (5) where is the element-wise product. The superscript n means the n-th layer. These can be rewritten for d =1,...,D n separately: ι n t,d = f (w nxι,d x t + w nh - ι,d h n-1 t + w nhι,d h n t-1 + w n ηι,d η n t-1,d + b n ι,d ) f (a n ι,t,d ) (6) ϕ n t,d = f (w nxϕ,d x t + w nh - ϕ,d h n-1 t + w nhϕ,d h n t-1 + w n ηϕ,d η n t-1,d + b n ϕ,d ) f (a n ϕ,t,d ) (7) η n t,d = ϕ n t,d η n t-1,d + ι n t,d g(w nxη,d x t + w nh - η,d h n-1 t + w nhη,d h n t-1 + b n η,d ) ϕ n t,d η n t-1,d + ι n t,d g(a n η,t,d ) (8) ω n t,d = f (w nxω,d x t + w nh - ω,d h n-1 t + w nhω,d h n t-1 + w n ηω,d η n t,d + b n ω,d ) f (a n ω,t,d ) (9) h n t,d = ω n t,d g(η n t,d ) . (10) Word probabilities for each t are obtained as follows: ˆ y t = b y + N n=1 W n hy h n t , y t exp(ˆ y t,k ) K k =1 exp(ˆ y t,k ) . (11) These can be rewritten for k =1,...,K separately: ˆ y t,k = b y,k + N n=1 D n d=1 w n hy,kd h n t,d , y t,k exp(ˆ y t,k ) K k =1 exp(ˆ y t,k ) . (12) 1

Transcript of A Note on BPTT for LSTM LM

Page 1: A Note on BPTT for LSTM LM

A Note on BPTT for LSTM LM

Tomonari MASADA @ Nagasaki University

January 15, 2015

1 Forward pass

K is the vocabulary size. N is the number of hidden layers. Dn is the number of hidden units at the nthlayer. The input sequence is X ≡ (x1, . . . ,xT ), where each xt ≡ (xt1, . . . , xtK)⊤ is a one-hot vector.

• ι : input gates

• ϕ : forget gates

• η : cells

• ω : output gates

• h : cell outputs

f and g are sigmoid functions.

ιnt = f(W nxιxt +W n

h−ιhn−1t +W n

hιhnt−1 +wn

ηι ⊙ ηnt−1 + bnι ) ≡ f(an

ι,t) (1)

ϕnt = f(W n

xϕxt +W nh−ϕh

n−1t +W n

hϕhnt−1 +wn

ηϕ ⊙ ηnt−1 + bnϕ) ≡ f(an

ϕ,t) (2)

ηnt = ϕn

t ⊙ ηnt−1 + ιnt ⊙ g(W n

xηxt +W nh−ηh

n−1t +W n

hηhnt−1 + bnη ) ≡ ϕn

t ⊙ ηnt−1 + ιnt ⊙ g(an

η,t) (3)

ωnt = f(W n

xωxt +W nh−ωh

n−1t +W n

hωhnt−1 +wn

ηω ⊙ ηnt + bnω) ≡ f(an

ω,t) (4)

hnt = ωn

t ⊙ g(ηnt ) , (5)

where ⊙ is the element-wise product. The superscript n means the n-th layer.These can be rewritten for d = 1, . . . , Dn separately:

ιnt,d = f(wn⊤xι,dxt +wn⊤

h−ι,dhn−1t +wn⊤

hι,dhnt−1 + wn

ηι,dηnt−1,d + bnι,d) ≡ f(anι,t,d) (6)

ϕnt,d = f(wn⊤

xϕ,dxt +wn⊤h−ϕ,dh

n−1t +wn⊤

hϕ,dhnt−1 + wn

ηϕ,dηnt−1,d + bnϕ,d) ≡ f(anϕ,t,d) (7)

ηnt,d = ϕnt,dη

nt−1,d + ιnt,dg(w

n⊤xη,dxt +wn⊤

h−η,dhn−1t +wn⊤

hη,dhnt−1 + bnη,d) ≡ ϕn

t,dηnt−1,d + ιnt,dg(a

nη,t,d) (8)

ωnt,d = f(wn⊤

xω,dxt +wn⊤h−ω,dh

n−1t +wn⊤

hω,dhnt−1 + wn

ηω,dηnt,d + bnω,d) ≡ f(anω,t,d) (9)

hnt,d = ωn

t,dg(ηnt,d) . (10)

Word probabilities for each t are obtained as follows:

yt = by +

N∑n=1

W nhyh

nt , yt ≡

exp(yt,k)∑Kk′=1 exp(yt,k′)

. (11)

These can be rewritten for k = 1, . . . ,K separately:

yt,k = by,k +

N∑n=1

Dn∑d=1

wnhy,kdh

nt,d , yt,k ≡ exp(yt,k)∑K

k′=1 exp(yt,k′). (12)

1

Page 2: A Note on BPTT for LSTM LM

2 Negative log likelihood

L(X) = −T∑

t=1

log yt,xt+1 = −T∑

t=1

logexp(yt,xt+1)∑Kk=1 exp(yt,k)

= −T∑

t=1

yt,xt+1 +T∑

t=1

log{ K∑

k=1

exp(yt,k)}

(13)

3 Backward pass

When f(a) = 11+exp(−a) , f

′(a) = f(a)(1− f(a)). When f(a) = tanh(a), f ′(a) = 1− f(a)2.

3.1

∂L(X)

∂yt,k= −

∂yt,xt+1

∂yt,k+

∂yt,klog

{ K∑k=1

exp(yt,k)}= −

∂yt,xt+1

∂yt,k+

∑Kk=1 exp(yt,k)

∂yt,k

∂yt,k∑Kk=1 exp(yt,k)

= −{δ(xt+1 = k)− yt,k

}(14)

For wnhy,kd,

∂yt,k∂wn

hy,kd

=∂by,k

∂wnhy,kd

+

N∑n=1

Dn∑d=1

(hnt,d

∂wnhy,kd

∂wnhy,kd

+ wnhy,kd

∂hnt,d

∂wnhy,kd

)= hn

t,d (15)

∴ ∂L(X)

∂wnhy,kd

=

T∑t=1

∂L(X)

∂yt,k

∂yt,k∂wn

hy,kd

= −T∑

t=1

hnt,d

{δ(xt+1 = k)− yt,k

}(16)

3.2 Output errors

∂L(X)

∂hnt,d

=K∑

k=1

∂L(X)

∂yt,k

∂yt,k∂hn

t,d

+

Dn∑d=1

∂L(X)

∂anω,t+1,d

∂anω,t+1,d

∂hnt,d

+

Dn∑d=1

∂L(X)

∂anη,t+1,d

∂anη,t+1,d

∂hnt,d

+

Dn∑d=1

∂L(X)

∂anϕ,t+1,d

∂anϕ,t+1,d

∂hnt,d

+

Dn∑d=1

∂L(X)

∂anι,t+1,d

∂anι,t+1,d

∂hnt,d

+

Dn+1∑d=1

∂L(X)

∂an+1ω,t,d

∂an+1ω,t,d

∂hnt,d

+

Dn+1∑d=1

∂L(X)

∂an+1η,t,d

∂an+1η,t,d

∂hnt,d

+

Dn+1∑d=1

∂L(X)

∂an+1ϕ,t,d

∂an+1ϕ,t,d

∂hnt,d

+

Dn+1∑d=1

∂L(X)

∂an+1ι,t,d

∂an+1ι,t,d

∂hnt,d

=

K∑k=1

wnhy,kd

∂L(X)

∂yt,k+

Dn∑d=1

wnhω,dd

∂L(X)

∂anω,t+1,d

+

Dn∑d=1

wnhη,dd

∂L(X)

∂anη,t+1,d

+

Dn∑d=1

wnhϕ,dd

∂L(X)

∂anϕ,t+1,d

+

Dn∑d=1

wnhι,dd

∂L(X)

∂anι,t+1,d

+

Dn+1∑d=1

wn+1h−ω,dd

∂L(X)

∂an+1ω,t,d

+

Dn+1∑d=1

wn+1h−η,dd

∂L(X)

∂an+1η,t,d

+

Dn+1∑d=1

wn+1h−ϕ,dd

∂L(X)

∂an+1ϕ,t,d

+

Dn+1∑d=1

wn+1h−ι,dd

∂L(X)

∂an+1ι,t,d

(17)

2

Page 3: A Note on BPTT for LSTM LM

3.3 Output gates

∂L(X)

∂anω,t,d

=∂L(X)

∂hnt,d

∂hnt,d

∂anω,t,d

(18)

∂hnt,d

∂anω,t,d

= g(ηnt,d)f′(anω,t,d) (19)

3.4 States

∂L(X)

∂ηnt,d=

∂L(X)

∂hnt,d

∂hnt,d

∂ηnt,d+

∂L(X)

∂ηnt+1,d

∂ηnt+1,d

∂ηnt,d+

∂L(X)

∂anϕ,t+1,d

∂anϕ,t+1,d

∂ηnt,d+

∂L(X)

∂anι,t+1,d

∂anι,t+1,d

∂ηnt,d

=∂L(X)

∂hnt,d

∂hnt,d

∂ηnt,d+ ϕn

t+1,d

∂L(X)

∂ηnt+1,d

+ wηϕ,d∂L(X)

∂anϕ,t+1,d

+ wηι,d∂L(X)

∂anι,t+1,d

(20)

∂hnt,d

∂ηnt,d= g(ηnt,d)

∂ωnt,d

∂ηnt,d+ ωn

t,dg′(ηnt,d)

= wnηω,df

′(anω,t,d)g(ηnt,d) + ωn

t,dg′(ηnt,d)

= wnηω,d

∂hnt,d

∂anω,t,d

+ ωnt,dg

′(ηnt,d) (21)

∴ ∂L(X)

∂ηnt,d=

∂L(X)

∂hnt,d

{wn

ηω,d

∂hnt,d

∂anω,t,d

+ ωnt,dg

′(ηnt,d)

}+ ϕn

t+1,d

∂L(X)

∂ηnt+1,d

+ wηϕ,d∂L(X)

∂anϕ,t+1,d

+ wηι,d∂L(X)

∂anι,t+1,d

= wnηω,d

∂L(X)

∂anω,t,d

+ ωnt,dg

′(ηnt,d)∂L(X)

∂hnt,d

+ ϕnt+1,d

∂L(X)

∂ηnt+1,d

+ wηϕ,d∂L(X)

∂anϕ,t+1,d

+ wηι,d∂L(X)

∂anι,t+1,d

(22)

3.5 Cells, forget gates, and input gates

∂L(X)

∂anη,t,d=

∂ηnt,d∂anη,t,d

∂L(X)

∂ηnt,d= ιnt,dg

′(anη,t,d)∂L(X)

∂ηnt,d(23)

∂L(X)

∂anϕ,t,d=

∂ηnt,d∂anϕ,t,d

∂L(X)

∂ηnt,d= f ′(anϕ,t,d)η

nt−1,d

∂L(X)

∂ηnt,d(24)

∂L(X)

∂anι,t,d=

∂ηnt,d∂anι,t,d

∂L(X)

∂ηnt,d= f ′(anι,t,d)g(η

nt,d)

∂L(X)

∂ηnt,d(25)

References

[1] Thomas Breuel, Volkmar Frinken, and Marcus Liwicki. Long Short-Term Memory Managing Long-Term Dependencies within Sequences. nntutorial-lstm.pdf at http://lstm.iupr.com/files

[2] Alex Graves and J urgen Schmidhuber. Framewise phoneme classification with bidirectional LSTM andother neural network architectures. Neural Networks 18 (2005) 602–610.

[3] Alex Graves. Generating Sequences With Recurrent Neural Networks. arXiv preprint arXiv:/1308.0850,2013.

3