ao icon indicating copy to clipboard operation
ao copied to clipboard

[low-bit optim] Fix load state dict when device is different

Open gau-nernst opened this issue 4 months ago • 1 comments

In optim.load_state_dict(state_dict), if optim dtype != state_dict dtype, aten._to_copy.default is called. This PR simply implements this op and add appropriate tests.

Update: In PyTorch pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when dtype is the same but device is different. Thus, I have to manually override .to() method instead. This is only done for PyTorch pre-2.4. FP8 is not affected since FP8 CUDA requires PyTorch 2.4 anyway. We can remove this hack once we drop 2.3 support.

gau-nernst avatar Oct 05 '24 02:10 gau-nernst