torchgfn
torchgfn copied to clipboard
Backward trajectory time steps in `intro_gfn_continuous_line.ipynb`
Thank you for this great library. In the tutorial notebook, for computing PB of trajectory, we only consider timestep range [trajectory_length, 2). Why do we stop at time step 2?
The code work and the performance seems to improve if we consider the timestep range [trajectory_length, 1). (Skipping modelling backward step of time 1 to time 0)
# Backward loop to compute logPB from existing trajectory under the backward policy.
for t in range(trajectory_length, 2, -1):
policy_dist = get_policy_dist(backward_model, trajectory[:, t, :])
action = trajectory[:, t, 0] - trajectory[:, t - 1, 0]
logPB += policy_dist.log_prob(action)