rl
rl copied to clipboard
[BUG] rollout with non-uniform tensors no longer supported
Describe the bug
There seems to be a bug with performing a rollout when the tensordicts produced do not have the same shape. The rollout used to produce LazyStackedTensorDict, but it currently fails to stack them. It's possible to fix this issue by using the context manager: python with set_lazy_legacy(True):
When I was using the version below: TorchRL: git+ssh://[email protected]/pytorch/rl.git@7cd6a181020ba2d0ec420d7c3a92c8d689be6bd1#egg=torchrl TensorDict: tensordict-nightly==2024.3.20
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
import torchrl
env = make_env() # some env that produces sequential data
# this raises a stacking error (It works in older versions)
env.rollout()
# this works
with tensordict.utils.set_lazy_legacy(True):
env.rollout()
Traceback (most recent call last):
File "/Users/dtsaras/Documents/CS/circuit_gen/src/rl/mcts_policy.py", line 502, in simulate
rollout = self.env.rollout(
File "/opt/homebrew/Caskroom/miniconda/base/envs/c_gen/lib/python3.10/site-packages/torchrl/envs/common.py", line 2567, in rollout
out_td = torch.stack(tensordicts, len(batch_size), out=out)
File "/opt/homebrew/Caskroom/miniconda/base/envs/c_gen/lib/python3.10/site-packages/tensordict/base.py", line 388, in __torch_function__
return TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/c_gen/lib/python3.10/site-packages/tensordict/_torch_func.py", line 490, in _stack
raise RuntimeError(
RuntimeError: The shapes of the tensors to stack is incompatible.
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Expected behavior
It's expected to return a LazyStackedTensorDict
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
- Python version 3.10
- Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
Additional context
Add any other context about the problem here.
Reason and Possible fixes
If you know or suspect the reason for this bug, paste the code lines and suggest modifications.
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [ ] I have provided a minimal working example to reproduce the bug (required) [there is no such environment to provide one]
Missed to include the versions: 0.4.0 1.26.2 3.10.13 (main, Sep 11 2023, 08:16:02) [Clang 14.0.6 ] darwin
Yes unfortunately we deprecated the behaviour of stack because that required loads of checks and also some surprising behaviours from time to time!
You can also change this by passing the contiguous=False argument to rollout
Note to self: we should make it so that contiguous=False calls LazyStackedTensorDict.maybe_dense_stack to get the "best" stack possible
Ok thanks that's helpful! Also, I wonder if we could create dense stacks of the nonuniform shaped tensors by batching them as nested tensors instead?
Yes we can do that, I don't think we have the util to stack things as nested tensors yet though The main issue we'll be facing is that nested tensors can only be stacked along dim 0 and rollout stacks along the last+1 dim, so it will only work for non-batched envs
I wonder if we encounter such data maybe we can automatically pad the data and return a padding mask along the dynamic dimensions. This can be the fallback when the contained tensors are nested tensors. If you think that's an option, I can look into implementing something like this after the NeurIPS deadline.
Another option is to create the nested tensor along dim 0 but "lie" with tensordict and present it as a transposed version we'll need to patch the get set etc
Closing this as now contiguous=False will do a best intention attempt to stack what can be stacked and return non-contiguous stacks everywhere else