Skip to content

This repo implements VQVAE on mnist and as well as colored version of mnist images. It also implements simple LSTM for generating sample numbers using the encoder outputs of trained VQVAE

Notifications You must be signed in to change notification settings

explainingai-code/VQVAE-Pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

VQVAE Implementation in pytorch with generation using LSTM

This repository implements VQVAE for mnist and colored version of mnist and follows up with a simple LSTM for generating numbers.

VQVAE Explanation and Implementation Video

VQVAE Video

Quickstart

  • Create a new conda environment with python 3.8 then run below commands
  • git clone https://github.com/explainingai-code/VQVAE-Pytorch.git
  • cd VQVAE-Pytorch
  • pip install -r requirements.txt
  • For running a simple VQVAE with minimal code to understand the basics python run_simple_vqvae.py
  • For playing around with VQVAE and training/inferencing the LSTM use the below commands passing the desired configuration file as the config argument
  • python -m tools.train_vqvae for training vqvae
  • python -m tools.infer_vqvae for generating reconstructions and encoder outputs for LSTM training
  • python -m tools.train_lstm for training minimal LSTM
  • python -m tools.generate_images for using the trained LSTM to generate some numbers

Configurations

  • config/vqvae_mnist.yaml - VQVAE for training on black and white mnist images
  • config/vqvae_colored_mnist.yaml - VQVAE with more embedding vectors for training colored mnist images

Data preparation

For setting up the dataset: Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation

Verify the data directory has the following structure:

VQVAE-Pytorch/data/train/images/{0/1/.../9}
	*.png
VQVAE-Pytorch/data/test/images/{0/1/.../9}
	*.png

Output

Outputs will be saved according to the configuration present in yaml files.

For every run a folder of task_name key in config will be created and output_train_dir will be created inside it.

During training of VQVAE the following output will be saved

  • Best Model checkpoints(VQVAE and LSTM) in task_name directory

During inference the following output will be saved

  • Reconstructions for sample of test set in task_name/output_train_dir/reconstruction.png
  • Encoder outputs on train set for LSTM training in task_name/output_train_dir/mnist_encodings.pkl
  • LSTM generation output in task_name/output_train_dir/generation_results.png

Sample Output for VQVAE

Running run_simple_vqvae should be very quick (as its very simple model) and give you below reconstructions (input in black black background and reconstruction in white background)

Running default config VQVAE for mnist should give you below reconstructions for both versions

Sample Generation Output after just 10 epochs Training the vqvae and lstm longer and more parameters(codebook size, codebook dimension, channels , lstm hidden dimension e.t.c) will give better results

Citations

@misc{oord2018neural,
      title={Neural Discrete Representation Learning}, 
      author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
      year={2018},
      eprint={1711.00937},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

This repo implements VQVAE on mnist and as well as colored version of mnist images. It also implements simple LSTM for generating sample numbers using the encoder outputs of trained VQVAE

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages