torchgfn icon indicating copy to clipboard operation
torchgfn copied to clipboard

add support for including Graphs as States of GFlowNet

Open ashdtu opened this issue 6 months ago • 0 comments

Describe your changes

Type: Feature Add support for including Graphs as States of GFlowNet The Graph structure is represented via the Data class of Torch Geometric. The GraphStates object is represented via the Batch class which encapsulates batching of Data objects and their efficient indexing. As opposed to the existing States object which requires appending dummy states for batching different length Trajectories together, we seek to support Trajectories by representing it as nested Batch object.

The existing implementation of Trajectory supports the following indexing dimensions: (Num time steps, Num trajectories, State Size). The nested Batch of Batch object representing state Trajectories would naturally indexing of the form (Num trajectories, Num timesteps, State size), and would need to implement logic for flipping the indexing dimensions internally in _getitem_() and _setitem_().

Issue Number

#153

Testing

  • [ ] Unit tests for helper functions of GraphState class: initialization, appending, batching etc.
  • [ ] Compatibility check with Trajectories, Transition class
  • [ ] Unit testing with FlowMatching : accessing GraphStates directly in loss function
  • [ ] Unit testing with TrajectoryBalance: accessing GraphStates via trajectory class

ashdtu avatar Aug 23 '24 09:08 ashdtu