[BUG] It's not clear how to call an advantage module with batched envs and pixel observations.
Describe the bug
When you get a tensordict rollout of shape (N_envs, N_steps, C, H, W) out of a collector and you want to apply an advantage module that starts with conv2d layers:
- directly applying the module will crash with the
conv2dlayer complaining about the input size e.g.RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [2, 128, 4, 84, 84] - flattening the tensordict first with
rollout.reshape(-1)so that it has shape[B, C, H, W]and then calling the advantage module will run but issue the warningtorchrl/objectives/value/advantages.py:99: UserWarning: Got a tensordict without a time-marked dimension, assuming time is along the last dimension.leaving you unsure of wether the advantages were computed correctly.
So it's not clear how one should proceed.
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)
Good point there Regarding reshaping: you should reshape and refine_names, I believe the last dim will still be time-compliant (but you need to make sure you have truncated signals at the end of each time step) Other than that, we could consider falling back on vmap / first-class-dimensions whenever this situation is encountered. I will give it a look and ping you once it's on its way, as usual.
@vmoens in some cases the env data may have an arbitrary batch size (*B) before the time dimension.
Is the current approach, before we land smth like https://github.com/pytorch-labs/tensordict/pull/525, to try to flatten all these dims into one making sure to add terminations when doing so?
I don't think so, as I said in my answer the proper approach should be to vmap over the leading dims up to the time dim. Wdyt?
Somehow In the PPO example, the advantage module is called on the rollout batch shape https://github.com/pytorch/rl/blob/147de71d090d5705182bfabd24a99f3b2ee4cec9/examples/ppo/ppo.py#L103 and doesn't crash with the conv2d complaining.
https://github.com/pytorch/rl/blob/147de71d090d5705182bfabd24a99f3b2ee4cec9/examples/ppo/utils.py#L341
I also managed to reproduce this with the ConvNet and MLP modules of PyTorch RL and my advantage module now runs without reshaping.
I'm sending more details to compare the settings.
Okay, so the ConvNet of TorchRL actually flattens the batch before running a forward and then unflattens it back.
Maybe this could be made clearer to the user so that when designing custom models they know that they have to do something similar.
Otherwise, vmaping would be the way to go. I'm just concerned about memory requirements compared to flattening the tensordict.
Otherwise, vmaping would be the way to go. I'm just concerned about memory requirements compared to flattening the tensordict.
@skandermoalla Looking back at this comment, I wonder why vmap should have higher mem requirements?
I'm not very familiar with vmap, but does the memory taken by the model weights stay the same when you vmap it?