rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] DataCollectors fail when device is set to MPS

Open LCarmi opened this issue 1 year ago • 0 comments

Describe the bug

When running experiments with multiprocess-based sampling of trajectories on macOS, the initialization of the data collectors fail

To Reproduce

from torchrl.envs.libs.gym import GymEnv
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.collectors import MultiSyncDataCollector

if __name__ == "__main__":
    env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
    policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
    collector = MultiSyncDataCollector(
        create_env_fn=[env_maker, env_maker],
        policy=policy,
        total_frames=2000,
        max_frames_per_traj=50,
        frames_per_batch=200,
        init_random_frames=-1,
        reset_at_each_iter=False,
        device="mps",
        storing_device="cpu",
        # cat_results="stack",
    )
    for i, data in enumerate(collector):
        if i == 2:
            print(data)
            break

This fails as follows:

Traceback (most recent call last):
  File "..././torchrl_test_mps_fail.py", line 9, in <module>
    collector = MultiSyncDataCollector(
  File ".../.venv/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1779, in __init__
    self._run_processes()
  File ".../.venv/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1976, in _run_processes
    proc.start()
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
  File ".../.venv/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 607, in reduce_storage
    metadata = storage._share_filename_cpu_()
  File ".../.venv/lib/python3.10/site-packages/torch/storage.py", line 450, in wrapper
    return fn(self, *args, **kwargs)
  File ".../.venv/lib/python3.10/site-packages/torch/storage.py", line 529, in _share_filename_cpu_
    return super()._share_filename_cpu_(*args, **kwargs)
RuntimeError: _share_filename_: only available on CPU

System info

>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.7.2 2.2.4 3.10.16 (main, Dec  3 2024, 17:27:57) [Clang 16.0.6 ] darwin

Reason and Possible fixes

I suspect this issue boils down to:

  • limitations of mps device, which does not work well with a pickle-based sharing of parameters
  • limitations of torchrl , which assume a spawn-based multiprocessing library
    • as opposed to a fork-based multiprocess context; forcing fork through multiprocessing.set_start_method('fork') gives a warning and makes collectors crash
    • a spawn context is imposed by torchrl https://github.com/pytorch/rl/blob/619fec69c33966fc92cbb4527d6dad567e094752/torchrl/init.py#L38
  • spawn multiprocessing context using pickle to copy the state of a process on a newly spawned one
    • this is hinted by a similar issue in DataLoader https://github.com/pytorch/pytorch/issues/87688

Checklist

  • [X] I have checked that there is no similar issue in the repo
  • [X] I have read the documentation
  • [X] I have provided a minimal working example to reproduce the bug

LCarmi avatar Mar 18 '25 09:03 LCarmi