mlx-examples
mlx-examples copied to clipboard
Huge memory usage when finetuning
When fine-tuning Mistral 7B in 4-bit quantization (qlora), I'm seeing huge memory usage (160GB VRAM)
Parameters used:
-
--batch-size 1
-
--lora-layers 16
The dataset is composed of around 1200 entries. No entry is longer than 7500 tokens, some are significantly shorter, most are right in the middle.
If I reduce --lora-layers
to 4
, the memory usage peaks at around 30GB, but the end result is very poor quality.
Is this sort of memory usage expected?
Sort of expected yes:
-The MLX memory cache will cache a lot of memory if you have a lot available on your machine -The GPU is pretty greedy, it will allocate way ahead of what it actually needs
So you likely don't actually need 160GB to run with your parameters efficiently.
Couple comments to improve memory use though:
- 7500 tokens is a lot, I would consider breaking it into two or three sub-paragraphs and then using a slightly larger batch size. It will be a lot faster and likely work just as well. The attention scales quadratically with the sequence length.
- We added gradient checkpointing which would reduce the memory requirements in theory by nearly an order of magnitude but at the expense of some compute. It isn't helping as much as we expect so we are ironing out some kinks with it, but it still helps