gt4sd-core
gt4sd-core copied to clipboard
feat: gflownet integration in gt4sd
A first draft of a submodule gt4sd.frameworks.gflownet. The code is adapted from this implementation. We assume molecule generation training on qm9 as a task.
gt4sd.frameworks.gflownet.dataloader- contains preprocessing steps, dataloader and sampler.gt4sd.frameworks.gflownet.envs- contains the action-state graph to build iteratively a molecule.gt4sd.frameworks.gflownet.loss- training strategies. Containstrajectory_balanceas proposed in Malkin et al. andtemporal_differenceinspired loss as proposed in Bengio et al..gt4sd.frameworks.gflownet.ml.model-graph_transformerandmxmnet.gt4sd.frameworks.gflownet.train.core- lightning trainer.
The implementation relies on PyTorch lightning and the module/data_module abstraction.
- [x] Interface for user. Simplify setup.
- [x] how to handle gpu-based dependencies like
torch-geometric,torch-scatter,torch-sparse,torch-cluster. - [x] Should we provide GFN for inference only or also for training/finetuning?
- [x] Bucket support for datasets.
Basic training example. Train a GFlowNet on QM9:
from examples.gflownet.dateset_qm9 import QM9Dataset
from examples.gflownet.task_qm9 import QM9GapTask
from gt4sd.frameworks.gflownet.train.core import train_gflownet_main
from gt4sd.frameworks.gflownet.envs.graph_building_env import GraphBuildingEnv
from gt4sd.frameworks.gflownet.envs.mol_building_env import MolBuildingEnvContext
def main():
"""Run basic GFN training on QM9."""
hps = {"dataset": "qm9", "dataset_path": "/GFN/qm9.h5", "device": "cpu"}
dataset = QM9Dataset(hps["dataset_path"], train=True, target="gap")
environment = GraphBuildingEnv()
context = MolBuildingEnvContext(["H", "C", "N", "F", "O"], num_cond_dim=32)
train_gflownet_main(
configuration=hps,
dataset=dataset,
environment=environment,
context=context,
_task=QM9GapTask,
)
if __name__ == "__main__":
main()
A first draft of a submodule
gt4sd.frameworks.gflownet. The code is adapted from this implementation. We assume molecule generation training onqm9as a task.
gt4sd.frameworks.gflownet.data- contains preprocessing steps, dataloader and sampler.gt4sd.frameworks.gflownet.envs- contains the action-state graph to build iteratively a molecule.gt4sd.frameworks.gflownet.loss- training strategies. Containstrajectory_balanceas proposed in Malkin et al. astemporal_differenceinspired loss as proposed in Bengio et al..gt4sd.frameworks.gflownet.model-graph_transformerandmxmnet.gt4sd.frameworks.gflownet.train- trainers.gt4sd.frameworks.gflownet.core- interface for the user.To Discuss:
- [ ] Interface for user. Simplify setup.
- [ ] Integrate TD-loss from this repo.
- [ ] how to handle gpu-based dependencies like
torch-geometric,torch-scatter,torch-sparse,torch-cluster.- [ ] Should we provide GFN for inference only or also for training/finetuning?
- [ ] Bucket support for datasets.
Basic training example. Train a GFlowNet on QM9
import torch from ruamel.yaml import YAML from gt4sd.frameworks.gflownet.train.trainer_qm9 import QM9GapTrainer def main(): """Example of how this model can be run outside of Determined""" yaml = YAML(typ="safe", pure=True) config_file = "src/gt4sd/frameworks/gflownet/tasks/qm9/" + "qm9.yaml" with open(config_file, "r") as f: hps = yaml.load(f) trial = QM9GapTrainer(hps, torch.device("cpu")) trial.run()
Thanks @georgosgeorgos amazing work, some considerations.
Regarding the user interface, totally agree, we need some iteration/discussion there. The TD-loss is something we can ask for inputs to @MJ10.
Regarding the GPU-deps, I would follow the strategy we applied for torchdrug and others, delegate installation to conda recipes and in case of problems tackling compatibility issue one at a time.
I would consider starting from inference and then wrapping in a family of training_pipelines the different training tasks we support (similar to granular).
Regarding the datasets, I'm not sure we should support them in the core, maybe we can consider adding an example/notebook in the respective folders and provide a script to download the data there (the snippet you pasted is a perfect example).
We have a preliminary version of GFlowNet integrated into GT4SD.
The main logic is explained in docs/source/gt4sd_gfn_md.md.
An example of how train on qm9 and define a task is provided in examples/gflownet/.
from gt4sd.frameworks.gflownet.arg_parser.parser import parse_arguments_from_config
from gt4sd.frameworks.gflownet.envs.graph_building_env import GraphBuildingEnv
from gt4sd.frameworks.gflownet.envs.mol_building_env import MolBuildingEnvContext
from gt4sd.frameworks.gflownet.tests.qm9 import QM9Dataset, QM9GapTask
from gt4sd.frameworks.gflownet.train.core import train_gflownet
def main():
"""Run basic GFN training on QM9."""
configuration = {"dataset": "qm9", "dataset_path": "/GFN/qm9.h5", "device": "cpu"}
# add user configuration
configuration.update(vars(parse_arguments_from_config()))
# build the environment and context
environment = GraphBuildingEnv()
context = MolBuildingEnvContext()
# build the dataset
dataset = QM9Dataset(configuration["dataset_path"], target="gap")
# build the task
task = QM9GapTask(
configuration=configuration,
dataset=dataset,
)
# train gflownet
train_gflownet(
configuration=configuration,
dataset=dataset,
environment=environment,
context=context,
task=task,
)
if __name__ == "__main__":
main()
One relevant issue we still have is an incompatibility between the SamplingIterator(IterableDataset) and the lightning data module. Basically, the data loader gets stuck in a loop when yielding a new sample. A fast fix is to set number_workers=0. Probably this issue is related to the lightning version we request for the library (<=1.3)
TODO:
- [x] add contribution
- [x] solve problem YAML and lightning
- [x] add example readme
- [ ] support multiple workers with iterator and lightning
- [x] GPU training
- [ ] TD-loss
- [x] fix mypy
Thanks! The changes look good. I had a few small questions and comments which I left. Another small question I had: Why is the
modelsfolder nested within amlfolder? Can't it just be at the same level aslossfor example?
Hi. Thanks for the comments. models is nested in ml for consistency with other frameworks in the library (see gt4sd/frameworks/granular/ for example).
Regarding the requirements:
- Use
requirements.txtandrequirements_ci.txt. - In
requirements_ci.txtwe force the CPU version ofpytorch_geometricand related libraries. - In
requirements.txtwe can install CPU or GPU versions based on the user system. - It can still happen that the wrong version ofpytorch_geometricwill be installed (incompatible with the CUDA driver on the local machine). - I would add a short note in the main README (troubleshoot section?) to point out that they have to follow the installation here: https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html