bitsandbytes
bitsandbytes copied to clipboard
AdEMA NaN when loading from state_dict
System Info
Running a standard training loop where I save the optimizer state_dict using opt.state_dict(). Upon loading using opt.load_state_dict() to resume, the model immediately NaNs after the first backprop step.
This only occurs using the AdEMA optimizer:
bnb.optim.AdEMAMix8bit(model.parameters(), lr=lr, t_alpha=T, t_beta3=T)
AdamW and others load state dict perfectly fine. Any ideas?
Reproduction
` opt = bnb.optim.AdEMAMix8bit(model.parameters()) #run training loop torch.save(opt.state_dict(), "dt.pt")
#try resuming opt from state_dict later opt.load_state_dict("dt.pt") #run training loop again `
Expected behavior
Optimizer should resume training without NaNning