OLMo icon indicating copy to clipboard operation
OLMo copied to clipboard

Improve the CPU unsharder

Open dirkgr opened this issue 2 years ago • 1 comments

Our checkpointing mechanism works by running torch.save() on each rank. This creates, for each rank, a surprisingly large file that is a pickled state dict. A lot of the tensors in those state dicts are of type ShardedTensor, which is a tensor that's sharded across multiple ranks. If we ever want to change the number of GPUs we're running on, or just generally export the model in a way that's useful to other people, we need to create one single unsharded state dict.

The default way of creating an unsharded checkpoint from that is to spin up a cluster with as many GPUs as we had when the model was training, have each rank torch.load() the file for that rank, and then save that model unsharded. The problem with this is that it's wasteful, and not always possible if we don't have a way of getting as many GPUs as we originally had again. For example, if LUMI is down. So we need a way to do this without GPUs, or at least with fewer GPUs.

The way we do this now is with the CPU unsharder script at https://github.com/allenai/LLM/blob/main/scripts/unshard.py. It will load all the state dicts, one for each rank, into a list, and then do exactly what torch does when you save the model in an unsharded way, except instead of calling torch.gather(), it looks up what it needs in the list of state dicts. There are two problems with this:

  1. It is slow. It spends most of the time in torch.load() for each of the ranks. It's not clear to me why this takes so long, but we won't fix this at the torch level, so there is nothing we can do about that.
  2. It seems that this does not work with checkpoints that were created at MosaicML.

So the task here is this:

  1. Change the code so that this is fast. My suggestion: Have one process that can read a single file and write it back out with one file per key in the state dict. We can run this process in parallel, once for each rank. Then write another process that does what the existing script does (i.e., do the torch.gather() substitution), except instead of referring to a list of loaded state dicts, it refers to the files that the first set of processes wrote.
  2. Verify that this works on checkpoints from LUMI, and then also on checkpoints from Cirrascale, MosaicML, or Kempner.

dirkgr avatar Sep 26 '23 20:09 dirkgr

Shane and I found that we may just be able to run this on Python 3.12 without the GIL and it might magically be fast!

dirkgr avatar Sep 26 '23 20:09 dirkgr

Marking the items prior to Feb 29th as "closed".

dumitrac avatar Apr 30 '24 21:04 dumitrac