[BUG] SyncDataCollector Crashes when init_random_frames=0 with a policy that is NOT random
Describe the bug
When yielding from a SyncDataCollector that uses a standard Actor (not a random policy) and init_random_frames=0, it crashes.
policy = Actor(
agent,
in_keys=["your_key"],
out_keys=["action"],
spec=train_env.action_spec,
)
train_data_collector = SyncDataCollector(
create_env_fn=train_env,
policy=policy,
init_random_frames=0,
...
)
Yielding example that causes the crash :
for data in tqdm(train_data_collector, "Env Data Collection"):
To Reproduce
- Create an actor that is not
RandomPolicy - Create a
SyncDataCollectorwith theactorand setinit_random_frames=0. - Try to yield from the data collector
- Observe the crash
Stack trace:
2024-11-04 12:04:33,606 [torchrl][INFO] check_env_specs succeeded!
2024-11-04 12:04:36.365 | INFO | __main__:main:60 - Policy Device: cuda
2024-11-04 12:04:36.365 | INFO | __main__:main:61 - Env Device: cpu
2024-11-04 12:04:36.365 | INFO | __main__:main:62 - Storage Device: cpu
Env Data Collection: 0%| | 0/1000000 [00:00<?, ?it/s]
Error executing job with overrides: ['env=dmc_reacher_hard', 'algo=sac_pixels']
Traceback (most recent call last):
File "/home/user/Documents/SegDAC/./scripts/train_rl.py", line 119, in main
trainer.train()
File "/home/user/Documents/SegDAC/segdac_dev/src/segdac_dev/trainers/rl_trainer.py", line 40, in train
for data in tqdm(self.train_data_collector, "Env Data Collection"):
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/tqdm/std.py", line 1181, in __iter__
for obj in iterable:
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 247, in __iter__
yield from self.iterator()
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1035, in iterator
tensordict_out = self.rollout()
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/_utils.py", line 481, in unpack_rref_and_invoke_function
return func(self, *args, **kwargs)
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1166, in rollout
env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/common.py", line 2862, in step_and_maybe_reset
tensordict = self.step(tensordict)
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/common.py", line 1505, in step
next_tensordict = self._step(tensordict)
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 783, in _step
tensordict_in = self.transform.inv(tensordict)
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/nn/common.py", line 314, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 357, in inv
out = self._inv_call(clone(tensordict))
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 1084, in _inv_call
tensordict = t._inv_call(tensordict)
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 3656, in _inv_call
return super()._inv_call(tensordict)
File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 342, in _inv_call
raise KeyError(f"'{in_key}' not found in tensordict {tensordict}")
KeyError: "'action' not found in tensordict TensorDict(\n fields={\n collector: TensorDict(\n fields={\n traj_ids: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},\n batch_size=torch.Size([]),\n device=cpu,\n is_shared=False),\n done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n pixels: Tensor(shape=torch.Size([3, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),\n pixels_transformed: Tensor(shape=torch.Size([3, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),\n step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),\n terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n batch_size=torch.Size([]),\n device=cpu,\n is_shared=False)"
Expected behavior
We should be able to yield with init_random_frames = 0
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...) :
pip install torchrl==0.6.0 - Python version :
3.10 - Versions of any other relevant libraries: output of my
pip list:
Package Version Editable project location
------------------------- ---------- -------------------------------------------
absl-py 2.1.0
antlr4-python3-runtime 4.9.3
attrs 24.2.0
av 13.1.0
certifi 2024.8.30
charset-normalizer 3.4.0
cloudpickle 3.1.0
comet-ml 3.47.1
configobj 5.0.9
dm_control 1.0.24
dm-env 1.6
dm-tree 0.1.8
dulwich 0.22.4
etils 1.10.0
everett 3.1.0
filelock 3.16.1
fsspec 2024.10.0
glfw 2.7.0
hydra-core 1.3.2
idna 3.10
importlib_resources 6.4.5
Jinja2 3.1.4
jsonschema 4.23.0
jsonschema-specifications 2024.10.1
labmaze 1.0.6
loguru 0.7.2
lxml 5.3.0
markdown-it-py 3.0.0
MarkupSafe 3.0.2
mdurl 0.1.2
mpmath 1.3.0
mujoco 3.2.4
networkx 3.4.2
numpy 2.1.3
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
omegaconf 2.3.0
orjson 3.10.11
packaging 24.1
pillow 11.0.0
pip 24.3.1
protobuf 5.28.3
psutil 6.1.0
Pygments 2.18.0
PyOpenGL 3.1.7
pyparsing 3.2.0
python-box 6.1.0
PyYAML 6.0.2
referencing 0.35.1
requests 2.32.3
requests-toolbelt 1.0.0
rich 13.9.4
rpds-py 0.20.1
scipy 1.14.1
semantic-version 2.10.0
sentry-sdk 2.18.0
setuptools 75.3.0
simplejson 3.19.3
sympy 1.13.1
tensordict 0.6.1
torch 2.5.1
torchaudio 2.5.1
torchrl 0.6.0
torchvision 0.20.1
tqdm 4.66.5
triton 3.1.0
typing_extensions 4.12.2
urllib3 2.2.3
wheel 0.44.0
wrapt 1.16.0
wurlitzer 3.1.1
zipp 3.20.2
Reason and Possible fixes
It seems like self._policy_output_keys from SyncDataCollector::_make_final_rollout is set to {} when init_random_frames=0 which causes an unwanted behavior in SyncDataCollector::rollout.
More precisely, these lines from SyncDataCollector::rollout :
policy_output = self.policy(policy_input)
if self._shuttle is not policy_output:
# ad-hoc update shuttle
self._shuttle.update(
policy_output, keys_to_update=self._policy_output_keys
)
In my case, policy_output was a tensor with the action key, but since self._policy_output_keys is {}, this means that self._shuttle is never updated to have the action key. This causes a crash with the error KeyError: "'action' not found in tensordict
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)
On it!
+1