torchtune
torchtune copied to clipboard
[RFC] Single Device Full Fine-tune for Llama7B in < 16GB
Context
On a single device, our current Llama7B full fine-tune recipe either OOMs with the AdamW
optimizer, or takes > 55GB with SGD
. Given the importance of single device fine-tuning on "commodity" GPUs (this is a critical value-prop of TorchTune), it's important for us to understand if we can reduce the memory footprint for this scenario, and the impact of these improvements on model quality. In this RFC, I present:
- An analysis of why our current recipe makes sub-optimal use of GPU memory
- Techniques we can use to reduce the memory footprint
- Prototype recipe with all of the techniques implemented
- Loss curves showing that the model continues to learn even after these improvements
Overall, I show that we can full-finetune a Llama7B model in <16GB of memory with reasonable loss curves.
Note: Performing detailed evals of the models fine-tuned using this recipe is beyond the scope of this RFC. We can do this as a follow-up.
Baselines: Current State
As a baseline, I consider the following setting:
-
batch_size
=1
(minimum possible) -
optimizer
=AdamW
(known to produce higher quality models than SGD) -
enable_activation_checkpointing
=False
(has negligible footprint compared to weights, gradients and opt state - more details below) -
dtype
=fp32
(mixed precision training uses more memory - more details below)
We can run this training using the following command:
tune --nnodes 1 --nproc_per_node 1 recipes/full_finetune.py \
--config recipes/configs/alpaca_llama2_full_finetune.yaml \
--override model_checkpoint=/home/kartikayk/cpts/llama2-7b-01242024 \
seed=30 \
tokenizer_checkpoint=/home/kartikayk/cpts/tokenizer.model \
epochs=3 \
batch_size=1 \
gradient_accumulation_steps=1 \
enable_activation_checkpointing=False \
enable_fsdp=False \
optimizer=AdamW
This particular run will OOM. To understand why, let's breakdown how the memory is used.
Model Weights in full precision: 4 * num_params = ~27GB
Gradients in full precision: 4 * num_params = ~27GB
AdamW State in full precision: 2 * 4 * num_params = ~54 GB
Total = ~108GB
Note: Activations are ignored since the peak memory is after bwd (during opt step)
when activations are no longer in memory. These are <1GB anyways.
One way around this OOM is to take the quality hit and use SGD. Since SGD doesn't keep around expensive state, we can reduce peak memory by ~54GB. The only change in the above config is
...
optimizer=SGD
To see what's happening' lets first take a look at the memory stats and then the memory snapshot using the memory visualizer.
Model is initialized.
After Model Setup :
Memory Allocated: 27.02 GB
Memory Reserved: 27.85 GB
Peak Memory: 27.83 GB
Tokenizer is initialized from file.
Optimizer is initialized.
Loss is initialized.
Dataset and Sampler are initialized.
| 0/52002 [00:00<?, ?it/s]
After Model Fwd :
Memory Allocated: 28.58 GB
Memory Reserved: 28.66 GB
Peak Memory: 28.60 GB
After Model Bwd :
Memory Allocated: 54.01 GB
Memory Reserved: 55.03 GB
Peak Memory: 54.01 GB
After Optim Zero :
Memory Allocated: 27.06 GB
Memory Reserved: 55.03 GB
Peak Memory: 54.01 GB
As expected, the peak memory is ~54GB. Following is the memory snapshot for this run.
So how can we improve upon this? First let's take a closer look at mixed precision training.
Detour: Mixed Precision Training
There's a common misnomer that mixed precision training (eg: bf16) helps reduce the memory footprint. This is not true because of the following reasons:
- Autocast creates two copies of the model weights in memory: one in half-precision and one in full precision - this increases memory by 1.5x
- Forwsrd activations saved for gradient computation are in half-precision, but the activations are a very small % of the memory (depends on sequence length and batch size among other factors)
- Gradients are computed in half precision but converted to full-precision for the weight updates - no memory savings
- Optimizer states are in full precision - no memory savings
We can enable mixed precision training by adding the following to the config above:
...
dtype=bf16
...
We check the gradients using pdb and indeed these are in fp32
(Pdb) self._model.output.weight.grad.dtype
torch.float32
Let's take a look at the memory statistics:
Model is initialized.
After Model Setup :
Memory Allocated: 27.02 GB
Memory Reserved: 27.85 GB
Peak Memory: 27.83 GB
Tokenizer is initialized from file.
Optimizer is initialized.
Loss is initialized.
Dataset and Sampler are initialized.
| 0/52002 [00:00<?, ?it/s]
After Model Fwd :
Memory Allocated: 41.32 GB
Memory Reserved: 41.35 GB
Peak Memory: 41.33 GB
After Model Bwd :
Memory Allocated: 54.01 GB
Memory Reserved: 61.98 GB
Peak Memory: 54.01 GB
After Optim Zero :
Memory Allocated: 27.06 GB
Memory Reserved: 61.98 GB
Peak Memory: 54.01 GB
Two observations:
- We have two copies of the model to compute forward resulting in ~41GB (27 + 13.5)
- Gradients are in fp32 - peak memory is the same as full precision training
Let's take a look at the snapshot for this run. Indeed the memory during forward is higher owing to the extra copy of the model weights. These are kept in memory for gradient computation.
Optimization 1: True BF16 Training
To get the true benefits of bf16
, we need "true" half-precision training where the model weights, gradients and optimizer states are all in bf16
. There is some evidence that this setup is stable for LLMs.
The memory breakdown for true bf16 training should include:
Model Weights: 2 * num_params = ~13.5GB
Gradients: 2 * num_params = ~13.5GB
Opt State: None (using SGD)
Total: ~27GB
Note: Activations are ignored since the peak memory is after bwd (during opt step) when activations are no longer in memory. These are <1GB anyways.
Let's look at the memory stats to see if this is indeed true
Model is initialized.
After Model Setup :
Memory Allocated: 13.51 GB
Memory Reserved: 28.14 GB
Peak Memory: 27.83 GB
Tokenizer is initialized from file.
Optimizer is initialized.
Loss is initialized.
Dataset and Sampler are initialized.
| 0/52002 [00:00<?, ?it/s]
After Model Fwd :
Memory Allocated: 14.39 GB
Memory Reserved: 28.14 GB
Peak Memory: 27.83 GB
After Model Bwd :
Memory Allocated: 27.02 GB
Memory Reserved: 28.14 GB
Peak Memory: 27.83 GB
After Optim Zero :
Memory Allocated: 13.55 GB
Memory Reserved: 28.14 GB
Peak Memory: 27.83 GB
This looks good, except that the peak memory is now shifted to model setup! This is because we load the weights in fp32
on device and then convert to bf16
. Let's change this load the weights on CPU, convert to bf16
and then load to device. This increases the time it takes to init the model from 0.5 seconds
to 13 seconds
, but it's worth to take this small hit for the additional savings.
After Model Init :
Memory Allocated: 13.51 GB
Memory Reserved: 13.51 GB
Peak Memory: 13.51 GB
Model is initialized.
After Model Setup :
Memory Allocated: 13.51 GB
Memory Reserved: 13.51 GB
Peak Memory: 13.51 GB
Tokenizer is initialized from file.
Optimizer is initialized.
Loss is initialized.
Dataset and Sampler are initialized.
| 0/52002 [00:00<?, ?it/s]
After Model Fwd :
Memory Allocated: 14.41 GB
Memory Reserved: 14.45 GB
Peak Memory: 14.42 GB
After Model Bwd :
Memory Allocated: 27.02 GB
Memory Reserved: 27.94 GB
Peak Memory: 27.02 GB
After Optim Zero :
Memory Allocated: 13.55 GB
Memory Reserved: 27.94 GB
Peak Memory: 27.02 GB
The memory savings can be seen in the snapshot as well (compare the y axes)
Now that we've reduced the memory footprint. Time to bring back AdamW and see if we can run the config which was previously OOM-ing. Following is the command used to launch this training:
tune --nnodes 1 --nproc_per_node 1 recipes/full_finetune_single_device.py \
--config recipes/configs/alpaca_llama2_full_finetune.yaml \
--override model_checkpoint=/home/kartikayk/cpts/llama2-7b-01242024 \
seed=30 \
tokenizer_checkpoint=/home/kartikayk/cpts/tokenizer.model \
epochs=3 \
batch_size=1 \
gradient_accumulation_steps=1 \
enable_activation_checkpointing=False \
enable_fsdp=False \
max_steps_per_epoch=2 \
epochs=1 \
optimizer=AdamW
Before we look at the memory stats, let's compute the expected values:
Model Weights: 2 * num_params = ~13.5GB
Gradients: 2 * num_params = ~13.5GB
Opt State: 2 * 2 * num_params = ~27GB
Total: ~54GB (same memory footprint as baseline SGD!)
Note: Activations are ignored since the peak memory is after bwd (during opt step) when activations are no longer in memory. These are <1GB anyways.
Now lets look at the memory stats:
Model is initialized.
After Model Setup :
Memory Allocated: 13.51 GB
Memory Reserved: 13.51 GB
Peak Memory: 13.51 GB
Tokenizer is initialized from file.
Optimizer is initialized.
Loss is initialized.
Dataset and Sampler are initialized.
| 0/52002 [00:00<?, ?it/s]
After Model Fwd :
Memory Allocated: 14.41 GB
Memory Reserved: 14.45 GB
Peak Memory: 14.42 GB
After Model Bwd :
Memory Allocated: 27.02 GB
Memory Reserved: 27.94 GB
Peak Memory: 27.02 GB
After Optim Zero :
Memory Allocated: 40.50 GB
Memory Reserved: 68.38 GB
Peak Memory: 67.45 GB
Note: The optimizer state only kicks in after other first call to optimizer.step()
is made (see snapshot below).
Most of the numbers match up, but there's an additional 13.5GB which is mysteriously equal to the size of model weights in half precision. Let's take a look at the snapshot to see where this is coming from.
Digging through the AdamW documentation, the extra memory seems to be related to the foreach flag which is a performance lever enabled by default:
foreach (bool, optional) – whether foreach implementation of optimizer is used.
If unspecified by the user (so foreach is None), we will try to use foreach over the for-loop
implementation on CUDA, since it is usually significantly more performant.
Note that the foreach implementation uses ~ sizeof(params) more peak memory than the
for-loop version due to the intermediates being a tensorlist vs just one tensor.
If memory is prohibitive, batch fewer parameters through the optimizer at a time or
switch this flag to False (default: None)
Disabling this and things look more in line with our expectation. Namely, peak memory is ~54GB WITH AdamW.
Optimizations 2: 8-Bit Optimizer
Given the above breakdown, one way to reduce the memory footprint of AdamW is by using the 8-bit version. This will reduce the AdamW state by a ~13.5GB (instead of 2 bytes per param, we use 1 byte per param) and the total from ~54GB to ~40.5GB.
To this, I use the AdamW8Bit
from BitsAndBytes which is a single line code change:
optimizer = bnb.optim.AdamW8bit(trainable_params, lr=lr)
Let's look at the memory stats to see if this matches up with our expectation:
Memory Stats after fwd:
Memory Allocated: 27.29 GB
Memory Reserved: 41.07 GB
Peak Memory: 40.73 GB
1|3|Loss: 1.4059289693832397: 0%| | 2/52002 [00:03<18:31:29, 1.28s/it]
Memory Stats after bwd:
Memory Allocated: 40.73 GB
Memory Reserved: 41.07 GB
Peak Memory: 40.73 GB
Memory Stats opt zero:
Memory Allocated: 27.25 GB
Memory Reserved: 41.07 GB
Peak Memory: 40.73 GB
Caveat: This introduces a new dependency in the code base. We might have a native PyTorch implementation of 8-bit AdamW, but this version from BnB has been thoroughly tested by the community.
Optimization 3: Fusing Optimizer with Backward to remove Gradients
The other optimization we can look at is to get rid of the gradients which take up ~13.5GB of memory. We can do this by fusing the optimizer with the backward so that weight updates happen as the relevant gradient is computed. This can be done using the following code change:
optimizer_dict = {p: torch.optim.AdamW([p], lr=lr) for p in self._model.parameters()}
def optimizer_hook(parameter) -> None:
optimizer_dict[parameter].step()
optimizer_dict[parameter].zero_grad()
for p in self._model.parameters():
p.register_post_accumulate_grad_hook(optimizer_hook)
The memory stats show that the peak memory is as expected (we don't have optimizer steps anymore):
Memory Stats after fwd:
Memory Allocated: 40.54 GB
Memory Reserved: 42.74 GB
Peak Memory: 41.02 GB
1|2|Loss: 1.1041574478149414: 0%| | 1/52002 [00:03<41:08:58, 2.85s/it]
Memory Stats after bwd:
Memory Allocated: 40.50 GB
Memory Reserved: 42.75 GB
Peak Memory: 41.05 GB
Caveat: IIUC, fusing the optimizer with backward makes it impossible to accumulate gradients to simulate larger batch sizes. But this might be worth the ability to run training (see loss curves below).
Optimization 4: Combine 1, 2 and 3
By using the 8-bit optimizer and removing the need for preserving gradients, we can further reduce the peak memory to ~27GB. Let's take a look at the memory stats:
Memory Stats after fwd:
Memory Allocated: 27.30 GB
Memory Reserved: 27.85 GB
Peak Memory: 27.51 GB
1|2|Loss: 1.1036440134048462: 0%| | 1/52002 [00:03<53:58:40, 3.74s/it]
Memory Stats after bwd:
Memory Allocated: 27.25 GB
Memory Reserved: 27.85 GB
Peak Memory: 27.55 GB
We can run this using the prototype script using the following command:
tune --nnodes 1 --nproc_per_node 1 recipes/full_finetune_single_device.py \
--config recipes/configs/alpaca_llama2_full_finetune.yaml \
--override model_checkpoint=/home/kartikayk/cpts/llama2-7b-01242024 \
seed=30 tokenizer_checkpoint=/home/kartikayk/cpts/tokenizer.model \
epochs=3 \
batch_size=1 \
gradient_accumulation_steps=1 \
enable_activation_checkpointing=True \
enable_fsdp=False \
epochs=1 \
optimizer_in_bwd=True \
optimizer=AdamW8Bit \
Putting it all together
To replicate the loss curves we've seen thus far from lit-gpt and LoRA, i set batch_size=8
. Following is the loss curve with the final loss value ~0.44, with a peak memory of ~35GB. With a little more tuning of the hyper-params, I believe we can bring the peak memory down. This will require access to evals to better understand the trajectory of model training.
If we replace AdamW8Bit with SGD, we further reduce the peak memory to ~14.4GB. The following logs from WandB show this setting with the loss, peak memory and allocated memory (logged after forward)
Deploy Preview for torchtune-preview ready!
Name | Link |
---|---|
Latest commit | 454a3a2d88e466ae0ee775aa3a2c017b8d5839ca |
Latest deploy log | https://app.netlify.com/sites/torchtune-preview/deploys/65d2d787069ee7000806157a |
Deploy Preview | https://deploy-preview-389--torchtune-preview.netlify.app |
Preview on mobile | Toggle QR Code...Use your smartphone camera to open QR code link. |
To edit notification comments on pull requests, go to your Netlify site configuration.
Thanks for an amazing analysis and for writing up the entire process so clearly! A few comments and questions (some may be pretty basic, I need to read through this a couple more times to make sure I understand everything in more detail).
(1) Any thoughts on using AdamW8bit vs SGD in your final, most memory-efficient training? I notice that you describe SGD as more of an aside in the final model even though the memory usage is further reduced. Is this for any particular reason (e.g. model performs worse)?
(2) I would love to run some evals on these, if the numbers are good there's no reason we shouldn't land some version of these changes. Even if they aren't, we can ablate along some of these dimensions to see which optimizations impact final model performance more than others and document it.
(3) (Maybe obvious) Given that full fine-tune memory > LoRA fine-tune memory, I assume a similar set of optimizations should allow us to run LoRA fine-tuning on a single device? Need to think about it a bit, but I wonder what the memory would look like there. (E.g. I think fusing optimizer with backward will actually not do as much in that case.)
(4) On initialization: seems like you are able to load weights on CPU and cast directly to the desired dtype without any extra memory. In that case, why do we even need meta device? Or do we just need it for larger models where CPU will OOM on loading model weights?
(5) Probably a dumb q: but when exactly does register_post_accumulate_grad_hook
fire? I am a bit confused about this memory-saving trick.. if we are doing backprop do we not need to hold onto the grad of a given parameter to calculate the grad of all parameters that are upstream of it in the computational graph?
I am sure I will have more questions, these are just some from a first pass 😃
Thanks @ebsmothers for taking a close look!
-
SGD is absolutely an option we should evaluate. Loss definitely seems to be going down (see the curve). But hard to say without running some evals/generation. Let me make this more prominent and move it to the "Putting it all together" section.
-
Agreed! I think we should definitely provide this recipe and let users play around with it to see if the setting would work for their use case. I would love to better understand how full finetuning with these options compares to PEFT. Thats the next TODO.
-
Thats a good observation. I think true bf16 should help reduce the memory footprint by half across the board. 8-bit optimizer should help as well. Though fusing might not be that effective since gradients for LoRA aren't going to take up a lot of memory.
-
I'm still trying to fully figure out Meta device and its interaction with FSDP for multi-device training. On single device, I'm not sure of the value though one thing to note is that loading on CPU and then moving to GPU causes a hit in model init time (in this case 14 seconds vs 0.5 seconds). I think as the model gets larger, this hit gets bigger.
-
According to the docs, this hook will be called after all gradients for a tensor have been accumulated.
Note: Activations are ignored since the peak memory is after bwd (during opt step) when activations are no longer in memory. These are <1GB anyways.
This is dependent on the context length though not sure if future Llama releases will put a dent in this assumption
This looks good, except that the peak memory is now shifted to model setup! This is because we load the weights in fp32 on device and then convert to bf16. Let's change this load the weights on CPU, convert to bf16 and then load to device. This increases the time it takes to init the model from 0.5 seconds to 13 seconds, but it's worth to take this small hit for the additional savings.
I don't think this is true for the weights on HuggingFace so there is no need for users to do a bf16 conversion
Caveat: This introduces a new dependency in the code base. We might have a native PyTorch implementation of 8-bit AdamW, but this version from BnB has been thoroughly tested by the community.
We don't have an 8 bit ADAM in core, Less has been working on a 4 bit ADAM though
Caveat: IIUC, fusing the optimizer with backward makes it impossible to accumulate gradients to simulate larger batch sizes. But this might be worth the ability to run training (see loss curves below).
Very cool trick! It's worth the tradeoff
the final loss value ~0.44, with a peak memory of ~35GB. With a little more tuning of the hyper-params, I believe we can bring the peak memory down.
Could you elaborate on which hyper params you had in mind? This is where I was thinking offloading might be a great insurance
Also a comment on SGD it's a good way to just sidestep the optimizer state discussion but many customers will not find it satisfying as an answer so IMO we just supply both an efficient ADAM version and SGD for the most memory constrained users
Great writeup!
A couple of questions/comments:
dtype=bf16 Why does this do AMP bf16 instead of bf16?
Digging through the AdamW documentation, the extra memory seems to be related to the foreach flag which is a performance lever enabled by default:
Interesting, I wonder if it makes sense to make this more granular, i.e. being able to set some max parameter size to trade off memory and performance more directly.
@gchanan Thank you!
Why does this do AMP bf16 instead of bf16?
Do you mind elaborating, I didn't fully follow?
Interesting, I wonder if it makes sense to make this more granular, i.e. being able to set some max parameter size to trade off memory and performance more directly
That would be really helpful. I haven't quantified the perf hit this is causing, but I'd imagine it would be large if this optimization was enabled by default?
Hi @kartikayk! I was curious to learn more about the activation memory for the workload.
- What was the sequence length?
- Is this representative of other fine-tuning workloads, or given less memory pressure, users would want to increase batch size or sequence length?
@awgu good questions!
What was the sequence length?
Let me go and look at the alpaca stats for this, but the sequence length is definitely not close to the max (4096). We are working on data utilities like sample packing where we can pack multiple instances to make training faster, but again the use of this will depend on the available memory.
Is this representative of other fine-tuning workloads
This is definitely not representative. In an ideal scenario, we would increase batch size (using gradient accumulation if we don't have enough memory). But in low-memory regimes, we'll have to figure out how we can enable training without compromising model quality too much (for some definition of too much). This will require plenty of experimentation. Any ideas are always welcome!
@kartikayk thanks for going through this route! Was a clear read and the memory diagrams were helpful.
Others have mentioned some of these, and I wanted to push further on what your next steps are: a. The activations being so small does not seem to be the typical use case, so I'm curious what the "recommendation" would be if the peak were in the forward with activations. For example, let's say after applying all the techniques you mentioned, you can now increase bs=2, and the new peak is in activations. Do you foresee a followup journey with LoRa/activations checkpointing? b. I'm curious how the configs are set for torchtune--it looks like you're running the model on one line with a bunch of CLI args--is that the industry standard interface for tweaking things like optimizer/qlora rank/etc? c. If optim-in-bwd prohibits grad accumulation which then hinders training, exploring lowbit optimizers may be the better alternative.
Regarding Greg's comment about being able to scale down the "mysterious memory" bump due to foreach, I had discussed some API options with Rohan that would use our param_groups to bucket params.
@janeyx99 Thank you for taking a look at this! Great questions!
To recap some offline discussions, the intent of this RFC was a few things:
- Figure out the theoretical limit for full-finetuning a Llama7B on single device and in the process have a deeper understanding of the memory requirements for this setup. This will inform default configs that we expose in TorchTune, recommendations we make to users etc
- Figure out the importance and priority for taking on BnB as a dependency for some of the optimizers which are popular within the community (8-bit AdamW for example)
With this context, a few thoughts for your questions.
The activations being so small does not seem to be the typical use case, so I'm curious what the "recommendation" would be if the peak were in the forward with activations
I'd decouple the recommendations that cut across fine-tuning methods with recommendations specific to full-finetuning. For single-device fine-tuning on memory constrained hardware, QLoRA and LoRA would definitely be the recommendation depending on variety of yet-to-be-determined factors such as the memory footprint of QLoRA and LoRA on a single device, training time implications of the Q in QLoRA, performance considerations on different tasks/datasets etc. All of this is WIP and I'll have more to say in ~2.5 weeks.
There are scenarios though where updating all of the params of the model is going to be important and so we need to figure out the right combinations of techniques we should be using and how to convey the memory-quality-performance trade-off to our users. With this in mind, I'd definitely prioritize methods which enable larger batch sizes (eg: gradient accumulation, low-precision optimizers and activation checkpointing). I'm also unsure about the complexities of some of the other solutions. For example, I'm unsure about how optimizer-state checkpointing will work for opt-in-bwd. I know there are some discussions around using distributed checkpointing for this, but that seems unintuitive for a single device scenario.
I'm curious how the configs are set for torchtune
We actually have an ongoing discussion around the config system [details] that can be helpful to read through. Generally, we expect a majority of our hobbyist users to be using TorchTune through our config system. I used CLI args only because I didn't pipe everything through for this prototype. You can also find the configs we currently expose to users here. Though these need some updating.
If optim-in-bwd prohibits grad accumulation which then hinders training, exploring lowbit optimizers may be the better alternative.
This sounds great to me. I'd love to chat more about the work happening in this direction. Will setup some time for an in-person discussion (a few things got in the way of setting this up early in the week!).
Let me know if this meaningfully answers your questions! Happy to share more information and/or update my own mental model based on your suggestions.
This is a great writeup, thanks a lot! I'm curious whether the fusing of optimizer with backward pass runs into any issues with FSDP.
@kiddyboots216 Theoretically FSDP should be composable with optim in bwd. This is not currently supported though but is a thing we're working on with per-params FSDP cc @awgu @weifengpy for more details.
@kartikayk are there any pending discussions on this RFC, or can we close this?