LoRA with quantization: `micro_batch_size` effect on memory footprint
[!Important] These are just quick tests with a single model on a single graphics card, so take it with a grain of salt. Nevertheless, this issue is worth discussing in my opinion.
Hi there 👋
I'm still not confident that the quantization works properly, so I decided to do quick tests with a small model just to see how much we can gain in memory footprint with and without quantization.
During experiments, I noticed a somewhat weird behavior: with a smaller micro_batch_size (1 or 2) the gain is bigger than with the default micro_batch_size of 4. In my head, I can explain it with the size of activations that outweighs quantization effect (though still doubtful). What I cannot explain is that with the micro_batch_size larger than the default value of 4 the memory footprint might be even larger than without using quantization.
All ran with Pythia-70m, precision 16-mixed, quantization bnb.nf4 and default parameters apart from micro_batch_size.
| Micro BatchSize | Cuda allocated | Cuda_allocated $_{quantized}$ | nvidia-smi | nvidia-smi $_{quantized}$ | Card |
|---|---|---|---|---|---|
| 1 | 1.37 | 0.80 | 1.63 | 0.96 | T4 |
| 2 | 1.92 | 1.42 | 2.04 | 1.73 | T4 |
| 4 | 3.04 | 2.69 | 3.93 | 3.26 | T4 |
| 8 | 5.29 | 5.23 | 7.93 | 7.78 | T4 |
| 16 | 9.79 | 10.32 | 11.304 | 10.407 | T4 |
| 32 | 18.77 | 20.49 | 19.217 | 20.62 | A10G |
Don't know who will be assigned to this task, so here is a list of steps that I would do:
- [ ] Sanity check with different models/precisions/graphic cards
- [ ] Comparison with Higgingface implementation of QLoRA
- [ ] Memory profiling with PyTorch Profiler
If no one thinks that this is urgent (or that it's even an issue) I'll work on it after I finish #461, maybe I'll do memory profiling.
Interesting, thanks for the anAlysis @Andrei-Aksionov . It's quite weird that QLoRA becomes worse for large microbatch sizes.
I think this may potentially be related to #477 where a similar problem occurs with longer context sizes. It's worth investigating this further imho. We should maybe test some non Lit-GPT implementation to see whether it's a Lit-GPT specific issue or a bitsandbytes issue.
I think this link will help https://github.com/RahulSChand/gpu_poor/issues/1