Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of...
Transcript of Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of...
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
FastandScalableBayesianDeepLearningbyWeight-PerturbationinAdam
Mohammad EmtiyazKhan*RIKENCenterforAIProject(AIP),Tokyo,Japan
Didrik Nielsen*(AIPRIKEN),Voot Tangkaratt*(AIPRIKEN),WuLin(UBC),Yarin Gal(UniversityofOxford),AkashSrivastava(UniversityofEdinburgh)
*EqualContribution
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
BayesianDeepLearning
Computeaveragesoverthesamplesfromtheposteriordistribution
2/13
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
ApproximateBayesianInference
ConvertBayesianinferencetoanoptimizationproblemusingVariational Inference(VI),andthenusegradient-basedmethods foroptimization
3/13
BayesbyBackprop (Blundelletal.2015),PracticalVI(Gravesetal.2011),Black-boxVI(Rangnathan etal.2014)andmanymore….
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
ApproximateBayesianInferencerequiresmorecomputation,memory,andimplementation
effortthanMLEIsitpossibletoreducethesecosts?
Byreplacinggradientswithnatural-gradients
4/13
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
MaximumLikelihoodEstimation(MLE)
5/13
RMSprop forMLE
max
✓
NX
i=1
log p(Di|✓) Log-likelihood
Backprop onminibatches
Scalevector(gradient-magnitude)
Adaptivegradientupdate
✓ µ
g 1M
X
i
r✓ log p(Di|✓)
s (1� �)s+ �g2
µ µ+ ↵gps+ �
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
GaussianMean-FieldVariational Inference
6/13
p(✓) = N (0, I/�) Knownpriorprecision
p(✓|D) =p(D|✓)p(✓)Rp(D|✓)p(✓)d✓⇡ q(✓) = N (µ,�2) Covariancematrix=diag(𝜎")
max
µ,�2L(µ,�2
) :=
NX
i=1
Eq[log p(Di|✓)]�KLhq(✓)kp(✓)
i
Data-fitterm Regularizer
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
MLEvsGradient-basedVI
7/13
RMSprop forMax-likelihood Gradient-basedVariational Inference
max
✓
NX
i=1
log p(Di|✓)
µ µ+ ↵rµLpsµ + �
� � + ↵r�Lps� + �
(Gravesetal.2011,Blundelletal.2015)
max
µ,�2L(µ,�2
) :=
NX
i=1
Eq[log p(Di|✓)]�KLhq(✓)kp(✓)
i
✓ µ
g 1M
X
i
r✓ log p(Di|✓)
s (1� �)s+ �g2
µ µ+ ↵gps+ �
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
MLEvsNatural-GradientVI
8/13
RMSprop forMax-likelihood Natural-GradientVI(Khan,Lin2017,Khan,Nielsen2018)
Variational Online-Newton(VON)Khanetal.2017
✓ µ
g 1M
X
i
r✓ log p(Di|✓)
s (1� �)s+ �g2
µ µ+ ↵gps+ �
✓ µ+✏, where ✏ ⇠ N (0, Ns+ �)
g 1M
X
i
r✓ log p(Di|✓)
s (1� �)s+ � 1M
X
i
r2✓✓ log p(Di|✓)
µ µ+ ↵g+�µ/N
s+ �/N
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
MLEvsNatural-GradientVI
9/13
RMSprop forMax-likelihood Natural-GradientVI(Khan,Lin2017,Khan,Nielsen2018)
✓ µ
g 1M
X
i
r✓ log p(Di|✓)
s (1� �)s+ �g2
µ µ+ ↵gps+ �
✓ µ+✏, where ✏ ⇠ N (0, Ns+ �)
g 1M
X
i
r✓ log p(Di|✓)
µ µ+ ↵g+�µ/N
s+ �/N
Variational OnlineGauss-Newton(VOGN)
s (1� �)s+ � 1M
X
i
hr✓ log p(Di|✓)
i2
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
MLEvsNatural-GradientVI
10/13
RMSprop forMax-likelihood Natural-GradientVI(Khan,Lin2017,Khan,Nielsen2018)
✓ µ
g 1M
X
i
r✓ log p(Di|✓)
s (1� �)s+ �g2
µ µ+ ↵gps+ �
✓ µ+✏, where ✏ ⇠ N (0, Ns+ �)
g 1M
X
i
r✓ log p(Di|✓)
s (1� �)s+ �g2
Variational RMSprop (Vprop)
µ µ+ ↵g+�µ/Nps+ �/N
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
Variational Adam(Vadam)
11/13
AdamforMax-likelihood Vadam forVI
✓ µ
g 1M
X
i
r✓ log p(Di|✓)
s (1� �)s+ �g2
m (1� �)m+ �g
m m/(1� (1� �)t)
s s/(1� (1� �)t)
µ µ+ ↵mps+ �
✓ µ+✏, where N (0, Ns+ �)
g 1M
X
i
r✓ log p(Di|✓)
s (1� �)s+ �g2
m (1� �)m+ �(g+�µ/N)
m m/(1� (1� �)t)
s s/(1� (1� �)t)
µ µ+ ↵mp
s+ �/N
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
Summary:UncertaintyusingAdam
Perturbtheweightsbeforebackprop.Chooseasmallminibatch size.
12/13
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190) 13/13
Bayesianlogisticregressionon“Breast-Cancer”(N=683,D=8)
Aswereducetheminibatch size,
Vadam givessimilarperformanceas
VOGN.
ErrorinPo
steriorA
pproximation
VOGN
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
0.0 0.5 1.0 1.5 2.0Epoch
0.0
0.2
0.4
0.6
0.8
1.0
log 2
loss
BBVI
Vadam
CVI-GGN
14/13
1layer64hiddenUnitswithReLu on
BreastCancer[N=683,D=10]
BBVIVadamVOGN
VOGNshowsfastconvergence
(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)
FastandScalableBayesianDeepLearningbyWeight-PerturbationinAdam
Postertonight(HallB#190)Codeavailableathttps://github.com/emtiyaz/vadam/
Alsocheckout“NoisyNatural-gradientasVI”byZhangetal.atthisconferene
13/15