rl
rl copied to clipboard
[BUG] SyncDataCollector Not Working With ParallelEnv when Built with Replay Buffer
Describe the bug
When creating a sync data collector with a replay buffer (passed to its constructor), then it crashes when yielding from the collector.
To Reproduce
- Create a
ParallelEnv(I used num_workers=2 in my test)
train_env = ParallelEnv(
num_workers=int(cfg["env"]["num_workers"]), create_env_fn=create_env_fn
)
- Create a replay buffer with his storage having ndim = 2
storage_kwargs = {}
storage_kwargs["max_size"] = capacity
storage_kwargs["device"] = storage_device
storage_dim = 1
if cfg["env"]["num_workers"] > 1:
storage_dim += 1
storage_kwargs["ndim"] = storage_dim
if "cpu" in storage_device.type:
# LazyMemmapStorage is only supported on CPU
replay_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(**storage_kwargs),
transform=transform,
)
else:
replay_buffer = TensorDictReplayBuffer(
storage=LazyTensorStorage(**storage_kwargs),
transform=transform,
)
- Create a
SyncDataCollectorand make sure to pass the replay buffer to its constructor :
max_frames_per_traj = cfg["env"]["max_frames_per_traj"]
frames_per_batch = 128
data_collector = SyncDataCollector(
create_env_fn=env,
policy=policy,
total_frames=data_collector_cfg["total_frames"],
max_frames_per_traj=max_frames_per_traj,
frames_per_batch=frames_per_batch,
env_device=cfg["env"]["device"],
storing_device=cfg["storage_device"],
policy_device=cfg["policy_device"],
exploration_type=exploration_type,
init_random_frames=data_collector_cfg.get("init_random_frames", 0),
postproc=None,
replay_buffer=replay_buffer,
)
- Yield from the data collector
for _ in data_collector:
pass
- Observe the crash :
Traceback (most recent call last):
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
cli.main()
File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 501, in main
run()
File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 351, in run_file
runpy.run_path(target, run_name="__main__")
File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
_run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
exec(code, run_globals)
File "scripts/train_rl.py", line 125, in <module>
main()
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/main.py", line 94, in decorated_main
_run_hydra(
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
_run_app(
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
run_and_report(
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
raise ex
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
return func()
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
lambda: hydra.run(
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
_ = ret.return_value
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
raise self._return_value
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
ret.return_value = task_function(task_cfg)
File "scripts/train_rl.py", line 118, in main
trainer.train()
File "/home/mila/b/myuser/SegDAC/segdac_dev/src/segdac_dev/trainers/rl_trainer.py", line 41, in train
for _ in tqdm(self.train_data_collector, "Env Data Collection"):
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/tqdm/std.py", line 1181, in __iter__
for obj in iterable:
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 247, in __iter__
yield from self.iterator()
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1035, in iterator
tensordict_out = self.rollout()
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/_utils.py", line 481, in unpack_rref_and_invoke_function
return func(self, *args, **kwargs)
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1177, in rollout
self.replay_buffer.add(self._shuttle)
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 1202, in add
self._set_index_in_td(data, index)
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 1246, in _set_index_in_td
tensordict.set("index", expand_as_right(index, tensordict))
File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/utils.py", line 370, in expand_as_right
raise RuntimeError(
RuntimeError: expand_as_right requires the destination tensor to have less dimensions than the input tensor, got tensor.ndimension()=2 and dest.ndimension()=1
Expected behavior
I would expect the same behavior as when we don't pass the replay buffer to the sync data collector and manually do :
for data in tqdm(self.train_data_collector, "Env Data Collection"):
self.replay_buffer.extend(data)
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...): pip
- Python version: 3.10
- Versions of any other relevant libraries: pip list :
Package Version Editable project location
------------------------- ------------------------------------------------------------- ------------------------------------------
absl-py 2.1.0
antlr4-python3-runtime 4.9.3
asttokens 2.4.1
attrs 24.2.0
av 13.1.0
certifi 2024.8.30
charset-normalizer 3.4.0
click 8.1.7
clip 1.0
cloudpickle 3.1.0
coloredlogs 15.0.1
comet-ml 3.47.1
comm 0.2.2
configobj 5.0.9
contourpy 1.3.1
cycler 0.12.1
Cython 3.0.11
debugpy 1.8.9
decorator 5.1.1
diffusers 0.31.0
dm_control 1.0.25
dm-env 1.6
dm-tree 0.1.8
docker-pycreds 0.4.0
drqv2 1.0.0
dulwich 0.22.6
efficientvit 0.0.0
einops 0.8.0
etils 1.10.0
everett 3.1.0
exceptiongroup 1.2.2
executing 2.1.0
filelock 3.16.1
flatbuffers 24.3.25
fonttools 4.55.0
fsspec 2024.10.0
ftfy 6.3.1
gitdb 4.0.11
GitPython 3.1.43
glfw 2.8.0
huggingface-hub 0.26.2
humanfriendly 10.0
hydra-core 1.3.2
idna 3.10
igraph 0.11.8
imageio 2.36.0
importlib_metadata 8.5.0
importlib_resources 6.4.5
ipdb 0.13.13
ipykernel 6.29.5
ipython 8.29.0
jedi 0.19.2
Jinja2 3.1.4
jsonschema 4.23.0
jsonschema-specifications 2024.10.1
jupyter_client 8.6.3
jupyter_core 5.7.2
kiwisolver 1.4.7
labmaze 1.0.6
lazy_loader 0.4
lightning-utilities 0.11.9
loguru 0.7.2
lvis 0.5.3
lxml 5.3.0
markdown-it-py 3.0.0
MarkupSafe 3.0.2
matplotlib 3.9.2
matplotlib-inline 0.1.7
mdurl 0.1.2
mpmath 1.3.0
mujoco 3.2.5
nest-asyncio 1.6.0
networkx 3.4.2
numpy 2.1.3
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
omegaconf 2.3.0
onnx 1.17.0
onnxruntime 1.20.1
onnxsim 0.4.36
opencv-python 4.10.0.84
opencv-python-headless 4.10.0.84
orjson 3.10.12
packaging 24.2
pandas 2.2.3
parso 0.8.4
pexpect 4.9.0
pillow 11.0.0
pip 24.3.1
platformdirs 4.3.6
prompt_toolkit 3.0.48
protobuf 5.28.3
psutil 6.1.0
ptyprocess 0.7.0
pure_eval 0.2.3
py-cpuinfo 9.0.0
pycocotools 2.0.8
Pygments 2.18.0
PyOpenGL 3.1.7
PyOpenGL-accelerate 3.1.7
pyparsing 3.2.0
python-box 6.1.0
python-dateutil 2.9.0.post0
pytz 2024.2
PyYAML 6.0.2
pyzmq 26.2.0
referencing 0.35.1
regex 2024.11.6
requests 2.32.3
requests-toolbelt 1.0.0
rich 13.9.4
rpds-py 0.21.0
ruamel.yaml 0.18.6
ruamel.yaml.clib 0.2.12
safetensors 0.4.5
scikit-image 0.24.0
scipy 1.14.1
seaborn 0.13.2
XXX 0.0.1
XXX 0.0.1
segment_anything 1.0
semantic-version 2.10.0
sentry-sdk 2.19.0
setproctitle 1.3.4
setuptools 75.6.0
simplejson 3.19.3
six 1.16.0
smmap 5.0.1
stack-data 0.6.3
sympy 1.13.1
tensordict 0.6.0
texttable 1.7.0
tifffile 2024.9.20
timm 1.0.11
TinyNeuralNetwork 0.1.0.20241024123327+19e5f6dd0f6e391d3c3640cf46d28f47eb76d289
tokenizers 0.20.4
tomli 2.1.0
torch 2.5.0
torch-fidelity 0.3.0
torchaudio 2.5.0
torchmetrics 1.6.0
torchprofile 0.0.4
torchrl 0.6.0
torchvision 0.20.0
tornado 6.4.2
tqdm 4.66.5
traitlets 5.14.3
transformers 4.46.3
triton 3.1.0
typing_extensions 4.12.2
tzdata 2024.2
ultralytics 8.3.38
ultralytics-thop 2.0.12
urllib3 2.2.3
wandb 0.18.7
wcwidth 0.2.13
wheel 0.45.1
wrapt 1.17.0
wurlitzer 3.1.1
zipp 3.21.0
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)