MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER aimran/icmla2019_presentation.pdf · PDF file...

Click here to load reader

  • date post

    29-Jun-2020
  • Category

    Documents

  • view

    0
  • download

    0

Embed Size (px)

Transcript of MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER aimran/icmla2019_presentation.pdf · PDF file...

  • MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKS

    Abdullah-Al-Zubaer Imran and Demetri Terzopoulos

    UCLA Computer Graphics & Vision Laboratory

  • • Limited training data

    • Leveraging by unlabeled examples

    • Unsupervised learning

    • Semi-supervised learning

    • Latent code helps understand the models

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks 2

    Why Generative Modeling?

  • 12/17/2019 Multi-Adversarial Variational Autoencoder Networks

    3

    Deep Generative Models: VAE

    [Kingma et al. 2013]

    • Likelihood maximization

    • Encoder: variational inference

    • Decoder: sample generation

    • Efficient variational inference

    • Blurry samples

    • Re-parameterization: z = μ + σ⊙ϵ, ϵ∼Normal(0,1)

    • Losses • Reconstruction: 𝐸𝑞(𝑧|𝑥) log 𝑝 𝑥 𝑧

    • Regularization: −𝐾𝐿(𝑞(𝑧|𝑥)||𝑝(𝑧))

  • 12/17/2019 Multi-Adversarial Variational Autoencoder Networks

    4

    Deep Generative Models: GAN

    [Goodfellow et al. 2014, Radford et al. 2015]

    • Mini-max game

    • Generator maps latent variables to data samples

    • Discriminator distinguishes generated and real samples

    • Sharpest image generation

    • Unstable and difficult to optimize

    • Losses

    𝐷 𝑚𝑎𝑥𝑉 𝐷 = 𝐸𝑥~𝑝𝑑𝑎𝑡𝑎(𝑥) 𝑙𝑜𝑔𝐷 𝑥 + 𝐸𝑥~𝑝𝑧(𝑧) log(1 − 𝐷 𝐺(𝑧)

    𝐺 𝑚𝑖𝑛𝑉 𝐺 = 𝐸𝑥~𝑝𝑧(𝑧) log(1 − 𝐷 𝐺(𝑧)

  • 12/17/2019 Multi-Adversarial Variational Autoencoder Networks

    5

    Deep Generative Models: PixelRNN

    [Oord et al. 2016]

    • Autoregressive model

    • Simple and stable training process

    • Inefficient sampling

    • Assign probability to every pixel in the image

    • Softmax loss

    𝑝 𝑥 = ෍

    𝑖=1

    𝑛2

    𝑝(𝑥𝑖|𝑥1, … , 𝑥𝑖−1)

    𝑝 𝑥𝑖, 𝑅|𝑥

  • Aim

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks

    6

    VAE-GAN

    PixelGAN Autoencoder

    • Improving the deep generative models

    • Evaluation measures

    [Larsen et al. 2016, Makhzani et al. 2017]

    Primary Aim

    (Efficient and stable generative modeling

    for medical image analysis)

    ✓ Combining generative models

    ✓ High quality image generation

    ✓ Learning from limited labeled data

  • Proposed: MAVENs • Highlights

    • Ensemble of multiple discriminators in VAE-GAN • Joint image generation and classification

    • Motivation • Instability in generative models

    • Mode collapsed generation

    • Poor image quality in VAE

    • Small labeled data

    • Objective • Improve samples and semi-supervised classification

    • Unified generative model

    • Variational inference with adversarial learning

    12/17/2019 7

    Basic comparisons of MAVEN with GAN, VAE, and VAE-GAN

    Multi-Adversarial Variational Autoencoder Networks

    GAN-mode collapsed generation

  • MAVENs: Architecture

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks

    8

  • MAVENs: Objectives

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks 9

    Discriminator Loss

    Supervised:

    𝐿𝐷𝑠𝑢𝑝𝑒𝑟𝑣𝑖𝑠𝑒𝑑 = − 𝝚𝑥,𝑦~𝑝𝑑𝑎𝑡𝑎log[𝑝 𝑦 = 𝑖 𝑥, 𝑖 < 𝑛 + 1])]

    Unsupervised:

    𝐿𝐷𝑟𝑒𝑎𝑙 = − 𝝚𝑥~𝑝𝑑𝑎𝑡𝑎 log[1 − 𝑝 𝑦 = 𝑛 + 1 𝑥)]

    𝐿𝐷𝑓𝑎𝑘𝑒_𝐺 = − 𝝚 ො𝑥~𝐺 log[𝑝 𝑦 = 𝑛 + 1 ෝ𝑥)]

    𝐿𝐷𝑓𝑎𝑘𝑒_𝐸 = − 𝝚 ෤𝑥~𝐺 log[𝑝 𝑦 = 𝑛 + 1 ෤𝑥)]

    Generator Loss

    𝐿𝐺𝑓𝑎𝑘𝑒_𝐺 = − 𝝚 ො𝑥~𝐺 log[1 − 𝑝 𝑦 = 𝑛 + 1 ෝ𝑥)]

    𝐿𝐺𝑓𝑎𝑘𝑒_𝐸 = − 𝝚 ෤𝑥~𝐺 log[1 − 𝑝 𝑦 = 𝑛 + 1 ෤𝑥)]

    𝐿𝐺𝑓𝑒𝑎𝑡𝑢𝑟𝑒 = 𝝚𝑥~𝑝𝑑𝑎𝑡𝑎𝑓 𝑥 − 𝝚 ො𝑥~𝐺𝑓(ො𝑥) 2 2

    Encoder Loss

    𝐿𝐸𝑓𝑒𝑎𝑡𝑢𝑟𝑒 = 𝝚𝑥~𝑝𝑑𝑎𝑡𝑎𝑓 𝑥 − 𝝚 ෤𝑥~𝐺𝑓( ෤𝑥) 2 2

    𝐿𝐸𝐾𝐿 = −𝝚𝑞𝞴 𝑧 𝑥 𝑙𝑜𝑔 𝑝(𝑧)

    𝑞𝞴(𝑧|𝑥)

  • MAVENs: Implementation Details

    • Datasets • SVHN (32 x 32 x 3) [street view digits]

    • CIFAR10 (32 x 32 x 3) [outdoor natural images]

    • Chest X-ray (128 x 128 x 1) [normal, bacterial and virus-pneumonia]

    • Baselines: DC-GAN and VAE-GAN

    • MAVENs with 2, 3, and 5 discriminators • Feedback as mean or random selection

    • Merely with 10% training data with their corresponding label information

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks 10

  • MAVENs: Evaluations

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks

    11

    • Image quality • Fréchet Inception Distance (FID)

    • Activation from pool3 of inception-v3 model

    𝐹𝐼𝐷 = µ𝑑𝑎𝑡𝑎 − µ𝑓𝑎𝑘𝑒 2 + 𝑇𝑟 𝞢𝑑𝑎𝑡𝑎 + 𝞢𝑓𝑎𝑘𝑒 − 2(𝞢𝑑𝑎𝑡𝑎𝞢𝑓𝑎𝑘𝑒)

    1/2

    • Descriptive Distribution Distance (DDD)

    • Comparing first four moments of the two distributions

    𝐷𝐷𝐷 = ෍

    𝑖=1

    𝑖=4

    −𝑙𝑜𝑔𝑤𝑖 µ𝑖 𝑑𝑎𝑡𝑎 − µ𝑖 𝑓𝑎𝑘𝑒

    • Classification • Overall accuracy • Class-wise F1 scoring

    𝐹1 = 2 ∗𝑝𝑟𝑒𝑐𝑖𝑠𝑖𝑜𝑛 ∗𝑟𝑒𝑐𝑎𝑙𝑙

    𝑝𝑟𝑒𝑐𝑖𝑠𝑖𝑜𝑛+𝑟𝑒𝑐𝑎𝑙𝑙

  • Model FID Score DDD Score

    DC-GAN 16.789±0.303 0.343

    VAE-GAN 13.252±0.001 0.329

    MAVEN-mean2D 11.675±0.001 0.309

    MAVEN-mean3D 11.515±0.065 0.300

    MAVEN-mean5D 10.909±0.001 0.294

    MAVEN-rand2D 11.384±0.001 0.316

    MAVEN-rand3D 10.791±0.029 0.357

    MAVEN-rand5D 11.052±0.751 0.323

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks 12

    SVHN Results: Generated Samples

    MAVEN-mean2D MAVEN-mean3D MAVEN-mean5D

    MAVEN-rand2D MAVEN-rand3D MAVEN-rand5D

  • SVHN Results: Classification

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks

    13

    Model Acc F1 Scores 0 1 2 3 4 5 6 7 8 9

    DC-GAN 0.876 0.860 0.920 0.890 0.840 0.890 0.870 0.830 0.890 0.820 0.840

    VAE-GAN 0.901 0.900 0.940 0.930 0.860 0.920 0.900 0.860 0.910 0.840 0.850

    MAVEN-

    mean2D

    0.909 0.890 0.930 0.940 0.890 0.930 0.900 0.870 0.910 0.870 0.890

    MAVEN-

    mean3D

    0.909 0.910 0.940 0.940 0.870 0.920 0.890 0.870 0.920 0.870 0.860

    MAVEN-

    mean5D

    0.905 0.910 0.930 0.930 0.870 0.930 0.900 0.860 0.910 0.860 0.870

    MAVEN-

    rand2D

    0.905 0.910 0.930 0.940 0.870 0.930 0.890 0.860 0.920 0.850 0.860

    MAVEN-

    rand3D

    0.907 0.890 0.910 0.920 0.870 0.900 0.870 0.860 0.900 0.870 0.890

    MAVEN-

    rand5D

    0.903 0.910 0.930 0.940 0.860 0.910 0.890 0.870 0.920 0.850 0.870

  • CIFAR10 Results: Generated Samples

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks

    14

    Model FID Score DDD Score

    DC-GAN 61.293±0.209 0.265

    VAE-GAN 15.511±0.125 0.224

    MAVEN-mean2D 12.743±0.242 0.223

    MAVEN-mean3D 11.316±0.808 0.190

    MAVEN-mean5D 12.123±0.140 0.207

    MAVEN-rand2D 12.820±0.584 0.194

    MAVEN-rand3D 12.620±0.001 0.202

    MAVEN-rand5D 18.509±0.001 0.215

    MAVEN-mean2D MAVEN-mean3D MAVEN-mean5D

    MAVEN-rand2D MAVEN-rand3D MAVEN-rand5D

  • CIFAR10 Results: Classification

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks 15

    Model Acc F1 Scores airplan autom

    o

    bird cat deer dog frog horse ship truck

    DC-GAN 0.713 0.,760 0.840 0.560 0.510 0.660 0.590 0.780 0.780 0.810 0.810

    VAE-

    GAN

    0.743 0.770 0.850 0.640 0.560 0.690 0.620 0.820 0.770 0.860 0.830

    MAVEN-

    mean2D

    0.761 0.800 0.860 0.650 0.590 0.750 0.680 0.810 0.780 0.850 0.850

    MAVEN-

    mean3D

    0.759 0.770 0.860 0.670 0.580 0.700 0.690 0.800 0.810 0.870 0.830

    MAVEN-

    mean5D

    0.771 0.800 0.860 0.650 0.610 0.710 0.640 0.810 0.790 0.880 0.820

    MAVEN-

    rand2D

    0.757 0.780 0.860 0.650 0.530 0.720 0.650 0.810 0.800 0.870 0.860

    MAVEN-

    rand3D

    0.756 0.780 0.860 0.640 0.580 0.720 0.650 0.830 0.800 0.870 0.830

    MAVEN-

    rand5D

    0.762 0.810 0.850 0.680 0.600 0.720 0.660 0.840 0.800 0.850 0.820

  • Model Vs Real Distributions: Good Match

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks 16

    SVHN CIFAR10

  • CXR Results: Generated Samples

    12/17/2019 Multi-Adversarial Variational Autoencoder Networks

    17

    Model FID Score DDD Score

    DC-GAN 152.511±0.370 0.145

    VAE-GAN 141.422±0.580 0.107

    MAVEN-mean2D 141.339±0.420 0.138

    MAVEN-mean3D 140.865±0.983 0.018

    MAVEN-mean5D 147.316±1.169 0.100

    M