Bias Correction in Learned Generative ... - Aditya...

25
Bias Correction in Learned Generative Models using Likelihood-free Importance Weighting Aditya Grover Stanford University Joint work with Jiaming Song, Alekh Agarwal, Kenneth Tran, Ashish Kapoor, Eric Horvitz, Stefano Ermon NeurIPS 2019

Transcript of Bias Correction in Learned Generative ... - Aditya...

Page 1: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Bias Correction in Learned Generative Models using Likelihood-free Importance Weighting

Aditya GroverStanford University

Joint work with Jiaming Song, Alekh Agarwal, Kenneth Tran, Ashish Kapoor, Eric Horvitz, Stefano Ermon

NeurIPS 2019

Page 2: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Transforming Science & Society

Hwang et al., 2018, Gómez-Bombarelli et al., 2016

Page 3: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

What You See Is Not What You Always Get

Odena et al., 2016

Page 4: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Learning Generative ModelsGiven: Samples from a data distribution Goal: Choose a model family ! and approximate a data distribution as closely as possible

"#$%$

"& ∈ !

"&(("#$%$, "&)

+,~"#$%$, = /, 0,… , 2 min

"&∈!(("#$%$, "&)

Page 5: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Challenges

• How to define distance !(⋅) between distributions?• Model mismatch: %&'(' ∉ *• Optimization is imperfect• Finite datasets: empirical data distribution +%&'(' is far from true data

distribution %&'('

min%/∈*

!(%&'(', %/)

%&'(' ≠ %/Generative models are biased w.r.t. 3&'('!

Page 6: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Evaluating Generative Models is Hard• Density estimation

§ Not applicable for models with ill-defined/intractable likelihoods e.g., GANs, VAEs§ Not correlated with sample quality (Theis et al., 2015)

• Sample quality metrics e.g., Inception Scores (Salimans et al., 2016), FID (Heusel et al., 2017), KID (Binkowski et al., 2018) etc.

• Downstream task e.g., semi-supervised learning

Page 7: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Identifying bias in generative modeling

• Let !:# → ℝ be some real-valued function of interest• We assume ! is unknown during training of generative model• Evidence of bias: &'()*) ! + ≠ &'- ! +

Page 8: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Motivating Use Cases

• Model-based Off-Policy Evaluation§ How do we safely evaluate target policy given data from a different source

policy?§ Value estimates are an expectation w.r.t. the estimated generative dynamics

model and target policy

• Model-based Data Augmentation§ Classifier trained on a mixture of real and generated data§ Loss is augmented with an expectation w.r.t. generated data

• Fair and Sample efficient generation

!"#$%$ & ' ≠ !") & '

Page 9: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Bias Mitigation!"#$%$ & ' ≠ !") & '

How to correct for bias due to model mismatch?• Option 1: Train deeper models

§ Increases estimation error§ Does not correct for distributional assumptions

• Option 2: Non-zero bias ≡ Instance of covariate shift. Can we use importance weighting?§ Reweight samples ' ∼ ") by the density ratio

, - ≔ "#$%$ -") -

=> !") , - & ' = !"#$%$ & '§ We don’t know "#$%$§ For many generative models, even ") is not known (e.g., VAEs, GANs)

Page 10: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Importance Weighting via ClassificationImportance weights can be estimated via binary classification!Train a classifier to distinguish real (Y=1) and generated data (Y=0).

• For a Bayes optimum classifier !"∗ : %×' → [0, 1],./ 0 = 23454(0)

28(0)= 9:∗ (;<=|?)

9:∗ ;<@ ?)• Practical checklist

üCalibrationüValidation set

• Not the same as GAN training§ Post-hoc bias correction§ Do not throw the discriminator! New generative model is a function of both ./(classifier) + 28

Page 11: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Synthetic Example100 samples

1000 samples

Page 12: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Model-based Monte Carlo EvaluationGoal: Evaluate !"#$%$ & ' via "(

• Default Monte Carlo estimator!"#$%$ & ' ≈ *

+∑- &('-) where '- ∼ "(• Likelihood Free Importance Weighted (LFIW) estimator

!"#$%$ & ' ≈ *+∑- 1(23)&('-) where '- ∼ "(

• Relative variance in 1 can be high. Self-normalized LFIW!"#$%$ & ' ≈ ∑- 1(23)

∑4 1(24)&('-)

Page 13: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Improved Sample Quality Metrics

Inception Score (↑) FID (↓) KID(↓)Reference 11.09 ± 0.1263 5.20 ± 0.05 0.008 ± 0.0004

PixelCNN++DefaultLFIW

5.16 ± 0.01176.68 ± 0.0773

58.70 ± 0.050655.83 ± 0.9695

0.196 ± 0.00010.126 ± 0.0009

SNGANDefaultLFIW

8.33± 0.02808.57 ± 0.0325

20.40 ± 0.074717.29 ± 0.0698

0.094 ± 0.00020.073 ±0.0004

Standard error around the mean computed over 10 runs.

Dataset: CIFAR10!

"#$%&% ! '

Page 14: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Importance Resampling Distribution • Define the importance resampling distribution as

!",$ ∝ &$ ' !" '• Normalization constant (",$ = *'∼!"[&$(')]• Density estimation and sampling are intractable• Particle-based approximation

• Approximate induced distribution with finite samples from 01 2• Approximated via resampling methods. E.g., Rejection Sampling (Azadi et al., 2019),

MCMC (Turner et al., 2019)

Page 15: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Sampling Importance Resampling• Choose a finite sampling budget ! > 0• Draw for a batch of ! points $%, $', … , $) from *+ and estimate

importance weights , -.

• Define a categorical distribution / 0 ∝ , -.

• Sample 0~/(0) and return $5

0.6 0.1 0.7

0.42 0.08 0.5

/ 0 = 1 = 0.42 / 0 = 2 = 0.08 /(0 = 3) = 0.5

Page 16: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Are we guaranteed to do better? • No• When is !",$ a “better” fit than !"? • Better: Kl-divergence reduces %&'[!)*+*, !",$] ≤ %&'[!)*+*, !"]

• Necessary and sufficient condition: !",$ is a better fit than !" iff:./∼!)*+* log 4$(/) ≥ log 8",$

• Necessary conditions:./∼!)*+* log 4$(/) ≥ ./ ~!" log 4$(/)./∼!)*+* 4$(/) ≥ ./ ~!" 4$(/)

Page 17: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Downstream ApplicationsData AugmentationOff-Policy Policy Evaluation

Page 18: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Data Augmentation• Goal: Augment training dataset for multi-class classification• Dataset: Omniglot (1000+ classes, 20 examples/class)• Procedure

§ Train a conditional generative model on Omniglot§ Use generated data for training downstream classifier

Page 19: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Importance Weighted Data Augmentation

Class 1

Class 2

Class 3

Real (random order)

Generated (sorted)

Real (random order)

Generated (sorted)

Real (random order)

Generated (sorted)

Decreasing Importance Weights

Page 20: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Importance Weighted Data AugmentationDataset AccuracyReal data only 0.6603 ± 0.0012Generated data only 0.4431 ± 0.0054Generated data + LFIW 0.4481 ± 0.0056Real + generated data 0.6600 ± 0.0040Real + generated data + LFIW 0.6818 ± 0.0022

Standard error around the mean computed over 5 runs.

LFIW on the augmented data increases overall test accuracy!

Page 21: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Downstream ApplicationsData AugmentationOff-Policy Policy Evaluation

Page 22: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Off-Policy Policy Evaluation (OPE)

• Easy to obtain logged trajectory data !", $", %&, !', $', %(, …!&~+ !" [+ ⋅ is initial state distribution] $-~./ !0 [./(⋅) is behavioral policy]%- ∼ 4(!0, $0) [4(⋅) is rewards model]!05' ∼ 6 !0, $0 [6(⋅) is transition dynamics model]

Note: ./ ⋅ , 6(⋅) are unknown• Goal: Evaluate the value 789 of a target policy .:

Treatment 1

Treatment 2

Data: Yes J

Data: No L

Page 23: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Debiasing Model-based OPE• Model-based approach

• Estimate !(⋅) as !%(⋅)• Generate target trajectories via !%(⋅) and &'• Estimate ()* by Monte Carlo

• Train classifier to distinguish triplets from logged data and predictions• Debiasing distributions over trajectories +

Page 24: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Debiased Model-Based OPE

Environment HalfCheetah Swimmer HumanoidModel-based 37.7 63.7 5753Model-based w/ LFIW 23.9 11 4798

Mean absolute error. Lower is better.

Page 25: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Summary• Generative models are biased• Likelihood-free importance weighting is a simple technique for bias

mitigation that works well for many downstream applications• Bias is a necessary evil for generalization. Key is to be able to control it!

• Future, Ongoing Work: Fair Generative Modeling via Weak Supervision. Aditya Grover*, Kristy Choi*, Rui Shu, Stefano Ermon. https://arxiv.org/abs/1910.12008