recurrent-interface-network-pytorch icon indicating copy to clipboard operation
recurrent-interface-network-pytorch copied to clipboard

mixed precision training

Open nicolas-dufour opened this issue 1 year ago • 13 comments

Hey, I've worked on reimplementing RIN based on the authors repo and this repo but I cannot manage to make it work with mixed precision and I see you do make use mixed precision here. When naivly switching to bfloat16 or float 16, my model get stuck in a weird state: image left: bfloat16; right float32

Did you encounter such issues in your implementation? If so do you have some pointers to make it work?

Thanks!

nicolas-dufour avatar Sep 04 '23 09:09 nicolas-dufour

@nicolas-dufour Hey Nicolas, there was an issue with the way I set up mixed precision in accelerate

in case you were using the same logic, it should be fixed in 0.7.6

lucidrains avatar Sep 04 '23 13:09 lucidrains

@nicolas-dufour thanks for sharing those float32 results! 😄

lucidrains avatar Sep 04 '23 13:09 lucidrains

Hey @lucidrains i'm using PL instead of accelerate but behaviour should be the same. After further investigation, it seems there is a real difficulty to leverage mixed precision for this network. If setting f32 for all the time and class embeddings calculation and for the qkv linear projections of the cross attention layers, I do manage to get better convergence, but the training remains very unstable and blows-up mid training. Also, image quality is really subpar with respect to fp32 throughout the training.

For now I'm trying to stabilize it without making architecture changes, but not sure if it's possible without forcing to many fp32 in the arch.

Will update here if i make any progress

nicolas-dufour avatar Sep 05 '23 15:09 nicolas-dufour

@nicolas-dufour if you figure out the cause do submit a PR

i'll try to stabilize it later this week once i get back on my deep learning machine

lucidrains avatar Sep 05 '23 15:09 lucidrains

@nicolas-dufour decided to offer a way to turn off the linear attention, in case that is the source of the instability you are experiencing

lucidrains avatar Sep 08 '23 17:09 lucidrains

@lucidrains thanks ! Sadly it's still pretty unstable. The authors said on their official repo that they didn't used mixed precision, so maybe this architecture cannot be trained with mixed precision and need major redesign to be stable in mixed precision =(

nicolas-dufour avatar Sep 11 '23 15:09 nicolas-dufour

@nicolas-dufour yes indeed 😢 i'll also share that i tried a similar architecture for some contract work (different domain) a while back and experienced the same

lucidrains avatar Sep 11 '23 15:09 lucidrains

@nicolas-dufour did you try the qk norm by any chance?

lucidrains avatar Sep 11 '23 15:09 lucidrains

@lucidrains no thanks for the pointer, will try it out.

nicolas-dufour avatar Sep 11 '23 16:09 nicolas-dufour

@lucidrains Tried the qk_norm but it didn't change a thing. Also tried to add a LN for the from tokens of the cross attention but training still is subpar and doesn't converge

nicolas-dufour avatar Sep 13 '23 12:09 nicolas-dufour

@nicolas-dufour bummer! this may be caveat to this architecture then

thank you for running the experiments and sharing this!

lucidrains avatar Sep 13 '23 14:09 lucidrains

I met the same issue. @nicolas-dufour what if you remove the torch.no_grad() in the inference for self-conditioning, while using mixed precision?

LeeDoYup avatar Jan 30 '24 07:01 LeeDoYup

This is indeed the only way I found. It's a bug where parameters that are not set to store gradients in mixed precision never store gradients again.

This issue is discussed here https://github.com/pytorch/pytorch/issues/112583

To make mixed precision work one key component was also to catch Nans and Infs on the gradients and skip those batches

nicolas-dufour avatar Jan 30 '24 09:01 nicolas-dufour