openpi icon indicating copy to clipboard operation
openpi copied to clipboard

Disable sharding for stats computation

Open tlpss opened this issue 8 months ago • 0 comments

When running the compute_norm_stats.py script on a device with multiple GPUs, you get a sharding error because the batch (with size=1) cannot be split across the devices.

You can fix this by prepending the environment variable CUDA_VISIBLE_DEVICES=1,to the command, but this PR makes it a bit more convenient by specifying SingleDeviceSharding in the dataloader constructor.

tlpss avatar Apr 17 '25 07:04 tlpss