Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C...
Transcript of Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C...
Sampling-free Uncertainty Estimation in Gated Recurrent Unitswith Applications to Normative Modeling in Neuroimaging
Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C. Johnson Vikas Singh
MOTIVATION
1. Given a visually ’good looking’ sequence prediction, how can we tell that its trajectoryis correct?
2. If it is, can we derive the degree of uncertainty on its prediction?
SP-GRU
Input Sequence
𝑡 = 1 𝑡 = 10
Ground Truth
Output Prediction
Model Uncertainty Map
𝑡 = 11 𝑡 = 20
Figure: Image sequence prediction with uncertainty. Given the first 10 frames of an input sequence (left),our model SP-GRU makes the Output Prediction and the pixel-level Model Uncertainty Map wherebright regions indicate high uncertainty. SP-GRU estimates the uncertainty deterministically withoutsampling model parameters.
GOALDerive a recurrent neural network architecture capable of estimating uncertainty withthe following properties:1. Deterministically estimate uncertainties in a sampling-free manner (e.g., without
Monte Carlo sampling)2. Uncertainties of all intermediate neurons can be expressed in terms of a distribu-
tion
PRELIMINARIES
I Gated Recurrent Unit (GRU):Reset Gate: r t = σ(Wrx t + br)Update Gate: z t = σ(Wzx t + bz)
State Candidate: ht = tanh(Uhx t + Wh(r t � ht−1) + bh)
Cell State: ht = (1− z t)� ht + z t � ht−1
I Exponential Families in Neural Networks: Let x ∈ X be a random variable withprobability density/mass function (pdf/pmf) fX . Then fX is an exponential family dis-tribution if
fX(x |η) = h(x) exp(ηTT (x)− A(η))
with natural parameters η, base measure h(x), and sufficient statistics T (x). Con-stant A(η) (log-partition function) ensures that the distribution normalizes to 1.
× gl
al
W l
W lal al+1 ∼ EXPFAM(gl(W lal))
Figure: A single exponential family neuron. Weights W l are learned, and the output of a neuron is asample generated from the exponential family defined a priori and by the natural parameters g l(W lal−1).
MOMENT MATCHING
I Linear Moment Matching (LMM): (1) the mean am following the standard linearity ofrandom variable expectations and (2) the variance as:
om = Wmam + bm, os = Wsas + bs + (Wm �Wm)as + Ws(am � am) (1)
I Nonlinear Moment Matching (NMM): Using the fact that σ(x) ≈ Φ(ζx) where Φ(·) isa probit function and ζ =
√π/8 is a constant, approximate the sigmoid functions for
am and as:
am ≈ σm(om,os) = σ
(om
(1 + ζ2os)12
), as ≈ σs(om,os) = σ
(ν(om + ω)
(1 + ζ2ν2os)12
)− a2
m (2)
where ν = 4 − 2√
2 and ω = − log(√
2 + 1). The hyperbolic tangent can be derivedfrom tanh(x) = 2σ(2x)− 1.
al−1m
al−1s
olm
ols
alm
als
LMM NMM
Figure: Linear Moment Matching (LMM) and Nonlinear Moment Matching (NMM) are performed at theweights/bias sums and activations respectively.
SAMPLING-FREE PROBABILISTIC GRU (SP-GRU)
ℎ𝑚𝑡−1
𝑧𝑠𝑡
𝑧𝑚𝑡
𝑥𝑚𝑡
𝑥𝑠𝑡
𝑟𝑠𝑡
𝑟𝑚𝑡
𝑥𝑚𝑡
𝑥𝑠𝑡
1 − 𝑧𝑠𝑡
1 − 𝑧𝑚𝑡
ℎ𝑠𝑡
ℎ𝑚𝑡𝑥𝑚
𝑡
𝑥𝑠𝑡
ℎ𝑠𝑡−1
ℎ𝑚𝑡
ℎ𝑠𝑡
Figure: SP-GRU cell structure. Solid lines/boxes and red dotted lines/boxes correspond to operationsand variables for mean m and variance s respectively. Circles are element-wise operators.
Operation Linear Transformation Nonlinear Transformation
Reset Gate otr ,m = Ur ,mx t
m + Wr ,mht−1m + br ,m r t
m = σm(otr ,m,ot
r ,s)
otr ,s = Ur ,sx t
s + Wr ,sht−1s + br ,s + [Ur ,m]2x t
s r ts = σs(ot
r ,m,otr ,s)
+Ur ,s[x tm]2 + [Wr ,m]2ht−1
s + Wr ,s[ht−1m ]2
Update Gate otz,m = Uz,mx t
m + Wz,mht−1m + bz,m z t
m = σm(otz,m,ot
z,s)
otz,s = Uz,sx t
s + Wz,sht−1s + bz,s + [Uz,m]2x t
s z ts = σs(ot
z,m,otz,s)
+Uz,s[x tm]2 + [Wz,m]2ht−1
s + Wz,s[ht−1m ]2
State Candidate oth,m
= Uh,mx tm + Wh,mht−1
m + bh,m htm = tanhm(ot
h,m,ot
h,s)
oth,s
= Uh,sxts + Wh,sh
t−1s + bh,s + [Uh,m]2x t
s hts = tanhs(ot
h,m,ot
h,s)
+Uh,s[x tm]2 + [Wh,m]2ht−1
s + Wh,s[ht−1m ]2
Cell State htm = (1− z t
m)� htm + z t
m � ht−1m Not Needed
hts = [(1− z t
s)]2 � hts + [z t
s]2 � ht−1s
Table: SP-GRU operations in mean and variance. � and [A]2 denotes the Hadamard product and A� Aof a matrix/vector A respectively. Note the Cell State does not involve nonlinear operations.
EXPERIMENT 1: MOVING MNIST
Figure: (a) Angle deviation trajectories. (b) Speed deviation trajectories.θ Ground Truth Prediction Uncertainty
20◦
25◦
30◦
35◦
v Ground Truth Prediction Uncertainty
5.0%
5.5%
6.0%
6.5%
b Input Prediction Uncertainty
0.0
0.2
0.4
0.6
1
Figure: Predictions and uncertainties (frames 11, 15, and 20) from testing varying deviations from trainedtrajectories (first of four rows, blue). Top: angle. Middle: speed. Bottom: pixel-level noise. Right: theaverage sum of per frame pixel-level variance using SP-GRU and MC-LSTM.
2 Moving Digits Prediction︷ ︸︸ ︷ 3 Moving Digits (Out of Domain) Prediction︷ ︸︸ ︷
Figure: SP-GRU predictor results. Left 3 rows: 2 moving digits (top: ground truth, middle: meanprediction, bottom: uncertainty estimate). Right 3 rows: 3 moving digits which are out of domain (i.e., notseen in training).
EXPERIMENT 2: NORMATIVE MODELING IN NEUROIMAGING
1. Brain Connectivity Sequence Sample GenerationOriginal Subject
Ordered and Binned by RAVLT Progression
N Samples
i = 1
i = N
PiB Negative PiB Positive
Original Subjects (Unordered, PiB Positive and Negative)
PiB Negative Ordered and Binned by RAVLT Progression PiB Positive Ordered and Binned by RAVLT Progression
N Samples of PiB Negative
i = 1
i = N
N Samples of PiB Positive
i = 1
i = N
Figure: The preprocessing procedure used to generate sample data for SP-GRU. Left: Globalconnectivity sequence samples. Right: PiB+/- connectivity sequence samples.
2. Normative Probability Map (NPM) (Marquand et al., 2016) For each subject i , ourtrue response at time j for connectivity k is given by yijk , with a bin-level variance ofσnjk . The SP-GRU predicts a mean response yijk and variation σijk .Normative Probability Map (NPM): zijk = (yijk − yijk)/
(√σ2
ijk + σ2njk
).
3. Normative Modeling in Neuroimaging: Pipeline(1) Test Sample Inputs
(each i and t: 1761 connectivities)
i = 1
i = N
(2) Predictions with Uncertainties(each i and t: 1761 means and variances)
t = 1 t = 2 t = 3 t = 4
(3) Sequential Normative Probability Maps(each i and t: 1761 NPMs)
SP-GRU (Trained Predictor)
Given t = 1,2,3,4
Predict t = 5,6,7,8
t = 5 t = 6 t = 7 t = 8
i = 1
i = N
1761
NPMs
1761
NPMs
1761
NPMs
1761
NPMs
1761
NPMs
1761
NPMs
1761
NPMs
1761
NPMs
t = 5 t = 6 t = 7 t = 8
i = 1
i = N
(4) Sequential Extreme Value Statistics(each i and t: 1 EVS)
t = 5 t = 6 t = 7 t = 8
i = 1
i = N
1 EVS 1 EVS1 EVS1 EVS
1 EVS 1 EVS1 EVS1 EVS
(5) Extreme Value Distributions
1. Construct histogram of EVS for each t
2. Fit generalized extreme value distributions (GED)
3. Derive confidence intervalst = 5 t = 6 t = 7 t = 8
Extreme Value Statistics
Robust summary of NPMs
(Mean of top 5%)
1761
NPMs1 EVS
t = 5 t = 6 t = 7 t = 8
1 EVS 1 EVS1 EVS1 EVS
(6) Sequential EVS of a New Subject N’
(following the above pipeline on a new subject)(7) Outlier Detection
Outlier EVS in at least one t ⇒ Outlier subject
t = 5 t = 6 t = 7 t = 8
i = N’ 1 EVS 1 EVS1 EVS1 EVS
Figure: Normative modeling pipeline for preclinical AD. (1) Given a set of test inputs (t = 1,2,3,4), (2)use the pretrained SP-GRU to make mean and variance predictions for each connectivity andt = 5,6,7,8. (3) Compute NPM for each prediction, and (4) derive EVS for each sample i and t . (5) FitGED and construct confidence intervals based on N EVS for each t . (6) Given a new sample, derive EVSfollowing (1)-(4), and (7) check the confidence intervals from (5) to determine heterogeneity.
4. Outlier Detection: Cognitively Healthy (PiB-) vs. At-Risk (PiB+) Detectedoutliers: 9 of 100 samples in PiB- and 19 of 100 samples in PiB+.Implication: Larger absolute fluctuations in DWI connectivity may be a goodindicator for disease risk as measured by amyloid burden
Research supported in part by NIH (R01AG040396, R01AG021155, R01AG027161, P50AG033514, R01AG059312, R01EB022883, R01AG062336), the Center for Predictive and Computational Phenotyping (U54AI117924), NSF CAREER Award (1252725), and a predoctoral fellowship to RRM via T32LM012413.