Variational Autoencoder with Planar Flows

$$ \newcommand{\vect}[1]{\boldsymbol{\mathbf{#1}}} \newcommand{\vx}{\vect{x}} \newcommand{\vz}{\vect{z}} \newcommand{\vphi}{\vect{\phi}} \newcommand{\vtheta}{\vect{\theta}} \newcommand{\vmu}{\vect{\mu}} \newcommand{\vsigma}{\vect{\sigma}} \newcommand{\N}{\mathcal{N}} \newcommand{\encoder}{q_{\vphi}(\vz \mid \vx)} \newcommand{\vepsilon}{\vect{\epsilon}} \newcommand{\snd}{\N(\vect{0}, \vect{I})} \newcommand{\muz}{\vmu_{\vphi}(\vx)} \newcommand{\sigmaz}{\vsigma^2_{\vphi}(\vx)} \newcommand{\elbo}{\mathcal{L}_{\vphi, \vtheta}(\vx)} \newcommand{\Ebb}{\mathbb{E}} \newcommand{\eencoder}[1]{\Ebb_{\encoder}\left[#1\right]} \newcommand{\decoder}{p_{\vtheta}(\vx \mid \vz)} \newcommand{\kl}[2]{\text{KL}\left(#1 \parallel #2\right)} \newcommand{\prior}{p(\vz)} \newcommand{\vlambda}{\vect{\lambda}} \newcommand{\vw}{\vect{w}} \newcommand{\vu}{\vect{u}} \newcommand{\Eqk}[1]{\Ebb_{q_K(\vz_K)}\left[#1\right]} \newcommand{\vuhat}{\widehat{\vu}} $$

Formulas

The Planar Flow that we are going to use is given by

$$ f(\vz_k) = \vu_k \tanh(\vw^\top_k \vz_{k-1} + b_k) $$

where $\vlambda_k := \left\{\vw_k, \vu_k, b_k\right\}$. It is shown in the paper that its LADJ for $K$ transformations is as follows

$$ \text{LADJ} = -\sum^K_{k=1} \log |1 + \vu^\top_k(1 - \tanh^2(\vw^\top_k \vz_{k-1} + b_k))\vw_k|. $$

However, to make sure that the transformations are actually invertible we need to replace the $\vu_k$ with the following vectors

$$ \vuhat_k = \vu_k + [-1 + \log(1 + e^{\vw^\top_k\vu_k}) - \vw^\top_k\vu_k]\frac{\vw_k}{\parallel \vw_k\parallel^2} $$

where $\vu_k$ is outputted by our encoder neural network. Overall, the objective function becomes

$$ \begin{align} -\elbo &= -\left[\sum^{\text{dim}(\vx)}_{i=1} x_i\log p_i(\vz_K) + (1 - x_i)\log(1 - p_i(\vz_K)) \right] \newline &\quad -\frac{\text{dim}(\vz)}{2}\log(2\pi) - \frac{1}{2}\sum^{\text{dim}(\vz)}_{i=1} \log \sigma_i - \frac{1}{2}(\vx - \vmu)^\top \text{Diag}\left(\frac{1}{\vsigma^2}\right)(\vx - \vmu) \newline &\quad -\sum^K_{k=1} \log |1 + \vuhat^\top_k(1 - \tanh^2(\vw^\top_k \vz_{k-1} + b_k))\vw_k| \newline &\quad -\frac{\text{dim}(\vz)}{2}\log(2\pi) - \frac{1}{2}\vx^\top\vx \end{align} $$

Coding

In Python (using Pytorch) we can code this as follows.

import numpy as np                             
import torch                                   
import torchvision                            
import torchvision.transforms as transforms    
from torch.utils.data import DataLoader        
from torchvision.utils import make_grid       
from torchvision.datasets import MNIST       
import matplotlib.pyplot as plt               
import torch.nn as nn                         
import torch.nn.functional as F                
import torch.optim as optim                   
from torch.distributions import MultivariateNormal
from math import log, pi
from matplotlib.colors import ListedColormap

# Neural Network Architecture               
e_hidden = 500       
d_hidden = 500        
latent_dim = 2
K = 3                    # Normalizing Flow depth

# Optimizer 
learning_rate = 0.001 
weight_decay = 1e-5 

# Learning
epochs = 100
batch_size = 100

# Use GPU/CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Prepare data
t = transforms.Compose([transforms.ToTensor()])

# Use transformation for both training and test set
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=t)
testset  = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=t)

# Load train and test set
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader  = DataLoader(testset, batch_size=batch_size, shuffle=True)

## Main VAE class
class VAE_PF_amortized(nn.Module):
    def __init__(self, e_hidden, d_hidden, latent_dim, K=K):
        """VAE with Planar Flows - Define layers of the architecture."""
        super(VAE_PF_amortized, self).__init__()
        # Encoding Layers
        self.e_input2hidden = nn.Linear(in_features=784, out_features=e_hidden)
        
        # mu and log var for q_0 (latent space)
        self.e_hidden2mean = nn.Linear(in_features=e_hidden, out_features=latent_dim)
        self.e_hidden2logvar = nn.Linear(in_features=e_hidden, out_features=latent_dim)
        
        # w, u, b for normalizing flows
        self.e_hidden2w = nn.Linear(in_features=e_hidden, out_features=latent_dim*K) 
        self.e_hidden2u = nn.Linear(in_features=e_hidden, out_features=latent_dim*K) 
        self.e_hidden2b = nn.Linear(in_features=e_hidden, out_features=K)
        
        # Decoding Layers
        self.d_latent2hidden = nn.Linear(in_features=latent_dim, out_features=d_hidden)
        self.d_hidden2image = nn.Linear(in_features=d_hidden, out_features=784)
        
        # Store setting
        self.K = K
        self.latent_dim = latent_dim
        
    def encode(self, x):
        """Maps a data batch (batch_size, 784) to q_0 parameters (mu, logvar) and planar flow 
        parameters (w, u, b)."""
        x = F.relu(self.e_input2hidden(x))
        
        # mu, sigma for latent space
        mu, logvar = self.e_hidden2mean(x), self.e_hidden2logvar(x)
        
        # parameters for normalizing flow
        w, u, b = self.e_hidden2w(x), self.e_hidden2u(x), self.e_hidden2b(x)
        
        # Reshape to facilitate dot products later on
        batch_size = x.size(0)
        w = w.view(batch_size, self.K, 1, latent_dim)
        u = u.view(batch_size, self.K, latent_dim, 1)
        b = b.view(batch_size, self.K, 1, 1)
        return mu, logvar, w, u, b
    
    def flow(self, z0, w, u, b):
        """Describes how a latent sample z_0 ~ N(mu, logvar) gets transformed by a sequence of K planar
        flows into z_K."""
        # Compute batch_size so that this works also at test time
        bs = z0.size(0)     # batch size. At training time is batch_size, at test time number of images
        
        z_k = z0                        # (batch_size, 2, 1)               
        ladj_sum = torch.zeros((bs, 1)) # (batch_size, 1)
         
        for k in range(self.K):
            # Grab parameters for this flow
            w_k = w[:, k, :, :]   # (batch_size, 1, 2)
            u_k = u[:, k, :, :]   # (batch_size, 2, 1)
            b_k = b[:, k, :, :]   # (batch_size, 1, 1)
                        
            # Compute uhat to make f() invertible
            uw = torch.bmm(w_k, u_k)   # (batch_size, 1, 1)
            m_uw = -1 + F.softplus(uw) # size (batch_size, 1, 1)
            uhat_k = u_k + ((m_uw - uw)* w_k.transpose(2, 1) / (torch.norm(w_k, dim=2, keepdim=True)**2))

            # Compute z_{k+1} = f(z_k)
            wz_plus_b = torch.bmm(w_k, z_k) + b_k
            z_k_plus_1 = (z_k + uhat_k * torch.tanh(wz_plus_b)).squeeze(2) # (batch_size, 2, 1) --> (batch_size, 2)
            #print("z_{k+1}: ", z_k_plus_1.size())
            # Compute Log-Absolute-Determinant-Jacobian & add it to running sum
            h_prime = (1 - torch.tanh(wz_plus_b)**2)
            ladj = -(1 + torch.bmm(h_prime*w_k, uhat_k)).abs().add(1e-8).log().squeeze(2)
            ladj_sum += ladj

            # Set z_k <- z_{k+1}
            zk = z_k_plus_1
        
        z0, z_k = z0.squeeze(2), z_k.squeeze(2)
        return z0, zk, ladj_sum
    
    def decode(self, z0):
        """Decodes a latent sample z0 by first feeding through the Planar Flows to obtain z_K and
        then feeding it through the decoder NN to obtain the mean reconstruction.
        NOTE: This should only be used at TEST TIME not at training time."""
        # Need (w, u, b) for the flow but when generating images. Samples them randomly
        bs = z0.size(0)   # Batch size for test set
        u = torch.randn((bs, 3, self.latent_dim, 1))
        w = torch.randn((bs, 3, 1, self.latent_dim))
        b = torch.randn((bs, 3, 1, 1))
        
        # Transform via a Normalizing Flow
        z0, z_k, ladj_sum = self.flow(z0.unsqueeze(2), w, u, b)
        
        # Decode z_K to a mean reconstruction
        return torch.sigmoid(self.d_hidden2image(torch.relu(self.d_latent2hidden(z_k))))
        
        
    def forward(self, x):
        """Describes the forward process of VAE+PF."""
        # Shape Flatten image to [batch_size, input_features]
        x = x.view(-1, 784)
        
        mu, logvar, w, u, b = self.encode(x)  # (batch_size, latent_dim)
        
        # Sample z0 from latent space using mu and logvar. Will have dimensions (batch_size, latent_dim, 1)
        if self.training:
            z0 = torch.randn_like(mu).mul(torch.exp(0.5*logvar)).add_(mu).unsqueeze(2) 
        else:
            z0 = mu.unsqueeze(2)
        
        # Feed z_0 through NF to get z_K
        z0, z_k, ladj_sum = self.flow(z0, w, u, b) # (batch_size, 2), (batch_size, 2), (batch_size, 1)
        
        # Feed z_K through Decoder to get mean reconstruction
        recon = torch.sigmoid(self.d_hidden2image(torch.relu(self.d_latent2hidden(z_k))))
        
        
        # KL = log_q0 + LADJ - log_pK
        log_q0 = -log(2*pi) -logvar.add(1e-8).sum(dim=1, keepdim=True) - torch.bmm(((z0 - mu).add(1e-8)/(2*torch.exp(logvar))).unsqueeze(1), (z0-mu).unsqueeze(2)).squeeze(2)
        # Batch multiplication: (batch_size, 1, 2) and (batch_size, 2, 1) gives (batch_size, 1, 1).
        log_pK = -log(2*pi) -0.5*torch.bmm(z_k.unsqueeze(1), z_k.unsqueeze(2)).squeeze(2).add(1e-8)  # (batch_size, 1)
        # return reconstruction and terms for kl divergence
        return recon, log_q0, ladj_sum, log_pK
Previous