torchgfn icon indicating copy to clipboard operation
torchgfn copied to clipboard

Sampling with backward policy recovers log_rewards of initial_states

Open alexandrelarouche opened this issue 7 months ago • 6 comments

Hi,

I am currently playing around with backward sampling (i.e. sampling from the P_B) from some fixed terminating states. However, I encountered what I can only guess is a bug in the Sampler class implementation of sample_trajectories. It seems that currently, the log-rewards are computed for the initial state only, when the provided P_B estimator has the is_backward attribute set to True.

I am specifically referring to the following block: https://github.com/GFNOrg/torchgfn/blob/b9b62006c7ccaba540af0cead2f34795f759ffd3/src/gfn/samplers.py#L233-L244

The solve is straightforward, I think. We need to set the log_rewards at the states provided as input whenever we sample backward. This assumes that the provided states are terminating states (which seems reasonable, but maybe there are some use cases I am not aware of where this is not desirable).

Actionable items from this issue:

  1. We need a test environment for backward sampling
  2. When estimator.is_backward is True, we must not compute the log_reward at the end of the trajectory. Instead, we should set the trajectories_log_rewards to the rewards of the states that were provided as input to the sample_trajectories method.

I can implement both actionable items. However, I would like to clarify if the above assumption is a non-starter for the solution.

Edit: Actually, I just noticed there is a catch if statement for backward trajectories saying they are not supported. So maybe this warrants a bigger issue. Edit 2: Turns out, there is a reverse_backward_trajectories function available to translate backward trajectories to forward ones. This puts extra burden on the user, however.

alexandrelarouche avatar May 14 '25 21:05 alexandrelarouche

Hi, @alexandrelarouche. Thank you for raising this issue.

I agree that we can improve the backward sampling and how we handle the log_rewards. I would like to think about this more deeply, so I would appreciate it if you could share your specific use case involving backward sampling.

Also, I want to hear your thoughts on our current approach; do you think we should allow learning directly from backward trajectories instead of converting them into forward ones using reverse_backward_trajectories?

hyeok9855 avatar May 28 '25 18:05 hyeok9855

Hey, thanks for being so open about discussing this.

My use case

Given a set of terminating states, I want to generate backward trajectories from the terminating states back to the initial state. That is, I want to fix the terminating state a trajectory ends in, take the assigned reward, and learn on a trajectory leading to the decided terminating state. This is similar, but different LocalSearch-GFN (and the like) where only partial backward-trajectories are needed to go forward again. I need FULL backward trajectories and I do not need to go forward again from these. We can view this as offline learning, with a fixed dataset of terminating states.

My two cents

  1. Once #316 is merged,reverse_backward_trajectories can be viewed as an implementation quirk of the library, since it is functionally is the same as supported backward trajectories (in term of learning).
  2. Directly allowing backward trajs would be a "nice to have" in the current state of torchgfn, but is not an urgent item, as we can get the same result with reverse_backward_trajectories.
  3. However, supporting backward trajectories would be more elegant and map more directly with what the procedure actually is imho, and would avoid the need to lookup relatively obscure/not obviously documented functions in the library.
  4. Properly implementing backward trajectories would boost the library's performance for this (admittedly niche) case, as currently we need to recompute forward log-probabilities of backward trajectories.
  5. If the maintainers decide not to focus on implementing backward trajectories right now, I think it would be necessary to introduce a documentation portion on using backward trajectories, and introducing the reverse_backward_trajectories function.

I think torchgfn can get away without implementing backward trajectories for the time being (if #316 is merged). However, I think it will need to worked on eventually, as the current support is fairly unoptimized. Furthermore, I do not think properly supporting backward trajectories is a big amount of work. A lot of comments in the code already point to potential improvements, and they seem fairly obvious to implement

alexandrelarouche avatar May 30 '25 18:05 alexandrelarouche

Thanks @alexandrelarouche. I totally agree with your opinions. I will discuss this with the team and try to address your concern as soon as possible. However, it may take some time (at least two weeks), as the entire team is currently focusing on another issue.

In the meantime, if you have any additional thoughts or suggestions, please feel free to share them. Your feedback is valuable, and I want to ensure we consider all angles as we move forward. I appreciate your patience, and I will keep you updated as we make progress on this issue. (Of course, we'll be more than welcome to any contributions!)

hyeok9855 avatar May 31 '25 05:05 hyeok9855

@alexandrelarouche I'm curious about this piece:

currently we need to recompute forward log-probabilities of backward trajectories

I probably misunderstand you, but this recomputation is needed normally, i.e., if you evaluate a forward trajectory under Pf, you would need to re-evaluate it's reversed version under Pb.

It's rare (never?) the case that you would evaluate the reverse of a trajectory drawn from a policy under the same policy at least using current practices.

I agree that more flexibility would be in general desirable here, and thank you for your roadmap which seems very good.

josephdviviano avatar Jun 23 '25 18:06 josephdviviano

I've merged your PR https://github.com/GFNOrg/torchgfn/pull/316, which should resolve this issue, I believe.

josephdviviano avatar Jun 24 '25 13:06 josephdviviano

@alexandrelarouche I'm curious about this piece:

currently we need to recompute forward log-probabilities of backward trajectories

I probably misunderstand you, but this recomputation is needed normally, i.e., if you evaluate a forward trajectory under Pf, you would need to re-evaluate it's reversed version under Pb.

It's rare (never?) the case that you would evaluate the reverse of a trajectory drawn from a policy under the same policy at least using current practices.

I agree that more flexibility would be in general desirable here, and thank you for your roadmap which seems very good.

Pardon me, I have been away doing theory instead of code for the past few weeks!

I think we spoke about this over slack, but for posterity, and others to contribute to this as well.

My proposal was the following: It would be easier for users if loss functions could receive backward trajectories seamlessly instead of making calls to undocumented / hidden functions.

To achieve this, we could create and populate two attributes: log_pf and log_pb for Trajectories. However, as @josephdviviano highlighted, this may pose computational challenges, as this requires a forward pass through BOTH the PF AND PB modules. This is problematic, as the requirement for the log_pb is a rather niche case, and this wastes compute in most scenarios.

A middle-ground was potentially having an extra "control" argument in the Trajectories constructor in order to determine whether the trajectories should request both PF and PB modules during trajectory sampling. This approach would be the most modular, but would also imply adding a lot of "if-else" statements all over the place. It is still unclear to me what the right approach is as of now, but the control argument seems the best as of now.

Edit: reformulated the proposal.

alexandrelarouche avatar Jul 01 '25 15:07 alexandrelarouche