[RFC] Optimizer CPU offload from torchao for single GPU low memory config
The recent addition of optimizer CPU offload in torchao can be useful for single GPU low memory config.
https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload
In my brief testing https://github.com/pytorch/torchtune/compare/main...gau-nernst:torchtune:optim_offload, there is ~25% increase in tok/s. Wandb project: https://wandb.ai/gau-nernst/torchtune. My system: 4070Ti SUPER (16GB VRAM), Ryzen 5600, DDR4.
There is also a difference in handling gradients memory.
- For CPU offload, I use
offload_gradients=TrueinCPUOffloadOptimizer, which free gradients once device-to-host transfer finishes. - For paged Adam, it is done via
optimizer_in_bwd=True.
Regarding memory usage, it's pretty strange since in nvidia-smi, paged Adam run also occupies a lot of memory (near 16GB). Perhaps because bnb manages their own unified memory so PyTorch doesn't report it? Also, for RAM usage, htop reports 55.5GB for paged Adam, and 64.1GB for offload Adam.
We probably need more testing. In particular:
- Different system configurations. CPU offload Adam can be dependent on RAM and CPU speed, since optim step is done on CPU. Paged Adam might be faster when there is more spare GPU memory, since paged Adam does optim step on GPU. The optimal batch size (to maximize tok/s) for each config might be different too.
-
Memory spike behavior. For CPU offload Adam, I had to add
expandable_segments:Trueto prevent OOM in the middle of training. Memory spike behavior might be unpredictable with CPU offload Adam, since it is not well tested. The spike might come from gradients offloading (ref: https://github.com/pytorch/ao/pull/584#discussion_r1704667190, not 100% sure). I haven't tested paged Adam withoutexpandable_segments:Trueyet.
Regardless, I think adding an extra option for low memory single GPU training is beneficial, even if it is not well-tested yet.
cc @msaroufim
Profiling trace for CPU offload Adam, bs=4, Llama2-7B
From the screenshot
- CPU Adam dominates time per step (>80%). A better CPU (more cores, AVX-512 support, and/or DDR5 RAM) would greatly improve performance.
- GPU spent very little time for computation. Increasing batch size to 8 leads to OOM after 30 steps. (training 7B model with 16GB VRAM is really at the limit of OOM 😆)
- (More minor) During backward, grad D2H bottlenecks backward computation. The solution would be to use a larger batch size (so backward takes more time), which is not possible for my GPU 😢. Also not sure why there are gaps during backward.
For future extension. Since CPU offload Adam already keeps a copy of params on CPU, we can extend this and implement:
- params offloading (kinda reviving #385). During forward, prefetch the next layer, do computation with current layer, and free current layer. Do the same for backward. Will incur extra 1x model size host-to-device during backward. Prefetch during forward and backward might be memory-bound if batch size is not large enough.
- mixed precision training. Only keep a low-precision copy of params on GPU (e.g. FP8, INT8) while maintaining a high-precision master copy of params on CPU (e.g. BF16). Optim step done on CPU is with high-precision copy. Will probably have accuracy problem compared to full BF16.
Both of these approaches aim to reduce memory footprint from params, thus we can use larger batch size -> more work for GPU, while keeping same amount of work for CPU, to reduce CPU Adam bottleneck.
Update: proof-of-concept for mixed-precision training. Keep FP8 E4M3 params on GPU (except embedding layer and LM head), BF16 params on CPU, computation is still in BF16 (weights are upcast to BF16). Increase batch size to improve throughput. Using 4070Ti SUPER, tok/s match 4090 w/ paged Adam from torchtune README. Accuracy issue will probably need extensive experiments and investigations.
Benchmarks with Phi3-mini 4B bs=16. 30% improvements.
This is great @gau-nernst! @msaroufim was just telling me about this, thanks for getting a prototype out with results so fast. I agree we should test it out a bit more but I don't see any harm in supporting it as a flag for users to play with. The bit about slightly higher memory util compared to bnb is one area I'd want to understand better, have you observed cases where optimizer offload OOMs but bnb PagedAdam doesn't? Also cc @SalmanMohammadi who's been putting together a tutorial on different memory and perf levers we can pull
have you observed cases where optimizer offload OOMs but bnb PagedAdam doesn't
From my limited testing, I haven't observed such cases. Seems like they both OOM at the same batch size.
Should I open a PR now, or you want to do your own testing with my branch first? My branch is just a quick hack together at the moment, will iron out the details when I create a PR.
@gau-nernst a PR would be great. I can do a bit of testing myself in parallel, but that way we can also frontload any potential design discussions for how we expose this in our recipes
CPUOffloadOptimizer has the following signature:
class CPUOffloadOptimizer:
def __init__(self, params, optimizer_class: Type[Optimizer], *, offload_gradients: bool = False, **kwargs) -> None:
...
Ideally offload_gradients should be exposed to the user too. The tricky part is that the config parser cannot parse optimizer_class from string. I'm thinking of the following designs:
- Add an
offload_optimizerflag (similar tooptimizer_in_bwd): kinda tricky to exposeoffload_gradients. Adding an extra flag seems messy. - Make a custom CPU offload Adam/AdamW class in torchtune: users just need to replace
optimizer._component_.offload_gradientsis naturally exposed. Downside is that we have to write a custom CPU offload optimizer class for each base optimizer. - Make a light wrapper around
CPUOffloadOptimizerthat parses string tooptimizer_class, using the existing_get_component_from_path(). Users only need to replaceoptimizer._component_like above. Seems to be the cleanest solution.
The benefits of offload_gradients=False is that we can do gradient accumulation. I will test with Phi3-mini 4B if using that option can improve tok/s over offload_gradients=True. If not, maybe we don't need to expose offload_gradients at all.
Tested offload_gradients=False for Phi3-mini 4B on my machine. The speed is terrible, because batch size is now limited to 1, so the training is bandwidth bound. Anyway, I will go with the approach 3. in my previous comment.
Another note. Since CPUOffloadOptimizer is only available in torchao main branch now, should I wait until this feature makes it to ao 0.5.0 release (I think 1 month from now?), or it's ok to include nightly feature from torchao.
Closing this issue as it is now possible through the TorchAO library.