Diffusion-Generative-Flow-Samplers
Diffusion-Generative-Flow-Samplers copied to clipboard
PyTorch implementation for our ICLR 2024 paper "Diffusion Generative Flow Samplers: Improving learning signals through partial trajectory optimization"
Diffusion Generative Flow Samplers
Dinghuai Zhang, Ricky Tian Qi Chen, Cheng-Hao Liu, Aaron Courville, Yoshua Bengio.
We propose a novel DGFS sampler for continuous space sampling from given unnormalized densities based on stochastic optimal control 🤖 formulation and the probabilistic 🎲 GFlowNet framework.
target/ has the target distribution code.
gflownet/ contains the DGFS algorithm code.
Examples
python -m gflownet.main target=gm dt=0.05
python -m gflownet.main target=funnel
python -m gflownet.main target=wells
Dependency
Apart from commonly used torch, torchvision, numpy, scipy, matplotlib, we use the following packages:
pip install hydra-core omegaconf submitit hydra-submitit-launcher
pip install wandb tqdm einops seaborn ipdb
