ao
ao copied to clipboard
[low-bit optim] Fix load state dict when device is different
In optim.load_state_dict(state_dict)
, if optim dtype != state_dict dtype, aten._to_copy.default
is called. This PR simply implements this op and add appropriate tests.
Update: In PyTorch pre-2.4, calling .to(device, dtype)
will not dispatch aten._to_copy.default
when dtype is the same but device is different. Thus, I have to manually override .to()
method instead. This is only done for PyTorch pre-2.4. FP8 is not affected since FP8 CUDA requires PyTorch 2.4 anyway. We can remove this hack once we drop 2.3 support.