torchgfn
torchgfn copied to clipboard
Function to revert backward trajectories
In previous versions of the code, when actions were integers, we had this function that reverts backward trajectories. It's not used as part of the codebase, but I remember using it for another project (probably GFN vs HVI). I just removed it (in an upcoming PR), and it would be nice to fix it and have it back
@staticmethod
def revert_backward_trajectories(trajectories: Trajectories) -> Trajectories:
"""Reverses a trajectory, but not compatible with continuous GFN. Remove."""
# TODO: this isn't used anywhere - it doesn't work as it assumes that the
# actions are ints. Do we need it?
assert trajectories.is_backward
new_actions = torch.full_like(trajectories.actions, -1)
new_actions = torch.cat(
[new_actions, torch.full((1, len(trajectories)), -1)], dim=0
)
# env.sf should never be None unless something went wrong during class
# instantiation.
if trajectories.env.sf is None:
raise AttributeError(
"Something went wrong during the instantiation of environment {}".format(
trajectories.env
)
)
new_states = trajectories.env.sf.repeat(
trajectories.when_is_done.max() + 1, len(trajectories), 1
)
new_when_is_done = trajectories.when_is_done + 1
for i in range(len(trajectories)):
new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.n_actions - 1
)
new_actions[: trajectories.when_is_done[i], i] = trajectories.actions[
: trajectories.when_is_done[i], i
].flip(0)
new_states[
: trajectories.when_is_done[i] + 1, i
] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip(
0
)
new_states = trajectories.env.States(new_states)
return Trajectories(
env=trajectories.env,
states=new_states,
actions=new_actions,
log_probs=trajectories.log_probs,
when_is_done=new_when_is_done,
is_backward=False,
)