Variational Inference for Diffusion...

25
Variational Inference for Diffusion Processes edric Archambeau Xerox Research Centre Europe [email protected] Joint work with Manfred Opper. Statlearn ’11 Grenoble, March 2011

Transcript of Variational Inference for Diffusion...

Page 1: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Variational Inference for Diffusion Processes

Cedric Archambeau

Xerox Research Centre [email protected]

Joint work with Manfred Opper.

Statlearn ’11Grenoble, March 2011

Page 2: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Stochastic differential systems

Many real dynamical systems are continuous in time:

Data assimilation (e.g. numerical weather prediction)

Systems biology (e.g. cellular stress response, transcription factors)

fMRI brain image data (e.g. voxel based activity)

Modelled by stochastic differential equations (SDEs):

dx(t) = f(x(t), t)dt + D1/2(x(t), t)dw(t),

where dw(t) is a Wiener process (Brownian motion):

dw(t) = lim∆t→0

εt√

∆t, εt ∼ N (0, I).

Deterministic drift f and stochastic diffusion component D

Continuous-time limit of discrete-time state-space model

Page 3: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Stochastic differential systems

Why should we bother?

A lot of theory, few (effective) data driven approaches

Time discretisation is unavoidable in practice

Physics models enforce continuity constraints, such that the numberof observations can be relatively small

High frequency fluctuations can be incorporated into the diffusion

Any discrete representation can be chosen a posteriori

Easy to handle irregular sampling/missing data

Bayesian approaches are natural:

The SDE induces a non-Gaussian prior over sample paths

Define a noise model (or likelihood) and simulate posterior processover trajectories via MCMC (Beskos et al., 2009)

Or develop fast deterministic approximations

Page 4: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Overview

Setting, notations and variational inference

Partially observed diffusion processes

Gaussian variational approximation

Experiments and conclusion

Page 5: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Bayesian inference (framework and notations)

Predictions are made by averaging over all possible models:

p(y∗|y) =

∫p(y∗|x) p(x|y) dx.

The latent variables are inferred using Bayes’ rule:

p(x|y)︸ ︷︷ ︸posterior

=

likelihood︷ ︸︸ ︷p(y|x)

prior︷︸︸︷p(x)

p(y)︸︷︷︸marginal likelihood

, p(y) =

∫p(y, x) dx.

Type II maximum likelihood estimation of the (hyper)parameters θ:

θML2 = argmaxθ

ln p(y|θ),

The marginals are in general analytically intractable:1 We can use Markov chain Monte Carlo to simulate the integrals;

potentially exact, but often slow.2 Or we can focus on fast(er) approximate inference schemes, such as

variational inference.

Page 6: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Approximate Bayesian inference (variational inference)

For any distribution q(x) ≈ p(x|y), we optimise a lower bound tothe log-marginal likelihood:

ln p(y|θ) = ln

∫p(y, x|θ) dx >

∫q(x) ln

p(y, x|θ)

q(x)dx

.= −F(q,θ).

(Variational) EM minimises the variational free energy iterativelyand monotonically (Beal, 2003):

F(q,θ) = − ln p(y|θ) + KL[q(x)‖p(x|y,θ)],

F(q,θ) = −〈ln p(y, x|θ)〉q(x) −H[q(x)].

where KL[q‖p] = Eq{ln qp} is the Kullback-Leibler divergence and

H[q] = −E{ln q)} the entropy.

An alternative approach is to minimise F(q,θ) with your favouriteoptimisation algorithm:

F(q,θ) = −〈ln p(y|x,θ)〉q(x) + KL[q(x)‖p(x|θ)].

Page 7: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Variational inference (continued)

−2 −1 0 1 2 3 40

0.2

0.4

0.6

0.8

1

Monotonic decrease of F ; convergence is easy to monitor (unlikeMCMC)

Deterministic, but different from Laplace approximation

Usually q is assumed to have a factorised form (q(x) ≈ p(x|y))

KL is wrt q; underestimation of correlations between latent variables

Exmaple: variational treatment of Student-t mixtures

Page 8: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Partially observed diffusion process

0 0.2 0.4 0.6 0.8 1!0.8

!0.6

!0.4

!0.2

0

0.2

0.4

0.6

t

W!

Model data by a latent diffusion process:

dx(t) = f(x(t), t)dt + D1/2(x(t), t)dw(t).

where f and D have a known functional form.

Discrete-time likelilhood observation operator:

yn = Hx(t = tn) + ηn.

Goal: infer the states x(t) and learn the parameters of f and Dgiven the data.

Page 9: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Variational inference for diffusion processes

We are interested in the posterior measure over the sample paths:

dP(x(t)|y1, . . . , yN)

dP(x(t))=

1

Z

∏n

P(yn|xt=tn) .

This quantity is non-Gaussian when f is nonlinear (and in generalintractable).

For an approximate measure Q(·), we minimise the variational freeenergy over a certain time interval:

F(Q,θ) = −〈lnP(y1, . . . , yN |x(t),θ)〉Q(x(t)) + KL[dQ(x(t))‖dP(x(t))],

where t ∈ [0,T ].

What is a suitable Q(·)?

Page 10: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Gaussian variational approximation

We restrict ourselves to a state independent diffusion matrix D.

Consider the following linear, but time-dependent SDE:

dx(t) = g(x(t), t)dt + D−1/2(t)dw(t),

whereg(x(t), t)

.= −A(t)x(t) + b(t).

It induces a non-stationary Gaussian measure, with marginal meanand marginal covariance satisfying a set of ODEs:

m(t) = −A(t)m(t) + b(t),

S(t) = −A(t)S(t)− S(t)A>(t) + D(t).

We view A(t) and b(t) as variational parameters and approximatethe posterior process by this non-stationary Gaussian process.

Page 11: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Gaussian process

Multivariate Gaussian:

Probability density over D random variables (based on correlations).

Characterized by a mean vector µ and covariance matrix Σ:

f ≡ (f1, . . . , fD)> ∼ N (µ,Σ).

Gaussian process (GP):

Probability measure over random functions (≈ infinitely long vector).

Marginal over any finite subset of variables is a consistent finitedimensional Gaussian!

Characterized by a mean function and a covariance function (kernel):

f (·) ∼ GP(m(·), k(·, ·)).

Gaussian processes for ML (Rasmussen and Williams, 2006)

A and b specify the kernel (in general no closed form solution)

Page 12: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Consistency constraints and smoothing algorithm

The objective function is of the form

F(Q,θ) =

∫Eobs(t)dt +

∫Esde(t)dt + KL[q(x0)‖p(x0)],

where

Esde(t) = −1

2〈(ft − gt)

>D−1(ft − gt)〉Q(xt).

The diffusion matrix of the linear SDE is by construction equal tothe diffusion matrix of the original SDE (so that F is finite).We enforce consistent Gaussian marginals by using the followingODEs as constraints (forward propagation):

m(t) = −A(t)m(t) + b(t),

S(t) = −A(t)S(t)− S(t)A>(t) + D(t).

Differentiating the Lagrangian leads to a set of ODEs for theLagrange multipliers (backward propagation):

λ(t) = −∇mEsde(t) + A>(t)λ(t), λ+n = λ−n −∇mEobs(t)|t=tn ,

Ψ(t) = −∇SEsde(t) + 2Ψ(t)A(t), Ψ+n = Ψ−n −∇SEobs(t)|t=tn .

Page 13: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Optimal Gaussian variational approximation

The non-linear SDE is reduced to a set of linear ODEs describing theevolution of the means, covariances and Lagrange multipliers.

The smoothing algorithm consists of a forward and a backwardintegration for fixed A(t) and b(t).

The observation are incorporated in the backward pass (cf. jumpconditions).

The optimal Gaussian variational approximation is obtained byoptimising F wrt the variational parameters A(t) and b(t).

At equilibrium, the variational parameters satisfy the followingconditions:

A = −⟨∂f

∂x

⟩+ 2DΨ,

b = 〈f(x)〉+ Am−Dλ.

The variational solution is closely related to statistical linearisation:

{A,b} ← argminA,b

⟨‖f(x) + Ax− b‖2

⟩.

Page 14: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Illustration of the statistical linearisation principle

!4 !2 0 2 4 6!350

!300

!250

!200

!150

!100

!50

0

50

100

150

y

f(y)

p(y) f(y) f(µ) + !f (y!µ) Ay + b

Page 15: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Related approaches

Continuous-time sigma point Kalman smoothers (KS; Sarkka andSottinen, 2008):

Unscented KS and central difference KS.

Gaussian approximation of the transition density.

No feedback loop to adjust the sigma points.

Perfect simulation approaches (Beskos et al., 2009):

No discrete time approximation of the transition density.

Transition density is non-Gaussian.

Drift is restricted to derive from a potential.

Convergence is difficult to monitor, potentially slower.

Other approaches include Particle smoothers, Hybrid MCMC (Eyinck etal., 2004) etc.

Page 16: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Diffusions with multiplicative noise

Apply explicit transformation to obtain a diffusion process withconstant diffusion matrix; such a transformation does not alwaysexist in the multivariate case.

Construct Gaussian variational approximation based on the followingODEs, which hold for any non-linear SDE:

m(t) = −A(t)m(t) + b(t),

S(t) = −A(t)S(t)− S(t)A>(t) + 〈D(x(t), t)〉Q(xt).

The smoothing algorithm is analogue; the expression of A(t) andb(t) is more involved.

Page 17: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Bi-stable dynamical system

The deterministic drift is defined as

f (t, x) = 4x(θ − x2), θ > 0.

The system is driven by the stochastic noise.!2 0 2 4!2

!1.5

!1

!0.5

0

0.5

1

1.5

2

u(x)

x

0 20 40!2

!1.5

!1

!0.5

0

0.5

1

1.5

2

t

x

0 2 4 6 8−2

−1

0

1

2Initialisation

time

stat

e

0 2 4 6 8−2

−1

0

1

2Variational smoother

time

stat

e

0 2 4 6 8−30

−20

−10

0

10

20

30Var params and Lagrange multip

time

Ab!"

0 2 4 6 8−30

−20

−10

0

10

20

30Var params and Lagrange multip

time

Ab!"

0 2 4 6 8−2

−1

0

1

2Initialisation

time

stat

e

0 2 4 6 8−2

−1

0

1

2Variational smoother

time

stat

e

0 2 4 6 8−30

−20

−10

0

10

20

30Var params and Lagrange multip

time

Ab!"

0 2 4 6 8−30

−20

−10

0

10

20

30Var params and Lagrange multip

time

Ab!"

Page 18: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Comparison to hybrid Markov Chain Monte Carlo (Eyincket al., 2004)

Reference solutionBased on a discrete approximationGenerate complete sample paths from posteriorModified MCMC scheme to increase acceptance rate (MolecularDynamics)Still requires to generate in order of 100,000 samples for good resultsHard to check convergence

0 1 2 3 4 5 6 7 8!2

!1.5

!1

!0.5

0

0.5

1

1.5

2

t

x

(a) θ = 1, σ = 0.5.

0 "0 #0 $0 %0 50!#

!"'5

!"

!0'5

0

0'5

"

"'5

#

t

)

(b) Large noise.

Page 19: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Comparison to the continuous-time Unscented KalmanSmoother

0 1 2 3 4 5 6 7 8−2

−1.5

−1

−0.5

0

0.5

1

1.5

2

t

x

Page 20: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Failure mode

0 1 2 3 4 5 6 7 8−2

−1.5

−1

−0.5

0

0.5

1

1.5

2

t

x

Page 21: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Stochastic Lorenz attractor

The Lorenz attractor:

fx = σ(y − z), σ > 0,

fy = ρx − y − xz , ρ > 0,

fz = xy − βz , β > 0.

When adding stochastic noise the system becomes chaotic.

−20 −10 0 10 20 30 −50

0

500

10

20

30

40

50

60

0 0.5 1 1.5 2 2.5 3 3.5 4 4.5 5−20

0

20

40

x 1

Variational smoother

0 0.5 1 1.5 2 2.5 3 3.5 4 4.5 5−40

−20

0

20

40

x 2

0 0.5 1 1.5 2 2.5 3 3.5 4 4.5 50

20

40

60

80

x 3

time

Page 22: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Parameter inference

(Variational) EM fails for the diffusion coefficient:

limδ→0

T/δ∑i=1

(xiδ − x(i−1)δ)(xiδ − x(i−1)δ)> =

∫ T

0

D(x(t), t)dt a.s.

Type II ML based on gradient techniques is ok as we change thesample paths together with the diffusion coefficient.

Cheap estimate of the posterior (sanity check; Lappalainen andMiskin, 2000):

q(θ) =e−F(Q,θ)p(θ)∫e−F(Q,θ)p(θ)dθ

0.3 0.4 0.5 0.6 0.7 0.8 0.90

0.2

0.4

0.6

0.8

1

1.2

1.4

1.6

D

Page 23: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Conclusion

Stochastic process models are very powerful when the number ofobservations is small compared to the complexity of the dynamics.

Gaussian variational approximation for non-linear SDEs boils downto solving a set of ODEs.

Preferred integration scheme can be used, no discrete timeapproximation of the transition density.

Can be viewed as generalisation of sigma-point Kalman smoother forcertain instantiations of the statistical linearisation principle.

Considerably faster than (most) MCMC schemes.

Diffusion matrix can be estimated; multiplicative noise is ok (inprinciple).

Error bars are underestimated.

Page 24: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

References

C. Archambeau, M. Opper: Approximate inference for continuous time Markovprocesses. Inference and Estimation in Probabilistic Time-Series Models.Cambridge University Presse, 2011.

C. Archambeau, M. Opper, Y. Shen, D. Cornford, J. Shawe-Taylor: VariationalInference for Diffusion Processes. NIPS 20, pp.17-24, 2008.

A. Beskos, et al. : Monte-Carlo maximum likelihood estimation for discretelyobserved diffusion processes. Annals of Statistics, 37:1, pp 223-245, 2009.

G. L. Eyink, J. L. Restrepo and F. J. Alexander: A mean field approximation indata assimilation for nonlinear dynamics. Physica D, 194:347368, 2004.

I. Karatzas and S. E. Schreve. Brownian Motion and Stochastic Calculus.Springer, 1998.

H. Lappalainen and J.W. Miskin: Ensemble learning. In M. Girolami, editor,Advances in Independent Component Analysis, pp 7692. Springer-Verlag, 2000.

C. E. Rasmussen and C. K.I. Williams: Gaussian Processes for MachineLearning. MIT Press, 2006.

S. Sarkka and T. Sottinen: Application of Girsanov Theorem to Particle Filteringof Discretely Observed Continuous-Time Non-Linear Systems. Bayesian Analysis,3:3, pp 555-584, 2008.

Page 25: Variational Inference for Diffusion Processesmistis.inrialpes.fr/statlearn/slides/Statlearn11_Archambeau.pdf · Stochastic di erential systems Many real dynamical systems are continuous

Informal proof for KL[Q(x(t))‖P(x(t))]

Consider the Euler-Muryama discrete approximation of the SDEs:

∆xk = fk∆t + D1/2∆wk , wk ∼ N (0,∆tI),

∆xk = gk∆t + D1/2∆wk , wk ∼ N (0,∆tI),

where ∆xk ≡ xk+1 − xk .

The joint distributions of discrete sample paths {xk}k≥0 for the true processand its approximation follow from the Markov property:

p(x0, . . . , xK |D) = p(x0)∏k>0

N (xk+1|xk + fk∆t,D∆t),

q(x0, . . . , xK |D) = q(x0)∏k>0

N (xk+1|xk + gk∆t,D∆t).

The KL between the two discretised processes is then given by

KL[q‖p] = KL[q(x0)‖p(x0)]−∑k>0

∫q(xk)

⟨ln

p(xk+1|xk)

q(xk+1|xk)

⟩q(xk+1|xk )

dxk

= KL[q(x0)‖p(x0)] +1

2

∑k>0

〈(fk − gk)>D−1(fk − gk)〉q(xk )∆t,

Passing to the limit is ok! (Formal proof based on the Girsanov theorem.)