torchgfn
torchgfn copied to clipboard
Trajectory Class Design
Description
-
Indexing Order: The State class currently supports batching along both
Num Timesteps
dimension(i.e. batch size(Timesteps, )
) and alongNum Timesteps, Num Trajectories
dimension (i.e. batch size(Timesteps, Trajectory)
). Flipping the indexing dimension(Timesteps, Trajectory)
-->(Trajectory, Timesteps)
maybe more user friendly to understand since conceptually Trajectories are a container over the State class. The decision to keep the former might have been motivated by keeping an easy access to the sink/special states using the mask, but we can check if it's worth flipping to the latter. On first glance, it seems like lot of small changes need to be made to accommodate it. -
The Trajectory class should be indifferent to the implementation of the State, Action class. Currently for eg: the
__repr__()
method uses the following hardcoded implementation assuming State always has a tensor attribute. This fails for eg on GraphStates implementation. This needs to be kept generic and derived from the states, actions,__repr__()
method.
def __repr__(self) -> str:
states = self.states.tensor.transpose(0, 1)
assert states.ndim == 3
trajectories_representation = ""
for traj in states[:10]:
one_traj_repr = []
for step in traj:
one_traj_repr.append(str(step.numpy()))
if step.equal(self.env.s0 if self.is_backward else self.env.sf):
break
trajectories_representation += "-> ".join(one_traj_repr) + "\n"
return (
f"Trajectories(n_trajectories={self.n_trajectories}, max_length={self.max_length}, First 10 trajectories:"
+ f"states=\n{trajectories_representation}"
# + f"actions=\n{self.actions.tensor.squeeze().transpose(0, 1)[:10].numpy()}, "
+ f"when_is_done={self.when_is_done[:10].numpy()})"
)
- We can try to see if we can make the entry point for accessing States consistent. FlowMatching for example, directly accesses the States objects and other algorithms use Trajectory class to access the State attribute. We should check if we can have a general template for users implementing their own loss functions to keep things consistent. We can maybe have a simple example how to address both cases using the Trajectory class itself.