Improving a Variational Autoencoder with Normalizing Flows

In order to fully grasp the concepts explained here, I strongly recommend you to read my three posts on Variational Autoencoders (in the following order)


Theory of Vanilla VAEs

Recall that in a Vanilla VAE we feed x into an encoder neural network and obtain (μ,logσ). These are the amortized parameters of our approximate posterior distribution

qϕ(zx)=N(zμϕ(x),diag(σϕ2(x)))

To get a latent sample zqϕ(zx) we need to use the reparametrization trick. This requires sampling ϵN(0,I) and then scaling it and shifting it according to the output of the neural network

z=μϕ(x)+σϕ2(x)ϵ.

To learn the parameters of our neural network our aim is to maximize the ELBO

Lϕ,θ(x)=Eqϕ(zx)[logpθ(xz)]KL(qϕ(zx)p(z))

The reconstruction error (the first term) is easy to compute in the Normal and Bernoulli case. In what follows, we will assume that the likelihood is a product of Bernoullis. This is the usual set-up when working with MNIST. The likelihood is then

pθ(xz)=i=1dim(x)pixi(1pi)1xi

where p=(p1,,pdim(x)) is the output of the decoder network, and p[0,1]dim(x). The reconstruction error can then be written as

Eqϕ(zx)[logpθ(xz)]=Eqϕ(zx)[logi=1dim(x)pi(z)xi(1pi(z))1xi]=Eqϕ(zx)[i=1dim(x)xilogpi(z)+(1xi)log(1pi(z))]j=1nzi=1dim(x)xilogpi(z)+(1xi)log(1pi(z))z(j)qϕ(zx)

where nz is the number of samples that we sample from qϕ(zx). Usually we simply set nz=1. This means we only sample one latent variable per datapoint. The objective function to minimize (I have flipped the sign) is therefore

Lϕ,θ(x)=i=1dim(x)xilogpi(z)+(1xi)log(1pi(z))12j=1dim(z)(1+logσj2μj2σj2)=BCE(p,x)12j=1dim(z)(1+logσj2μj2σj2)

Using Pytorch we can code it as

def vae_loss(image, reconstruction, mu, logvar):
  """Loss for the Variational AutoEncoder."""
  # Compute the binary_crossentropy.
  recon_loss = F.binary_cross_entropy(
      input=reconstruction.view(-1, 28*28),    # input is p(z) (the mean reconstruction)
      target=image.view(-1, 28*28),            # target is x   (the true image)
      reduction='sum'                          
  )
  # Compute KL divergence using formula (closed-form)
  kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  return reconstruction_loss + kl

VAE with Normalizing Flows

This time, we not only want our encoder to output (μ,logσ) to shift and scale ϵN(0,I). We also want to feed

N(zμϕ(x),diag(σϕ2(x)))

through a series of K transformations each one of them depending on a set of parameters λk. Denoting λ=(λ1,,λK) we essentially want our Encoder to work as follows:

xEncoder(μ,logσ,λ1,,λK)=(μ,logσ,λ)

Then we would firstly use (μ,logσ) to compute z0 using the reparametrization trick z0=μ+σϵϵN(0,I)

and finally we would feed z0 into the series of transformations to reach the final zK

zK=fKfK1f2f1(z0).

This means that our approximating distribution is not

qϕ(zx)=N(zμϕ(x),diag(σϕ2(x)))

anymore but, rather, it can be found using the usual change of variable formula

logqϕ(zx)=logqK(zK)=logq0(z0)k=1Kln|detfkzk1|

where the base distribution q0(z0) is the old approximate posterior distribution q0(z0)=N(z0μϕ(x),diag(σϕ2(x))).

Thanks to the law of the unconscious statistician we have

Lϕ,θ(x)=Eqϕ(zx)[logpθ(xz)]KL(qϕ(zx)p(z))=EqK(zK)[logpθ(xzK)]EqK(zK)[logqK(zK)logp(zK)]=Eq0(z0)[logpθ(xzK)]Eq0(z0)[logqK(zK)logp(zK)]

As usual, we can approximate this using Monte Carlo and generally we only need one sample. By drawing z0q0(z0)=N(μ,diag(σ)) we can approximate the ELBO as follows

Lϕ,θ(x)[i=1dim(x)xilogpi(zK)+(1xi)log(1pi(zK))]logqK(zK)+logp(zK).

This means that our objective function is given by Lϕ,θ(x)=BCE(p,x)+logq0(z0)+LADJlogp(zK)

where the Log-Absolute-Determinant-Jacobian is the usual LADJ=k=1Kln|detfkzk1|

Previous
Next