Download - MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

Transcript
Page 1: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKS

Abdullah-Al-Zubaer Imran and Demetri Terzopoulos

UCLA Computer Graphics & Vision Laboratory

Page 2: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

• Limited training data

• Leveraging by unlabeled examples

• Unsupervised learning

• Semi-supervised learning

• Latent code helps understand the models

12/17/2019 Multi-Adversarial Variational Autoencoder Networks 2

Why Generative Modeling?

Page 3: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

12/17/2019Multi-Adversarial Variational Autoencoder Networks

3

Deep Generative Models: VAE

[Kingma et al. 2013]

• Likelihood maximization

• Encoder: variational inference

• Decoder: sample generation

• Efficient variational inference

• Blurry samples

• Re-parameterization: z = μ + σ⊙ϵ, ϵ∼Normal(0,1)

• Losses• Reconstruction: 𝐸𝑞(𝑧|𝑥) log 𝑝 𝑥 𝑧

• Regularization: −𝐾𝐿(𝑞(𝑧|𝑥)||𝑝(𝑧))

Page 4: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

12/17/2019Multi-Adversarial Variational Autoencoder Networks

4

Deep Generative Models: GAN

[Goodfellow et al. 2014, Radford et al. 2015]

• Mini-max game

• Generator maps latent variables to data samples

• Discriminator distinguishes generated and real samples

• Sharpest image generation

• Unstable and difficult to optimize

• Losses

𝐷𝑚𝑎𝑥𝑉 𝐷 = 𝐸𝑥~𝑝𝑑𝑎𝑡𝑎(𝑥) 𝑙𝑜𝑔𝐷 𝑥 + 𝐸𝑥~𝑝𝑧(𝑧) log(1 − 𝐷 𝐺(𝑧)

𝐺𝑚𝑖𝑛𝑉 𝐺 = 𝐸𝑥~𝑝𝑧(𝑧) log(1 − 𝐷 𝐺(𝑧)

Page 5: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

12/17/2019Multi-Adversarial Variational Autoencoder Networks

5

Deep Generative Models: PixelRNN

[Oord et al. 2016]

• Autoregressive model

• Simple and stable training process

• Inefficient sampling

• Assign probability to every pixel in the image

• Softmax loss

𝑝 𝑥 =

𝑖=1

𝑛2

𝑝(𝑥𝑖|𝑥1, … , 𝑥𝑖−1)

𝑝 𝑥𝑖, 𝑅|𝑥<𝑖 𝑝(𝑥𝑖, 𝐺|𝑥<𝑖 , 𝑥𝑖,𝑅)𝑝(𝑥𝑖, 𝐵|𝑥<𝑖 , 𝑥𝑖,𝑅 , 𝑥𝑖,𝐺)

Page 6: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

Aim

12/17/2019Multi-Adversarial Variational Autoencoder Networks

6

VAE-GAN

PixelGAN Autoencoder

• Improving the deep generative models

• Evaluation measures

[Larsen et al. 2016, Makhzani et al. 2017]

Primary Aim

(Efficient and stable generative modeling

for medical image analysis)

✓ Combining generative models

✓ High quality image generation

✓ Learning from limited labeled data

Page 7: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

Proposed: MAVENs• Highlights

• Ensemble of multiple discriminators in VAE-GAN• Joint image generation and classification

• Motivation• Instability in generative models

• Mode collapsed generation

• Poor image quality in VAE

• Small labeled data

• Objective• Improve samples and semi-supervised classification

• Unified generative model

• Variational inference with adversarial learning

12/17/2019 7

Basic comparisons of MAVEN with GAN, VAE, and VAE-GAN

Multi-Adversarial Variational Autoencoder Networks

GAN-mode collapsed generation

Page 8: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

MAVENs: Architecture

12/17/2019Multi-Adversarial Variational Autoencoder Networks

8

Page 9: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

MAVENs: Objectives

12/17/2019 Multi-Adversarial Variational Autoencoder Networks 9

Discriminator Loss

Supervised:

𝐿𝐷𝑠𝑢𝑝𝑒𝑟𝑣𝑖𝑠𝑒𝑑 = − 𝝚𝑥,𝑦~𝑝𝑑𝑎𝑡𝑎log[𝑝 𝑦 = 𝑖 𝑥, 𝑖 < 𝑛 + 1])]

Unsupervised:

𝐿𝐷𝑟𝑒𝑎𝑙 = − 𝝚𝑥~𝑝𝑑𝑎𝑡𝑎 log[1 − 𝑝 𝑦 = 𝑛 + 1 𝑥)]

𝐿𝐷𝑓𝑎𝑘𝑒_𝐺 = − 𝝚 ො𝑥~𝐺 log[𝑝 𝑦 = 𝑛 + 1 ෝ𝑥)]

𝐿𝐷𝑓𝑎𝑘𝑒_𝐸 = − 𝝚 𝑥~𝐺 log[𝑝 𝑦 = 𝑛 + 1 𝑥)]

Generator Loss

𝐿𝐺𝑓𝑎𝑘𝑒_𝐺 = − 𝝚 ො𝑥~𝐺 log[1 − 𝑝 𝑦 = 𝑛 + 1 ෝ𝑥)]

𝐿𝐺𝑓𝑎𝑘𝑒_𝐸 = − 𝝚 𝑥~𝐺 log[1 − 𝑝 𝑦 = 𝑛 + 1 𝑥)]

𝐿𝐺𝑓𝑒𝑎𝑡𝑢𝑟𝑒 = 𝝚𝑥~𝑝𝑑𝑎𝑡𝑎𝑓 𝑥 − 𝝚 ො𝑥~𝐺𝑓(ො𝑥) 2

2

Encoder Loss

𝐿𝐸𝑓𝑒𝑎𝑡𝑢𝑟𝑒 = 𝝚𝑥~𝑝𝑑𝑎𝑡𝑎𝑓 𝑥 − 𝝚 𝑥~𝐺𝑓( 𝑥) 2

2

𝐿𝐸𝐾𝐿 = −𝝚𝑞𝞴 𝑧 𝑥 𝑙𝑜𝑔𝑝(𝑧)

𝑞𝞴(𝑧|𝑥)

Page 10: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

MAVENs: Implementation Details

• Datasets• SVHN (32 x 32 x 3) [street view digits]

• CIFAR10 (32 x 32 x 3) [outdoor natural images]

• Chest X-ray (128 x 128 x 1) [normal, bacterial and virus-pneumonia]

• Baselines: DC-GAN and VAE-GAN

• MAVENs with 2, 3, and 5 discriminators• Feedback as mean or random selection

• Merely with 10% training data with their corresponding label information

12/17/2019 Multi-Adversarial Variational Autoencoder Networks 10

Page 11: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

MAVENs: Evaluations

12/17/2019Multi-Adversarial Variational Autoencoder Networks

11

• Image quality• Fréchet Inception Distance (FID)

• Activation from pool3 of inception-v3 model

𝐹𝐼𝐷 = µ𝑑𝑎𝑡𝑎 − µ𝑓𝑎𝑘𝑒2+ 𝑇𝑟 𝞢𝑑𝑎𝑡𝑎 + 𝞢𝑓𝑎𝑘𝑒 − 2(𝞢𝑑𝑎𝑡𝑎𝞢𝑓𝑎𝑘𝑒)

1/2

• Descriptive Distribution Distance (DDD)

• Comparing first four moments of the two distributions

𝐷𝐷𝐷 =

𝑖=1

𝑖=4

−𝑙𝑜𝑔𝑤𝑖 µ𝑖 𝑑𝑎𝑡𝑎 − µ𝑖 𝑓𝑎𝑘𝑒

• Classification• Overall accuracy

• Class-wise F1 scoring

𝐹1 =2 ∗𝑝𝑟𝑒𝑐𝑖𝑠𝑖𝑜𝑛 ∗𝑟𝑒𝑐𝑎𝑙𝑙

𝑝𝑟𝑒𝑐𝑖𝑠𝑖𝑜𝑛+𝑟𝑒𝑐𝑎𝑙𝑙

Page 12: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

Model FID Score DDD Score

DC-GAN 16.789±0.303 0.343

VAE-GAN 13.252±0.001 0.329

MAVEN-mean2D 11.675±0.001 0.309

MAVEN-mean3D 11.515±0.065 0.300

MAVEN-mean5D 10.909±0.001 0.294

MAVEN-rand2D 11.384±0.001 0.316

MAVEN-rand3D 10.791±0.029 0.357

MAVEN-rand5D 11.052±0.751 0.323

12/17/2019 Multi-Adversarial Variational Autoencoder Networks 12

SVHN Results: Generated Samples

MAVEN-mean2D MAVEN-mean3D MAVEN-mean5D

MAVEN-rand2D MAVEN-rand3D MAVEN-rand5D

Page 13: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

SVHN Results: Classification

12/17/2019Multi-Adversarial Variational Autoencoder Networks

13

Model Acc F1 Scores0 1 2 3 4 5 6 7 8 9

DC-GAN 0.876 0.860 0.920 0.890 0.840 0.890 0.870 0.830 0.890 0.820 0.840

VAE-GAN 0.901 0.900 0.940 0.930 0.860 0.920 0.900 0.860 0.910 0.840 0.850

MAVEN-

mean2D

0.909 0.890 0.930 0.940 0.890 0.930 0.900 0.870 0.910 0.870 0.890

MAVEN-

mean3D

0.909 0.910 0.940 0.940 0.870 0.920 0.890 0.870 0.920 0.870 0.860

MAVEN-

mean5D

0.905 0.910 0.930 0.930 0.870 0.930 0.900 0.860 0.910 0.860 0.870

MAVEN-

rand2D

0.905 0.910 0.930 0.940 0.870 0.930 0.890 0.860 0.920 0.850 0.860

MAVEN-

rand3D

0.907 0.890 0.910 0.920 0.870 0.900 0.870 0.860 0.900 0.870 0.890

MAVEN-

rand5D

0.903 0.910 0.930 0.940 0.860 0.910 0.890 0.870 0.920 0.850 0.870

Page 14: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

CIFAR10 Results: Generated Samples

12/17/2019Multi-Adversarial Variational Autoencoder Networks

14

Model FID Score DDD Score

DC-GAN 61.293±0.209 0.265

VAE-GAN 15.511±0.125 0.224

MAVEN-mean2D 12.743±0.242 0.223

MAVEN-mean3D 11.316±0.808 0.190

MAVEN-mean5D 12.123±0.140 0.207

MAVEN-rand2D 12.820±0.584 0.194

MAVEN-rand3D 12.620±0.001 0.202

MAVEN-rand5D 18.509±0.001 0.215

MAVEN-mean2D MAVEN-mean3D MAVEN-mean5D

MAVEN-rand2D MAVEN-rand3D MAVEN-rand5D

Page 15: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

CIFAR10 Results: Classification

12/17/2019 Multi-Adversarial Variational Autoencoder Networks 15

Model Acc F1 Scoresairplan autom

o

bird cat deer dog frog horse ship truck

DC-GAN 0.713 0.,760 0.840 0.560 0.510 0.660 0.590 0.780 0.780 0.810 0.810

VAE-

GAN

0.743 0.770 0.850 0.640 0.560 0.690 0.620 0.820 0.770 0.860 0.830

MAVEN-

mean2D

0.761 0.800 0.860 0.650 0.590 0.750 0.680 0.810 0.780 0.850 0.850

MAVEN-

mean3D

0.759 0.770 0.860 0.670 0.580 0.700 0.690 0.800 0.810 0.870 0.830

MAVEN-

mean5D

0.771 0.800 0.860 0.650 0.610 0.710 0.640 0.810 0.790 0.880 0.820

MAVEN-

rand2D

0.757 0.780 0.860 0.650 0.530 0.720 0.650 0.810 0.800 0.870 0.860

MAVEN-

rand3D

0.756 0.780 0.860 0.640 0.580 0.720 0.650 0.830 0.800 0.870 0.830

MAVEN-

rand5D

0.762 0.810 0.850 0.680 0.600 0.720 0.660 0.840 0.800 0.850 0.820

Page 16: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

Model Vs Real Distributions: Good Match

12/17/2019 Multi-Adversarial Variational Autoencoder Networks 16

SVHN CIFAR10

Page 17: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

CXR Results: Generated Samples

12/17/2019Multi-Adversarial Variational Autoencoder Networks

17

Model FID Score DDD Score

DC-GAN 152.511±0.370 0.145

VAE-GAN 141.422±0.580 0.107

MAVEN-mean2D 141.339±0.420 0.138

MAVEN-mean3D 140.865±0.983 0.018

MAVEN-mean5D 147.316±1.169 0.100

MAVEN-rand2D 154.501±0.345 0.038

MAVEN-rand3D 158.749±0.297 0.179

MAVEN-rand5D 152.778±1.254 0.180

MAVEN-mean2D MAVEN-mean3D MAVEN-mean5D

MAVEN-rand2D MAVEN-rand3D MAVEN-rand5D

Page 18: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

CXR Results: Classification

12/17/2019 Multi-Adversarial Variational Autoencoder Networks 18

Page 19: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

Model Vs Real Distributions: Not-So-Good Match

12/17/2019 Multi-Adversarial Variational Autoencoder Networks 19

CXR

Page 20: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

Conclusions & Future Work

Significance

New generative model

Improved image quality and classification

Evaluation measure for deep generative models

Limitation

Performance for medical image data

Execution time

What’s Next

Hyper-parameters for medical images

Constrained generation

Complex image analysis tasks

Generative multi-tasking

12/17/2019 Multi-Adversarial Variational Autoencoder Networks 20

Page 21: MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKSweb.cs.ucla.edu/~aimran/icmla2019_presentation.pdf · 12/17/2019 Multi-Adversarial Variational Autoencoder Networks 3 Deep Generative

MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKS

Abdullah-Al-Zubaer Imran and Demetri Terzopoulos

UCLA Computer Graphics & Vision Laboratory

Questions?