mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

LoRA training really slow

Open gladjoyhub opened this issue 1 year ago • 9 comments

around 1 token per second, each evaluation take 30 minutes There are long lines (2000 words) in custom train.jsonl Please help, give a clue. Thanks!

log:

Iter 1: Val loss 2.612, Val took 1822.826s Iter 10: Train loss 3.608, It/sec 0.015, Tokens/sec 1.425 Iter 20: Train loss 3.508, It/sec 0.013, Tokens/sec 1.522 Iter 30: Train loss 3.524, It/sec 0.014, Tokens/sec 1.024 Iter 40: Train loss 3.352, It/sec 0.014, Tokens/sec 1.383 Iter 50: Train loss 3.298, It/sec 0.014, Tokens/sec 1.062 Iter 60: Train loss 3.141, It/sec 0.013, Tokens/sec 1.233 Iter 70: Train loss 2.889, It/sec 0.013, Tokens/sec 1.389 Iter 80: Train loss 2.708, It/sec 0.013, Tokens/sec 0.986 Iter 90: Train loss 2.332, It/sec 0.013, Tokens/sec 1.236 Iter 100: Train loss 2.168, It/sec 0.013, Tokens/sec 1.405 Iter 110: Train loss 2.099, It/sec 0.013, Tokens/sec 1.363 Iter 120: Train loss 1.986, It/sec 0.013, Tokens/sec 3.075 Iter 130: Train loss 1.606, It/sec 0.013, Tokens/sec 1.458 Iter 140: Train loss 1.749, It/sec 0.013, Tokens/sec 1.634 Iter 150: Train loss 1.804, It/sec 0.012, Tokens/sec 2.393 Iter 160: Train loss 1.583, It/sec 0.013, Tokens/sec 1.546 Iter 170: Train loss 1.434, It/sec 0.013, Tokens/sec 1.145 Iter 180: Train loss 1.559, It/sec 0.013, Tokens/sec 1.303 Iter 190: Train loss 1.774, It/sec 0.013, Tokens/sec 1.381 Iter 200: Train loss 1.555, It/sec 0.013, Tokens/sec 1.169

gladjoyhub avatar Dec 31 '23 23:12 gladjoyhub

It’s probably using too much memory. Read the section in the readme on how to reduce memory use.

If it’s still super slow your machine may not have enough memory and you can try a smaller model perhaps.

awni avatar Jan 01 '24 00:01 awni

I've trained LORA on a M2PRO 32GB. Note batch-size to keep it from running outta memory python3 lora.py --model converted_model2mlx --train --iters 1000 --batch-size 2 Iter 410: Train loss 1.090, It/sec 0.035, Tokens/sec 6.726 Iter 420: Train loss 1.194, It/sec 0.034, Tokens/sec 6.420 Iter 430: Train loss 1.175, It/sec 0.034, Tokens/sec 6.624 Iter 440: Train loss 1.018, It/sec 0.035, Tokens/sec 7.203 Iter 450: Train loss 1.084, It/sec 0.033, Tokens/sec 6.521

bigsnarfdude avatar Jan 01 '24 00:01 bigsnarfdude

Now I am using the default train.jsonl, now speeds get around 200 Tokens/sec, which is good. but then at Iter 130, speed drops to 1.4 tokens/sec. Asitop says RAM Usage: 13.5/32.0GB - swap:22.3/24.0GB . So memory is enough. What's the cause? Please help! Thanks!

(venv-metal) joy@Juns-Mac-Studio lora % python3 lora.py --model /Users/joy/Downloads/mistral-7B-v0.1 --adapter-file /Users/joy/Downloads/mistral_lora_mlx/adapter_model.npz --train --iters 600 --batch-size 1 --lora-layers 4

Loading pretrained model Total parameters 7242.158M Trainable parameters 0.426M Loading datasets Training Iter 1: Val loss 2.265, Val took 11.405s Iter 10: Train loss 2.453, It/sec 2.397, Tokens/sec 209.030 Iter 20: Train loss 2.220, It/sec 1.963, Tokens/sec 196.531 Iter 30: Train loss 2.086, It/sec 1.867, Tokens/sec 198.981 Iter 40: Train loss 2.236, It/sec 1.892, Tokens/sec 198.473 Iter 50: Train loss 1.982, It/sec 1.902, Tokens/sec 189.601 Iter 60: Train loss 1.986, It/sec 2.120, Tokens/sec 195.705 Iter 70: Train loss 1.936, It/sec 1.885, Tokens/sec 202.311 Iter 80: Train loss 1.747, It/sec 1.918, Tokens/sec 197.142 Iter 90: Train loss 1.777, It/sec 1.630, Tokens/sec 180.719 Iter 100: Train loss 1.854, It/sec 1.943, Tokens/sec 186.681 Iter 110: Train loss 1.811, It/sec 0.250, Tokens/sec 27.399 Iter 120: Train loss 1.546, It/sec 0.240, Tokens/sec 22.059 Iter 130: Train loss 1.717, It/sec 0.015, Tokens/sec 1.405

gladjoyhub avatar Jan 01 '24 23:01 gladjoyhub

Wow that's odd. Did you happen to do anything on your computer around the time that it slowed down? It's possible the GPU got used by something else?

Also you are using a lot of swap memory (22G!) which is pretty strange especially given that it looks like you have plenty of DRAM..

Is the slow down reproducible for you?

awni avatar Jan 02 '24 02:01 awni

No, didn't do anything else, GPU usage was under 50%. memory swap was probably dragging it down, but DRAM is plenty why should it do memory swap?

One suspicion is I didn't install an independent conda environment, but used existing Pycharm virtual env, and installed MLX via Pycharm GUI.

Update: Solved. I installed a fresh miniconda env, python version 3.11 (previously 3.9), now consistantly around 200 tokens/sec, swap memory is low 7.5/9GB.

gladjoyhub avatar Jan 02 '24 06:01 gladjoyhub

RAM Usage: 27.0/32.0GB - swap:0.3/2.0GB Never saw any jumps in swap after the initial model load. (dont have pycharm. i only use terminal and vim) can not recreate behaviour

Iter 550: Train loss 1.243, It/sec 1.717, Tokens/sec 164.101 Iter 560: Train loss 1.454, It/sec 1.670, Tokens/sec 164.164 Iter 570: Train loss 1.377, It/sec 1.575, Tokens/sec 163.180 Iter 580: Train loss 1.258, It/sec 1.540, Tokens/sec 165.814 Iter 590: Train loss 1.483, It/sec 1.762, Tokens/sec 164.224 Iter 600: Train loss 1.254, It/sec 1.612, Tokens/sec 160.715 Iter 600: Val loss 1.350, Val took 13.449s

bigsnarfdude avatar Jan 02 '24 17:01 bigsnarfdude

Thanks for the input @bigsnarfdude

awni avatar Jan 02 '24 17:01 awni

One observation with custom data, make sure the text is less than 2048 tokens length in train, test and valid jsonl files. I wrote a quick program to check and cleanup the ones that exceed this.

import lora
import json

model, tokenizer = lora.load_model('./mistral_mlx_q')

def clean(input_file,output_file):
    f = open(input_file)
    out = open(output_file,'w')
    for line in f.readlines():
        obj = json.loads(line)
        tokens  = len(tokenizer.encode(obj["text"],eos=True))
        if tokens<2000:
            out.write(line)
    out.close()

gavi avatar Jan 07 '24 03:01 gavi

@gavi that is a very good point and thanks for the script. One could even decrease the maximum line length to 1024 or 512 for even more memory savings depending the use case.

awni avatar Jan 07 '24 04:01 awni