torchgfn icon indicating copy to clipboard operation
torchgfn copied to clipboard

`Transitions` container should then have some `estimator_outputs` attribute to avoid duplicate computation.

Open josephdviviano opened this issue 1 year ago • 0 comments

In detailed_balance.py, we have:

        if not self.off_policy:
            valid_log_pf_actions = transitions.log_probs
        else:
            # Evaluate the log PF of the actions sampled off policy.
            # I suppose the Transitions container should then have some
            # estimator_outputs attribute as well, to avoid duplication here ?
            module_output = self.pf(states)  # TODO: Inefficient duplication.
            valid_log_pf_actions = self.pf.to_probability_distribution(
                states, module_output
            ).log_prob(
                actions.tensor
            )  # Actions sampled off policy.

We could aboid this second forward pass of .pf() by storing the estimator outputs in the transitions class.

Ideally, both Trajectories and Transitions would be able to access the same estimator outputs in memory if there were ever a need to keep track of both.

josephdviviano avatar Feb 13 '24 16:02 josephdviviano