Composing graphical models and neural networks for ... · Composing graphical models and neural...
Transcript of Composing graphical models and neural networks for ... · Composing graphical models and neural...
Composing graphical models and neural networksfor structured representations and fast inference
modelingidea
use PGM priors to organize the latent space, along withneural net observation models for flexible representations
inferenceidea
use PGMs to synthesize information from recognition netsinstead of making a single inference net do everything
TL;DR variational autoencoders + latent graphical models
probabilistic graphical models
+ structured representations
+ priors and uncertainty
+ data and computational efficiency within rigid model classes
– rigid assumptions may not fit
– feature engineering
– more flexible models can require slow top-down inference
deep neural networks
– neural net “goo”
– difficult parameterization
– can require lots of data
+ flexible, high capacity
+ feature learning
+ recognition networks for fast bottom-up inference
automatically learn representations in which structured PGMs fit well
supervisedlearning
unsupervisedlearning
motivation
qp p q p qnatural gradient SVI
q
?
(x) , argmax
q(x)L[ q(✓)q(x) ] q
?(x) , N (x|µ(y;�),⌃(y;�))q
?(x) , ?
variational autoencoders structured VAEs (this work)
+ optimal local factor
– expensive for general obs.
+ exploit conj. graph structure
+ arbitrary inference queries
+ natural gradients
– suboptimal local factor
+ fast for general obs.
– does all local inference
– limited inference queries
– no natural gradients
± optimal given conj. evidence
+ fast for general obs.
+ exploit conj. graph structure
+ arbitrary inference queries
+ some natural gradients
�
reardart
z1 z2 z3 z4 z5 z6 z7
x1 x2 x3 x4 x5 x6 x7
y1 y2 y3 y4 y5 y6 y7
✓
�
inference
L(⌘✓
, ⌘�
, ⌘x
) , Eq(✓)q(�)q(x)
hlog
p(✓,�,x)p(y | x,�)q(✓)q(�)q(x)
i
bL(⌘✓
, ⌘x
,�) , Eq(✓)q(�)q(x)
hlog
p(✓,�,x) exp{ (x;y,�)}q(✓)q(�)q(x)
i
where �(x; y,�) is a conjugate potential for p(x | ✓)
⌘⇤x
(⌘✓
,�),argmax
⌘
x
bL(⌘✓
, ⌘x
,�) LSVAE(⌘✓, ⌘� ,�),L(⌘✓
, ⌘�
, ⌘⇤x
(⌘✓
,�))
learning to parse mouse behavior from depth video
1(a)
step 1
sampleminibatch
apply recognitionnetworks…
1(b)
…to getPGM potentials
1(c)
compute local evidencewith recognition networks step 2
to fixed point
run localmean field
2(a)
and messagepassing
2(b)
2(c)
run fast PGM inference step 3
sample and computeflat gradients w.r.t and � �3(a)
use already-computed values toget natural gradient w.r.t. ⌘✓
3(b)
compute unbiased ELBO gradientswith respect to all parameters
Matthew James JohnsonDavid Duvenaud
Alexander B. WiltschkoSandeep R. Datta
Ryan P. Adams
[email protected]@[email protected]@[email protected]
Proposition (log evidence lower bound)
LSVAE(⌘✓, ⌘� ,�)
max
⌘
x
L(⌘✓
, ⌘�
, ⌘x
)
log p(y)
⌘✓, ⌘�
max
⌘
x
L(⌘✓
, ⌘�
, ⌘x
)
log p(y)
⌘✓, ⌘�
max
�LSVAE(⌘✓, ⌘� ,�)
if 9� 2 Rmwith (x; y,�) = Eq(�) log p(y |x, �)
yn
✓
�xn
Eq(�) log p(yn |xn, �)
xn
✓
�xn
yn
(xn; yn,�)
xn
/b/ /ax/ /n/ /ae/ /n/ /ax/
0
10 20 30 40 50 60 7010
2030
40
mm
10
20
30
40
mm
50
60
10 20 30 40 50 60 70
mm0
10 20 30 40 50 60 7010
2030
40
mm
10
20
30
40
mm
50
60
10 20 30 40 50 60 70
mm
0mm10 20 30 40 50 60 70 9080 100 110 120 130 140 150
1020
3040
mm
10
20
30
40
mm
50
60
10 20 30 40 50 60 70 9080 100 110 120 130 140 150
0
10 20 30 40 50 60 7010
2030
40
mm
10
20
30
40
mm
50
60
10 20 30 40 50 60 70
mm0
10 20 30 40 50 60 7010
2030
40
mm
10
20
30
40
mm
50
60
10 20 30 40 50 60 70
mm
0mm10 20 30 40 50 60 70 9080 100 110 120 130 140 150
1020
3040
mm
10
20
30
40
mm
50
60
10 20 30 40 50 60 70 9080 100 110 120 130 140 150
dart pause rear
main idea learn to summarize complicated evidencewith simple conjugate potentials (as in CRFs)
fit a latent switching linear dynamical system (SLDS)and a neural network image model for observations
SVAEapproach
flat gradient
natural gradient
github.com/mattjj/svae