Post on 08-May-2018
Computing GradientHung-yi Lee
李宏毅
Introduction
• Backpropagation: an efficient way to compute the gradient
• Prerequisite • Backpropagation for feedforward net:
• http://speech.ee.ntu.edu.tw/~tlkagk/courses/MLDS_2015_2/Lecture/DNN%20backprop.ecm.mp4/index.html
• Simple version: https://www.youtube.com/watch?v=ibJpTrp5mcE
• Backpropagation through time for RNN: http://speech.ee.ntu.edu.tw/~tlkagk/courses/MLDS_2015_2/Lecture/RNN%20training%20(v6).ecm.mp4/index.html
• Understanding backpropagation by computational graph
• Tensorflow, Theano, CNTK, etc.
Computational Graph
Computational Graph
• A “language” describing a function• Node: variable (scalar, vector, tensor ……)
• Edge: operation (simple function)
𝑎 = 𝑓 𝑏, 𝑐
a
b
c
Example 𝑦 = 𝑓 𝑔 ℎ 𝑥
𝑢 = ℎ 𝑥 𝑣 = 𝑔 𝑢 𝑦 = 𝑓 𝑣
x u v yℎ 𝑔 𝑓
𝑓
Computational Graph
• Example: e = (a+b) ∗ (b+1)
a b
c d
e
e = c ∗ d
c = a + b
d = b + 1∗
++1
2 1
3 2
6
Review: Chain Rule
Case 1
Case 2
yhz xgy
dx
dy
dy
dz
dx
dz
zyx
yxkz , shy sgx
ds
dy
y
z
ds
dx
x
z
ds
dz
s z
x
y
xfz
x y zg h
sfz
x
y
zs
Computational Graph
• Example: e = (a+b) ∗ (b+1)
a b
c d
e
∗
++1𝜕𝑐
𝜕𝑎
b
d
d
e
b
c
c
e
b
e
Compute Τ𝜕𝑒 𝜕𝑏
𝜕𝑐
𝜕𝑏
𝜕𝑑
𝜕𝑏
𝜕𝑒
𝜕𝑑
𝜕𝑒
𝜕𝑐Sum over all paths from b to e
=1=1 =1
=d =c
=b+1
=a+b
=1x(b+1)+1x(a+b)
Computational Graph
• Example: e = (a+b) ∗ (b+1)
a b
c d
e
∗
++1𝜕𝑐
𝜕𝑎
b
d
d
e
b
c
c
e
b
e
Compute Τ𝜕𝑒 𝜕𝑏
𝜕𝑐
𝜕𝑏
𝜕𝑑
𝜕𝑏
𝜕𝑒
𝜕𝑑
𝜕𝑒
𝜕𝑐
=1=1 =1
=d =c
3 2
5 3
15
3 5a = 3, b = 2
Computational Graph
• Example: e = (a+b) ∗ (b+1)
a b
c d
e
∗
++1𝜕𝑐
𝜕𝑎
Compute Τ𝜕𝑒 𝜕𝑏
𝜕𝑐
𝜕𝑏
𝜕𝑑
𝜕𝑏
𝜕𝑒
𝜕𝑑
𝜕𝑒
𝜕𝑐
=1=1 =1
=d =c5
1Start with 1
1 1
8=8Compute Τ𝜕𝑒 𝜕𝑏
a = 3, b = 23
Computational Graph
• Example: e = (a+b) ∗ (b+1)
a b
c d
e
∗
++1𝜕𝑐
𝜕𝑎
Compute Τ𝜕𝑒 𝜕𝑎
𝜕𝑐
𝜕𝑏
𝜕𝑑
𝜕𝑏
𝜕𝑒
𝜕𝑑
𝜕𝑒
𝜕𝑐
=1=1 =1
=d =c3 5
1Start with 1
1
3=3
a = 3, b = 2
Computational Graph
• Example: e = (a+b) ∗ (b+1)
a b
c d
e
∗
++1𝜕𝑐
𝜕𝑎
𝜕𝑐
𝜕𝑏
𝜕𝑑
𝜕𝑏
𝜕𝑒
𝜕𝑑
𝜕𝑒
𝜕𝑐
=1=1 =1
=d =c3 5
3
3
1 Start with 1
5
Reverse mode
8a
e
b
e
a
e
Compute
b
e
and
a = 3, b = 2
What is the benefit?
Computational Graph
• Parameter sharing: the same parameters appearing in different nodes
x
x
v
u
x
y
𝑦 = 𝑥𝑒𝑥2 𝜕𝑦
𝜕𝑥=? 𝑒𝑥
2+ 𝑥 ∙ 𝑒𝑥
2∙ 2𝑥
∗ ∗
exp
𝑢
𝑥𝑥
𝑥
𝑢
= 𝑒𝑥2
= 𝑒𝑥2
𝜕𝑦
𝜕𝑥= 𝑥 ∙ 𝑒𝑥
2∙ 𝑥
𝜕𝑦
𝜕𝑥= 𝑥 ∙ 𝑒𝑥
2∙ 𝑥
𝜕𝑦
𝜕𝑥= 𝑒𝑥
2
Computational Graphfor Feedforward Net
Review:Backpropagation
l
i
l
ij
l
i
l
ij zw
z
w
C
C
Forward Pass Backward Pass
11 ll za
1211 llll baWz
1
11
lx
la
j
l
j
……
1
2
j
……
1
2
il
ijw
Layer lLayer 1l
111 bxWz
11 za
l
i
Error signal
l
iz
11 • lTlll Wz
CL
y
Lzδ •
LTLLL Wz • 11
Review:Backpropagation
l
i1
2
n…
1
C
y
Lz1
Lz2
L
nz
2y
C
ny
C
Layer L
2
…
1L
1
z
1
m
Layer L-1
……
……
…… TW L
CyL1-L
1L
2
z
1L mz……
Backward Pass
11 • lTlll Wz
CL
y
Lzδ •
LTLLL Wz • 11
l
i
l
ij
l
i
l
ij z
C
w
z
w
C
Error signal
(we do not use softmax here)
Feedforward Network
y
x
W1
b1
z1 a1
𝜎z2 y
W2
b2
𝜎
𝑧1 = 𝑊1𝑥 + 𝑏1 𝑧2 = 𝑊2𝑎1 + 𝑏2
= 𝜎 𝜎 b1W1 x +𝜎 b2W2 + bLWL +… …
Loss Function of Feedforward Network
x
W1
b1
z1 a1
𝜎z2 y
W2
b2
𝜎
𝑧1 = 𝑊1𝑥 + 𝑏1 𝑧2 = 𝑊2𝑎1 + 𝑏2
ො𝑦 𝐶
𝐶 = 𝐿 𝑦, ො𝑦
Gradient of Cost Function
x
W1
z1 a1
𝜎z2 y
W2
ො𝑦 𝐶
𝐶 = 𝐿 𝑦, ො𝑦
𝜎
b1 b2
To compute the gradient …
Computing the partial derivative on the edge
Using reverse mode Output is always a scalar
𝜕𝐶
𝜕𝑊2
𝜕𝐶
𝜕𝑊1
𝜕𝐶
𝜕𝑏2𝜕𝐶
𝜕𝑏1
𝜕𝑎1
𝜕𝑧1=?
vector
vector
Jacobian Matrix
𝑦 = 𝑓 𝑥 𝑥 =
𝑥1𝑥2𝑥3
𝑦 =𝑦1𝑦2
𝜕𝑦
𝜕𝑥=
Τ𝜕𝑦1 𝜕𝑥1 Τ𝜕𝑦1 𝜕𝑥2 Τ𝜕𝑦1 𝜕𝑥3Τ𝜕𝑦2 𝜕𝑥1 Τ𝜕𝑦2 𝜕𝑥2 Τ𝜕𝑦2 𝜕𝑥3
Example
𝑥1 + 𝑥2𝑥32𝑥3
= 𝑓
𝑥1𝑥2𝑥3
𝜕𝑦
𝜕𝑥=
1 𝑥3 𝑥20 0 2
size of y
size of x
Gradient of Cost Function
y
ො𝑦 𝐶
𝐶 = 𝐿 𝑦, ො𝑦
Last Layer
1y
2y
iy
…
1y
2y
iy
……… ො𝑦 =
00⋮1⋮
𝑟C
𝐶 = −𝑙𝑜𝑔𝑦𝑟Cross Entropy:
𝑖 ≠ 𝑟:
𝑖 = 𝑟:
𝜕𝐶
𝜕𝑦𝑖=?
Τ𝜕𝐶 𝜕𝑦𝑖 = 0
Τ𝜕𝐶 𝜕𝑦𝑟 = −1/𝑦𝑟𝜕𝐶
𝜕𝑦= 0 ⋯ −1/𝑦𝑟 ⋯
Τ𝜕𝐶 𝜕𝑦
Gradient of Cost Function
z2 y𝜎
ො𝑦 𝐶
𝐶 = 𝐿 𝑦, ො𝑦
𝜕𝑦
𝜕𝑧2𝑖 ≠ 𝑗: 0Τ𝜕𝑦𝑖 𝜕𝑧𝑗
2 =
𝑖 = 𝑗: Τ𝜕𝑦𝑖 𝜕𝑧𝑖2 = 𝜎′ 𝑧𝑖
2
𝜎′ 𝑧𝑖2
𝑧𝑖2
𝑧
𝑦
𝜕𝑦
𝜕𝑧2is a Jacobian matrix
i-th row, j-th column: Τ𝜕𝑦𝑖 𝜕𝑧𝑗2 Τ𝜕𝐶 𝜕𝑦
How about softmax? ☺
square 𝑦
𝑧2
Diagonal Matrix
𝑦𝑖 = 𝜎 𝑧𝑖2
a1 z2 y
W2
ො𝑦 𝐶
𝐶 = 𝐿 𝑦, ො𝑦
𝑧2 = 𝑊2𝑎1𝜕𝑦
𝜕𝑧2𝜕𝑧2
𝜕𝑎1
𝜕𝑧2
𝜕𝑊2
𝜕𝑧2
𝜕𝑎1is a Jacobian matrix
𝑧𝑖2 = 𝑤𝑖1
2 𝑎11 +𝑤𝑖2
2 𝑎21 +
⋯+𝑤𝑖𝑛2 𝑎𝑛
1
𝜕𝑧𝑖2
𝜕𝑎𝑗1 = 𝑊𝑖𝑗
2
𝑧2𝑎1
i-th row, j-th column:
𝜎
𝑊2
l
i
l
l
l
i
l
l
ll
ll
l
i
l
l
b
b
b
a
a
a
ww
ww
z
z
z
2
1
1
1
2
1
1
2221
12112
1
Τ𝜕𝐶 𝜕𝑦
Considering W2
as a mxn vector
𝜕𝑧2
𝜕𝑊2=
a1 z2 y
W2
ො𝑦 𝐶
𝐶 = 𝐿 𝑦, ො𝑦
𝑧2 = 𝑊2𝑎1𝜕𝑦
𝜕𝑧2
𝜕𝑧2
𝜕𝑊2
𝜎𝑊2
m
mxn
𝑧𝑖2 = 𝑤𝑖1
2 𝑎11 +𝑤𝑖2
2 𝑎21 +
⋯+𝑤𝑖𝑛2 𝑎𝑛
1
l
i
l
l
l
i
l
l
ll
ll
l
i
l
l
b
b
b
a
a
a
ww
ww
z
z
z
2
1
1
1
2
1
1
2221
12112
1
𝑖 = 𝑗:
𝑖 ≠ 𝑗:𝜕𝑧𝑖
2
𝜕𝑊𝑗𝑘2 = 0
𝜕𝑧𝑖2
𝜕𝑊𝑖𝑘2 = 𝑎𝑘
1
𝜕𝑧𝑖2
𝜕𝑊𝑗𝑘2 =?
𝜕𝑧𝑖2
𝜕𝑊𝑗𝑘2
i
(j-1)xn+k
mxn
n m
Τ𝜕𝐶 𝜕𝑦
(considering Τ𝜕𝑧2 𝜕𝑊2 as a tensor makes thing easier)
𝜕𝑧2
𝜕𝑊2=
a1 z2 y
W2
ො𝑦 𝐶
𝐶 = 𝐿 𝑦, ො𝑦
𝑧2 = 𝑊2𝑎1𝜕𝑦
𝜕𝑧2
𝜕𝑧2
𝜕𝑊2
𝜎𝑊2
m
mxn
𝑖 = 𝑗:
𝑖 ≠ 𝑗:𝜕𝑧𝑖
2
𝜕𝑊𝑗𝑘2 = 0
𝜕𝑧𝑖2
𝜕𝑊𝑖𝑘2 = 𝑎𝑘
1
𝜕𝑧𝑖2
𝜕𝑊𝑗𝑘2 =?
𝜕𝑧𝑖2
𝜕𝑊𝑗𝑘2
i
(j-1)xn+k
𝑚
𝑚 × 𝑛
𝑎
𝑎𝑎
0
0i=1
i=2
Τ𝜕𝐶 𝜕𝑦
mxn
n m
𝑗 = 1 𝑗 = 2 𝑗 = 3
Considering W2
as a mxn vector
x
W1
z1 a1
𝜎z2 y
W2
𝜎
ො𝑦 𝐶
𝐶 = 𝐿 𝑦, ො𝑦
𝜕𝑦
𝜕𝑧2𝜕𝑎1
𝜕𝑧1 𝑊2𝑊1
𝜕𝑧2
𝜕𝑊2
𝜕𝑧1
𝜕𝑊1
𝜕𝐶
𝜕𝑦
𝜕𝑦
𝜕𝑧2𝑊2
𝜕𝑎1
𝜕𝑧1
𝜕𝐶
𝜕𝑊1
𝜕𝐶
𝜕𝑊1=
𝜕𝑧1
𝜕𝑊1= [⋯
𝜕𝐶
𝜕𝑊𝑖𝑗1 ⋯]
Τ𝜕𝐶 𝜕𝑦
Question
Q: Only backward pass for computational graph?
Q: Do we get the same results from the two different approaches?
Computational Graphfor Recurrent Network
Recurrent Network
fh0 h1
y1
x1
f h2
y2
x2
f h3
y3
x3
……
𝑦𝑡 , ℎ𝑡 = 𝑓 𝑥𝑡 , ℎ𝑡−1;𝑊𝑖 ,𝑊ℎ ,𝑊𝑜
𝑦𝑡 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥 𝑊𝑜ℎ𝑡
ℎ𝑡 = 𝜎 𝑊𝑖𝑥𝑡 +𝑊ℎℎ𝑡−1
(biases are ignored here)
Recurrent Network
𝑦𝑡 , ℎ𝑡 = 𝑓 𝑥𝑡, ℎ𝑡−1;𝑊𝑖 ,𝑊ℎ,𝑊𝑜
𝑦𝑡 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥 𝑊𝑜ℎ𝑡𝑎𝑡 = 𝜎 𝑊𝑖𝑥𝑡 +𝑊ℎℎ𝑡−1
Wh
Wi
Wo
x1
a0
y1
h1z1m1
n1
o1
×
×
×
+ 𝜎
𝑠𝑜𝑓𝑡𝑚𝑎𝑥
Recurrent Network
Wh
Wi
Wo
x1
h0
y1
h1
𝑦𝑡 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥 𝑊𝑜ℎ𝑡
ℎ𝑡 = 𝜎 𝑊𝑖𝑥𝑡 +𝑊ℎℎ𝑡−1
𝑦𝑡 , ℎ𝑡 = 𝑓 𝑥𝑡, ℎ𝑡−1;𝑊𝑖 ,𝑊ℎ,𝑊𝑜
𝑦𝑡 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥 𝑊𝑜ℎ𝑡𝑎𝑡 = 𝜎 𝑊𝑖𝑥𝑡 +𝑊ℎℎ𝑡−1
Recurrent Network
Wh
Wi
Wo
x1
h0
y1
h1
ො𝑦1 C1
C
Wh
Wi
Wo
x2
y2
h2
Wh
Wi
Wo
x3
y3
h3
ො𝑦2 C2 ො𝑦3 C3
𝐶 = 𝐶1 + 𝐶2 + 𝐶3
Wh
Wi
Wo
x1
h0
y1
h1
ො𝑦1 C1
C
Wh
Wi
Wo
x2
y2
h2
Wh
Wi
Wo
x3
y3
h3
ො𝑦2 C2 ො𝑦3 C3
𝐶 = 𝐶1 + 𝐶2 + 𝐶3
111
Τ𝜕𝐶3 𝜕𝑦3Τ𝜕𝐶2 𝜕𝑦2Τ𝜕𝐶1 𝜕𝑦1
Τ𝜕𝑦3 𝜕ℎ3Τ𝜕𝑦2 𝜕ℎ2Τ𝜕𝑦1 𝜕ℎ1
Τ𝜕ℎ3 𝜕𝑊ℎ
Τ𝜕ℎ3 𝜕ℎ2Τ𝜕ℎ2 𝜕ℎ1
Τ𝜕ℎ2 𝜕𝑊ℎΤ𝜕ℎ1 𝜕𝑊ℎ
𝜕𝐶
𝜕𝑊ℎ
𝜕𝐶
𝜕𝑊ℎ
𝜕𝐶
𝜕𝑊ℎ
Recurrent Network
1 2 3
𝜕𝐶
𝜕𝑊ℎ1
𝜕ℎ1
𝜕𝑊ℎ
𝜕𝐶1
𝜕𝑦1𝜕𝑦1
𝜕ℎ1𝜕𝐶2
𝜕𝑦2𝜕𝑦2
𝜕ℎ2𝜕ℎ2
𝜕ℎ1
𝜕𝐶3
𝜕𝑦3𝜕𝑦3
𝜕ℎ3𝜕ℎ3
𝜕ℎ2𝜕ℎ2
𝜕ℎ1+ +=
𝜕𝐶
𝜕𝑊ℎ2 = ……
𝜕𝐶
𝜕𝑊ℎ3 = ……
𝜕𝐶
𝜕𝑊ℎ=
𝜕𝐶
𝜕𝑊ℎ1 +
𝜕𝐶
𝜕𝑊ℎ2 +
𝜕𝐶
𝜕𝑊ℎ3
Reference
• Textbook: Deep Learning
• Chapter 6.5
• Calculus on Computational Graphs: Backpropagation
• https://colah.github.io/posts/2015-08-Backprop/
• On chain rule, computational graphs, and backpropagation
• http://outlace.com/Computational-Graph/
Acknowledgement
•感謝翁丞世同學找到投影片上的錯誤