openpi
openpi copied to clipboard
Disable sharding for stats computation
When running the compute_norm_stats.py script on a device with multiple GPUs, you get a sharding error because the batch (with size=1) cannot be split across the devices.
You can fix this by prepending the environment variable CUDA_VISIBLE_DEVICES=1,to the command, but this PR makes it a bit more convenient by specifying SingleDeviceSharding in the dataloader constructor.