torchgfn icon indicating copy to clipboard operation
torchgfn copied to clipboard

Explore ways of vectorizing the flow matching loss

Open saleml opened this issue 1 year ago • 0 comments

Currently, the flow matching loss requires a loop over all possible actions for action_idx in range(self.env.n_actions - 1):.

This might be impractical if the number of actions blows. We might want to explore ways of vectorizing that for loop.

One idea is to "repeat" the states, and creating a big actions tensor.

saleml avatar Aug 02 '23 20:08 saleml