gt4sd-core icon indicating copy to clipboard operation
gt4sd-core copied to clipboard

feat: gflownet integration in gt4sd

Open georgosgeorgos opened this issue 3 years ago • 4 comments

A first draft of a submodule gt4sd.frameworks.gflownet. The code is adapted from this implementation. We assume molecule generation training on qm9 as a task.

  • gt4sd.frameworks.gflownet.dataloader - contains preprocessing steps, dataloader and sampler.
  • gt4sd.frameworks.gflownet.envs - contains the action-state graph to build iteratively a molecule.
  • gt4sd.frameworks.gflownet.loss - training strategies. Contains trajectory_balance as proposed in Malkin et al. and temporal_difference inspired loss as proposed in Bengio et al..
  • gt4sd.frameworks.gflownet.ml.model - graph_transformer and mxmnet.
  • gt4sd.frameworks.gflownet.train.core - lightning trainer.

The implementation relies on PyTorch lightning and the module/data_module abstraction.

  • [x] Interface for user. Simplify setup.
  • [x] how to handle gpu-based dependencies like torch-geometric, torch-scatter, torch-sparse, torch-cluster.
  • [x] Should we provide GFN for inference only or also for training/finetuning?
  • [x] Bucket support for datasets.

Basic training example. Train a GFlowNet on QM9:


from examples.gflownet.dateset_qm9 import QM9Dataset
from examples.gflownet.task_qm9 import QM9GapTask
from gt4sd.frameworks.gflownet.train.core import train_gflownet_main
from gt4sd.frameworks.gflownet.envs.graph_building_env import GraphBuildingEnv
from gt4sd.frameworks.gflownet.envs.mol_building_env import MolBuildingEnvContext


def main():
    """Run basic GFN training on QM9."""

    hps = {"dataset": "qm9", "dataset_path": "/GFN/qm9.h5", "device": "cpu"}

    dataset = QM9Dataset(hps["dataset_path"], train=True, target="gap")
    environment = GraphBuildingEnv()
    context = MolBuildingEnvContext(["H", "C", "N", "F", "O"], num_cond_dim=32)

    train_gflownet_main(
        configuration=hps,
        dataset=dataset,
        environment=environment,
        context=context,
        _task=QM9GapTask,
    )


if __name__ == "__main__":
    main()

georgosgeorgos avatar Jul 28 '22 08:07 georgosgeorgos

A first draft of a submodule gt4sd.frameworks.gflownet. The code is adapted from this implementation. We assume molecule generation training on qm9 as a task.

  • gt4sd.frameworks.gflownet.data - contains preprocessing steps, dataloader and sampler.
  • gt4sd.frameworks.gflownet.envs - contains the action-state graph to build iteratively a molecule.
  • gt4sd.frameworks.gflownet.loss - training strategies. Contains trajectory_balance as proposed in Malkin et al. as temporal_difference inspired loss as proposed in Bengio et al..
  • gt4sd.frameworks.gflownet.model - graph_transformer and mxmnet.
  • gt4sd.frameworks.gflownet.train - trainers.
  • gt4sd.frameworks.gflownet.core - interface for the user.

To Discuss:

  • [ ] Interface for user. Simplify setup.
  • [ ] Integrate TD-loss from this repo.
  • [ ] how to handle gpu-based dependencies like torch-geometric, torch-scatter, torch-sparse, torch-cluster.
  • [ ] Should we provide GFN for inference only or also for training/finetuning?
  • [ ] Bucket support for datasets.

Basic training example. Train a GFlowNet on QM9

import torch
from ruamel.yaml import YAML

from gt4sd.frameworks.gflownet.train.trainer_qm9 import QM9GapTrainer


def main():
    """Example of how this model can be run outside of Determined"""
    yaml = YAML(typ="safe", pure=True)
    config_file = "src/gt4sd/frameworks/gflownet/tasks/qm9/" + "qm9.yaml"
    with open(config_file, "r") as f:
        hps = yaml.load(f)
    
    trial = QM9GapTrainer(hps, torch.device("cpu"))
    trial.run()

Thanks @georgosgeorgos amazing work, some considerations.

Regarding the user interface, totally agree, we need some iteration/discussion there. The TD-loss is something we can ask for inputs to @MJ10.

Regarding the GPU-deps, I would follow the strategy we applied for torchdrug and others, delegate installation to conda recipes and in case of problems tackling compatibility issue one at a time.

I would consider starting from inference and then wrapping in a family of training_pipelines the different training tasks we support (similar to granular).

Regarding the datasets, I'm not sure we should support them in the core, maybe we can consider adding an example/notebook in the respective folders and provide a script to download the data there (the snippet you pasted is a perfect example).

drugilsberg avatar Jul 28 '22 11:07 drugilsberg

We have a preliminary version of GFlowNet integrated into GT4SD.

The main logic is explained in docs/source/gt4sd_gfn_md.md. An example of how train on qm9 and define a task is provided in examples/gflownet/.

from gt4sd.frameworks.gflownet.arg_parser.parser import parse_arguments_from_config
from gt4sd.frameworks.gflownet.envs.graph_building_env import GraphBuildingEnv
from gt4sd.frameworks.gflownet.envs.mol_building_env import MolBuildingEnvContext
from gt4sd.frameworks.gflownet.tests.qm9 import QM9Dataset, QM9GapTask
from gt4sd.frameworks.gflownet.train.core import train_gflownet


def main():
    """Run basic GFN training on QM9."""

    configuration = {"dataset": "qm9", "dataset_path": "/GFN/qm9.h5", "device": "cpu"}
    # add user configuration
    configuration.update(vars(parse_arguments_from_config()))

    # build the environment and context
    environment = GraphBuildingEnv()
    context = MolBuildingEnvContext()
    # build the dataset
    dataset = QM9Dataset(configuration["dataset_path"], target="gap")
    # build the task
    task = QM9GapTask(
        configuration=configuration,
        dataset=dataset,
    )
    # train gflownet
    train_gflownet(
        configuration=configuration,
        dataset=dataset,
        environment=environment,
        context=context,
        task=task,
    )


if __name__ == "__main__":
    main()

One relevant issue we still have is an incompatibility between the SamplingIterator(IterableDataset) and the lightning data module. Basically, the data loader gets stuck in a loop when yielding a new sample. A fast fix is to set number_workers=0. Probably this issue is related to the lightning version we request for the library (<=1.3)

TODO:

  • [x] add contribution
  • [x] solve problem YAML and lightning
  • [x] add example readme
  • [ ] support multiple workers with iterator and lightning
  • [x] GPU training
  • [ ] TD-loss
  • [x] fix mypy

georgosgeorgos avatar Aug 11 '22 11:08 georgosgeorgos

Thanks! The changes look good. I had a few small questions and comments which I left. Another small question I had: Why is the models folder nested within a ml folder? Can't it just be at the same level as loss for example?

Hi. Thanks for the comments. models is nested in ml for consistency with other frameworks in the library (see gt4sd/frameworks/granular/ for example).

georgosgeorgos avatar Aug 18 '22 06:08 georgosgeorgos

Regarding the requirements:

  • Use requirements.txt and requirements_ci.txt.
  • In requirements_ci.txt we force the CPU version of pytorch_geometric and related libraries.
  • In requirements.txt we can install CPU or GPU versions based on the user system. - It can still happen that the wrong version of pytorch_geometric will be installed (incompatible with the CUDA driver on the local machine). - I would add a short note in the main README (troubleshoot section?) to point out that they have to follow the installation here: https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html

georgosgeorgos avatar Aug 18 '22 14:08 georgosgeorgos