torchgfn icon indicating copy to clipboard operation
torchgfn copied to clipboard

Backward trajectory time steps in `intro_gfn_continuous_line.ipynb`

Open tsa87 opened this issue 7 months ago • 0 comments

Thank you for this great library. In the tutorial notebook, for computing PB of trajectory, we only consider timestep range [trajectory_length, 2). Why do we stop at time step 2?

The code work and the performance seems to improve if we consider the timestep range [trajectory_length, 1). (Skipping modelling backward step of time 1 to time 0)

# Backward loop to compute logPB from existing trajectory under the backward policy.
for t in range(trajectory_length, 2, -1):
    policy_dist = get_policy_dist(backward_model, trajectory[:, t, :])
    action = trajectory[:, t, 0] - trajectory[:, t - 1, 0]
    logPB += policy_dist.log_prob(action)

tsa87 avatar Jul 24 '24 20:07 tsa87