SimplerEnv
SimplerEnv copied to clipboard
Vectorized/GPU Sim Evaluation / ManiSkill 3 Port
- [x] Add GPU sim environments for bridge dataset (in ms3 repo: https://github.com/haosulab/ManiSkill/pull/536)
- [x] Modify Octo to perform batched inference
- [x] Modify Octo to accept torch tensors/dl packed jax arrays
- [x] Modify RT1-X to do batched inference
Closes #35 #36
some notes on changes
- removed tf image resize in favor of jax image resize which permits vmapping and does not have a issue where tf will allocate a lot of GPU memory if given a jax array instead of numpy.
- 72 parallel envs + model inference uses about 18-20 GB of GPU memory. For the envs with max 60 step episodes this completes 72 episodes of octo-small evaluation in about 1 minute vs 8 minutes (depending on GPU). 8x speed improvement!