rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] SyncDataCollector Crashes when init_random_frames=0 with a policy that is NOT random

Open AlexandreBrown opened this issue 1 year ago • 2 comments

Describe the bug

When yielding from a SyncDataCollector that uses a standard Actor (not a random policy) and init_random_frames=0, it crashes.

policy = Actor(
        agent,
        in_keys=["your_key"],
        out_keys=["action"],
        spec=train_env.action_spec,
    )
train_data_collector = SyncDataCollector(
        create_env_fn=train_env,
        policy=policy,
        init_random_frames=0,
        ...
    )

Yielding example that causes the crash :

for data in tqdm(train_data_collector, "Env Data Collection"):

To Reproduce

  1. Create an actor that is not RandomPolicy
  2. Create a SyncDataCollector with the actor and set init_random_frames=0.
  3. Try to yield from the data collector
  4. Observe the crash

Stack trace:

2024-11-04 12:04:33,606 [torchrl][INFO] check_env_specs succeeded!
2024-11-04 12:04:36.365 | INFO     | __main__:main:60 - Policy Device: cuda
2024-11-04 12:04:36.365 | INFO     | __main__:main:61 - Env Device: cpu
2024-11-04 12:04:36.365 | INFO     | __main__:main:62 - Storage Device: cpu
Env Data Collection:   0%|                                                                                                                                      | 0/1000000 [00:00<?, ?it/s]
Error executing job with overrides: ['env=dmc_reacher_hard', 'algo=sac_pixels']
Traceback (most recent call last):
  File "/home/user/Documents/SegDAC/./scripts/train_rl.py", line 119, in main
    trainer.train()
  File "/home/user/Documents/SegDAC/segdac_dev/src/segdac_dev/trainers/rl_trainer.py", line 40, in train
    for data in tqdm(self.train_data_collector, "Env Data Collection"):
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 247, in __iter__
    yield from self.iterator()
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1035, in iterator
    tensordict_out = self.rollout()
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/_utils.py", line 481, in unpack_rref_and_invoke_function
    return func(self, *args, **kwargs)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1166, in rollout
    env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/common.py", line 2862, in step_and_maybe_reset
    tensordict = self.step(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/common.py", line 1505, in step
    next_tensordict = self._step(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 783, in _step
    tensordict_in = self.transform.inv(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/nn/common.py", line 314, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 357, in inv
    out = self._inv_call(clone(tensordict))
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 1084, in _inv_call
    tensordict = t._inv_call(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 3656, in _inv_call
    return super()._inv_call(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 342, in _inv_call
    raise KeyError(f"'{in_key}' not found in tensordict {tensordict}")
KeyError: "'action' not found in tensordict TensorDict(\n    fields={\n        collector: TensorDict(\n            fields={\n                traj_ids: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},\n            batch_size=torch.Size([]),\n            device=cpu,\n            is_shared=False),\n        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n        is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n        pixels: Tensor(shape=torch.Size([3, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),\n        pixels_transformed: Tensor(shape=torch.Size([3, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),\n        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),\n        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n    batch_size=torch.Size([]),\n    device=cpu,\n    is_shared=False)"

Expected behavior

We should be able to yield with init_random_frames = 0

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...) : pip install torchrl==0.6.0
  • Python version : 3.10
  • Versions of any other relevant libraries: output of my pip list :
Package                   Version    Editable project location
------------------------- ---------- -------------------------------------------
absl-py                   2.1.0
antlr4-python3-runtime    4.9.3
attrs                     24.2.0
av                        13.1.0
certifi                   2024.8.30
charset-normalizer        3.4.0
cloudpickle               3.1.0
comet-ml                  3.47.1
configobj                 5.0.9
dm_control                1.0.24
dm-env                    1.6
dm-tree                   0.1.8
dulwich                   0.22.4
etils                     1.10.0
everett                   3.1.0
filelock                  3.16.1
fsspec                    2024.10.0
glfw                      2.7.0
hydra-core                1.3.2
idna                      3.10
importlib_resources       6.4.5
Jinja2                    3.1.4
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
labmaze                   1.0.6
loguru                    0.7.2
lxml                      5.3.0
markdown-it-py            3.0.0
MarkupSafe                3.0.2
mdurl                     0.1.2
mpmath                    1.3.0
mujoco                    3.2.4
networkx                  3.4.2
numpy                     2.1.3
nvidia-cublas-cu12        12.4.5.8
nvidia-cuda-cupti-cu12    12.4.127
nvidia-cuda-nvrtc-cu12    12.4.127
nvidia-cuda-runtime-cu12  12.4.127
nvidia-cudnn-cu12         9.1.0.70
nvidia-cufft-cu12         11.2.1.3
nvidia-curand-cu12        10.3.5.147
nvidia-cusolver-cu12      11.6.1.9
nvidia-cusparse-cu12      12.3.1.170
nvidia-nccl-cu12          2.21.5
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu12          12.4.127
omegaconf                 2.3.0
orjson                    3.10.11
packaging                 24.1
pillow                    11.0.0
pip                       24.3.1
protobuf                  5.28.3
psutil                    6.1.0
Pygments                  2.18.0
PyOpenGL                  3.1.7
pyparsing                 3.2.0
python-box                6.1.0
PyYAML                    6.0.2
referencing               0.35.1
requests                  2.32.3
requests-toolbelt         1.0.0
rich                      13.9.4
rpds-py                   0.20.1
scipy                     1.14.1
semantic-version          2.10.0
sentry-sdk                2.18.0
setuptools                75.3.0
simplejson                3.19.3
sympy                     1.13.1
tensordict                0.6.1
torch                     2.5.1
torchaudio                2.5.1
torchrl                   0.6.0
torchvision               0.20.1
tqdm                      4.66.5
triton                    3.1.0
typing_extensions         4.12.2
urllib3                   2.2.3
wheel                     0.44.0
wrapt                     1.16.0
wurlitzer                 3.1.1
zipp                      3.20.2

Reason and Possible fixes

It seems like self._policy_output_keys from SyncDataCollector::_make_final_rollout is set to {} when init_random_frames=0 which causes an unwanted behavior in SyncDataCollector::rollout.
More precisely, these lines from SyncDataCollector::rollout :

policy_output = self.policy(policy_input)
if self._shuttle is not policy_output:
    # ad-hoc update shuttle
    self._shuttle.update(
        policy_output, keys_to_update=self._policy_output_keys
    )

In my case, policy_output was a tensor with the action key, but since self._policy_output_keys is {}, this means that self._shuttle is never updated to have the action key. This causes a crash with the error KeyError: "'action' not found in tensordict

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)
  • [x] I have read the documentation (required)
  • [x] I have provided a minimal working example to reproduce the bug (required)

AlexandreBrown avatar Nov 04 '24 17:11 AlexandreBrown

On it!

vmoens avatar Nov 04 '24 21:11 vmoens

+1

jannessm avatar Dec 11 '24 12:12 jannessm