Mohammad Reza Keshtkaran, Chethan Pandarinath · Mohammad Reza Keshtkaran, Chethan Pandarinath...

1
1850 3000 4150 Sample validation loss 0 0.2 0.4 0.6 0.8 1850 0.7 0.8 92 184 368 1836 # of training trials 0 0.2 0.4 0.6 0.8 1 R 2 Y-velocity 5% of trials used Time Trial Neurons Held-out samples Held-out inferred rate 92 184 368 1836 # of training trials 0 20 40 60 % R 2 improvement 92 184 368 572 # of training trials 0 20 40 60 % R 2 improvement 92 184 368 572 # of training trials 0 0.2 0.4 0.6 0.8 R 2 X-velocity 16% of trials used 92 184 368 572 # of training trials 0 0.2 0.4 0.6 0.8 R 2 Y-velocity 16% of trials used 92 184 368 1836 # of training trials 0 0.2 0.4 0.6 0.8 1 R 2 X-velocity 5% of trials used Smoothing Fixed HPs HP-optimized Sequential autoencoder (SAE) for modeling neural dynamics Pathological overfitting in SAEs Denoising autoencoders do not address overfitting in SAEs Two solutions to address pathological overfitting in SAEs 2- Sample validation 1- Coordinated dropout (CD) HP optimization trains accurate models with 5-10x less data Conclusions References Experimental setup - decoding arm velocity Dataset 1: Curved reaching task Two tasks: 1- Curved reaching 2- Random target HP optimization performed using Population based training (PBT) 3 Dataset 2: Random target task 2.3 1.4 0.1 2.0 0.5 2.9 1.6 2.3 1.4 0.1 2.0 0.5 2.9 1.6 Train on these samples To reconstruct these samples Training example Training example Training step n SAE sample dropped gradient blocked Neuron 1 Neuron 2 Neuron 3 Neuron n Decode arm velocity (linear decoder) Record neural activity Observed spikes True trajectory Estimated trajectory Estimated spike rates SAE Enabling hyperparameter optimization in sequential autoencoders for spiking neural data Mohammad Reza Keshtkaran, Chethan Pandarinath Georgia Tech and Emory University Email: [email protected] Evaluate the network by how well it can predict the rates for the samples it has never seen during training or evaluation. Simulation setup 2 Time Time Neurons Neurons Time Neurons SAE Nonlinear dynamical system Evaluate using R 2 Ground truth firing rates Synthetic data generation Estimation of firing rates Poisson sampling (noisy) Synthetic spikes Inferred firing rates Trained with coordinated dropout Standard training Z = W 1 T X encoder latent variable Q mxn X oxn Z hxn m < o ≤ h Reconstruction low-d data: high-d observation: latent variable: decoder auto-encoder X= W 2 Z X= W 2 W 1 T X CD is equivalent to preventing diagonal weights to be trained. Simulation example, linear over-complete autoencoder: W W: W: 0 50 100 150 200 Epoch 0 1 2 3 4 Loss (MSE) Train loss True Loss Lower bound 0 50 100 150 200 Epoch 0 1 2 3 4 Loss (MSE) Y = W p T Q X = Y + N where N ~ N(0, σ) True loss: ||Y-X|| 2 Train loss: ||X-X|| 2 Train loss True Loss Lower bound Decoder Encoders Controller Reconstruction loss (Poisson) Observed Spikes Initial State Inferred Inputs Inferred Rates Latent Factors Main hyperparameters: - KL penalies - L2 penalies - Dropout - Learning rate Maximize Evidence Lower Bound ELBO (L): Reconstruction log-likelihood: KL penalty: Acknowledgements: This work was supported NSF NCS 1835364, and DARPA Intelligent Neural Interfaces program. [1] Pandarinath, Chethan, Daniel J. O’Shea, Jasmine Collins, et al. “Inferring Single-Trial Neural Population Dynamics Using Sequential Auto-Encoders.” Nature Methods, 2018. [2] Sussillo, David, Rafal Jozefowicz, L. F. Abbott, and Chethan Pandarinath. "Lfads-latent factor analysis via dynamical systems." arXiv preprint arXiv:1608.06315 (2016). [3] Jaderberg, Max, Valentin Dalibard, Simon Osindero, et al. "Population based training of neural networks." arXiv preprint arXiv:1711.09846 (2017). SAEs are prone to a particular type of overfitting that cannot be detected through using standard validation loss. Lack of a reliable validation metric prevented HP optimization in SAEs. We developed two solutions “Coordinated dropout” and “Sample validation” to address pathological overfitting in SAEs and enable HP optimization. HP optimization led to accurate models while using 5-10-fold less training data. CD forces the network to only model shared structure underlying the observa- tions. CD first passes in a subset of samples at the input (by applying dropout). Next, to update network weights, CD only uses gradients calculated for recon- struction of the complementary subset of samples. Used synthetic data with ground truth to demonstrate pathological overfitting. True performance 1350 1850 2350 Standard validation loss 0 0.2 0.4 0.6 0.8 1850 0.7 0.8 True performance 1200 1600 2000 2400 Validation loss 0 0.4 0.8 Zero Masking 0% 10% 30% 50% True performance 1200 1600 2000 2400 Validation loss Salt and Pepper 0% 10% 20% 30% True performance 0 0.4 0.8 SAE Model: Latent factor analysis through dynamical systems (LFADS) 1 Better models Overfitting Better Loss 200 models, random hyperparams spikes True rate Standard validation loss 1200 1600 2000 2400 Overfit Good fit Underfit 0 0.2 0.4 0.6 0.8 True performance Time Neurons Time Neurons SAE Neural spikes (point process) Neural spikes Decoding Accuracy Less training data Typical data size in neuroscience Time Neurons Maximize Poisson likelihood Record neural activity Inferred rates True hand position Estimated hand position from neural activity Decode behavior Decreasing data size Decreasing data size Decreasing data size Decreasing data size CD Off CD On Smoothing SAE Example rate for one neuron Fixed HPs HP-optimized Decoded movements

Transcript of Mohammad Reza Keshtkaran, Chethan Pandarinath · Mohammad Reza Keshtkaran, Chethan Pandarinath...

Page 1: Mohammad Reza Keshtkaran, Chethan Pandarinath · Mohammad Reza Keshtkaran, Chethan Pandarinath Georgia Tech and Emory University Email: mkeshtk@emory.edu Evaluate the network by how

1850 3000 4150Sample validation loss

00.

20.

40.

60.

8R2 18500.

70.

8

921843681836# of training trials

00.

20.

40.

60.

81

R2

Y-velocity

5% of trials used

TimeTrial

Neu

rons

Held-out samples

Held-out inferred rate

921843681836# of training trials

020

4060

%

R2 improvement

92184368572# of training trials

020

4060

%

R2 improvement

92184368572# of training trials

00.

20.

40.

60.

8R2

X-velocity

16% of trials used

92184368572# of training trials

00.

20.

40.

60.

8R2

Y-velocity

16% of trials used

921843681836# of training trials

00.

20.

40.

60.

81

R2

X-velocity

5% of trials used

Smoothing

Fixed HPs

HP-optimized

Sequential autoencoder (SAE) for modeling neural dynamics

Pathological overfitting in SAEs Denoising autoencoders do not address overfitting in SAEs

Two solutions to address pathological overfitting in SAEs

2- Sample validation

1- Coordinated dropout (CD)

HP optimization trains accurate models with 5-10x less data

Conclusions

References

Experimental setup - decoding arm velocity

Dataset 1: Curved reaching task

Two tasks:1- Curved reaching2- Random target

HP optimization performed using Population based training (PBT)3

Dataset 2: Random target task

2.3

1.4

0.1

2.0

0.5

2.9

1.6

2.3

1.4

0.1

2.0

0.5

2.9

1.6

Train on these samples

To reconstruct these samples

Training example

Training example Training step n

SAE

sample droppedgradient blocked

ft Wrate

Wfac

rtInitial state g0

time time

trials

neur

ons

5.4 1.3

1.2 0.9

1.7 1.1... ...

µ σ

0.9 1.4

Neuron 1Neuron 2Neuron 3

Neuron n

Decode arm velocity

(linear decoder)Record neural activity

Observed spikes

True trajectory

Estimated trajectory

Estimated spike rates

SAE

Enabling hyperparameter optimization in sequential autoencoders for spiking neural dataMohammad Reza Keshtkaran, Chethan PandarinathGeorgia Tech and Emory University Email: [email protected]

Evaluate the network by how well it can predict the rates for the samples it has never seen during training or evaluation.Simulation setup2

TimeTime

Neu

rons

Neu

rons

Time

Neu

rons

Compare

ft Wrate

Wfac

rt

Initial state g0

time

time time

trials

neur

ons

neur

ons

neur

ons

5.4 1.3

1.2 0.9

1.7 1.1

... ...

µ σ

SAENonlinear dynamical

system

Evaluate using R2

Ground truth firing rates

Synthetic data generation

Estimation of firing rates

Poisson sampling (noisy)

Synthetic spikes

Inferred firing rates

Trained with coordinated dropoutStandard training

Z = W1T X

encoder

latent variable

Q ∈ ℝmxn

X ∈ ℝoxn

Z ∈ ℝhxn

m < o ≤ h

Reconstruction

low-d data:high-d observation:

latent variable:

decoder auto-encoderX′ = W2

Z X′ = W2 W1

T X

CD is equivalent to preventing diagonal weights to be trained.

Simulation example, linear over-complete autoencoder:

W

W: W:

0 50 100 150 200Epoch

0

1

2

3

4

Loss

(MSE

) Train lossTrue LossLower bound

0 50 100 150 200Epoch0

1

2

3

4

Loss

(MSE

)

Y = WpTQ

X = Y + N where N ~ N(0, σ)

True loss: ||Y-X′||2Train loss: ||X-X′||2

Train lossTrue LossLower bound

DecoderEncoders

Controller

Reconstruction loss(Poisson)

Observed Spikes

Initial State

Inferred Inputs

Inferred Rates

Latent Factors

Main hyperparameters:- KL penalies- L2 penalies- Dropout- Learning rate

Maximize Evidence Lower Bound ELBO (L):

Reconstruction log-likelihood:

KL penalty:

Acknowledgements: This work was supported NSF NCS 1835364, and DARPA Intelligent Neural Interfaces program.

[1] Pandarinath, Chethan, Daniel J. O’Shea, Jasmine Collins, et al. “Inferring Single-Trial Neural Population Dynamics Using Sequential Auto-Encoders.” Nature Methods, 2018.[2] Sussillo, David, Rafal Jozefowicz, L. F. Abbott, and Chethan Pandarinath. "Lfads-latent factor analysis via dynamical systems." arXiv preprint arXiv:1608.06315 (2016).[3] Jaderberg, Max, Valentin Dalibard, Simon Osindero, et al. "Population based training of neural networks." arXiv preprint arXiv:1711.09846 (2017).

• SAEs are prone to a particular type of overfitting that cannot be detected through using standard validation loss.

• Lack of a reliable validation metric prevented HP optimization in SAEs.• We developed two solutions “Coordinated dropout” and “Sample validation” to

address pathological overfitting in SAEs and enable HP optimization.• HP optimization led to accurate models while using 5-10-fold less training data.

CD forces the network to only model shared structure underlying the observa-tions. CD first passes in a subset of samples at the input (by applying dropout). Next, to update network weights, CD only uses gradients calculated for recon-struction of the complementary subset of samples.

Used synthetic data with ground truth to demonstrate pathological overfitting.

True

per

form

ance

1350 1850 2350Standard validation loss

00.

20.

40.

60.

8R2

18500.7

0.8

True

per

form

ance

1200 1600 2000 2400Validation loss

00.

40.

8R2

Zero Masking0%10%30%50%

True

per

form

ance

1200 1600 2000 2400Validation loss

Salt and Pepper0%10%20%30%

True

per

form

ance

00.

40.

8

SAE Model: Latent factor analysis through dynamical systems (LFADS)1

Better models

Overfittin

g

Better Loss

200 models, random hyperparams

spikes

True rate

Standard validation loss1200 1600 2000 2400

Overfit Good fit Underfit

00.

20.

40.

60.

8Tr

ue p

erfo

rman

ce

Time

Neu

rons

Time

Neu

rons

Compare

ft Wrate

Wfac

rt

Initial state g0

time

time time

trialsne

uron

s

neur

ons

neur

ons

5.4 1.3

1.2 0.9

1.7 1.1

... ...

µ σ

SAE

Neural spikes (point process)

Neural spikes

Dec

odin

g A

ccur

acy

Less training data

Typical data size in neuroscience

Time

Neu

rons

MaximizePoisson

likelihood

Record neural activity

Inferred rates

True hand positionEstimated hand position from neural activityDecode

behavior

Decreasing data size Decreasing data size

Decreasing data sizeDecreasing data size

CD OffCD On

Smoothing

SAE

Example rate for one neuron

Fixed HPs

HP-optimizedDecoded movements