torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

[RFC]: CPU Offloading

Open msaroufim opened this issue 2 years ago • 2 comments

Intro

gdoc for easier commenting https://docs.google.com/document/d/1LuJWq636hmF99NTlOe23nNZ9mAXXUg-880FCiQZ0wmI/edit

CPU Offloading is an insurance policy against OOMs, the idea is to send parts of a model to CPU since VRAM is the most precious resource i.e: I can't easily buy another GPU

In the extreme case you can just have the entire model on CPU but that will be painfully slow so this proposal is to discuss a tradeoff. I will focus on CPU offloading in the single GPU case (a nice suggestion by @rohan-varma) and will punt on discussions for multiple GPU (for now) since it makes it clearer what CPU offloading is and what its tradeoffs are.

There's already a community of people that do this kind of work manually to run diffusion models in google colab so this note just clarifies this space a bit more. In the diffusion space this work is quite a bit simpler since diffusion models are pipelines with multiple nn modules so you just swap in the module you need just in time to GPU VRAM and swap it back out when done

Where to offload?

CPU offloading is a slightly ambiguous term since we need to specify where exactly data needs to be offloaded and there's 2 ways of going about this

  1. to(device="cpu") or cpu() on individual tensors which mean offload to RAM. This is what most people refer to as CPU offloading
  2. NVME as in your SSD which is a promising direction given the new 5th generation 5e NVME chips

PyTorch already has a nice API to put tensors on CPU with to() we will discuss the logic for how to do this in a later section but depending on the specific algorithm we might also choose to do copies in a non blocking way to(device="cpu", non_blocking=True) this way we can overlap more computation while a data transfer is happening

For NVME there is no such API but we can store tensors easily on disk as long as we ensure that the path we're writing to does indeed have NVME support. This is easy to do like so https://github.com/msaroufim/cpuoffload/blob/main/nvme.py where we use a combination of df and lsblk with the core idea being

    def forward(self, x):
        x = self.layer1(x)
        # Pickle layer1 and save it to disk via pickle
        # Delete layer1 and free GPU memory
        x = self.layer2(x)
        return x

What to offload?

The current FSDP implementation of CPU offloading offloads the optimizer state but effectively there's many more things that can be offloaded

  1. Parameters: For example moving layers back and forth between CPU and GPU to free memory
  2. Gradients: I'm not sure how to make this work, need to think more
  3. Optimizer: For optimizers like ADAM that take 2 * # params in memory we can offload their state to CPU and perform the optimizer step on CPU. This would necessitate a fast vectorized CPU ADAM implementation like the one present in deepspeed and loaded in a POC here https://github.com/msaroufim/tinyoptimizer/tree/master/cpu_optimizer - An alternative if we're willing to use torch.compile() is to code generate a vectorized ADAM implementation
  4. Activations: Activations for inference can be cleared but for training they must remain in memory for the backwards pass. Note that this is different from activation checkpointing

How to offload?

So far we've discussed more internal details but let's assume we had APIs to offload whatever we wanted wherever we wanted we still have many possible offloading algorithms we could potentially implement

  1. Static offloading: The main idea here is we a priori decide which state to offload so an API for this might either take in a FQN to a layer like layer1. linear.weight or a type like an nn.Linear and then traverse the named parameters to selectively offload them
def offload_by_layer_type(model, layer_type, device="cpu"):
    for layer in model.children():
        if isinstance(layer, layer_type):
            layer.to(device)
        offload_by_layer_type(layer, layer_type, device)

Some variants might even offload only parts of a tensor but we can think of those later

  1. Offload when OOM: The idea here is that we almost never want to offload unless we run the risk of an OOM. So before allocating a new Tensor to the GPU do some quick calculation to estimate the likelihood of an OOM and if that risk is high then clear some of the tensors that come earlier in the computation graph

  2. Offload layer by layer: the idea here is to allocate tensors to the GPU when they're needed and then deallocate them when they no longer are. It's most similar conceptually to the offload when OOM approach but just has a more aggressive eviction policy

    def forward(self, x):
        x = self.layer1(x)
        self.layer1.to(device="cpu", non_blocking=True)
        x = self.layer2(x) 
        self.layer2.to(device="cpu", non_blocking=True)
        return x

For each policy we will have a hard requirement that we cannot change the user nn.Module code so will have to implement every policy using register_forward_hook() and register_backward_hook(). There's a kink in that hooks operate a module level and not a layer level but I feel like there must be some easy workaround

Failure metrics

We need to be measuring 3 metrics

  1. Peak GPU memory allocated: If this is more than just allocating the model on GPU then this project is a fail
  2. Throughput: If this is slower than just running the model on CPU then this project is a fail
  3. Largest model we can run on Mark's desktop (4090 24GB of VRAM, 32 core CPU, 1TB SSD): Mark should be able to run Llama70b finetunes on his personal desktop

Feedback from Joe: talk to Alban, Jeff, Horace - they're thinking similarly about policies for checkpointing and what UX should look like. Also talk to Mikayla and Alban and how do you load directly into GPU memory and avoid extra copy from CPU. Driss: how do you prefetch? Andrew: Use seperate cuda streams and manage memory and make sure you don't prefetch too much for example fetch one layer at a time

Where would this code live

For now considering we're limiting to single GPU this code can either be a utility in a personal repo or torchtune if we do end up realizing that most of our customers are single GPU customers. However upstreaming to core would need to figure out how to make this compose with FSDP

Crazier idea is maybe logic should just live inside torch.load() - similar to what Mikayla is doing

Less is saying someone (Paul Johnson) was working on this. Andrew: look at what fairscale is doing they have an offloading API also similar to what deepspeed has. Less: CUDA unified memory? Jane: Does that work (seems like not right now). Andrew: numerical results between CPU and GPU are non negligible so beware. Jane: let's not to do policy 2. Jane: should we do network filesystem reads? Mark: we can but not prioritizing rn. Less: look at colossal using some smart heuritics to cache things on the fly. Andrew: monolithics and assumes control over everything, global memory management systems are hard to merge in pytorch

Specific work items

So great, what does a plan for this actually look like

  • [x] Sign off on this RFC
  • [ ] MVP: parameter offloading + optimizer offloading with fused CPU ADAM implementing static offloading on single GPU with CPU/RAM offloading - Per Andrew we can do this today we are just missing the fast CPU ADAM which Mark will provide
  • [ ] Implement NVME offloading
  • [ ] Implement more offloading policies with some generic abstract policy class
  • [ ] Scope out what's needed to support all of the above in FSDP

msaroufim avatar Feb 14 '24 23:02 msaroufim

@msaroufim This is awesome!!!

Is there any way you could convert to a PR so we can comment on it line-by-line? I know there's a Google Doc, but that's only open to internal people.

joecummings avatar Feb 15 '24 22:02 joecummings

Thanks @joecummings ! I'm already getting a bunch of feedback on the google doc so I'd rather get more feedback there if that's OK

msaroufim avatar Feb 16 '24 02:02 msaroufim

@msaroufim Guessing this is no longer relevant?

kartikayk avatar Apr 21 '24 16:04 kartikayk

Yeah we're addressing this in core directly

msaroufim avatar Apr 21 '24 16:04 msaroufim