paxml
paxml copied to clipboard
Jax + tpu and AQT int8 train model loss is abnormal
I used the aqt_einsum function in the code to only quantify the qk sccore, and then trained the model. However, I found that the loss dropped very slowly after training to a certain number of steps (such as 200 steps), which was quite different from the loss curve trained by bfloat16. Am I missing something? For example, does backward need some additional processing? ps: I train model on jax==0.4.23 and tpu v5p-8
In other words, is there a training example for AQT int8 in pax?