Download - Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C ...pitt.edu/~sjh95/project/spgru/uai2019_poster.pdfSeong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C. Johnson Vikas Singh

Transcript
Page 1: Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C ...pitt.edu/~sjh95/project/spgru/uai2019_poster.pdfSeong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C. Johnson Vikas Singh

Sampling-free Uncertainty Estimation in Gated Recurrent Unitswith Applications to Normative Modeling in Neuroimaging

Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C. Johnson Vikas Singh

MOTIVATION

1. Given a visually ’good looking’ sequence prediction, how can we tell that its trajectoryis correct?

2. If it is, can we derive the degree of uncertainty on its prediction?

SP-GRU

Input Sequence

𝑡 = 1 𝑡 = 10

Ground Truth

Output Prediction

Model Uncertainty Map

𝑡 = 11 𝑡 = 20

Figure: Image sequence prediction with uncertainty. Given the first 10 frames of an input sequence (left),our model SP-GRU makes the Output Prediction and the pixel-level Model Uncertainty Map wherebright regions indicate high uncertainty. SP-GRU estimates the uncertainty deterministically withoutsampling model parameters.

GOALDerive a recurrent neural network architecture capable of estimating uncertainty withthe following properties:1. Deterministically estimate uncertainties in a sampling-free manner (e.g., without

Monte Carlo sampling)2. Uncertainties of all intermediate neurons can be expressed in terms of a distribu-

tion

PRELIMINARIES

I Gated Recurrent Unit (GRU):Reset Gate: r t = σ(Wrx t + br)Update Gate: z t = σ(Wzx t + bz)

State Candidate: ht = tanh(Uhx t + Wh(r t � ht−1) + bh)

Cell State: ht = (1− z t)� ht + z t � ht−1

I Exponential Families in Neural Networks: Let x ∈ X be a random variable withprobability density/mass function (pdf/pmf) fX . Then fX is an exponential family dis-tribution if

fX(x |η) = h(x) exp(ηTT (x)− A(η))

with natural parameters η, base measure h(x), and sufficient statistics T (x). Con-stant A(η) (log-partition function) ensures that the distribution normalizes to 1.

× gl

al

W l

W lal al+1 ∼ EXPFAM(gl(W lal))

Figure: A single exponential family neuron. Weights W l are learned, and the output of a neuron is asample generated from the exponential family defined a priori and by the natural parameters g l(W lal−1).

MOMENT MATCHING

I Linear Moment Matching (LMM): (1) the mean am following the standard linearity ofrandom variable expectations and (2) the variance as:

om = Wmam + bm, os = Wsas + bs + (Wm �Wm)as + Ws(am � am) (1)

I Nonlinear Moment Matching (NMM): Using the fact that σ(x) ≈ Φ(ζx) where Φ(·) isa probit function and ζ =

√π/8 is a constant, approximate the sigmoid functions for

am and as:

am ≈ σm(om,os) = σ

(om

(1 + ζ2os)12

), as ≈ σs(om,os) = σ

(ν(om + ω)

(1 + ζ2ν2os)12

)− a2

m (2)

where ν = 4 − 2√

2 and ω = − log(√

2 + 1). The hyperbolic tangent can be derivedfrom tanh(x) = 2σ(2x)− 1.

al−1m

al−1s

olm

ols

alm

als

LMM NMM

Figure: Linear Moment Matching (LMM) and Nonlinear Moment Matching (NMM) are performed at theweights/bias sums and activations respectively.

SAMPLING-FREE PROBABILISTIC GRU (SP-GRU)

ℎ𝑚𝑡−1

𝑧𝑠𝑡

𝑧𝑚𝑡

𝑥𝑚𝑡

𝑥𝑠𝑡

𝑟𝑠𝑡

𝑟𝑚𝑡

𝑥𝑚𝑡

𝑥𝑠𝑡

1 − 𝑧𝑠𝑡

1 − 𝑧𝑚𝑡

ℎ𝑠𝑡

ℎ𝑚𝑡𝑥𝑚

𝑡

𝑥𝑠𝑡

ℎ𝑠𝑡−1

ℎ𝑚𝑡

ℎ𝑠𝑡

Figure: SP-GRU cell structure. Solid lines/boxes and red dotted lines/boxes correspond to operationsand variables for mean m and variance s respectively. Circles are element-wise operators.

Operation Linear Transformation Nonlinear Transformation

Reset Gate otr ,m = Ur ,mx t

m + Wr ,mht−1m + br ,m r t

m = σm(otr ,m,ot

r ,s)

otr ,s = Ur ,sx t

s + Wr ,sht−1s + br ,s + [Ur ,m]2x t

s r ts = σs(ot

r ,m,otr ,s)

+Ur ,s[x tm]2 + [Wr ,m]2ht−1

s + Wr ,s[ht−1m ]2

Update Gate otz,m = Uz,mx t

m + Wz,mht−1m + bz,m z t

m = σm(otz,m,ot

z,s)

otz,s = Uz,sx t

s + Wz,sht−1s + bz,s + [Uz,m]2x t

s z ts = σs(ot

z,m,otz,s)

+Uz,s[x tm]2 + [Wz,m]2ht−1

s + Wz,s[ht−1m ]2

State Candidate oth,m

= Uh,mx tm + Wh,mht−1

m + bh,m htm = tanhm(ot

h,m,ot

h,s)

oth,s

= Uh,sxts + Wh,sh

t−1s + bh,s + [Uh,m]2x t

s hts = tanhs(ot

h,m,ot

h,s)

+Uh,s[x tm]2 + [Wh,m]2ht−1

s + Wh,s[ht−1m ]2

Cell State htm = (1− z t

m)� htm + z t

m � ht−1m Not Needed

hts = [(1− z t

s)]2 � hts + [z t

s]2 � ht−1s

Table: SP-GRU operations in mean and variance. � and [A]2 denotes the Hadamard product and A� Aof a matrix/vector A respectively. Note the Cell State does not involve nonlinear operations.

EXPERIMENT 1: MOVING MNIST

Figure: (a) Angle deviation trajectories. (b) Speed deviation trajectories.θ Ground Truth Prediction Uncertainty

20◦

25◦

30◦

35◦

v Ground Truth Prediction Uncertainty

5.0%

5.5%

6.0%

6.5%

b Input Prediction Uncertainty

0.0

0.2

0.4

0.6

1

Figure: Predictions and uncertainties (frames 11, 15, and 20) from testing varying deviations from trainedtrajectories (first of four rows, blue). Top: angle. Middle: speed. Bottom: pixel-level noise. Right: theaverage sum of per frame pixel-level variance using SP-GRU and MC-LSTM.

2 Moving Digits Prediction︷ ︸︸ ︷ 3 Moving Digits (Out of Domain) Prediction︷ ︸︸ ︷

Figure: SP-GRU predictor results. Left 3 rows: 2 moving digits (top: ground truth, middle: meanprediction, bottom: uncertainty estimate). Right 3 rows: 3 moving digits which are out of domain (i.e., notseen in training).

EXPERIMENT 2: NORMATIVE MODELING IN NEUROIMAGING

1. Brain Connectivity Sequence Sample GenerationOriginal Subject

Ordered and Binned by RAVLT Progression

N Samples

i = 1

i = N

PiB Negative PiB Positive

Original Subjects (Unordered, PiB Positive and Negative)

PiB Negative Ordered and Binned by RAVLT Progression PiB Positive Ordered and Binned by RAVLT Progression

N Samples of PiB Negative

i = 1

i = N

N Samples of PiB Positive

i = 1

i = N

Figure: The preprocessing procedure used to generate sample data for SP-GRU. Left: Globalconnectivity sequence samples. Right: PiB+/- connectivity sequence samples.

2. Normative Probability Map (NPM) (Marquand et al., 2016) For each subject i , ourtrue response at time j for connectivity k is given by yijk , with a bin-level variance ofσnjk . The SP-GRU predicts a mean response yijk and variation σijk .Normative Probability Map (NPM): zijk = (yijk − yijk)/

(√σ2

ijk + σ2njk

).

3. Normative Modeling in Neuroimaging: Pipeline(1) Test Sample Inputs

(each i and t: 1761 connectivities)

i = 1

i = N

(2) Predictions with Uncertainties(each i and t: 1761 means and variances)

t = 1 t = 2 t = 3 t = 4

(3) Sequential Normative Probability Maps(each i and t: 1761 NPMs)

SP-GRU (Trained Predictor)

Given t = 1,2,3,4

Predict t = 5,6,7,8

t = 5 t = 6 t = 7 t = 8

i = 1

i = N

1761

NPMs

1761

NPMs

1761

NPMs

1761

NPMs

1761

NPMs

1761

NPMs

1761

NPMs

1761

NPMs

t = 5 t = 6 t = 7 t = 8

i = 1

i = N

(4) Sequential Extreme Value Statistics(each i and t: 1 EVS)

t = 5 t = 6 t = 7 t = 8

i = 1

i = N

1 EVS 1 EVS1 EVS1 EVS

1 EVS 1 EVS1 EVS1 EVS

(5) Extreme Value Distributions

1. Construct histogram of EVS for each t

2. Fit generalized extreme value distributions (GED)

3. Derive confidence intervalst = 5 t = 6 t = 7 t = 8

Extreme Value Statistics

Robust summary of NPMs

(Mean of top 5%)

1761

NPMs1 EVS

t = 5 t = 6 t = 7 t = 8

1 EVS 1 EVS1 EVS1 EVS

(6) Sequential EVS of a New Subject N’

(following the above pipeline on a new subject)(7) Outlier Detection

Outlier EVS in at least one t ⇒ Outlier subject

t = 5 t = 6 t = 7 t = 8

i = N’ 1 EVS 1 EVS1 EVS1 EVS

Figure: Normative modeling pipeline for preclinical AD. (1) Given a set of test inputs (t = 1,2,3,4), (2)use the pretrained SP-GRU to make mean and variance predictions for each connectivity andt = 5,6,7,8. (3) Compute NPM for each prediction, and (4) derive EVS for each sample i and t . (5) FitGED and construct confidence intervals based on N EVS for each t . (6) Given a new sample, derive EVSfollowing (1)-(4), and (7) check the confidence intervals from (5) to determine heterogeneity.

4. Outlier Detection: Cognitively Healthy (PiB-) vs. At-Risk (PiB+) Detectedoutliers: 9 of 100 samples in PiB- and 19 of 100 samples in PiB+.Implication: Larger absolute fluctuations in DWI connectivity may be a goodindicator for disease risk as measured by amyloid burden

Research supported in part by NIH (R01AG040396, R01AG021155, R01AG027161, P50AG033514, R01AG059312, R01EB022883, R01AG062336), the Center for Predictive and Computational Phenotyping (U54AI117924), NSF CAREER Award (1252725), and a predoctoral fellowship to RRM via T32LM012413.