torchgfn
torchgfn copied to clipboard
add support for including Graphs as States of GFlowNet
Describe your changes
Type: Feature
Add support for including Graphs as States of GFlowNet
The Graph structure is represented via the Data
class of Torch Geometric. The GraphStates object is represented via the Batch
class which encapsulates batching of Data
objects and their efficient indexing. As opposed to the existing States object which requires appending dummy states for batching different length Trajectories together, we seek to support Trajectories by representing it as nested Batch
object.
The existing implementation of Trajectory supports the following indexing dimensions: (Num time steps, Num trajectories, State Size)
. The nested Batch of Batch object representing state Trajectories would naturally indexing of the form (Num trajectories, Num timesteps, State size), and would need to implement logic for flipping the indexing dimensions internally in _getitem_()
and _setitem_()
.
Issue Number
#153
Testing
- [ ] Unit tests for helper functions of GraphState class: initialization, appending, batching etc.
- [ ] Compatibility check with Trajectories, Transition class
- [ ] Unit testing with FlowMatching : accessing GraphStates directly in loss function
- [ ] Unit testing with TrajectoryBalance: accessing GraphStates via trajectory class