paxml icon indicating copy to clipboard operation
paxml copied to clipboard

Jax + tpu and AQT int8 train model loss is abnormal

Open Lisennlp opened this issue 11 months ago • 0 comments

I used the aqt_einsum function in the code to only quantify the qk sccore, and then trained the model. However, I found that the loss dropped very slowly after training to a certain number of steps (such as 200 steps), which was quite different from the loss curve trained by bfloat16. Am I missing something? For example, does backward need some additional processing? ps: I train model on jax==0.4.23 and tpu v5p-8

In other words, is there a training example for AQT int8 in pax?

Lisennlp avatar Mar 04 '24 09:03 Lisennlp