Skip to content

An official pytorch implementation of EACL2024 short paper "Flow Matching for Conditional Text Generation in a Few Sampling Steps"

License

Notifications You must be signed in to change notification settings

dongzhuoyao/flowseq

Repository files navigation

Flow Matching for Conditional Text Generation in a Few Sampling Steps ( EACL 2024 )

Vincent Tao Hu, Di Wu, Yuki M. Asano, Pascal Mettes, Basura Fernando, Bjorn Ommer, Cees G.M. Snoek

This repository represents the official implementation of the EACL2024 paper titled "Flow Matching for Conditional Text Generation in a Few Sampling Steps".

Hugging Face Model Website Paper GitHub License

landscape

Dataset

https://drive.google.com/drive/folders/1sU8CcOJE_aaaKLijBNzr-4y1i1YoZet2?usp=drive_link

Run

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc-per-node=4 flow_train.py 
CUDA_VISIBLE_DEVICES=0,2 torchrun --nnodes=1 --nproc-per-node=2 flow_train.py  data=qg
CUDA_VISIBLE_DEVICES=6 torchrun --nnodes=1 --nproc-per-node=1 flow_train.py  data=qg

Evaluation

Download the pretrained checkpoint from https://huggingface.co/taohu/flowseq/tree/main, more checkpoints are coming soon.

python flow_sample_eval_s2s.py    data=qqp_acc data.eval.is_debug=0 data.eval.model_path='qqp_ema_0.9999_070000.pt' data.eval.candidate_num=1 data.eval.ode_stepnum=1

Environment Preparation

conda create -n flowseq  python=3.10
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
pip install  torchdiffeq  matplotlib h5py  accelerate loguru blobfile ml_collections
pip install hydra-core wandb einops scikit-learn --upgrade
pip install einops 
pip install transformers
pip install nltk bert_score datasets torchmetrics

Optional

pip install diffusers

Common Issue

  • The inference result on non-single steps

    Our work is main about the explore the single-step sampling. The anchor loss is encouraged to infer the original dataset by single step, the multiple step was implemented by using a zigzag manner following the Consistency Models, this codebase doesn't include that implementation yet.

  • Batch size

    If your GPU is not rich enough, try to decrease the batch size to 128~256, and stop using the accumulate gradients, this can somehow reach fair performance according my experience.

  • Typical Issue

https://github.com/Shark-NLP/DiffuSeq/issues/5
https://github.com/Shark-NLP/DiffuSeq/issues/22

Citation

Please add the citation if our paper or code helps you.

@inproceedings{HuEACL2024,
        title = {Flow Matching for Conditional Text Generation in a Few Sampling Steps},
        author = {Vincent Tao Hu and Di Wu and Yuki M Asano and Pascal Mettes and Basura Fernando and Björn Ommer and Cees G M Snoek},
        year = {2024},
        date = {2024-03-27},
        booktitle = {EACL},
        tppubtype = {inproceedings}
        }

About

An official pytorch implementation of EACL2024 short paper "Flow Matching for Conditional Text Generation in a Few Sampling Steps"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages