Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili...

15
Hierarchical Disentangled Representations Babak Esmaeili Northeastern University [email protected] Hao Wu Northeastern University [email protected] Sarthak Jain Northeastern University [email protected] Siddharth Narayanaswamy University of Oxford [email protected] Brooks Paige Alan Turing Institute University of Cambridge [email protected] Jan-Willem van de Meent Northeastern University [email protected] Abstract Deep latent-variable models learn representa- tions of high-dimensional data in an unsuper- vised manner. A number of recent efforts have focused on learning representations that disen- tangle statistically independent axes of varia- tion, often by introducing suitable modifica- tions of the objective function. We synthesize this growing body of literature by formulating a generalization of the evidence lower bound that explicitly represents the trade-offs between sparsity of the latent code, bijectivity of repre- sentations, and coverage of the support of the empirical data distribution. Our objective is also suitable to learning hierarchical representa- tions that disentangle blocks of variables whilst allowing for some degree of correlations within blocks. Experiments on a range of datasets demonstrate that learned representations con- tain interpretable features, are able to learn dis- crete attributes, and generalize to unseen com- binations of factors. 1 Introduction Deep generative models represent data x using a low- dimensional set of latent variables z (sometimes referred to as a code). The relationship between x and z is de- scribed by a conditional probability distribution p θ (x|z) parameterized by a deep neural network. These mod- els have seen much recent success in training generative models that can simulate high-fidelity representations of complex data such as images [Gatys et al., 2015; Gulra- jani et al., 2017], audio [Oord et al., 2016], and language [Bowman et al., 2016]. The smooth low-dimensional z can be used as a compressed representation for down- stream tasks such as text classification [Xu et al., 2017], Bayesian optimization [G ´ omez-Bombarelli et al., 2018; Kusner et al., 2017], and lossy image compression [Theis et al., 2017]. The setting in which an approximate poste- rior distribution q φ (z|x) is learned simultaneously to the generative model is known as a variational autoencoder (VAE), where q φ (z|x) and p θ (x|z) represent probabilis- tic encoders and decoders respectively. While deep generative models often provide high-fidelity reconstructions, the representation z is generally not di- rectly amenable to human interpretation. In contrast to classical linear methods such as principal component or factor analysis, individual dimensions of z do not nec- essarily encode any particular semantically meaningful variation in x. This has motivated a search for ways of learning disentangled representations, where pertur- bations of individual dimensions of a the latent code z perturb the corresponding x in an interpretable manner. Recent work in learning disentangled representations us- ing deep generative models has broadly followed two approaches, one (semi-)supervised and one unsupervised. In the supervised or semi-supervised setting, a generative model is specified which allows for inclusion of prior information on what may constitute a factor of variation. This supervision may take the form of partitioning the data into subsets which vary only along some particular qualitative dimension [Kulkarni et al., 2015; Bouchacourt et al., 2017], or may take the form of explicit labels of particular sources of variation for some or all of the data [Kingma et al., 2014; Siddharth et al., 2017]. In this latter work, the overall training objective is largely kept the same, and instead the focus is on defining structured VAEs [Johnson et al., 2016] that incorporating a graphical model structure which partitions the latent variables z into interpretable and uninterpretable subsets of dimen- arXiv:1804.02086v1 [stat.ML] 6 Apr 2018

Transcript of Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili...

Page 1: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

Hierarchical Disentangled Representations

Babak EsmaeiliNortheastern University

[email protected]

Hao WuNortheastern University

[email protected]

Sarthak JainNortheastern University

[email protected]

Siddharth NarayanaswamyUniversity of Oxford

[email protected]

Brooks PaigeAlan Turing Institute

University of [email protected]

Jan-Willem van de MeentNortheastern University

[email protected]

Abstract

Deep latent-variable models learn representa-tions of high-dimensional data in an unsuper-vised manner. A number of recent efforts havefocused on learning representations that disen-tangle statistically independent axes of varia-tion, often by introducing suitable modifica-tions of the objective function. We synthesizethis growing body of literature by formulatinga generalization of the evidence lower boundthat explicitly represents the trade-offs betweensparsity of the latent code, bijectivity of repre-sentations, and coverage of the support of theempirical data distribution. Our objective isalso suitable to learning hierarchical representa-tions that disentangle blocks of variables whilstallowing for some degree of correlations withinblocks. Experiments on a range of datasetsdemonstrate that learned representations con-tain interpretable features, are able to learn dis-crete attributes, and generalize to unseen com-binations of factors.

1 Introduction

Deep generative models represent data x using a low-dimensional set of latent variables z (sometimes referredto as a code). The relationship between x and z is de-scribed by a conditional probability distribution pθ(x|z)parameterized by a deep neural network. These mod-els have seen much recent success in training generativemodels that can simulate high-fidelity representations ofcomplex data such as images [Gatys et al., 2015; Gulra-jani et al., 2017], audio [Oord et al., 2016], and language[Bowman et al., 2016]. The smooth low-dimensional z

can be used as a compressed representation for down-stream tasks such as text classification [Xu et al., 2017],Bayesian optimization [Gomez-Bombarelli et al., 2018;Kusner et al., 2017], and lossy image compression [Theiset al., 2017]. The setting in which an approximate poste-rior distribution qφ(z|x) is learned simultaneously to thegenerative model is known as a variational autoencoder(VAE), where qφ(z|x) and pθ(x|z) represent probabilis-tic encoders and decoders respectively.

While deep generative models often provide high-fidelityreconstructions, the representation z is generally not di-rectly amenable to human interpretation. In contrast toclassical linear methods such as principal component orfactor analysis, individual dimensions of z do not nec-essarily encode any particular semantically meaningfulvariation in x. This has motivated a search for waysof learning disentangled representations, where pertur-bations of individual dimensions of a the latent code zperturb the corresponding x in an interpretable manner.

Recent work in learning disentangled representations us-ing deep generative models has broadly followed twoapproaches, one (semi-)supervised and one unsupervised.In the supervised or semi-supervised setting, a generativemodel is specified which allows for inclusion of priorinformation on what may constitute a factor of variation.This supervision may take the form of partitioning thedata into subsets which vary only along some particularqualitative dimension [Kulkarni et al., 2015; Bouchacourtet al., 2017], or may take the form of explicit labels ofparticular sources of variation for some or all of the data[Kingma et al., 2014; Siddharth et al., 2017]. In thislatter work, the overall training objective is largely keptthe same, and instead the focus is on defining structuredVAEs [Johnson et al., 2016] that incorporating a graphicalmodel structure which partitions the latent variables zinto interpretable and uninterpretable subsets of dimen-

arX

iv:1

804.

0208

6v1

[st

at.M

L]

6 A

pr 2

018

Page 2: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

Figure 1: Correspondence between the generative modelpθ(x, z) and the inference model qφ(z,x) in variationalautoencoders. The generative model combines a priorp(z) over latent variables with a likelihood pθ(x | z).In the inference model, a distribution q(x) over a fi-nite sample set approximates an unknown data distribu-tion. The VAE objective minimizes the KL divergencebetween qφ(z,x) and pθ(x, z), which means that theinference marginal qφ(z) must match the prior, the gen-erative marginal pθ(x) must match the empirical datadistribution q(x), and the generative posterior pθ(x | z)must match the encoder distribution qφ(z | x).

sions, and makes explicit any hierarchical or structuredrelationship between these latent variables.

In contrast, unsupervised methods for learning disentan-gled factors don’t require specification of which aspectsof variation in the data we may wish to extract. Instead,these methods modify the objective function, penaliz-ing specific terms in order to induce representations inwhich latent variables naturally coincide with precon-ceived notions of disentangled factors. The β-VAE [Hig-gins et al., 2016] modifies the VAE objective to encour-age independence between the dimensions of the latentz. Three recent papers [Kim and Mnih, 2018; Gao et al.,2018; Chen et al., 2018] all aim to address limitationsof the β-VAE by specifically focusing on optimizing theterm corresponding to the total correlation, a divergenceKL (qφ(z) ||

∏d qφ(zd)); although the approaches differ,

they share the common goal of aiming to induce inter-pretable, disentangled representations by minimizing thisterm, thus encouraging independence across the dimen-sions of z.

As a further illustrative example, consider disentanglingMNIST and Google Street-View House Numbers (SVHN)

images into constituent digit and “style”, an abstract vari-able which represents any aspects of the image not cap-tured by the number itself. This is usually achieved bysupervising the latent variable which corresponds to thedigit and learning the style in an unsupervised manner[Kingma et al., 2014; Siddharth et al., 2017]. Alterna-tively, we could consider an unsupervised approach wherewe encode some one-hot encoded notion of digit as a setof dimensions of z subject to a sparsity constraint; that is,a “digit” (or “identity”) is simply some representation forwhich only one entry is active at a time.

In this paper we reinterpret the standard VAE objectiveas a KL divergence between a generative model and itscorresponding inference model. Doing so enables us toboth synthesize various generalizations of the VAE objec-tive and more clearly highlight the trade-offs associatedwith its optimization. Like recent approaches by Kim andMnih [2018], Chen et al. [2018], Gao et al. [2018], weidentify minimization of the total correlation as a meansof inducing disentangled representations. We additionallyderive a hierarchical decomposition of the the variationallower bound (named Hierarchically Factorized VAE orHFVAE) that enables use to induce different levels ofstatistical independence between groups of variables andbetween individual variables in the same group.

We evaluate our methodology on a variety of datasetsincluding dSprites, MNIST, FMNIST and CelebA. Quali-tative evaluation shows that our objective indeed uncov-ers interpretable features, whereas quantitative metricsdemonstrate improvements over the state of the art. More-over we show that the learned disentangled representa-tions can recover combinations of features that were notpresent in the training set.

2 Background

Variational autoencoders (VAEs) jointly optimize twomodels. The generative model defines a distribution ona set of latent variables z and observed data x in termsof a prior p(z) and a likelihood pθ(x | z), which is oftenreferred to as the decoder model,

pθ(x, z) = pθ(x | z)p(z). (1)

This distribution is estimated in tandem with an encoder, aconditional distribution qφ(z | x) which performs approx-imate inference in this model. This relationship betweenthe generative model and the inference model motivatethe “autoencoder” view of deep generative models.

If we denote the empirical data distribution as q(x), thentogether with the encoder the inference model defines a

Page 3: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

L(θ, φ) = Eqφ(z,x)[log

pθ(x, z)

qφ(z,x)

]= Eqφ(z,x)

[log

pθ(x, z)

pθ(x)p(z)+ log

qφ(z)q(x)

qφ(z,x)+ log

pθ(x)

q(x)+ log

p(z)

qφ(z)

],

= Eqφ(z,x)

[log

pθ(x | z)pθ(x)︸ ︷︷ ︸1

− logqφ(z | x)qφ(z)︸ ︷︷ ︸2

]− KL (q(x) || pθ(x))︸ ︷︷ ︸

3

−KL (qφ(z) || p(z))︸ ︷︷ ︸4

.

Figure 2: ELBO decomposition. The VAE objective can be understood as minimizing the KL between a generativemodel pθ(x, z) = pθ(x | z)p(z) and an inference model qφ(z,x) = qφ(z | x)q(x). We can decompose this objectiveinto 4 terms. Term 1 , which can be intuitively thought of as the uniqueness of the reconstruction, is regularizedby the mutual information 2 , which represents the uniqueness of the encoding. Minimizing the KL in term 3 isequivalent to maximizing the marginal likelihood Eq(x)[log pθ(x)]. Combined maximization of 1 + 3 is equivalentto maximizing Eqφ(z,x)[log pθ(x | z)]. Term 4 matches the inference marginal qφ(z) to the prior p(z), which in turnensures realism for samples x ∼ pθ(x) from the generative model.

joint distribution qφ(z,x), with

qφ(z,x) := qφ(z | x)q(x),

q(x) :=1

N

N∑n=1

δxn(x).(2)

A VAE optimizes these two models using a single objec-tive L(θ, φ), known as the evidence lower bound (ELBO),

L(θ, φ) := 1

N

N∑n=1

Eqφ(z|xn)[log

pθ(xn, z)

qφ(z | xn)

], (3)

= Eq(x)[log pθ(x)− KL (qφ(z|x) || pθ(z|x))

],

≤ Eq(x)[log pθ(x)].

3 Objective Decomposition

A number of recent efforts have considered alternate de-compositions of the VAE objective, both to formulate abetter understanding of the trade-offs in optimizing VAEarchitectures and to identify generalizations that help in-duce desirable features such as disentangled representa-tions. In order to summarize these efforts, we propose toexpress the VAE objective as a KL divergence

L(θ, φ) := −KL (qφ(z,x) || pθ(x, z)) , (4)

= Eqφ(z,x)[pθ(x, z)

qφ(z | x)

]− Eq(x)[log q(x)].

This definition is equivalent to previous definitions sinceit only differs from the expression in Equation (3) bya constant, which is the entropy of the empirical datadistribution q(x).

H(x) = −Eq(x)[log q(x)] = logN. (5)

An advantage of this interpretation as a KL divergence isthat it becomes more apparent what it means to optimizethe objective with respect to the generative model param-eters θ and the inference model parameters φ (Figure 1).In particular, it is clear that the KL is minimized whenpθ(x, z) = qφ(z,x), which in turn implies equality forthe marginal pθ(x) = q(x). It also implies qφ(z) = p(z)for the marginal of the inference model,

qφ(z) =

∫qφ(z,x) dx =

1

N

N∑n=1

qφ(z | xn). (6)

To better understand the trade-offs involved in optimizingthe VAE objective, we can perform a decomposition (Fig-ure 2) similar to the one obtained by Hoffman and Johnson[2016]. This decomposition contains 4 terms. Terms 3and 4 enforce consistency between the marginal distri-butions over x and z. Minimizing the KL in term 3 max-imizes the marginal likelihood Eq(x)[log pθ(x)], whereasminimizing 4 ensures that the inference marginal qφ(z)approximates the prior p(z).

Terms 1 and 2 enforce consistency between the condi-tional distributions. Intuitively speaking, term 1 maxi-mizes the identifiability of the values z that generate eachxn; when we sample z ∼ qφ(z | xn), then the likelihoodpθ(x

n | z) under the generative model should be higherthan the marginal likelihood pθ(xn). Term 2 regularizesterm 1 by minimizing the mutual information I(z;x)between x and z in the inference model, which meansthat qφ(z | xn) maps each xn to less identifiable values.

Note that evaluation of term 1 is intractable in practice,since we are not able to pointwise evaluate pθ(x). Wecan circumvent this intractability by combining 1 + 3

Page 4: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

Figure 3: Illustration of the role of each of the term in the decomposition from Figure 2. Each figure shows the effectof removing one term from the objective. A: Removing 1 means that we no longer require a unique z for each xn.Term 2 will then minimize I(x; z) which means that each xn is mapped to the prior. B: Removing 2 eliminates theconstraint that I(x; z) must be small under the inference model, causing each xn to be mapped to a smaller regionin z space. C: Removing 3 eliminates the constraint that pθ(x) must match q(x). D: Removing 4 eliminates theconstraint that qφ(z) must match p(z).

into a single term, which recovers the reconstruction error

argmaxθ,φ

Eqφ(z,x)[log

pθ(x | z)pθ(x)

+ logpθ(x)

q(x)

]= argmax

θ,φEqφ(z,x) [log pθ(x | z)] .

Given that the terms 2 , 3 and 4 are bounded fromabove by 0, we can interpret the decomposition in Figure 2as a Lagrangian relaxation of the constrained optimizationproblem (see, e.g., Alemi et al. [2016])

maxθ,φ

Eqφ(x,z)[log

pθ(x | z)pθ(x)

], 1

s.t. I(x; z) = 0, 2

KL (q(x) || pθ(x)) = 0, 3

KL (qφ(z) || p(z)) = 0. 4

for which we can define the Lagrangian

L(θ, φ) = 1 − λ2 2 − λ3 3 − λ4 4 . (7)

We can now adjust each of the Lagrange multipliersλ2, λ3, λ4 in order to control which of our constraintswe are willing to relax more, and which of our con-straints we would like to relax less. To build intuition forwhat it means to relax each of these constraints, Figure 3shows the effect of removing each term from the objec-tive. When we remove 3 or 4 we can learn models inwhich pθ(x) deviates from q(x) or pφ(z) deviates fromp(z). When we remove 1 we remove the requirementthat pθ(xn | z) should be higher when z ∼ qφ(z | xn)than when z ∼ p(z). Provided the decoder model is suffi-ciently expressive, we will then learn a generative model

that ignores the latent code z. This type of solution doesin fact arise in certain cases, even when 1 is includedin the objective, particularly when using auto-regressivedecoder architectures [Chen et al., 2016].

When we remove 2 , we learn a model that minimizesoverlap between encoder distributions qφ(z | xn) whenconditioned on different data points xn in order to maxi-mize 1 . This maximizes the mutual information I(x; z),which is bounded from above by logN . In practice 2often saturates to logN , even when included in the objec-tive, which suggests that maximizing 1 outweighs thiscost, at least for the encoder/decoder architectures thatare commonly considered in present-day models.

4 Hierarchically Factorized VAEs

In the context of this paper, we are interested in using thedecomposition from Figure 2 to define an objective thatwill induce disentangled representations by encouragingstatistical independence between features. The β-VAEobjective [Higgins et al., 2016] aims to achieve this goalby defining the Lagrangian relaxation

Lβ-VAE(θ, φ) = Eq(x)[Eqφ(z|x) [log pθ(x|z)]

− βKL (qφ(z|x) || p(z))].

We can express this objective in the terms of Figure 2 as

Lβ-VAE(θ, φ) = 1 + 3 + β(

2 + 4).

In order to induce disentangled representations, the au-thors set β > 1. This works well in certain cases, but

Page 5: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

−KL (qφ(z) || p(z)) = −Eqφ(z)[log

qφ(z)

p(z)

]= −Eqφ(z)

[log

qφ(z)∏d qφ(zd)

+ log

∏d qφ(zd)∏d p(zd)

+ log

∏d p(zd)

p(z)

]

= Eqφ(z)[log

p(z)∏d p(zd)

− logqφ(z)∏d qφ(zd)

]︸ ︷︷ ︸

A

−∑d

KL (qφ(zd) || p(zd))︸ ︷︷ ︸B

.

−KL (qφ(zd,e) || p(zd,e)) = Eqφ(zd)[log

p(zd)∏e p(zd,e)

− logqφ(zd)∏e qφ(zd,e)

]︸ ︷︷ ︸

i

−∑e

KL (qφ(zd,e) || p(zd,e))︸ ︷︷ ︸ii

Figure 4: Hierarchical KL decomposition. We can decompose term 4 into subcomponents A and B . TermA matches the total correlation between variables in the inference model relative to the total correlation under the

generative model. Term B minimizes the KL divergence between the inference marginal and prior marginal for eachvariable zd. When the variables zd contain sub-variables zd,e, we can recursively decompose the KL on the marginalszd into term i , which matches the total correlation ii . Thus, it minimizes per-dimension KL divergence.

it has the drawback in that it also increases the strengthof 2 , which means that the encoder model may discardmore information about x in order to minimize the mutualinformation I(x; z).

4.1 KL decomposition

Looking at the β-VAE objective, it seems intuitive that in-creasing the weight of term 4 is likely to aid disentangle-ment. One notion of disentanglement is that there shouldbe a low degree of correlation between different latentvariables zd. If we choose the a prior p(z) =

∏d p(zd)

in which different variables are independent, then mini-mizing the KL term should induce an inference marginalqφ(z) =

∏d qφ(zd) in which zd are also independent.

As also noted by Kim and Mnih [2018] and Chen et al.[2018], we can introduce an additional level of decompo-sition for 4 in the objective (Figure 4). As with term 1+ 2 , the term A consists of two components. The sec-ond of these terms minimizes the total correlation (TC),which is a generalization of the mutual information tomore than two variables,

TC(z) = Eqφ(z)

[log

qφ(z)∏d qφ(zd)

],

= KL

(qφ(z)

∣∣∣∣∣∣∣∣∣∣∏d

qφ(zd)

).

(8)

Minimizing the total correlation means that we will learnqφ(z) in which different zd are statistically independent,thereby inducing disentanglement.

The first component of A is not present in the objectivesdefined by Kim and Mnih [2018] and Chen et al. [2018]. It

maximizes the probability of p(z) under the prior, relativeto the probability of the product of marginals

∏d p(zd).

Note that this form is analogous to 1 , which maximizespθ(x, z) relative to the product of marginals pθ(x)p(z).Maximizing A with respect to φ will match the totalcorrelation in q(z) to the total correlation in p(z). Whenp(z) =

∏d p(zd) we recover the term from Kim and

Mnih [2018] and Chen et al. [2018].

In cases where zd itself represents a group of variables,rather than a single variable, we can now continue to de-compose to another set of terms i and ii which matchthe total correlation for zd and the KL divergences forconstituent variables zd,e. We could in principle con-tinue this decomposition for any number of levels. Thisprovides an opportunity to induce hierarchies of disen-tangled features. For example, when modeling MNISTwe may wish to emphasize the constraint that the digitidentity (which is a discrete variable) should be uncor-related with variables that characterize the handwritingstyle. That said for other variables it is less obvious thatthere should be no correlations. For example, there couldbe a causal relationship between the pen stroke width andthe size of openings. In this case we may want to imposea higher level of regularization on the correlation betweendigit identity and style variables than on the correlationbetween style variables themselves.

4.2 Approximation of the Objective

In order to induce hierarchically factored representationswe will employ an objective of the form

L(θ, φ) = 1 + 3 + ii

+ α 2 + β A + γ i .(9)

Page 6: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

Paper Objective

Kingma and Welling [2013]1 + 2 + 3 + 4

Rezende et al. [2014]

Higgins et al. [2016] 1 + 3 + β(

2 + 4)

Kumar et al. [2017] 1 + 2 + 3 + λ 4

Zhao et al. [2017] 1 + 3 + λ 4

Gao et al. [2018]1 + 2 + 3 + 4− λ 2

a

Achille and Soatto [2018] 1 + 3 + β 2 + γ A∗

Kim and Mnih [2018] 1 + 2 + 3Chen et al. [2018] + B + β A

HFVAE (this paper)1 + 3 + ii

+ α 2 + β A + γ i

Table 1: Comparison of objectives in autoencoding deepgenerative models. The objective in this paper is mostclosely related to recent work by Kim and Mnih [2018]and Chen et al. [2018], but incorporates an additional levelof hierarchical decomposition. The asterisk A

∗ indicatesthat the prior factorizes, i.e. p(z) =

∏d p(zd). The no-

tation 2a

refers to restriction of the mutual information2 to a subset of ”Anchor” variables za.

In this objective α controls the amount of I(x; z) regu-larization. We include it for completeness, although wefind in practice that I(x; z) saturates in the cases that weconsider in our experimental evaluation. The term β con-trols the TC regularization between groups of variables,whereas γ controls the TC regularization for individualvariables.

In order to optimize this objective, we need to approxi-mate the inference marginals qφ(z), qφ(zd), and qφ(zd,e).Computing these quantities exactly requires a full passover the dataset, since qφ(z) is a mixture over all datapoints in the training set

qφ(z) =1

N

N∑n=1

qφ(z | xn). (10)

We approximate qφ(z) with a Monte Carlo estimate ofqφ(z) over the same batch of samples that we use toapproximate all other terms in the objective L(θ, φ). Forsimplicity we will consider the term

Eqφ(z,x)[log qφ(z)] '1

B

B∑b=1

log qφ(z(b)), (11)

where z(b) is sampled via the normal construction byselecting a batch of items x(b) and then sampling fromthe inference model

x(b) ∼ q(x),z(b) ∼ qφ(z | x(b)).

(12)

We define the estimate of qφ(z(b)) as (see Appendix A.1)

qφ(z(b)) =

1

Nqφ(z

(b) | x(b))

+N − 1

N(B − 1)

∑b′ 6=b

qφ(z(b) | x(b′)).

(13)

We can think of this approximation as a partially stratifiedsample, in which we deterministically include the termxn = x(b) and compute a Monte Carlo estimate overthe remaining terms, treating samples x(b

′) for b′ 6= b assamples from the distribution q(x | x 6= x(b)).

We now substitute log qφ(z) for log qφ(z) in Equa-tion (11). By Jensen’s inequality this yields a lower boundon our original expectation since

E

[1

B

B∑b=1

log qφ(z(b))

]≤ 1

B

B∑b=1

logE[qφ(z

(b))].

The fact that our approximation is a lower bound meansthat minimizing ii is equivalent to maximizing an upperbound (note that is 2 = −KL (qφ(z) || p(z))). Thatsaid, the estimator is consistent, which means that in prac-tice the bias is likely sufficient small given the batch sizesthat are needed to approximate the inference marginal(512-1024).

5 Related Work

This work is closely related to a number of recently pro-posed generalizations of VAE objectives, which we sum-marize in Table 1. As noted above, the β-VAE objective[Higgins et al., 2016] uses a multiplier β > 1 for theterms 2 and 4 . In settings where we are not primar-ily interested in inducing disentangled representations,β-VAE objective has also been used with β < 1 in or-der to increase the quality of reconstructions (see, e.g.,[Alemi et al., 2016; Engel et al., 2017; Liang et al., 2018]).This also decreases the relative weight of 2 , but this inpractice does not influence the learned representation incases where I(x; z) saturates anyway.

Zhao et al. [2017] consider an objective that eliminatesthe mutual information 2 entirely and assigns an addi-tional weight to the KL divergence between qφ(z) andp(z) in 4 . Kumar et al. [2017] approximates the KL di-vergence in 4 by matching the covariance of qφ(z) and

Page 7: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

Figure 5: MNIST and FMNIST reconstructions for zdvalues ranging from -3 to 3. Rows contain both differentsamples from the dataset and different dimensions d.

p(z). Recent work by Gao et al. [2018] connects VAEsto the principle of correlation explanation, and defines anobjective that reduces the mutual information regulariza-tion in 2 for a subset of ”Anchor” variables za. Achilleand Soatto [2018] approach VAEs from an informationbottleneck perspective and introduce a TC term into theobjective.

The KL decomposition in Figure 4 is very similar to theone that was recently introduced by Kim and Mnih [2018]and Chen et al. [2018]. It induce disentangled represen-tations by increasing the weight of A to minimize thetotal correlation. Relative to these approaches two wehere take a slightly more general perspective:

1. We impose a hierarchical structure on our model byidentifying groups of variables zd, which allows us tocontrol the weight of A relative to the weight of termi to ensure that the total correlation between groups

of variables is enforced more rigorously than the totalcorrelation within groups of variables.

Informative Unused

I(x;z d)

0 20 40 60 80 1000

50100150200250

0 20 40 60 80 10002468

1012

Figure 6: The HFVAE objective prunes unnecessary fea-tures. Shown is the mutual information 2 for each in-dividual dimension I(x; zd). For uninformative dimen-sions (right) the mutual information decreases to 0.

Model FactorVAE Eastwood

VAE 0.71 0.40β-VAE (β = 4.0) 0.72 0.71β-SVAE (β = 4.0) 0.47 0.19β-TCVAE (β = 4.0) 0.71 0.31HFVAE (β = 4.0, γ = 3.0) 0.78 0.74

Table 2: Disentanglement scores for the dSprites dataset[Higgins et al., 2016] using the metrics proposed by Kimand Mnih [2018] and Eastwood and Williams [2018].

2. Rather than a diagonal Gaussian prior p(z), we con-sider priors that combine discrete and continuous vari-ables and can incorporate parameters pθ(z) to learna covariance structure (for continuous variables) orconditional probabilities (for discrete variables).

6 Experiments

In order to assess how the HFVAE objective performsrelative to existing approaches, we evaluate a variety ofdatasets and tasks. We consider 4 datasets:

dSprites [Higgins et al., 2016]: 737,280 binary 64× 64images of 2D shapes with ground truth factors.

MNIST [LeCun et al., 2010]: 60000 gray-scale 32× 32images of handwritten digits.

Fashion-MNIST [Xiao et al., 2017]: 60000 gray-scale32× 32 images of clothing items divided in 10 classes.

CelebA [Liu et al., 2015]: 202,599 RGB 64 × 64 × 3images of celebrity faces.

On these datasets we compare a number of objectives andpriors. The objectives we compare are the standard VAEobjective [Kingma and Welling, 2013; Rezende et al.,2014], the β-VAE objective [Higgins et al., 2016], the

Page 8: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

Input β-VAE (β = 4) β-SVAE (β = 4) HFVAE (β = 12, γ = 4) HFVAE (from prior)

Figure 7: Manipulation of the thickness variable of the range -3 to 3. The β-VAE is not able to maintain digit identityas we vary thickness. The SVAE and HFVAE, which incorporate a discrete variable into the prior, are able to maintainthe digit identity across the entire range.

Orientation Smiling Sunglasses

HFV

AE

β-V

AE

z16 varying z15 varying z12 varying

Figure 8: Interpretable factors in CelebA for a HFVAE (β = 5.0, γ = 3.0) and a β-VAE (β = 8.0)

TC-VAE objective [Chen et al., 2018], and our HFVAEobjective. The priors p(z) that we employ are:

VAE, β-VAE, TCVAE: A diagonal Gaussian prior with10 dimensions (dSprites, MNIST, FMNIST), or 20 dimen-sions (CelebA).

β-SVAE, HFVAE: A mixed Gaussian discrete prior.MNIST, FMNIST: 1 10-dimensional Gaussian variable,1 discrete variable with 10 classes. dSprites: 1 10-dimensional Gaussian variable, 1 discrete variable with 3classes. CelebA: 1 20-dimensional Gaussian variable, 2Bernoulli variables.

We train using Adam with default parameters using batchsizes between 512 and 1024. Model architectures aredescribed in Appendix B.

6.1 Interpretability of Features

We begin with a qualitative evaluation of the featuresthat are identified when training with the HFVAE objec-tive. Figure 5 shows results for the MNIST data and theFMNIST data. For the MNIST data the representationrecovers 7 interpretable features. For the remaining 3features the mutual information term I(x; zd) decreasesto 0 as can be seen in Figure 6. Similarly for the CelebAdataset (see Figure 8) we uncover interpretable featuressuch as the orientation of the face, variation from smilingto non-smiling, and the presence of sunglasses.

6.2 Quantitative Metrics on dSprites

As a quantitative assessment of the quality of of learnedrepresentations, we evaluate the metrics proposed by Kimand Mnih [2018] and Eastwood and Williams [2018]. Forthe Eastwood and Williams metric, we used a randomforest as the regressor from features and ground-truth

Page 9: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

Pruning Generalization Pruning Generalization

3 2 1 0 1 20

20

40

60

80

100

3 2 1 0 1 20

100

200

300

2 0 20

25

50

75

100

2 0 20

100

200

300

Figure 9: Generalization to unseen combinations of factors. A HFVAE is trained on the full dataset, after a subset of thedata is pruned. We then test generalization of a model trained on the pruned dataset to the removed portion of the data.

factors. In Table 1, we list these metrics for each of themodel types and objectives defined above. The HFVAEoutperforms other approaches.

6.3 Unsupervised Learning of Discrete Labels

The ability of the model to disentangle discrete variablesdepends on how separable the classes of objects are fora given dataset. In the case of the Fashion-MNIST, theclasses (i.e. the clothing items) are distinctive enoughthat our model can capture it in the discrete variable. Onthe other hand, in the case of the dSprites dataset, theshapes are very similar, which makes it difficult for ourmodel to capture them effectively in a discrete variable.The MNIST dataset resides in the middle of this spectrum(For example, it is possible to convert a 9 into a 4 undersome continuous variation).

In figure 7, we have sampled one data point for eachdigit and vary its thickness variable for each of the mod-els - β−VAE, β-structured VAE, and HFVAE. Clearly, astructured VAE does a better job at disentangling digitvs thickness. The major advantage of HFVAE comparedto β-structured VAE is the ability to generate samplesfrom the prior which captures the digit. We found that β-structured VAE failed to disentangle the digit from style,and the samples generated from the discrete prior doesnot corresponds to a distinct digit.

6.4 Zero-shot Generalization

One of the desiderata of disentangled representations isthat they should not only capture distinct interpretablefactors of variation, but also allow generalization to pre-viously unseen combinations of features. For example,we can imagine a pink elephant even when we have notencountered such an object before.

To test whether our learned representations exhibit thesegeneralization properties, we perform the following exper-iment. We first train the HFVAE on the MNIST dataset.We then prune some fraction of the data. We will hereconsider the case where we prune 7s with a high strokethickness and 0s with a narrow character width. We thentrain a new HFVAE model on the the pruned dataset anduse the removed portion of the data as a test set to evaluategeneralization properties.

Figure 9 shows the results of this experiment. As we cansee, the model trained on pruned data is able to reconstructdigits with values for the stroke and character width thatwere never seen during training. The histograms for thefeature values show that the encoder model is similarlyable to extract features from previously unseen examples.

7 Discussion

Much of the work on learning disentangled representa-tions thus far has focused on cases where the distinct

Page 10: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

factors of variation take the form of a set of scalar vari-ables that are uncorrelated and for which supervisioncould in principle be obtained. As we begin to applythese techniques to real world datasets we are likely toencounter many forms of correlations between latent vari-ables, particularly when there are causal dependenciesbetween these variables. This work presents a first step inthe direction of enabling the learning of more structureddisentangled representations. By enforcing statistical in-dependence between groups of variables, we are now, inprinciple able to disentangle variables that have higher-dimensional representations. An avenue of future work isto develop datasets that allow us to more rigorously testour ability to extract such higher-dimensional variables.

ReferencesA. Achille and S. Soatto. Information Dropout: Learning Op-

timal Representations Through Noisy Computation. IEEETransactions on Pattern Analysis and Machine Intelligence,PP(99):1–1, 2018.

Alexander A. Alemi, Ian Fischer, Joshua V. Dillon, andKevin Murphy. Deep Variational Information Bottleneck.arXiv:1612.00410 [cs, math], December 2016.

Diane Bouchacourt, Ryota Tomioka, and Sebastian Nowozin.Multi-level variational autoencoder: Learning disentangledrepresentations from grouped observations. arXiv preprintarXiv:1705.08841, 2017.

Samuel R Bowman, Luke Vilnis, Oriol Vinyals, Andrew M Dai,Rafal Jozefowicz, and Samy Bengio. Generating sentencesfrom a continuous space. CoNLL 2016, page 10, 2016.

Tian Qi Chen, Xuechen Li, Roger Grosse, and David Duve-naud. Isolating sources of disentanglement in variationalautoencoders. arXiv preprint arXiv:1802.04942, 2018.

Xi Chen, Diederik P. Kingma, Tim Salimans, Yan Duan, Pra-fulla Dhariwal, John Schulman, Ilya Sutskever, and PieterAbbeel. Variational lossy autoencoder. arXiv preprintarXiv:1611.02731, 2016.

Cian Eastwood and Christopher K. I. Williams. A Frameworkfor the Quantitative Evaluation of Disentangled Representa-tions. In International Conference on Learning Representa-tions, February 2018.

Jesse Engel, Matthew Hoffman, and Adam Roberts. LatentConstraints: Learning to Generate Conditionally from Un-conditional Generative Models. arXiv:1711.05772 [cs, stat],November 2017.

Shuyang Gao, Rob Brekelmans, Greg Ver Steeg, and AramGalstyan. Auto-encoding total correlation explanation. arXivpreprint arXiv:1802.05822, 2018.

Leon A Gatys, Alexander S Ecker, and Matthias Bethge.A neural algorithm of artistic style. arXiv preprintarXiv:1508.06576, 2015.

Rafael Gomez-Bombarelli, Jennifer N Wei, David Duve-naud, Jose Miguel Hernandez-Lobato, Benjamın Sanchez-Lengeling, Dennis Sheberla, Jorge Aguilera-Iparraguirre,Timothy D Hirzel, Ryan P Adams, and Alan Aspuru-Guzik.Automatic chemical design using a data-driven continuousrepresentation of molecules. ACS Central Science, 2018.

Ishaan Gulrajani, Kundan Kumar, Faruk Ahmed, Adrien AliTaiga, Francesco Visin, David Vazquez, and Aaron Courville.PixelVAE: A latent variable model for natural images. InInternational Conference on Machine Learning, 2017.

Irina Higgins, Loic Matthey, Arka Pal, Christopher Burgess,Xavier Glorot, Matthew Botvinick, Shakir Mohamed, andAlexander Lerchner. beta-VAE: Learning basic visual con-cepts with a constrained variational framework. In Interna-tional Conference on Learning Representations, 2016.

Matthew D. Hoffman and Matthew J. Johnson. Elbo surgery:Yet another way to carve up the variational evidence lowerbound. In Workshop in Advances in Approximate BayesianInference, NIPS, 2016.

Matthew Johnson, David K Duvenaud, Alex Wiltschko, Ryan PAdams, and Sandeep R Datta. Composing graphical modelswith neural networks for structured representations and fastinference. In Advances in Neural Information ProcessingSystems, pages 2946–2954, 2016.

Hyunjik Kim and Andriy Mnih. Disentangling by factorising.arXiv preprint arXiv:1802.05983, 2018.

Diederik P. Kingma and Max Welling. Auto-encoding varia-tional bayes. In International Conference on Learning Rep-resentations, 2013.

Diederik P Kingma, Shakir Mohamed, Danilo Jimenez Rezende,and Max Welling. Semi-supervised learning with deep gener-ative models. In Advances in Neural Information ProcessingSystems, pages 3581–3589, 2014.

Tejas D Kulkarni, William F Whitney, Pushmeet Kohli, and JoshTenenbaum. Deep convolutional inverse graphics network. InAdvances in Neural Information Processing Systems, pages2539–2547, 2015.

Abhishek Kumar, Prasanna Sattigeri, and Avinash Balakrishnan.Variational inference of disentangled latent concepts fromunlabeled observations. arXiv preprint arXiv:1711.00848,2017.

Matt J Kusner, Brooks Paige, and Jose Miguel Hernandez-Lobato. Grammar variational autoencoder. In InternationalConference on Machine Learning, 2017.

Yann LeCun, Corinna Cortes, and CJ Burges. Mnist handwrittendigit database. AT&T Labs [Online]. Available: http://yann.lecun. com/exdb/mnist, 2, 2010.

Dawen Liang, Rahul G. Krishnan, Matthew D. Hoffman, andTony Jebara. Variational Autoencoders for CollaborativeFiltering. arXiv:1802.05814 [cs, stat], February 2018.

Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deeplearning face attributes in the wild. In Proceedings of theIEEE International Conference on Computer Vision, pages3730–3738, 2015.

Aaron van den Oord, Sander Dieleman, Heiga Zen, Karen Si-monyan, Oriol Vinyals, Alex Graves, Nal Kalchbrenner, An-drew Senior, and Koray Kavukcuoglu. Wavenet: A generativemodel for raw audio. arXiv preprint arXiv:1609.03499, 2016.

Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra.Stochastic backpropagation and approximate inference indeep generative models. In Proceedings of The 31st Interna-tional Conference on Machine Learning, pages 1278–1286,2014.

N Siddharth, Brooks Paige, Jan-Willem Van de Meent, AlbanDesmaison, Frank Wood, Noah D Goodman, Pushmeet Kohli,

Page 11: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

and Philip HS Torr. Learning disentangled representationswith semi-supervised deep generative models. In Advancesin Neural Information Processing Systems, 2017.

Lucas Theis, Wenzhe Shi, Andrew Cunningham, and FerencHuszar. Lossy image compression with compressive autoen-coders. arXiv preprint arXiv:1703.00395, 2017.

Han Xiao, Kashif Rasul, and Roland Vollgraf. Fashion-mnist:a novel image dataset for benchmarking machine learningalgorithms, 2017.

Weidi Xu, Haoze Sun, Chao Deng, and Ying Tan. Variationalautoencoder for semi-supervised text classification. In AAAI,pages 3358–3364, 2017.

Shengjia Zhao, Jiaming Song, and Stefano Ermon. InfoVAE:Information maximizing variational autoencoders. arXivpreprint arXiv:1706.02262, 2017.

Page 12: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

A Appendix

A.1 Approximating the Inference Marginal

We will here derive a Monte Carlo estimator for the entropy of the marginal qφ(z) of the inference model

Hφ[z] = −Eqφ(z) [log qφ(z)] . (14)

As with other terms in the objective, we can approximate this expectation by sampling z(b) ∼ qφ(z) using,

x(b) ∼ q(x), b = 1, . . . , B, (15)

z(b) ∼ qφ(z | x(b)). (16)

We now additionally need to approximate the values,

log qφ(z(b)) = log

[1

N

N∑n=1

qφ(z(b) | xn)

]. (17)

We will do so by pulling the term for which xn = x(b) out of the sum

qφ(z(b)) =

1

Nqφ(z

(b) | x(b)) +1

N

∑xn 6=x(b)

qφ(z(b) | xn).

As also noted by Chen et al. [2018], the intuition behind this decomposition is that qφ(z(b) | x(b)) will in general bemuch larger than qφ(z(b) | xn).

We can approximate the second term using a Monte Carlo estimate from samples x(b,c) ∼ q(x | x 6= x(b)),

1

N − 1

∑xn 6=x(b)

qφ(z(b) | xn) ' 1

C

C∑c=1

qφ(z(b) | x(b,c)).

Note here that we have written 1/(N − 1) instead of 1/N in order to ensure that the sum defines an expected valueover the distribution q(x | x 6= x(b)).

In practice, we can replace the samples x(b,c) with the samples b′ 6= b from the original batch, which yields an estimatorover C = B − 1 samples

q(z(b)) =1

Nqφ(z

(b) | x(b)) +N − 1

N(B − 1)

∑b′ 6=b

qφ(z(b) | x(b′)).

Note that this estimator is unbiased, which is to say that

E[q(z(b))] = q(z(b)). (18)

In order to compute the entropy, we now define an estimator Hφ(z), which defines a upper bound on Hφ(z)

Hφ[z] ' −1

B

B∑b=1

log qφ(z(b)) ≥ Hφ[z]. (19)

The upper bound relationship follows from Jensen’s inequality which states that

E[log qφ(z)] ≤ logE[qφ(z)] = log qφ(z). (20)

Page 13: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

B Model Architectures

We used two hidden variables for each of the datasets. One variable is modeled as a Normal distribution which representstyle (denoted as zn), and one modeled as a Concrete distribution to detect categories (denoted as zc). We used Adamoptimizer with learning rate 1e-3 and default settings.

Encoder DecoderInput 28× 28 grayscale image Input zn ∈ R10, zc ∈ (0, 1)10

FC. 400 ReLU FC. 200 ReLUFC. 2× 200 ReLU, FC. 10 (zc) FC. 400 ReLUFC. 2× 10 (zn) FC. 28× 28 Sigmoid

Table 3: Encoder and Decoder architecture for MNIST and Fashion MNIST data.

Encoder DecoderInput 64× 64 binary image Input zn ∈ R10, zc ∈ (0, 1)3

FC. 1200 ReLU FC. 400 TanhFC. 1200 ReLU FC. 1200 TanhFC. 2× 400 ReLU, FC. 3 (zc) FC. 1200 TanhFC. 2× 10 (zn) FC. 64× 64 Sigmoid

Table 4: Encoder and Decoder architecture for dSprite data.

Encoder DecoderInput 64× 64 RGB image Input zn ∈ R20, zc ∈ {0, 1}24× 4 conv, 32 BatchNorm ReLU, stride 2 FC. 256 ReLU4× 4 conv, 32 BatchNorm ReLU, stride 2 FC. (4× 4× 64) Tanh4× 4 conv, 64 BatchNorm ReLU, stride 2 4× 4 upconv, 64 BatchNorm ReLU, stride 24× 4 conv, 64 BatchNorm ReLU, stride 2 4× 4 upconv, 32 BatchNorm ReLU, stride 2FC. 2× 256 ReLU, 2 FC. (zc) 4× 4 upconv, 32 BatchNorm ReLU, stride 2FC. 2× 20 ReLU 4× 4 upconv, 3, stride 2

Table 5: Encoder and Decoder architecture for CelebA data.

Page 14: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

C Latent Traversals

Figure 10: Qualitative results for disentanglement in MNIST dataset. In each case, one particular zd is varying from -3to 3 while the others are fixed at 0.

Page 15: Hierarchical Disentangled RepresentationsHierarchical Disentangled Representations Babak Esmaeili Northeastern University esmaeili.b@husky.neu.edu Hao Wu Northeastern University wu.hao10@husky.neu.edu

Figure 11: Qualitative results for disentanglement in Fashion-MNIST dataset. In each case, one particular zd is varyingfrom -3 to 3 while the others are fixed at 0.