rl
rl copied to clipboard
[BUG] MultiDiscreteTensorSpec .to("mps")
Describe the bug
There is a bug when an EnvBase has a MultiDiscreteTensorSpec and moved to mps (I have not tested yet if the issue persists with cuda).
To Reproduce
The example below should suffice to reproduce the error:
from torchrl.data import MultiDiscreteTensorSpec
nodes=MultiDiscreteTensorSpec((1, 8))
nodes.to("mps")
File "/Users/dtsaras/Documents/CS/rl/torchrl/envs/common.py", line 2812, in to
self.__dict__["_input_spec"] = self.input_spec.to(device).lock_()
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/dtsaras/Documents/CS/rl/torchrl/data/tensor_specs.py", line 3715, in to
kwargs[key] = value.to(dest)
^^^^^^^^^^^^^^
File "/Users/dtsaras/Documents/CS/rl/torchrl/data/tensor_specs.py", line 3715, in to
kwargs[key] = value.to(dest)
^^^^^^^^^^^^^^
File "/Users/dtsaras/Documents/CS/rl/torchrl/data/tensor_specs.py", line 2961, in to
return self.__class__(
^^^^^^^^^^^^^^^
Expected behavior
Expect the environment to move to the accelerator without complaining
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
- Python version
- Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.4.24 1.26.4 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ] darwin
Additional context
Add any other context about the problem here.
Reason and Possible fixes
If you know or suspect the reason for this bug, paste the code lines and suggest modifications.
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)