Nan loss when training Llama-3.2-vision
Issue
I keep getting nan loss when training Llama-3.2-vision
I tried:
- gradient clipping
- lower learning rate
- higher batch size, lora rank and alpha
But with no success.
Steps to reproduce:
- Install
pc/llama3.2-visionbranch.
pip install -U git+https://github.com/Blaizzy/mlx-vlm.git@pc/llama3.2-vision
- Add these two lines (31-32) to the
lora.pyto limit the dataset.
dataset = load_dataset(args.dataset, split=args.split+"[:20%]")
dataset = dataset.rename_columns({"image": "images", "conversations": "messages"})
- Quantize model (Optional).
python -m mlx_vlm.convert --hf-path unsloth/Llama-3.2-11B-Vision-Instruct -q --mlx-path Llama-3.2-11B-Vision-Instruct-4bit
- Start training.
python -m mlx_vlm.lora --model-path Llama-3.2-11B-Vision-Instruct-4bit --dataset
5CD-AI/Viet-ShareGPT-4o-Text-VQA --split Viet_OCR_VQA --steps 100 --learning-rate 5e-6 --lora-rank 16 --lora-alpha 16
cc: @awni
So there are a couple things you should change in general about your Llama implementation:
- Use
nn.RMSNorminstead of rolling your own - Use
nn.RoPEinstead of rolling your own
These will both be (much) faster and numerically more stable. The NaNs are getting introduced during overflow in your RMSNorm implementation. Typically whenever you accumulate a lot of numbers you need to accumulate the result in a higher precision (so mean in your case). The nn.RMSNorm does this implicitly without the need for casting between mx.float32 and mx.float16.
I double check most of your model files are using nn.RMSNorm or nn.LayerNorm when possible. And same for RoPE. The inference especially will be much faster.
Thanks a lot!
Yes, I was using a custom RMSNorm, I changed it to nn.RMSNorm and it's 3.25x faster 🚀.
When it comes to rope I was already using nn.RoPE since there are no changes needed and it's easier to integrate with cache.
The NaNs are getting introduced during overflow in your RMSNorm implementation. Typically whenever you accumulate a lot of numbers you need to accumulate the result in a higher precision (so mean in your case). The nn.RMSNorm does this implicitly without the need for casting between mx.float32 and mx.float16.
How did you check this?
@awni I made the recommended changes but I can't seem to be able to run training on my machine (M3 Max 96GB).
It throws an error after processing 3 samples even with batch size of 1.
{'Epoch': 0, 'Step': 0, 'Loss': '1.5820'}
3%|█▍ | 3/100 [00:09<05:03, 3.13s/it, Epoch=0, Step=0, Loss=1.5820]
zsh: segmentation fault python -m mlx_vlm.lora --model-path Llama-3.2-11B-Vision-Instruct-4bit
/opt/homebrew/Caskroom/miniconda/base/envs/mlx_code/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
Could you please try it on your M2 ultra and see if the nan loss persists?
It's running on my M1 Max (32GB) with this command:
python -m mlx_vlm.lora --model-path Llama-3.2-11B-Vision-Instruct-4bit --dataset 5CD-AI/Viet-ShareGPT-4o-Text-VQA --split Viet_OCR_VQA --steps 100 --learning-rate 5e-6 --lora-rank 16 --lora-alpha 16
and the modifications to the dataset you posted above. So far it processed 11 steps no problem (I modified the print to print every step):
{'Epoch': 0, 'Step': 0, 'Loss': '1.5796'}
{'Epoch': 0, 'Step': 1, 'Loss': '1.8235'}
{'Epoch': 0, 'Step': 2, 'Loss': '1.9262'}
{'Epoch': 0, 'Step': 3, 'Loss': '1.5627'}
{'Epoch': 0, 'Step': 4, 'Loss': '1.5274'}
{'Epoch': 0, 'Step': 5, 'Loss': '1.7451'}
{'Epoch': 0, 'Step': 6, 'Loss': '1.9609'}
{'Epoch': 0, 'Step': 7, 'Loss': '0.9124'}
{'Epoch': 0, 'Step': 8, 'Loss': '1.7157'}
{'Epoch': 0, 'Step': 9, 'Loss': '1.6776'}
{'Epoch': 0, 'Step': 10, 'Loss': '1.8323'}
{'Epoch': 0, 'Step': 11, 'Loss': '1.4830'}
However, you should not be getting a segfault. That isn't good. Which version of MLX are you running? Anything else different in your setup?
Also I notice the GPU utilization is pretty poor which is also not good. It should be close to 100% GPU utilization during training so there should be a bottleneck somewhere that needs fixing.
Thanks!
Wow, that's really weird.
Here is my setup:
prince_canuma@MacBook-Pro-3 ~ % pip list | grep mlx
fastmlx 0.2.1
mlx 0.18.0
mlx-embeddings 0.0.1 /Users/prince_canuma/Documents/Projects/LLMs/mlx-embeddings
mlx-lm 0.19.0
mlx-vlm 0.1.0 /Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm
Also I notice the GPU utilization is pretty poor which is also not good. It should be close to 100% GPU utilization during training so there should be a bottleneck somewhere that needs fixing.
I suspect the dataset loading function. I know it's not the best but I thought it's an optimization for the next release this one already took long enough.
https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/trainer/trainer.py#L58
Could you try upgrading to the latest MLX (0.18.1) (and if it's used here MLX LM (0.19.1)) just to be sure we didn't fix something.. (I think this PR may be related: https://github.com/ml-explore/mlx/pull/1452)
Also remind me what's your machine and OS?
I suspect the dataset loading function. I know it's not the best but I thought it's an optimization for the next release this one already took long enough.
Data loading is often the issue. And yes next release is quite reasonable.. just letting you know in case you didn't notice it.
Upgrading to v0.18.1 fixed it! 🚀
Data loading is often the issue. And yes next release is quite reasonable.. just letting you know in case you didn't notice it.
Thank you! Do you have any tips specific to MLX?
When I started getting the error, I figure it could be the data loading so I made some initial rough optimizations like using a generator and deleting the batch after processing and using the metal clear cache command.
Also remind me what's your machine and OS?
Macbook Pro 14-inch Chip: M3 Max URAM: 96GB OS: Sonoma 14.5
Thank you! Do you have any tips specific to MLX?
First verify that data loading is in fact the issue. I would do that by using the same batch over and over instead of loading it and make sure the GPU utilization is close to 100%.
If data loading is the problem then look into what's actually slow. Is it the IO itself or some preprocessing steps?
- If you preload the dataset into RAM it probably isn't the IO
- Do you do the preprocessing in MLX? If not, maybe try doing that so it runs fast on the GPU..
and using the metal clear cache command.
I wouldn't manually clear the cache unless you have a really good reason. That will typically just slow everything down.
Awesome, thanks!
If you preload the dataset into RAM it probably isn't the IO
Do you do the preprocessing in MLX? If not, maybe try doing that so it runs fast on the GPU..
I preload/prefetch the batch before running it.
Then probably is the HF processor I use here for preparing the inputs is the bottleneck.
I would do that by using the same batch over and over instead of loading it and make sure the GPU utilization is close to 100%.
Could you elaborate here, I didn't quite get it.