torchgfn
torchgfn copied to clipboard
`Transitions` container should then have some `estimator_outputs` attribute to avoid duplicate computation.
In detailed_balance.py
, we have:
if not self.off_policy:
valid_log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions sampled off policy.
# I suppose the Transitions container should then have some
# estimator_outputs attribute as well, to avoid duplication here ?
module_output = self.pf(states) # TODO: Inefficient duplication.
valid_log_pf_actions = self.pf.to_probability_distribution(
states, module_output
).log_prob(
actions.tensor
) # Actions sampled off policy.
We could aboid this second forward pass of .pf()
by storing the estimator outputs in the transitions class.
Ideally, both Trajectories
and Transitions
would be able to access the same estimator outputs in memory if there were ever a need to keep track of both.