Introduction

In the previous post we motivated the idea of a Variational Auto Encoder. Here we will have a go at implementing a very simple model to get a sense of all the components steps involved. This is quite a long post but much of it is just code. All the code in this post can be found in this Colab notebook.

import torch
import math
from torch import nn
import tqdm.notebook as tqdm
from easydict import EasyDict

Data

Here is a simple 2d data distribution

\[\mathbf{x} = \left[x_1, x_2\right]^T\] \[\mathbf{x} \sim N(\mu, \Sigma)\]

where $\Sigma$ is a full covariance matrix, meaning that it can have non-zero off-diagonal entries.

import numpy as np
import matplotlib.pyplot as plt

mu = np.stack([1., 0.5])
cov = np.stack([[1, 0.5],
              [0.5, 1]])

import numpy as np

def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

assert is_pos_def(cov)

x = np.random.multivariate_normal(
    mu, cov,
    size=10000
)
plt.figure(figsize=(8, 8))
plt.hist2d(x[:, 0], x[:, 1], bins=100, cmap='Blues');
plt.title('samples $\mathbf{x}^{(i)} \sim N(\mu, \Sigma)$');

png

However we assume that conditioned on the latent variable $z$, $\mathbf{x}$ has a simpler diagonal covariance, in other words that $x_1$ and $x_2$ are independent conditioned on $z$.

\[\mathbf{z} \sim \mathcal{N}(0, I)\] \[\mathbf{x} \sim \mathcal{N}(\boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma})))\]
plt.figure(figsize=(8, 6))
plt.hist(np.random.normal(size=[10000]), bins=100)
plt.title('samples $z^{(i)} \sim \mathcal{N}(0, I)$');

png

For $q(\mathbf{z} \vert \theta)$ we choose a normal distribution with a diagonal covariance which means the KL-divergence term in the loss can be found exactly.

Neural network representation

We will use a simple feed-forward network to represent each of $ \mathcal{N}(\mu’, \sigma’^2I \vert \mathbf{z})$ and $q(\mathbf{z} \vert \theta)$. This is the formulation from the VAE paper where the parameters are obtained via 2-layer MLP. Note that the log of the variance is predicted which keeps the output within a smaller range of values compared to predicting the raw variance.

\[\log p_{\theta}\left (\mathbf{x} \vert \mathbf{z} \right) = \log \mathcal{N}(\mu, \sigma^2I \vert \mathbf{z})\] \[\mathbf{h} = \text{Dense}(\mathbf{x})\] \[\left[\log \sigma^2;\mu\right] = \text{Dense}(\tanh(\mathbf{h}))\]

Similarly for $\mathbf{z}$

\[\log q_{\phi}\left(\mathbf{z} \vert \mathbf{x} \right) = \log \mathcal{N}(\mu, \sigma^2I \vert \mathbf{x})\] \[\mathbf{h} = \text{Dense}(\mathbf{x})\] \[\left[\log \sigma^2;\mu\right] = \text{Dense}(\tanh(\mathbf{h}))\]
class GaussianMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(GaussianMLP, self).__init__()
        self.shared = nn.Sequential(
            nn.Linear(in_features=input_dim,
                                out_features=hidden_dim),
            nn.Tanh()
        ) 
        self.param_layer = nn.Linear(in_features=hidden_dim,
                                     out_features=latent_dim * 2)
        
    def forward(self, inputs):
        shared = self.shared(inputs)
        params = self.param_layer(shared)
        return params

First let us create a model comprising these two distributions where $p_{\theta}\left (\mathbf{x} \vert \mathbf{z} \right)$ is called encoder and $q_{\phi}\left(\mathbf{z} \vert \mathbf{x} \right)$ is called decoder.

class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder=encoder
        self.decoder=decoder

Now let us add functions to run the encoding

def encode(self, inputs):
    params_z = self.encoder(inputs)
    mu_z = params_z[:, :(params_z.size(1)//2)]
    log_sig2_z = params_z[:, (params_z.size(1)//2):]
    return mu_z, log_sig2_z

VAE.encode = encode

And decoding

def decode(self, latents):
    params_x = self.decoder(latents)
    mu_x = params_x[:, :(params_x.size(1)//2)]
    log_sig2_x = params_x[:, (params_x.size(1)//2):]
    return mu_x, log_sig2_x

VAE.decode = decode

Training

  • Sample training examples from the ground truth distribution $\mathbf{x}^{(i)} \sim p(\mathbf{x})$
  • Predict parameters for $q_\phi(\mathbf{z} \vert \mathbf{x})$
  • Sample latents from the approximate distribtion $\mathbf{z}^{(i)} \sim q_\phi(\mathbf{z} \vert \theta)$
  • Predict parameters for $p_\theta(\mathbf{x} \vert \mathbf{z})$
  • Calculate the loss and run backward pass

Since we are predicting the parameters of a distribution, to calculate the loss we don’t need to sample from $p_\theta(\mathbf{x} \vert \mathbf{z})$. We only need to find the pdf of the distribution which can be done e.g. via torch.distributions.Normal.

def forward(self, inputs, noise):
        # inputs: [B, H, W, F]
        # noise: [N, B, Z]
        mu_z, log_sig2_z = self.encode(inputs)
        # [N, B, Z]
        latents = sample_latents(mu_z, log_sig2_z, noise)
        # [N * B, Z]
        mu_x, log_sig_x = self.decode(latents.reshape(-1, latents.size(-1)))
        # [N, B, Z], [B, Z], [B, Z], [N, B, H, W, F] 
        return (latents, mu_z, log_sig2_z, 
                mu_x.reshape(-1, *inputs.size()),
                log_sig_x.reshape(-1, *inputs.size()))
VAE.forward = forward

Reparameterization Trick

Now at this point you might have some questions

  1. Where did this noise come from?
  2. What is the sample_latents function doing?

Now we will go into more details about the mathematics behind all this in a subsequent post (coming soon) but in simple terms consider this diagram of the steps involved. Remember that for the loss here we don’t need to sample from $p_\theta(\mathbf{x} \vert \mathbf{z})$ so the graph stops after predicting the parameters of this distribution.

Figure showing VAE steps WITHOUT reparameterization trick

We can’t differentiate through the sampling step shown in blue. To get around that the authors of the paper use a reparameterization trick whereby you generate noise samples $\mathbf{\epsilon^{(i)}}$ e.g. from $\mathcal{N}(0, I)$, transform them to $\mathbf{z}^{(i)}$ and input these to the decoder. The transformation is done in a differentiable way. For a normal distribution you can convert samples $\boldsymbol{\epsilon} \sim \mathcal{N}(0, I)$ to samples $\mathbf{z}^{(i)} \sim \mathcal{N}(\boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma})))$ by transforming each element $z_i$ as follows

\[z_i = \sigma_i \cdot \epsilon_i + \mu_i\]

Now we can differentiate through the blue part in the figure below.

Figure showing VAE steps WITH reparameterization trick

The sample_latents functions simply transforms the noise into latents.

def sample_latents(mu, log_var, noise):
    return mu + torch.exp(log_var / 2) * noise

Loss

Finally let us add a loss function that calls forward and then calculates the loss given by this expression we saw earlier.

\[L(\theta, \phi) = -D_{KL}\left( q_\phi(\mathbf{z} \vert \mathbf{x})\Vert p_\theta(\mathbf{z})\right) + \mathbb{E}_{q_\phi(\mathbf{z} \vert \mathbf{x})} \left[\log p_\theta(\mathbf{z}\vert \mathbf{x}) \right]\]

Often we can’t calculate this exactly so we have to approximate it. For the log likelihood term the approximation is as follows for a single $\mathbf{x}^{(i)}$:

\[\mathbb{E}_{q_\phi(\mathbf{z} \vert \mathbf{x}^{(i)})} \left[\log p_\theta(\mathbf{z}\vert \mathbf{x}^{(i)}) \right] \approx \frac{1}{T}\sum_{t=1}^T\log p_\theta(\mathbf{x}^{(i)} \vert \mathbf{z}^{(i, t)})\]

You might have noted that the noise input has shape [N, B, Z]. That is because we sample an $\boldsymbol{\epsilon}^{(i, t)}$ corresponding to each $\mathbf{z}^{(i, t)}$.

It turns out that because both $p_\theta$ and $q_\phi$ are normal distributions, the KL divergence term becomes a simple expression that can be calculated exactly. Here is the expression for a $J$ dimensional latent

\[-D_{KL}\left( q_\phi(\mathbf{z} \vert \mathbf{x}^{(i)})\Vert p_\theta(\mathbf{z})\right) = \frac{1}{2}\sum_{j=1}^J\left(1 + \log((\sigma_j^{(i)})^2) - (\mu_j^{(i)})^2 - (\sigma_j^{(i)})^2\right)\]

We derive all of this in a future post.

def loss(self, x, noise):
        z, mu_z, log_sig2_z, mu_x, log_sig2_x = self(x, noise)
        KL = -(1 + log_sig2_z - mu_z**2 - torch.exp(log_sig2_z)) / 2
        KL_term = KL.sum(dim=-1).mean()
        recon = -torch.distributions.Normal(mu_x, torch.exp(log_sig2_x/2)).log_prob(x)
        recon = recon.sum(dim=-1).mean()
        return KL_term + recon, recon, KL_term
    
VAE.loss = loss

Now we can combine all these parts into a training step

def train_step(cfg, vae, optim, batch):
    optim.zero_grad()
    noise = torch.randn([cfg.num_mc, batch.size(0), cfg.latent_dim])
    loss, recon_loss, KL_term = vae.loss(batch, noise)
    loss.backward()
    optim.step()
    return loss.detach().numpy(), KL_term.detach().numpy(), recon_loss.detach().numpy()

We can write a similar function to calculate losses during validation. Here we use the same noise values each time.

def run_test(data, model, batch_size, latent_size, num_mc, noise):
    # noise: [T, N, B, L]
    model.eval()
    num_iters = math.ceil(len(data) / batch_size)
    losses = []
    counts = []
    for idx in tqdm.trange(num_iters):
        slc = slice(idx * batch_size, (idx + 1) * batch_size)
        batch = torch.from_numpy(data[slc])
        loss = model.loss(batch, noise[idx, :, :batch.size(0)])
        losses.append([l.detach().numpy() for l in loss])
        counts.append(len(batch))
    return np.sum(losses * np.stack(counts)[:, None], axis=0) / np.sum(counts)

Inference

  • Sample latents from the prior $z^{(i)} \sim \mathcal{N}(0, I)$
  • Predict parameters for $p_\theta(\mathbf{x} \vert \mathbf{z})$
  • Sample from $p_\theta(\mathbf{x} \vert \mathbf{z})$

Note we might also just choose to return $\mu_{\mathbf{x}\vert\mathbf{z}}$ instead of sampling from $p_\theta(\mathbf{x} \vert \mathbf{z})$. This tends to be the case when dealing with high dimensional data like images where the samples tend be too noisy.

def generate_samples(cfg, vae, n_samples, noise=False):
    vae.eval()
    with torch.no_grad():
        z = torch.randn([n_samples, cfg.latent_dim])
        mu_x, log_sig_x = vae.decode(z)
        

        if noise:
            # 1 sample per latent
            return torch.distributions.Normal(mu_x, torch.exp(log_sig_x /2)).sample().cpu().numpy()

    return mu_x.cpu().numpy()

Putting it together

Let us write a run_model function that sets up the models and goes through all the steps described above.

def run_model(cfg, train_data, test_data):
    
    # Initialise encoder and decoder
    encoder = GaussianMLP(input_dim=cfg.input_dim,
                               hidden_dim=cfg.hidden_dim,
                               latent_dim=cfg.latent_dim,
                               )
    decoder = GaussianMLP(input_dim=cfg.latent_dim,
                               hidden_dim=cfg.hidden_dim,
                               latent_dim=cfg.input_dim,
                               )
    
    # Initialise VAE 
    vae = VAE(encoder, decoder)
    
    # Intialise optimizer
    optim = torch.optim.Adam(params=vae.parameters(), lr=1e-3)
    
    # Initialise other values and data structures
    batch_size = cfg.batch_size
    
    num_train_iters = math.ceil(len(train_data) / batch_size)
    
    
    train_losses, test_losses, num = [], [], []
    
    train_data = train_data
    test_data = test_data
    
    # Noise to be reused for testing
    noise_val = torch.randn(
        [math.ceil(len(test_data) / batch_size), 
         cfg.num_mc, batch_size, cfg.latent_dim]
    )
    
    # Run a forward pass on the uninitialised model
    test_losses.append(run_test(test_data, vae, batch_size, cfg.latent_dim, cfg.num_mc, noise_val))
    
    print(f"Initial Test loss: {test_losses[-1][0]:.4f}")
    
    # For visualisation we might want to save samples after each step
    if cfg.sample_each_step:
        samples_per_step = []
    
    
    for epoch in range(cfg.num_epochs):
        vae.train()
        
        # Shuffle data each epoch
        train_data = train_data[np.random.permutation(len(train_data))]
        
        
        
        # Train the model for each batch, saving losses and if sample_each_step is True, samples
        for i_trn in tqdm.trange(num_train_iters):
            slc = slice(i_trn * batch_size, (i_trn + 1) * batch_size)
            batch = torch.from_numpy(train_data[slc])
            losses = train_step(cfg, vae, optim, batch)
            train_losses.append(losses)
            num.append(batch.size(0))
            if cfg.sample_each_step:
                samples_per_step.append(generate_samples(cfg, vae, n_samples=cfg.n_samples, noise=True))
            
        # Run validation
        test_losses.append(run_test(test_data, vae, batch_size, cfg.latent_dim, cfg.num_mc, noise_val))
        
        print(
            f"Epoch {epoch}",
            f"Training loss: {np.sum([l[0] * b for l, b in zip(train_losses[-num_train_iters:], num[-num_train_iters:])] / np.sum(num[-num_train_iters:]) ):.4f},",
            f"Test loss: {test_losses[-1][0]:.4f}"
        )
    
    # Generate some samples with and without noise
    samples_no_noise = generate_samples(cfg, vae, n_samples=cfg.n_samples, noise=False)
    samples_with_noise = generate_samples(cfg, vae, n_samples=cfg.n_samples, noise=True)
    result = (np.stack(train_losses), np.stack(test_losses),  samples_with_noise, samples_no_noise)
    return result + (samples_per_step,) if cfg.sample_each_step else result

Now get hold of a bunch of samples.

dataset = np.random.multivariate_normal(
    np.array(mu),
    np.array(cov),
    size=12500,
).astype('float32')

trn_data, tst_data = np.split(dataset, [10000])

Define a few settings

config = EasyDict(
        input_dim=2,
        num_mc=1,
        print_interval=100,
        batch_size=128,
        hidden_dim=512,
        latent_dim=2,
        n_samples=1000,
        num_epochs=1,
        sample_each_step=True
    )

Finally we are ready to run

result = run_model(config, trn_data, tst_data)

Visualising the results

from scipy.stats import multivariate_normal

Let us generate some samples from the true distribution and also plot the contours of the probability density function.

mvn = multivariate_normal(mu, cov)
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
xx, yy = np.meshgrid(np.linspace(-3, 5, 100), np.linspace(-3, 7, 100))
p = mvn.pdf(np.stack([xx, yy], axis=-1))

true_data = tst_data[:len(result[-3])]
pred_data = result[-1][-3]
for vals, clr, ax, dist in zip([true_data, pred_data], ['blue', 'purple'], axes, ['true', 'learned']):
    ax.contour(xx, yy, p, levels=20, alpha=0.7, cmap='cool')
    ax.hist2d(*vals.T, bins=100, cmap='hot') #, marker='.', c=clr, alpha=0.5)
    ax.set_title(f'Histogram of samples from {dist} distribution', fontsize=16)
    ax.set_xlim(max(true_data[:, 0].min(), pred_data[:, 0].min()), 
                min(true_data[:, 0].max(), pred_data[:, 0].max()),
               )
    ax.set_ylim(max(true_data[:, 1].min(), pred_data[:, 1].min()), 
                min(true_data[:, 1].max(), pred_data[:, 1].max()))

png

Using the saved samples from each step we can see how the model improves over time.

fig, axes = plt.subplots(4, 4, figsize=(16, 16))
xx, yy = np.meshgrid(np.linspace(-3, 5, 100), np.linspace(-3, 7, 100))
p = mvn.pdf(np.stack([xx, yy], axis=-1))
N = 15
size = len(result[-1])
stride = size // N

ax = axes[0, 0]
#cmap = plt.cm.viridis
#ax.contour(xx, yy, p, levels=20,  alpha=0.5, cmap=cmap)
ax.scatter(*tst_data.T,  marker='.', alpha=.8)
ax.set_title(f'True', fontsize=16)
ax.set_xlim(-4, 6)
ax.set_ylim(-4, 5)

for ax, idx, vals in zip(axes.ravel()[1:], np.arange(size)[::stride][:N], 
                                     result[-1][::stride][:N]):
    #ax.contour(xx, yy, p, levels=20,  alpha=0.5, cmap=cmap)
    ax.scatter(*vals.T, c='green', marker='.', alpha=.8)
    ax.set_title(f'Generated / step {idx + 1}', fontsize=16)
    ax.set_xlim(-4, 6)
    ax.set_ylim(-4, 5)

png

This is a simple task and it can be seen that the model is very quickly able to learn the distribution

What’s next

In the subsequent blog posts we will start delving into the mathematical details that we only briefly mentioned here. We will also learn how to implement more sophisticated VAEs.