ray_lightning
ray_lightning copied to clipboard
Distributed training performance slowdown when resuming from a checkpoint.
I am using ray_lightning to distribute training across a 8 node ray cluster with GPU. I am seeing the training performance significantly slow down (by a factor of 2-3) when resuming from a checkpoint. When I start a new training it takes an average of 35 minutes per epoch. But, when I restart the training from a previous checkpoint it takes over 90 minutes per epoch. This behavior is pretty consistent. I am using CLI to submit the job to the remote Ray cluster. To isolate this problem I also tried to run a multi-node distributed training just using Pytorch lightning, and this was not a problem. Resuming from a checkpoint took just about the same time as with a fresh training run.
I have the code here to reproduce the example, along with instructions to run it.
https://gist.github.com/subhashbylaiah/6b403339cfaf619c59af403f9740bf29
From my analysis, as also I have shared in the notes to reproduce this, I see the cause of the issue to be somehow associated with the precision of the input images.
- When the input tensors are in float16 performance slows down when training from a prior checkpoint (no issue when training from scratch).
- When the input tensors are in float32 performance is good whether training from a checkpoint or from scratch.
BTW, the trainer precision is still fp16 in both cases.
Library versions
ray==1.12.1
ray_lightning==0.2.0
pytorch_lightning==1.5.10
pytorch==1.11.0
Hi @subhashbylaiah, i see the assumption here is the model uses float16
for the model datatype.
In the ddp source code, it uses torch.float
i.e. https://pytorch.org/docs/stable/tensors.html#data-types. You can find here
https://github.com/ray-project/ray_lightning/blob/6aed848f757a03c03166c1a9bddfeea5153e7b90/ray_lightning/ray_ddp.py#L377-L386
i am going to test this assumption, and keep u posted here.
i do see this
on the other hand, use the ddp. There is no extra memory.

@amogkam, my current guess for this issue is as follows.
the trainer is using the delayed gpu accelerator. the checkpoint is gpu checkpoint.
When resuming from the checkpoint, it load the gpu checkpoint. and the speed might be also due to load the gpu checkpoint from the cpu and then moving to the gpu.
@amogkam, my current guess for this issue is as follows.
the trainer is using the delayed gpu accelerator. the checkpoint is gpu checkpoint.
When resuming from the checkpoint, it load the gpu checkpoint. and the speed might be also due to load the gpu checkpoint from the cpu and then moving to the gpu.
Thanks @JiahaoYao for checking on this issue. Can you please confirm if you are able to reproduce this issue with the example code?