rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] `Transform._call` is not called from reset

Open codingWhale13 opened this issue 1 year ago • 0 comments

Describe the bug

The docstring of Transform._call() says it will be called by TransformedEnv.step() and TransformedEnv.reset(). However, resetting the transformed environment does not trigger _call().

To Reproduce

from tensordict import TensorDictBase
from torchrl.envs.transforms import Transform, TransformedEnv
from torchrl.envs import GymEnv

class PrintHiTransform(Transform):
    def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
        print("Hi")
        return tensordict

env = GymEnv("Pendulum-v1")
env = TransformedEnv(env, PrintHiTransform())

print("Calling env.reset()")
initial_state = env.reset()  # Does NOT print "Hi"

action = env.rand_action(initial_state)
print("Calling env.step()")
env.step(action)  # Prints "Hi" as desired

System info

Describe the characteristic of your environment:

  • Python version: 3.9
  • torchrl==0.6.0
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

0.6.0 2.0.2 3.9.20 | packaged by conda-forge | (main, Sep 30 2024, 17:49:10) [GCC 13.3.0] linux

Possible fix

I would fix this by changing Transform._reset().

def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
    """Resets a transform if it is stateful."""
    return self._call(tensordict_reset)  # Was before: return tensordict_reset

Maybe there's a good reason why this is not the case, but to me it seems inconsistent: Why should _step call _call but _reset does not?

codingWhale13 avatar Nov 22 '24 10:11 codingWhale13