Skip to content
/ SDAT Public

[ICML 2022]Source code for "A Closer Look at Smoothness in Domain Adversarial Training",

License

Notifications You must be signed in to change notification settings

val-iisc/SDAT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Smooth Domain Adversarial Training

Harsh Rangwani*, Sumukh K Aithal*, Mayank Mishra, Arihant Jain, R. Venkatesh Babu

This is the official PyTorch implementation for our ICML'22 paper: A Closer Look at Smoothness in Domain Adversarial Training.[Paper]

PWC PWC

Introduction

Smooth Domain Adversarial Training

In recent times, methods converging to smooth optima have shown improved generalization for supervised learning tasks like classification. In this work, we analyze the effect of smoothness enhancing formulations on domain adversarial training, the objective of which is a combination of task loss (eg. classification, regression etc.) and adversarial terms. We find that converging to a smooth minima with respect to (w.r.t.) task loss stabilizes the adversarial training leading to better performance on target domain. In contrast to task loss, our analysis shows that converging to smooth minima w.r.t. adversarial loss leads to sub-optimal generalization on the target domain. Based on the analysis, we introduce the Smooth Domain Adversarial Training (SDAT) procedure, which effectively enhances the performance of existing domain adversarial methods for both classification and object detection tasks.

TLDR: Just do a few line of code change to improve your adversarial domain adaptation algorithm by converting it to it's smooth variant.

Why use SDAT?

  • Can be combined with any DAT algorithm.
  • Easy to integrate with a few lines of code.
  • Leads to significant improvement in the accuracy of target domain.

DAT Based Method w/ SDAT

We provide the details of changes required to convert any DAT algorithm (eg. CDAN, DANN, CDAN+MCC etc.) to it's Smooth DAT version.

optimizer = SAM(classifier.get_parameters(), torch.optim.SGD, rho=args.rho, adaptive=False,
                    lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
# optimizer refers to the Smooth optimizer which contains parameters of the feature extractor and classifier.
optimizer.zero_grad()
# ad_optimizer refers to standard SGD optimizer which contains parameters of domain classifier.
ad_optimizer.zero_grad()

# Calculate task loss
class_prediction, feature = model(x)
task_loss = task_loss_fn(class_prediction, label)
task_loss.backward()

# Calculate ϵ̂ (w) and add it to the weights
optimizer.first_step()

# Calculate task loss and domain loss
class_prediction, feature = model(x)
task_loss = task_loss_fn(class_prediction, label)
domain_loss = domain_classifier(feature)
loss = task_loss + domain_loss
loss.backward()

# Update parameters (Sharpness-Aware update)
optimizer.step()
# Update parameters of domain classifier
ad_optimizer.step()

Getting started

  • Requirements

    • pytorch 1.9.1
    • torchvision 0.10.1
    • wandb 0.12.2
    • timm 0.5.5
    • prettytable 2.2.0
    • scikit-learn
  • Installation

git clone https://github.com/val-iisc/SDAT.git
cd SDAT
pip install -r requirements.txt

We use Weights and Biases (wandb) to track our experiments and results. To track your experiments with wandb, create a new project with your account. The project and entity arguments in wandb.init must be changed accordingly. To disable wandb tracking, the log_results flag can be used.

  • Datasets

    The datasets used in the repository can be downloaded from the following links: The datasets are automatically downloaded to the data/ folder if it is not available.

Training

We report our numbers primarily on two domain adaptation methods: CDAN w/ SDAT and CDAN+MCC w/ SDAT. The training scripts can be found under the examples subdirectory.

Domain Adversarial Training (DAT)

To train using standard CDAN and CDAN+MCC, use the cdan.py and cdan_mcc.py files, respectively. Sample command to execute the training of the aforementioned methods with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) can be found below.

python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results

Smooth Domain Adversarial Training (SDAT)

To train using our proposed CDAN w/ SDAT and CDAN+MCC w/ SDAT, use the cdan_sdat.py and cdan_mcc_sdat.py files, respectively.

A sample script to run CDAN+MCC w/ SDAT with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) is given below.

python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results

Additional commands to reproduce the results can be found from run_office_home.sh and run_visda.sh under examples.

Results

We following table reports the accuracy score across the various splits of Office-Home and VisDA-2017 datasets using CDAN+MCC w/ SDAT with VIT B-16 backbone. We also provide downloadable weights for the corresponding pretrained classifier.

Dataset Source Target Accuracy Checkpoints
Office-Home Art Clipart 70.8 ckpt
Art Product 80.7 ckpt
Art Real World 90.5 ckpt
Clipart Art 85.2 ckpt
Clipart Product 87.3 ckpt
Clipart Real World 89.7 ckpt
Product Art 84.1 ckpt
Product Clipart 70.7 ckpt
Product Real World 90.6 ckpt
Real World Art 88.3 ckpt
Real World Clipart 75.5 ckpt
Real World Product 92.1 ckpt
VisDA-2017 Synthetic Real 89.8 ckpt

Evaluation

To evaluate a classifier with pretrained weights, use the eval.py under examples. Set the --weight_path argument with the path of the weight to be evaluated.

A sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on Office-Home (with Art as source domain and Clipart as the target domain) is given below.

python eval.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 -b 24 --no-pool --weight_path path_to_weight.pth --log_name Ar2Cl_cdan_mcc_sdat_vit_eval --gpu 0 --phase test

A sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on VisDA-2017 (with Synthetic as source domain and Real as the target domain) is given below.

python eval.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --per-class-eval --train-resizing cen.crop --weight_path path_to_weight.pth --log_name visda_cdan_mcc_sdat_vit_eval --gpu 0 --no-pool --phase test

Overview of the arguments

Generally, all scripts in the project take the following flags

  • -a: Architecture of the backbone. (resnet50|vit_base_patch16_224)
  • -d: Dataset (OfficeHome|DomainNet)
  • -s: Source Domain
  • -t: Target Domain
  • --epochs: Number of Epochs to be trained for.
  • --no-pool: Use --no-pool for all experiments with ViT backbone.
  • --log_name: Name of the run on wandb.
  • --gpu: GPU id to use.
  • --rho: $\rho$ value in SDAT (Applicable only for SDAT runs).

Acknowledgement

Our implementation is based on the Transfer Learning Library. We use the PyTorch implementation of SAM from https://github.com/davda54/sam.

Citation

If you find our paper or codebase useful, please consider citing us as:

@InProceedings{rangwani2022closer,
  title={A Closer Look at Smoothness in Domain Adversarial Training},
  author={Rangwani, Harsh and Aithal, Sumukh K and Mishra, Mayank and Jain, Arihant and Babu, R. Venkatesh},
 booktitle={Proceedings of the 39th International Conference on Machine Learning},
  year={2022}
}

Releases

No releases published

Packages

No packages published

Languages