cellxgene-census icon indicating copy to clipboard operation
cellxgene-census copied to clipboard

Fix ml.ExperimentDataPipe for Distributed Training

Open JulesGM opened this issue 9 months ago • 11 comments

The current form of ml.ExperimentDataPipe breaks in distributed training when the amount of samples isn't split between GPUs in a way that allows for each GPUs to have the same number of batches, as the mainstream ways to train models are synchronous (like DDP), & this causes some GPUs to wait at a barrier to a new epoch while other ones finish the last batch of the previous epoch.

My / our interpretation of the source of this problem is that this is because chunks are created on the indices before the data is split between the GPUs. In this PR, the indices are first split between the GPUs, & then the chunks are created from this data. This way of doing things minimizes the discrepancy in number of samples between GPUs to a maximum of 1 sample.

This would still break however, with a batch size of 1, which would cause some GPUs to have one more batch than others. To compensate this, we add the option to drop_last, to drop at maximum world_size - 1 samples so that each GPU have the same amount of data. if drop_last is false, the GPUs that are missing one data point pad their data with the first data point of their data. This is taking inspiration from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L83 and https://github.com/huggingface/accelerate/blob/v0.28.0/src/accelerate/data_loader.py#L345 (which do something similar, but for batch sizes)

A side effect of this is that the partition computation function becomes independent of the world size & rank, as it just works on it's current GPU's ids (which have been computed specifically for that GPU higher in the code).

This change seems like it preserves the behavior of shuffle, as the per-gpu chunks are still shuffled, and the per-gpu data is still shuffled, which makes the composition & the order of the composition of the batches pretty random.

JulesGM avatar Apr 29 '24 23:04 JulesGM