Skip to content

Deep Multi-Branch Aggregation Network for Semantic Segmentation in PyTorch

License

Notifications You must be signed in to change notification settings

haritsahm/pytorch-DMANet

Repository files navigation

Deep Multi-Branch Aggregation Network for Real-Time Semantic Segmentation in Street Scenes

PyTorch Lightning Config: Hydra Template
Paper

DMA-Net Architecture

This is an implementation of DMA-Net in Pytorch. The project is for my self exploration with Pytorch Lightning and Hydra tools and enhance my programming skills. DMA-Net is a real-time semantic segmentation network for street scenes in self-driving cars.

Added Features

  1. D-Adaptaion Optmizers Learning rate free learning for SGD, AdaGrad and Adam! by facebookresearch/dadaptation/ Simlply enable by using:

    model.auto_lr=True model.lr=1.0
    
  2. Hyperparameter Search Since its hard to reproduce the result from the original author, I added 2 variables high_level_features and low_level_features to set the feature sizes in the model.

    • high_level_features: its the CBR (upmid_cbr) input size after addition ops between sub-network 3 and sub-network 4 in the upscaling pipeline.

    • low_level_features: its the CBR (uplow_cbr) input size after addition ops between sub-network 2 and upmid_cbr in the upscaling pipeline.

    model.net.low_level_features=128 model.net.high_level_features=128
    

How to run

Install dependencies

# clone project
git clone https://github.com/haritsahm/pytorch-DMANet.git
cd pytorch-DMANet

# [OPTIONAL] create conda environment
conda create -n myenv python=3.10
conda activate myenv

# install pytorch according to instructions
# https://pytorch.org/get-started/

# install requirements
pip install -r requirements.txt

Prepare dataset

Run and follow the notebook to prepare and visualize dataset using Fiftyone Fiftyone Sample

Train Commands

1. Train with default configurations

# train on CPU
python src/train.py trainer=cpu paths.data_dir=data/cityscape_fo_image_segmentation

# train on GPU
python src/train.py trainer=gpu paths.data_dir=data/cityscape_fo_image_segmentation

# train with DDP (Distributed Data Parallel) (4 GPUs)
python src/train.py trainer=ddp trainer.devices=4 paths.data_dir=data/cityscape_fo_image_segmentation

2. Train model with chosen experiment configuration from configs/experiment/

# train using cityscape dataset
python train.py experiment=dmanet_cityscape paths.data_dir=data/cityscape_fo_image_segmentation

# train using camvid dataset
python train.py experiment=dmanet_camvid

3. Override any parameter

python train.py experiment=dmanet_cityscape paths.data_dir=data/cityscape_fo_image_segmentation trainer.max_epochs=20 datamodule.batch_size=64 model.net.low_level_features=128 model.net.high_level_features=256

Read the full documentation on how to use pytorch-lightning + hydra

TODO:

  • Train model using cloud instances
  • Validate and compare model metrics (cityscapes and camvid)