mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Some improvements to LoRA

Open awni opened this issue 5 months ago • 3 comments

  • [x] Sort dataset prior to batching for more consistent lengths
  • [x] Compile non MOE models
  • [x] Add checkpointing as an option --grad-checkpoint

Compile Benchmarks

Decent gain from both sorting and compile (we need a larger average to do a proper comparison, but you can see overall sorting helps as does compile:

Original (no sorting, no compile)

Iter 10: Train loss 2.053, Learning Rate 1.000e-05, It/sec 1.312, Tokens/sec 521.974, Trained Tokens 3978
Iter 20: Train loss 1.450, Learning Rate 1.000e-05, It/sec 1.234, Tokens/sec 494.780, Trained Tokens 7989
Iter 30: Train loss 1.319, Learning Rate 1.000e-05, It/sec 1.196, Tokens/sec 487.183, Trained Tokens 12063
Iter 40: Train loss 1.240, Learning Rate 1.000e-05, It/sec 1.279, Tokens/sec 490.877, Trained Tokens 15900
Iter 50: Train loss 1.172, Learning Rate 1.000e-05, It/sec 1.274, Tokens/sec 501.700, Trained Tokens 19837
Iter 60: Train loss 1.062, Learning Rate 1.000e-05, It/sec 1.284, Tokens/sec 503.399, Trained Tokens 23757

Just sorting no Compile

Iter 10: Train loss 2.063, Learning Rate 1.000e-05, It/sec 1.446, Tokens/sec 578.409, Trained Tokens 3999, Peak mem 17.383 GB
Iter 20: Train loss 1.639, Learning Rate 1.000e-05, It/sec 1.325, Tokens/sec 534.407, Trained Tokens 8032, Peak mem 17.383 GB
Iter 30: Train loss 1.364, Learning Rate 1.000e-05, It/sec 1.374, Tokens/sec 528.128, Trained Tokens 11877, Peak mem 17.383 GB
Iter 40: Train loss 1.253, Learning Rate 1.000e-05, It/sec 1.391, Tokens/sec 529.224, Trained Tokens 15682, Peak mem 17.383 GB
Iter 50: Train loss 1.087, Learning Rate 1.000e-05, It/sec 1.425, Tokens/sec 527.453, Trained Tokens 19383, Peak mem 17.383 GB
Iter 60: Train loss 1.165, Learning Rate 1.000e-05, It/sec 1.354, Tokens/sec 514.346, Trained Tokens 23181, Peak mem 17.383 GB

With Compile + Sorting

Iter 10: Train loss 2.061, Learning Rate 1.000e-05, It/sec 1.462, Tokens/sec 584.663, Trained Tokens 3999, Peak mem 17.408 GB
Iter 20: Train loss 1.651, Learning Rate 1.000e-05, It/sec 1.366, Tokens/sec 551.017, Trained Tokens 8032, Peak mem 17.423 GB
Iter 30: Train loss 1.368, Learning Rate 1.000e-05, It/sec 1.403, Tokens/sec 539.275, Trained Tokens 11877, Peak mem 17.423 GB
Iter 40: Train loss 1.257, Learning Rate 1.000e-05, It/sec 1.398, Tokens/sec 531.795, Trained Tokens 15682, Peak mem 17.423 GB
Iter 50: Train loss 1.089, Learning Rate 1.000e-05, It/sec 1.488, Tokens/sec 550.811, Trained Tokens 19383, Peak mem 17.423 GB
Iter 60: Train loss 1.168, Learning Rate 1.000e-05, It/sec 1.410, Tokens/sec 535.432, Trained Tokens 23181, Peak mem 17.423 GB

Checkpoint Benchmarks

TLDR: reduces memory nicely with large batch + lots of LoRA layers, especially noticeable with QLoRA where the model occupies less memory.

Regular LoRA with the command:

python -m mlx_lm.lora --model mistralai/Mistral-7B-v0.1 --train --data ../lora/data --lora-layers 32 --grad-checkpoint --batch-size 8
Peak memory
No Checkpoint 32.078 GB
Checkpoint 19.795 GB

QLoRA with the command:

python -m mlx_lm.lora --model mlx-community/NeuralBeagle14-7B-4bit-mlx --train --data ../lora/data --lora-layers 32 --batch-size 8
Peak memory
No Checkpoint 20.695 GB
Checkpoint 8.199 GB

awni avatar Mar 04 '24 02:03 awni

I will wait for this to land and then adopt it here https://github.com/ml-explore/mlx/pull/788

awni avatar Mar 05 '24 23:03 awni

I will wait for this to land and then adopt it here ml-explore/mlx#788

When fine-tuning, I only tried to print out the Peak memory (mx.metal.get_peak_memory() / 2**30) information without making any other adjustments. Everything was normal at the beginning of the run, but after running for a while, there was a significant anomaly in the values (14.809 GB -> 17179869184.000 GB), with the printed information as follows:

...
Iter 160: Train loss 1.165, Learning Rate 1.000e-05, It/sec 0.279, Tokens/sec 267.007, Trained Tokens 152357, peak_memory 14.809 GB
Iter 170: Train loss 1.123, Learning Rate 1.000e-05, It/sec 0.291, Tokens/sec 270.279, Trained Tokens 161649, peak_memory 14.809 GB
Iter 180: Train loss 1.077, Learning Rate 1.000e-05, It/sec 0.239, Tokens/sec 234.937, Trained Tokens 171491, peak_memory 14.809 GB
Iter 190: Train loss 1.042, Learning Rate 1.000e-05, It/sec 0.291, Tokens/sec 272.583, Trained Tokens 180845, peak_memory 17179869183.987 GB
Iter 200: Train loss 1.058, Learning Rate 1.000e-05, It/sec 0.275, Tokens/sec 267.219, Trained Tokens 190570, peak_memory 17179869183.997 GB
Iter 200: Val loss 1.017, Val took 60.756s
Iter 200: Saved adapter weights to checkpoints/200_adapters.npz.
Iter 210: Train loss 1.002, Learning Rate 1.000e-05, It/sec 0.252, Tokens/sec 242.204, Trained Tokens 200185, peak_memory 17179869184.000 GB
Iter 220: Train loss 1.012, Learning Rate 1.000e-05, It/sec 0.253, Tokens/sec 249.107, Trained Tokens 210023, peak_memory 17179869184.000 GB
Iter 230: Train loss 0.994, Learning Rate 1.000e-05, It/sec 0.255, Tokens/sec 248.113, Trained Tokens 219751, peak_memory 17179869184.000 GB
...

I'm not sure if this phenomenon is caused by not fully adopting all the code logic of the current PR. If that's the case, the current issue can be ignored.

madroidmaq avatar Mar 06 '24 02:03 madroidmaq

@madroidmaq I think that is accounted for by a race condition that we recently fixed. It should already be fixed on main in MLX.

awni avatar Mar 06 '24 15:03 awni