Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C...

1
Sampling-free Uncertainty Estimation in Gated Recurrent Units with Applications to Normative Modeling in Neuroimaging Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C. Johnson Vikas Singh M OTIVATION 1. Given a visually ’good looking’ sequence prediction, how can we tell that its trajectory is correct? 2. If it is, can we derive the degree of uncertainty on its prediction? SP-GRU Input Sequence =1 = 10 Ground Truth Output Prediction Model Uncertainty Map = 11 = 20 Figure: Image sequence prediction with uncertainty. Given the first 10 frames of an input sequence (left), our model SP-GRU makes the Output Prediction and the pixel-level Model Uncertainty Map where bright regions indicate high uncertainty. SP-GRU estimates the uncertainty deterministically without sampling model parameters. G OAL Derive a recurrent neural network architecture capable of estimating uncertainty with the following properties: 1. Deterministically estimate uncertainties in a sampling-free manner (e.g., without Monte Carlo sampling) 2. Uncertainties of all intermediate neurons can be expressed in terms of a distribu- tion P RELIMINARIES I Gated Recurrent Unit (GRU): Reset Gate: r t = σ (W r x t + b r ) Update Gate: z t = σ (W z x t + b z ) State Candidate: ˆ h t = tanh(U ˆ h x t + W ˆ h (r t h t -1 )+ b ˆ h ) Cell State: h t =(1 - z t ) ˆ h t + z t h t -1 I Exponential Families in Neural Networks: Let x X be a random variable with probability density/mass function (pdf/pmf) f X . Then f X is an exponential family dis- tribution if f X (x |η )= h(x ) exp(η T T (x ) - A(η )) with natural parameters η , base measure h(x ), and sufficient statistics T (x ). Con- stant A(η ) (log-partition function) ensures that the distribution normalizes to 1. × g l a l W l W l a l a l+1 EXPFAM(g l (W l a l )) Figure: A single exponential family neuron. Weights W l are learned, and the output of a neuron is a sample generated from the exponential family defined a priori and by the natural parameters g l (W l a l -1 ). M OMENT M ATCHING I Linear Moment Matching (LMM): (1) the mean a m following the standard linearity of random variable expectations and (2) the variance a s : o m = W m a m + b m , o s = W s a s + b s +(W m W m )a s + W s (a m a m ) (1) I Nonlinear Moment Matching (NMM): Using the fact that σ (x ) Φ(ζ x ) where Φ(·) is a probit function and ζ = p π/8 is a constant, approximate the sigmoid functions for a m and a s : a m σ m (o m , o s )= σ o m (1 + ζ 2 o s ) 1 2 ! , a s σ s (o m , o s )= σ ν (o m + ω ) (1 + ζ 2 ν 2 o s ) 1 2 ! - a 2 m (2) where ν = 4 - 2 2 and ω = - log( 2 + 1). The hyperbolic tangent can be derived from tanh(x )= 2σ (2x ) - 1. a l-1 m a l-1 s o l m o l s a l m a l s LMM NMM Figure: Linear Moment Matching (LMM) and Nonlinear Moment Matching (NMM) are performed at the weights/bias sums and activations respectively. S AMPLING - FREE P ROBABILISTIC GRU (SP-GRU) −1 1− 1− −1 Figure: SP-GRU cell structure. Solid lines/boxes and red dotted lines/boxes correspond to operations and variables for mean m and variance s respectively. Circles are element-wise operators. Operation Linear Transformation Nonlinear Transformation Reset Gate o t r ,m = U r ,m x t m + W r ,m h t -1 m + b r ,m r t m = σ m (o t r ,m , o t r ,s ) o t r ,s = U r ,s x t s + W r ,s h t -1 s + b r ,s +[U r ,m ] 2 x t s r t s = σ s (o t r ,m , o t r ,s ) +U r ,s [x t m ] 2 +[W r ,m ] 2 h t -1 s + W r ,s [h t -1 m ] 2 Update Gate o t z ,m = U z ,m x t m + W z ,m h t -1 m + b z ,m z t m = σ m (o t z ,m , o t z ,s ) o t z ,s = U z ,s x t s + W z ,s h t -1 s + b z ,s +[U z ,m ] 2 x t s z t s = σ s (o t z ,m , o t z ,s ) +U z ,s [x t m ] 2 +[W z ,m ] 2 h t -1 s + W z ,s [h t -1 m ] 2 State Candidate o t ˆ h,m = U ˆ h,m x t m + W ˆ h,m h t -1 m + b ˆ h,m ˆ h t m = tanh m (o t ˆ h,m , o t ˆ h,s ) o t ˆ h,s = U ˆ h,s x t s + W ˆ h,s h t -1 s + b ˆ h,s +[U ˆ h,m ] 2 x t s ˆ h t s = tanh s (o t ˆ h,m , o t ˆ h,s ) +U ˆ h,s [x t m ] 2 +[W ˆ h,m ] 2 h t -1 s + W ˆ h,s [h t -1 m ] 2 Cell State h t m =(1 - z t m ) ˆ h t m + z t m h t -1 m Not Needed h t s = [(1 - z t s )] 2 ˆ h t s +[z t s ] 2 h t -1 s Table: SP-GRU operations in mean and variance. and [A] 2 denotes the Hadamard product and A A of a matrix/vector A respectively. Note the Cell State does not involve nonlinear operations. E XPERIMENT 1: M OVING MNIST Figure: (a) Angle deviation trajectories. (b) Speed deviation trajectories. θ Ground Truth Prediction Uncertainty 20 25 30 35 v Ground Truth Prediction Uncertainty 5.0% 5.5% 6.0% 6.5% Figure: Predictions and uncertainties (frames 11, 15, and 20) from testing varying deviations from trained trajectories (first of four rows, blue). Top: angle. Middle: speed. Bottom: pixel-level noise. Right: the average sum of per frame pixel-level variance using SP-GRU and MC-LSTM. 2 Moving Digits Prediction z }| { 3 Moving Digits (Out of Domain) Prediction z }| { Figure: SP-GRU predictor results. Left 3 rows: 2 moving digits (top: ground truth, middle: mean prediction, bottom: uncertainty estimate). Right 3 rows: 3 moving digits which are out of domain (i.e., not seen in training). E XPERIMENT 2: N ORMATIVE M ODELING IN N EUROIMAGING 1. Brain Connectivity Sequence Sample Generation Original Subject Ordered and Binned by RAVLT Progression N Samples i = 1 i = N PiB Negative PiB Positive Original Subjects (Unordered, PiB Positive and Negative) PiB Negative Ordered and Binned by RAVLT Progression PiB Positive Ordered and Binned by RAVLT Progression N Samples of PiB Negative i = 1 i = N N Samples of PiB Positive i = 1 i = N Figure: The preprocessing procedure used to generate sample data for SP-GRU. Left: Global connectivity sequence samples. Right: PiB+/- connectivity sequence samples. 2. Normative Probability Map (NPM) (Marquand et al., 2016) For each subject i , our true response at time j for connectivity k is given by y ijk , with a bin-level variance of σ njk . The SP-GRU predicts a mean response ¯ y ijk and variation σ ijk . Normative Probability Map (NPM): z ijk =(y ijk - ¯ y ijk )/ q σ 2 ijk + σ 2 njk . 3. Normative Modeling in Neuroimaging: Pipeline (1) Test Sample Inputs (each i and t: 1761 connectivities) i = 1 i = N (2) Predictions with Uncertainties (each i and t: 1761 means and variances) t = 1 t = 2 t = 3 t = 4 (3) Sequential Normative Probability Maps (each i and t: 1761 NPMs) SP-GRU (Trained Predictor) Given t = 1,2,3,4 Predict t = 5,6,7,8 t = 5 t = 6 t = 7 t = 8 i = 1 i = N 1761 NPMs 1761 NPMs 1761 NPMs 1761 NPMs 1761 NPMs 1761 NPMs 1761 NPMs 1761 NPMs t = 5 t = 6 t = 7 t = 8 i = 1 i = N (4) Sequential Extreme Value Statistics (each i and t: 1 EVS) t = 5 t = 6 t = 7 t = 8 i = 1 i = N 1 EVS 1 EVS 1 EVS 1 EVS 1 EVS 1 EVS 1 EVS 1 EVS (5) Extreme Value Distributions 1. Construct histogram of EVS for each t 2. Fit generalized extreme value distributions (GED) 3. Derive confidence intervals t = 5 t = 6 t = 7 t = 8 Extreme Value Statistics Robust summary of NPMs (Mean of top 5%) 1761 NPMs 1 EVS t = 5 t = 6 t = 7 t = 8 1 EVS 1 EVS 1 EVS 1 EVS (6) Sequential EVS of a New Subject N’ (following the above pipeline on a new subject) (7) Outlier Detection Outlier EVS in at least one t Outlier subject t = 5 t = 6 t = 7 t = 8 i = N’ 1 EVS 1 EVS 1 EVS 1 EVS Figure: Normative modeling pipeline for preclinical AD. (1) Given a set of test inputs (t = 1, 2, 3, 4), (2) use the pretrained SP-GRU to make mean and variance predictions for each connectivity and t = 5, 6, 7, 8. (3) Compute NPM for each prediction, and (4) derive EVS for each sample i and t . (5) Fit GED and construct confidence intervals based on N EVS for each t . (6) Given a new sample, derive EVS following (1)-(4), and (7) check the confidence intervals from (5) to determine heterogeneity. 4. Outlier Detection: Cognitively Healthy (PiB-) vs. At-Risk (PiB+) Detected outliers: 9 of 100 samples in PiB- and 19 of 100 samples in PiB+. Implication: Larger absolute fluctuations in DWI connectivity may be a good indicator for disease risk as measured by amyloid burden Research supported in part by NIH (R01AG040396, R01AG021155, R01AG027161, P50AG033514, R01AG059312, R01EB022883, R01AG062336), the Center for Predictive and Computational Phenotyping (U54AI117924), NSF CAREER Award (1252725), and a predoctoral fellowship to RRM via T32LM012413.

Transcript of Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C...

Page 1: Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C ...pitt.edu/~sjh95/project/spgru/uai2019_poster.pdfSeong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C. Johnson Vikas Singh

Sampling-free Uncertainty Estimation in Gated Recurrent Unitswith Applications to Normative Modeling in Neuroimaging

Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C. Johnson Vikas Singh

MOTIVATION

1. Given a visually ’good looking’ sequence prediction, how can we tell that its trajectoryis correct?

2. If it is, can we derive the degree of uncertainty on its prediction?

SP-GRU

Input Sequence

𝑡 = 1 𝑡 = 10

Ground Truth

Output Prediction

Model Uncertainty Map

𝑡 = 11 𝑡 = 20

Figure: Image sequence prediction with uncertainty. Given the first 10 frames of an input sequence (left),our model SP-GRU makes the Output Prediction and the pixel-level Model Uncertainty Map wherebright regions indicate high uncertainty. SP-GRU estimates the uncertainty deterministically withoutsampling model parameters.

GOALDerive a recurrent neural network architecture capable of estimating uncertainty withthe following properties:1. Deterministically estimate uncertainties in a sampling-free manner (e.g., without

Monte Carlo sampling)2. Uncertainties of all intermediate neurons can be expressed in terms of a distribu-

tion

PRELIMINARIES

I Gated Recurrent Unit (GRU):Reset Gate: r t = σ(Wrx t + br)Update Gate: z t = σ(Wzx t + bz)

State Candidate: ht = tanh(Uhx t + Wh(r t � ht−1) + bh)

Cell State: ht = (1− z t)� ht + z t � ht−1

I Exponential Families in Neural Networks: Let x ∈ X be a random variable withprobability density/mass function (pdf/pmf) fX . Then fX is an exponential family dis-tribution if

fX(x |η) = h(x) exp(ηTT (x)− A(η))

with natural parameters η, base measure h(x), and sufficient statistics T (x). Con-stant A(η) (log-partition function) ensures that the distribution normalizes to 1.

× gl

al

W l

W lal al+1 ∼ EXPFAM(gl(W lal))

Figure: A single exponential family neuron. Weights W l are learned, and the output of a neuron is asample generated from the exponential family defined a priori and by the natural parameters g l(W lal−1).

MOMENT MATCHING

I Linear Moment Matching (LMM): (1) the mean am following the standard linearity ofrandom variable expectations and (2) the variance as:

om = Wmam + bm, os = Wsas + bs + (Wm �Wm)as + Ws(am � am) (1)

I Nonlinear Moment Matching (NMM): Using the fact that σ(x) ≈ Φ(ζx) where Φ(·) isa probit function and ζ =

√π/8 is a constant, approximate the sigmoid functions for

am and as:

am ≈ σm(om,os) = σ

(om

(1 + ζ2os)12

), as ≈ σs(om,os) = σ

(ν(om + ω)

(1 + ζ2ν2os)12

)− a2

m (2)

where ν = 4 − 2√

2 and ω = − log(√

2 + 1). The hyperbolic tangent can be derivedfrom tanh(x) = 2σ(2x)− 1.

al−1m

al−1s

olm

ols

alm

als

LMM NMM

Figure: Linear Moment Matching (LMM) and Nonlinear Moment Matching (NMM) are performed at theweights/bias sums and activations respectively.

SAMPLING-FREE PROBABILISTIC GRU (SP-GRU)

ℎ𝑚𝑡−1

𝑧𝑠𝑡

𝑧𝑚𝑡

𝑥𝑚𝑡

𝑥𝑠𝑡

𝑟𝑠𝑡

𝑟𝑚𝑡

𝑥𝑚𝑡

𝑥𝑠𝑡

1 − 𝑧𝑠𝑡

1 − 𝑧𝑚𝑡

ℎ𝑠𝑡

ℎ𝑚𝑡𝑥𝑚

𝑡

𝑥𝑠𝑡

ℎ𝑠𝑡−1

ℎ𝑚𝑡

ℎ𝑠𝑡

Figure: SP-GRU cell structure. Solid lines/boxes and red dotted lines/boxes correspond to operationsand variables for mean m and variance s respectively. Circles are element-wise operators.

Operation Linear Transformation Nonlinear Transformation

Reset Gate otr ,m = Ur ,mx t

m + Wr ,mht−1m + br ,m r t

m = σm(otr ,m,ot

r ,s)

otr ,s = Ur ,sx t

s + Wr ,sht−1s + br ,s + [Ur ,m]2x t

s r ts = σs(ot

r ,m,otr ,s)

+Ur ,s[x tm]2 + [Wr ,m]2ht−1

s + Wr ,s[ht−1m ]2

Update Gate otz,m = Uz,mx t

m + Wz,mht−1m + bz,m z t

m = σm(otz,m,ot

z,s)

otz,s = Uz,sx t

s + Wz,sht−1s + bz,s + [Uz,m]2x t

s z ts = σs(ot

z,m,otz,s)

+Uz,s[x tm]2 + [Wz,m]2ht−1

s + Wz,s[ht−1m ]2

State Candidate oth,m

= Uh,mx tm + Wh,mht−1

m + bh,m htm = tanhm(ot

h,m,ot

h,s)

oth,s

= Uh,sxts + Wh,sh

t−1s + bh,s + [Uh,m]2x t

s hts = tanhs(ot

h,m,ot

h,s)

+Uh,s[x tm]2 + [Wh,m]2ht−1

s + Wh,s[ht−1m ]2

Cell State htm = (1− z t

m)� htm + z t

m � ht−1m Not Needed

hts = [(1− z t

s)]2 � hts + [z t

s]2 � ht−1s

Table: SP-GRU operations in mean and variance. � and [A]2 denotes the Hadamard product and A� Aof a matrix/vector A respectively. Note the Cell State does not involve nonlinear operations.

EXPERIMENT 1: MOVING MNIST

Figure: (a) Angle deviation trajectories. (b) Speed deviation trajectories.θ Ground Truth Prediction Uncertainty

20◦

25◦

30◦

35◦

v Ground Truth Prediction Uncertainty

5.0%

5.5%

6.0%

6.5%

b Input Prediction Uncertainty

0.0

0.2

0.4

0.6

1

Figure: Predictions and uncertainties (frames 11, 15, and 20) from testing varying deviations from trainedtrajectories (first of four rows, blue). Top: angle. Middle: speed. Bottom: pixel-level noise. Right: theaverage sum of per frame pixel-level variance using SP-GRU and MC-LSTM.

2 Moving Digits Prediction︷ ︸︸ ︷ 3 Moving Digits (Out of Domain) Prediction︷ ︸︸ ︷

Figure: SP-GRU predictor results. Left 3 rows: 2 moving digits (top: ground truth, middle: meanprediction, bottom: uncertainty estimate). Right 3 rows: 3 moving digits which are out of domain (i.e., notseen in training).

EXPERIMENT 2: NORMATIVE MODELING IN NEUROIMAGING

1. Brain Connectivity Sequence Sample GenerationOriginal Subject

Ordered and Binned by RAVLT Progression

N Samples

i = 1

i = N

PiB Negative PiB Positive

Original Subjects (Unordered, PiB Positive and Negative)

PiB Negative Ordered and Binned by RAVLT Progression PiB Positive Ordered and Binned by RAVLT Progression

N Samples of PiB Negative

i = 1

i = N

N Samples of PiB Positive

i = 1

i = N

Figure: The preprocessing procedure used to generate sample data for SP-GRU. Left: Globalconnectivity sequence samples. Right: PiB+/- connectivity sequence samples.

2. Normative Probability Map (NPM) (Marquand et al., 2016) For each subject i , ourtrue response at time j for connectivity k is given by yijk , with a bin-level variance ofσnjk . The SP-GRU predicts a mean response yijk and variation σijk .Normative Probability Map (NPM): zijk = (yijk − yijk)/

(√σ2

ijk + σ2njk

).

3. Normative Modeling in Neuroimaging: Pipeline(1) Test Sample Inputs

(each i and t: 1761 connectivities)

i = 1

i = N

(2) Predictions with Uncertainties(each i and t: 1761 means and variances)

t = 1 t = 2 t = 3 t = 4

(3) Sequential Normative Probability Maps(each i and t: 1761 NPMs)

SP-GRU (Trained Predictor)

Given t = 1,2,3,4

Predict t = 5,6,7,8

t = 5 t = 6 t = 7 t = 8

i = 1

i = N

1761

NPMs

1761

NPMs

1761

NPMs

1761

NPMs

1761

NPMs

1761

NPMs

1761

NPMs

1761

NPMs

t = 5 t = 6 t = 7 t = 8

i = 1

i = N

(4) Sequential Extreme Value Statistics(each i and t: 1 EVS)

t = 5 t = 6 t = 7 t = 8

i = 1

i = N

1 EVS 1 EVS1 EVS1 EVS

1 EVS 1 EVS1 EVS1 EVS

(5) Extreme Value Distributions

1. Construct histogram of EVS for each t

2. Fit generalized extreme value distributions (GED)

3. Derive confidence intervalst = 5 t = 6 t = 7 t = 8

Extreme Value Statistics

Robust summary of NPMs

(Mean of top 5%)

1761

NPMs1 EVS

t = 5 t = 6 t = 7 t = 8

1 EVS 1 EVS1 EVS1 EVS

(6) Sequential EVS of a New Subject N’

(following the above pipeline on a new subject)(7) Outlier Detection

Outlier EVS in at least one t ⇒ Outlier subject

t = 5 t = 6 t = 7 t = 8

i = N’ 1 EVS 1 EVS1 EVS1 EVS

Figure: Normative modeling pipeline for preclinical AD. (1) Given a set of test inputs (t = 1,2,3,4), (2)use the pretrained SP-GRU to make mean and variance predictions for each connectivity andt = 5,6,7,8. (3) Compute NPM for each prediction, and (4) derive EVS for each sample i and t . (5) FitGED and construct confidence intervals based on N EVS for each t . (6) Given a new sample, derive EVSfollowing (1)-(4), and (7) check the confidence intervals from (5) to determine heterogeneity.

4. Outlier Detection: Cognitively Healthy (PiB-) vs. At-Risk (PiB+) Detectedoutliers: 9 of 100 samples in PiB- and 19 of 100 samples in PiB+.Implication: Larger absolute fluctuations in DWI connectivity may be a goodindicator for disease risk as measured by amyloid burden

Research supported in part by NIH (R01AG040396, R01AG021155, R01AG027161, P50AG033514, R01AG059312, R01EB022883, R01AG062336), the Center for Predictive and Computational Phenotyping (U54AI117924), NSF CAREER Award (1252725), and a predoctoral fellowship to RRM via T32LM012413.