bitsandbytes icon indicating copy to clipboard operation
bitsandbytes copied to clipboard

AdEMA NaN when loading from state_dict

Open darius-lam opened this issue 4 months ago • 1 comments

System Info

Running a standard training loop where I save the optimizer state_dict using opt.state_dict(). Upon loading using opt.load_state_dict() to resume, the model immediately NaNs after the first backprop step.

This only occurs using the AdEMA optimizer:

bnb.optim.AdEMAMix8bit(model.parameters(), lr=lr, t_alpha=T, t_beta3=T)

AdamW and others load state dict perfectly fine. Any ideas?

Reproduction

` opt = bnb.optim.AdEMAMix8bit(model.parameters()) #run training loop torch.save(opt.state_dict(), "dt.pt")

#try resuming opt from state_dict later opt.load_state_dict("dt.pt") #run training loop again `

Expected behavior

Optimizer should resume training without NaNning

darius-lam avatar Oct 02 '24 18:10 darius-lam