rl
rl copied to clipboard
[BUG] `Transform._call` is not called from reset
Describe the bug
The docstring of Transform._call() says it will be called by TransformedEnv.step() and TransformedEnv.reset(). However, resetting the transformed environment does not trigger _call().
To Reproduce
from tensordict import TensorDictBase
from torchrl.envs.transforms import Transform, TransformedEnv
from torchrl.envs import GymEnv
class PrintHiTransform(Transform):
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
print("Hi")
return tensordict
env = GymEnv("Pendulum-v1")
env = TransformedEnv(env, PrintHiTransform())
print("Calling env.reset()")
initial_state = env.reset() # Does NOT print "Hi"
action = env.rand_action(initial_state)
print("Calling env.step()")
env.step(action) # Prints "Hi" as desired
System info
Describe the characteristic of your environment:
- Python version: 3.9
- torchrl==0.6.0
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.6.0 2.0.2 3.9.20 | packaged by conda-forge | (main, Sep 30 2024, 17:49:10) [GCC 13.3.0] linux
Possible fix
I would fix this by changing Transform._reset().
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
"""Resets a transform if it is stateful."""
return self._call(tensordict_reset) # Was before: return tensordict_reset
Maybe there's a good reason why this is not the case, but to me it seems inconsistent: Why should _step call _call but _reset does not?