DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] load_checkpoint should load directly to gpu

Open stas00 opened this issue 3 years ago • 1 comments

Describe the bug

Currently HF Transformers integration users can finetune a model and save the checkpoint with given resources. However resuming from that same checkpoint requires much more CPU peak memory - which can be huge for large models, which prevents users from resuming their finetuning. (The current workaround is to add a huge swap file)

To Reproduce

I reproduced it as part of this bug report: https://github.com/huggingface/transformers/issues/17258

The full reproduction steps are here: https://github.com/huggingface/transformers/issues/17258#issuecomment-1133492187

I also verified that torch.load doesn't load everything in CPU memory when map_location="cpu" https://github.com/huggingface/transformers/issues/17258#issuecomment-1133522602

and I tracked the issue down to deepspeed loading those potentially huge zero checkpoints (70GB for gpt-j-6) into cpu memory first:

https://github.com/microsoft/DeepSpeed/blob/5208eb73da5269034ded69c4dd7c4bff81df81e7/deepspeed/runtime/engine.py#L2748

Expected behavior

save_checkpoint and load_checkpoint should require approximately the same amount of memory and should be lean and not need any CPU memory other than the size of the largest param or optim state since torch.load copies params via cpu.

With upcoming models like 176B the current implementation just won't work as it would require several TBs of CPU memory to load a zero checkpoint.

@tjruwase, @jeffra

stas00 avatar May 21 '22 05:05 stas00

As this problem is recurrent for HF Transformers' users - meanwhile I shared a hack to stagger checkpoint loading for those who need here: https://github.com/huggingface/transformers/issues/17534#issuecomment-1151693075

If you're not using HF Trainer you can patch deepspeed's load_checkpoint directly, using similar code - you just need the rank number the deepspeed way there or get it from int(os.environ.get("LOCAL_RANK", "0"))

much later edit - this idea actually doesn't work because of the barrier calls, so staggering is not possible, since the first process won't free up memory on cpu until all other processes loaded it.

stas00 avatar Jun 09 '22 22:06 stas00

Any update?

Follow the suggestion in here to make a large swapfile, but the loading takes forever ...

desperadoola avatar Jul 04 '23 05:07 desperadoola

Any update?

Follow the suggestion in here to make a large swapfile, but the loading takes forever ...

Change 'pin_memory' to False, and follow this #3629 solve the problem. Now we can resume training from a FALCON-40B checkpoint, with 1T CPU memory.

desperadoola avatar Jul 04 '23 08:07 desperadoola