What is Variational autoencoder


Variational Autoencoder: From Theory to Implementation

Autoencoders have been around for decades and are widely used to compress data, denoise, anomaly detection, clustering, image colorization, and many other applications in various domains. Variational Autoencoder (VAE) is a type of autoencoder that achieved state-of-the-art performance on many datasets, including MNIST, FashionMNIST, CIFAR-10, CelebA, and others. In this article, we will dive into VAE theory and implementation using PyTorch, one of the popular deep learning libraries for implementing neural networks.

Autoencoder: The Basics

Autoencoder is a neural network that learns to encode an input signal into a lower-dimensional latent space and then decode it back to the original input signal. The autoencoder consists of two parts: the encoder, which maps an input signal to a latent space, and the decoder, which reconstructs the input signal from the latent space. The encoder and decoder can be implemented as convolutional neural networks (CNNs), recurrent neural networks (RNNs), or any other type of neural network. The goal of the autoencoder is to minimize the reconstruction error between the input signal and the reconstructed signal.

Encoder:
  • Convolutional neural network with multiple convolutional layers and pooling layers.
  • ReLU activation function in each layer.
  • Flatten the output of the last convolutional layer and connect it to the first dense layer.
  • Two dense layers with batch normalization and ReLU activation function in each layer.
  • Mean and log variance outputs in the last dense layer.
Decoder:
  • Dense layer with batch normalization and ReLU activation function.
  • Reshape the output of the dense layer into a 3D tensor.
  • Transpose convolutional layer with multiple filters and strides.
  • ReLU activation function in each layer.
  • Sigmoid activation function in the last layer.
Limitations of Autoencoder

Autoencoder is limited in generating new data samples from the learned data distribution. Autoencoder can only interpolate between the data samples, which means it can only generate variations of the existing data samples. Autoencoder can't generate novel data samples that don't exist in the training data. Also, autoencoder is limited in the flexibility of the latent space. Autoencoder learns an implicit representation of the input data in the latent space, which may not be interpretable or controllable. For example, if we want to generate a new face with a specific hairstyle, it is hard to control the latent space to produce the desired hairstyle.

Variational Autoencoder: The Solution

Variational Autoencoder (VAE) addresses the limitations of the autoencoder by learning a flexible and interpretable latent space. VAE uses a probabilistic model to map the input signal to the latent space, where each dimension of the latent space represents a different aspect or feature of the input signal. VAE also learns the probability distribution of the latent space, which allows it to sample new points from the latent space and generate novel data samples.

VAE Theory

Let's assume that we have a dataset of images X and we want to learn a probabilistic model that can generate new images that are similar to X. We can model the generative process as follows:

  • p(z): prior probability distribution over the latent space z.
  • p(x|z): conditional probability distribution over the data x given the latent variable z, also called the likelihood function.

We want to learn the parameters of these two probability distributions, such that we can sample from p(z) and then generate new samples x by passing the sampled z through the decoder network. The problem is that we don't know the exact form of p(z) and p(x|z), and we can't calculate the likelihood function directly because the latent variable z is unknown.

The solution is to use the variational inference technique, which approximates the likelihood function using a variational lower bound that is tractable to calculate. The variational lower bound is a lower bound on the log-likelihood function that depends on a variational distribution q(z|x) that approximates the true posterior distribution over the latent variables given the observed data. The variational lower bound is defined as:

  • ELBO = E[log p(x|z)] - KL(q(z|x) || p(z))

The first term is the expected reconstruction error between the input signal x and the reconstructed signal x̂ from the decoder. The second term is the Kullback-Leibler divergence between the variational distribution q(z|x) and the prior distribution p(z).

The goal is to maximize the ELBO with respect to the parameters of the encoder and decoder networks, which is equivalent to minimizing the following loss function:

  • L(x,x̂,z) = -ELBO = log p(x|z) + KL(q(z|x) || p(z))

The loss function consists of two terms: the reconstruction loss, which measures how well the decoded signal x̂ matches the input signal x, and the regularization loss, which encourages the variational distribution q(z|x) to be close to the prior distribution p(z).

VAE Implementation in PyTorch

In the following sections, we will walk you through the implementation of a VAE using PyTorch. We will use the MNIST dataset, which contains grayscale images of handwritten digits from 0 to 9. The images are 28x28 pixels and are normalized to have pixel values between 0 and 1.

Step 1: Loading the Data

We will start by loading the MNIST dataset using the torchvision module in PyTorch:

``` import torch from torchvision import datasets, transforms # Define data transform transform = transforms.Compose([ transforms.ToTensor(), ]) # Load MNIST dataset train_data = datasets.MNIST('data', train=True, download=True, transform=transform) test_data = datasets.MNIST('data', train=False, download=True, transform=transform) # Define data loader batch_size = 128 train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True) ```

The transform.Compose() function creates a sequence of transformations to apply to the images in the dataset. In this case, we only apply the ToTensor() transformation, which converts the images to PyTorch tensors and normalizes the pixel values to be between 0 and 1.

The datasets.MNIST() function loads the MNIST dataset and applies the specified transform to the images. We load both the training and testing sets and create data loaders to iterate over the batches of data during training and testing.

Step 2: Define the VAE Model

We will define the VAE model as a subclass of the nn.Module class in PyTorch. The VAE model consists of an encoder and decoder network, which are implemented as two fully connected neural networks. The encoder network takes an input image and produces the mean and log variance of the latent variable z. The decoder network takes a sampled latent variable z and produces the reconstructed image.

``` import torch.nn as nn import torch.nn.functional as F class VAE(nn.Module): def __init__(self, latent_dim=20): super(VAE, self).__init__() # Encoder network self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, latent_dim) self.fc22 = nn.Linear(400, latent_dim) # Decoder network self.fc3 = nn.Linear(latent_dim, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): # Flatten the input image x = x.view(-1, 784) # Pass the input through the encoder network x = F.relu(self.fc1(x)) mean = self.fc21(x) logvar = self.fc22(x) return mean, logvar def reparametrize(self, mean, logvar): # Sample the latent variable from the variational distribution std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = eps * std + mean return z def decode(self, z): # Pass the latent variable through the decoder network z = F.relu(self.fc3(z)) x = torch.sigmoid(self.fc4(z)) # Reshape the output to be a 28x28 image x = x.view(-1, 1, 28, 28) return x def forward(self, x): # Encode the input image mean, logvar = self.encode(x) # Sample the latent variable from the variational distribution z = self.reparametrize(mean, logvar) # Decode the latent variable into an image x_reconstructed = self.decode(z) return x_reconstructed, mean, logvar ```

The VAE class takes a latent dimension argument, which specifies the number of dimensions in the latent variable z. The __init__() function defines the encoder and decoder networks as fully connected neural networks with ReLU activation functions. The encoder network takes an input image and produces the mean and log variance of the latent variable z. The decoder network takes a sampled latent variable z and produces the reconstructed image.

The encode() function takes an input image x, flattens it, and passes it through the encoder network. The encoder network produces two outputs: mean and log variance of the latent variable z, which will be used to sample the variational distribution q(z|x).

The reparametrize() function takes the mean and log variance of the latent variable z and returns a sampled latent variable z from the variational distribution q(z|x). The reparametrization trick is used to ensure that the derivative can be propagated through the sampling operation.

The decode() function takes a sampled latent variable z and passes it through the decoder network. The decoder network produces an output that is reshaped to be a 28x28 image. The output of the decoder network is then passed through a sigmoid activation function to ensure that the pixel values are between 0 and 1.

The forward() function takes an input image x and passes it through the encoder and decoder networks. The function returns the reconstructed image x_reconstructed, the mean and log variance of the latent variable z, which will be used to calculate the Kullback-Leibler divergence term in the loss function.

Step 3: Define the Loss Function

We will define the VAE loss function as the sum of the reconstruction loss and the regularization loss:

``` def vae_loss(reconstructed_x, x, mean, logvar): # Calculate the reconstruction loss reconstruction_loss = F.binary_cross_entropy(reconstructed_x, x, reduction='sum') # Calculate the regularization loss regularization_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) # Combine the reconstruction and regularization loss loss = reconstruction_loss + regularization_loss return loss ```

The vae_loss() function takes the reconstructed image x_reconstructed, the original image x, the mean and log variance of the latent variable z, and calculates the reconstruction loss and regularization loss. The reconstruction loss is the binary cross-entropy loss between the reconstructed image x_reconstructed and the original image x. The regularization loss is the Kullback-Leibler divergence between the variational distribution q(z|x) and the prior distribution p(z), which encourages the latent variable z to be close to the prior distribution and prevent overfitting. The total loss is the sum of the reconstruction and regularization loss.

Step 4: Training the Model

We will train the VAE model using stochastic gradient descent (SGD) optimizer and a fixed learning rate. We will use a validation set to monitor the model's performance during training.

``` # Define the VAE model latent_dim = 20 model = VAE(latent_dim=latent_dim) # Define the optimizer optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # Define the number of epochs epochs = 50 # Move the model to the GPU if available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) # Training loop for epoch in range(epochs): # Train the model on the training set model.train() train_loss = 0 for i, (x, _) in enumerate(train_loader): # Move the data to the GPU if available x = x.to(device) # Forward pass x_reconstructed, mean, logvar = model(x) # Calculate the loss loss = vae_loss(x_reconstructed, x, mean, logvar) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() # Update the running loss train_loss += loss.item() # Evaluate the model on the validation set model.eval() with torch.no_grad(): val_loss = 0 for x, _ in test_loader: # Move the data to the GPU if available x = x.to(device) # Forward pass x_reconstructed, mean, logvar = model(x) # Calculate the loss loss = vae_loss(x_reconstructed, x, mean, logvar) # Update the running loss val_loss += loss.item() # Print the training and validation loss of the epoch print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss/len(train_loader):.2f}, Val Loss: {val_loss/len(test_loader):.2f}') ```

The training script defines the VAE model, optimizer, number of epochs, and moves the model to the GPU if available. In the training loop, we iterate over the batches of data in the training set, move the data to the GPU if available, and forward pass the data through the VAE model. We calculate the loss and backward pass the gradients through the network. The optimizer updates the model parameters using the gradients. We repeat this process for the entire training set. After each epoch, we evaluate the model on the validation set by forwarding pass and calculating the loss. We print the average training and validation loss for each epoch. At the end of training, we can use the VAE model to sample new images by sampling from the prior distribution p(z), passing the sampled z through the decoder network, and visualizing the output image:

``` import matplotlib.pyplot as plt import numpy as np # Generate a random latent variable z z = torch.randn(1, latent_dim).to(device) # Decode the latent variable into an image model.eval() with torch.no_grad(): x = model.decode(z) # Reshape the output to be a 28x28 image x = x.view(28, 28).cpu().numpy() # Plot the image plt.imshow(x, cmap='gray') plt.axis('off') plt.show() ```
Conclusion

Variational Autoencoder (VAE) is a type of autoencoder that overcomes the limitations of the standard autoencoder by learning a flexible and interpretable latent space. VAE uses the variational inference technique to approximate the likelihood function and learn the parameters of the probabilistic model. VAE has achieved state-of-the-art performance on many datasets, including MNIST, FashionMNIST, CIFAR-10, CelebA, and others. The implementation of VAE in PyTorch involved defining the VAE model, loss function, and training script. One of the advantages of VAE is its ability to generate novel data samples by sampling from the prior distribution p(z) and passing the sampled z through the decoder network. VAE has many potential applications in image generation, data compression, anomaly detection, and many other domains.