bitsandbytes icon indicating copy to clipboard operation
bitsandbytes copied to clipboard

Fine tuning with int8 and NLP models...is stable embedding needed?

Open lessw2020 opened this issue 2 years ago • 1 comments

Thanks for the great work on the optimizer quantization! I'm trying to fine tune a T5 model using the adam 8 bit...but I'm finding the val loss is significantly worse (i.e.10x) vs using BFloat16 or FP32 optimizer states. It does train stably in terms of steadily improving loss, but the starting loss is so far behind that it's not practical.

I'm wondering if we thus need to employ the stable embeddings for t5 and fine tuning...but if so, how do we do that without negatively affecting the already trained embeddings?

Or is the stable embeddings designed solely for the 'train from scratch' scenario and this high loss is due to other factors (i.e. t5 was trained in BFloat16 instead of FP32)?
Thanks for any insights!

lessw2020 avatar Aug 16 '22 16:08 lessw2020

Training from scratch with a large Vision Transformer (500M) worked, so this issue seems to be specific to the NLP embeddings. I'll try to isolate the embedding layer and keep that in BF16 as a fix here for fine tuning.

lessw2020 avatar Aug 19 '22 03:08 lessw2020

Quick update - I see very similar behaviour on T5 if you run with BF16 and stochastic rounding, so it seems the embeddings on already trained T5 are super sensitive. Things do work fine if you run with Kahan summation and BF16, but that's much higher precision relatively. I'm going to close at this point and focus on BF16 and Kahan atm, as I think it will be too special case to try and extract out the embeddings on every different NLP model.

lessw2020 avatar Aug 21 '22 03:08 lessw2020