torchdistx
torchdistx copied to clipboard
AnyPrecision optimizer dynamic casting
Describe the bug:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
dtypes of exp_avg_sq
and grad
differ while using in-place operations.
Describe how to reproduce:
# uses default hyperparameters such as momentum=float32 and variance=bfloat16
optim = AnyPrecisionAdamW(model.parameters())
optim.step()
Error log
RuntimeError: !needs_dynamic_casting<func_t>::check(iter) INTERNAL ASSERT FAILED at "../aten/src/ATen/native/cpu/Loops.h":347, please report a bug to PyTorch.