Numeric stability w/ AMP
Hello, a contributor recently added EfficientViT to timm so I explored the model before merging... I found that it could not train in mixed precision without instantly having NaN loss. The problem appears to be the q/k/v matmuls and the division
- to train it seems required to at least force q/k/v to float32 and ensure matmuls are in float32 by disabling autocast as per my mods https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py#L262
- to eval stabily with the pretrained weights and amp or float16/bfloat16, the eps for the division needs to be changed from 1e-15 to ~ 1e-5
Have you observed similar or thought of any approaches to improve this?
Hello Ross,
Thank you for sharing your findings!
I also have similar findings that q/k/v Matmul and the division need to be float32 during training to avoid NaN loss. We currently do not have a good remedy for this. Given that the q/k/v Matmul and the division are lightweight, your current approach is an excellent workaround to bypass the problem. Certainly, we will delve further into this matter and will keep you updated once we identify an effective solution.
Regarding the evaluation stability, I am not sure whether changing the eps to 1e-5 will hurt the accuracy or not. If possible, I think keeping the division in float32 during testing is a better solution since its computation cost is negligible.
Thank you, Han