torchgfn
torchgfn copied to clipboard
Explore ways of vectorizing the flow matching loss
Currently, the flow matching loss requires a loop over all possible actions for action_idx in range(self.env.n_actions - 1):
.
This might be impractical if the number of actions blows. We might want to explore ways of vectorizing that for loop.
One idea is to "repeat" the states
, and creating a big actions tensor.