MS-AMP
MS-AMP copied to clipboard
[#168 fix] add context manager to fake `ScalingTensor`/`ScalingParameter`'s `__class__` as `torch.Tensor`
Description See #168. This is the most non-invasive fix I could come up with. Thanks to @aliencaocao for idea.
Minor Revision
- adds
msamp.common.tensor.tensor.pretend_scaling_is_torch
, which can be used to fixGradScaler().step()
.
This is a non-breaking change as it does not deviate from prior behaviour without explicitly calling with pretend_scaling_is_torch()
.