TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] Fixed bug with loading calibrated weights
I encountered the bug related to calibration of scaling factors and loading the model.
When I calibrate the scaling factors with model with weights in bf16, the model parameters are in stored bf16 and scaling factors are stored in parameters which are dumped into "*._extra_state" in state_dict(). Suppose I want to load these weights to model initialized within fp8_model_init=True context. Then in load_state_dict():
- Tensors are copied before extra state. When I copy bf16 source tensor to fp8 destination tensor, the scaling factor is updated using old fp8_metadata.
- Then _extra_state is copied without any impact on tensors.
In the forward() of many modules we can see the usage of a function:
@torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None:
"""Replace FP8 meta tensor scale-inverse with cached value
The FP8 meta tensor scale_inv entry corresponding to this
tensor is replaced with the scale_inv value used to construct
the tensor.
"""
if self._fp8_meta is None:
return
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
scale_inv.view(1).copy_(self._scale_inv.view(1))
for example in layernorm_mlp.py
forward:
if primary_weights_in_fp8:
# Weights are already in FP8
fc1_weight.reset_fp8_meta_scale_inv()
fc2_weight.reset_fp8_meta_scale_inv()
fc1_weight_fp8 = fc1_weight
fc2_weight_fp8 = fc2_weight
fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None
which overrides the fp8_meta scale_inv in the module by the value of this parameter in the tensor, which is clearly wrong.
The result of this bug is demonstrated in code below:
import transformer_engine.pytorch as te
import torch
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.common.recipe import Format, DelayedScaling
import torch
import torch.optim as optim
f = te.Linear(16, 16)
#######################################################################################
# I want to obtain TE module with scaling factors not equal to one. #
#######################################################################################
optimizer = optim.Adam(f.parameters(), lr=0.001)
sample_data = torch.rand((16, 16)).cuda() * 100000
sample_labels = torch.rand((16, 16)).cuda()
criterion = torch.nn.MSELoss()
with te.fp8_autocast(enabled=True):
for _ in range(1000):
optimizer.zero_grad() # Zerowanie gradientów
output = f(sample_data) # Wykonanie przejścia wprzód
loss = criterion(output, sample_labels) # Obliczenie straty
loss.backward() # Obliczenie gradientów
optimizer.step() # Aktualizacja wag
# calibration
with te.fp8_autocast(enabled=False, calibrating=True):
for _ in range(1000):
output = f(sample_data) # Wykonanie przejścia wprzód
#######################################################################################
# I initialize model g only with fp8_parameters and copy parameters from f #
# models should return the same value, but they do not, #
# beacause g has wrong scaling factor #
#######################################################################################
with te.fp8_model_init():
g = te.Linear(16, 16)
g.load_state_dict(f.state_dict())
# This loads _extra_state after loading the tensor. Tensors are loaded with wrong scale.
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
torch.manual_seed(1234)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
print(str(f(sample_data) - g(sample_data))[:100])
# Should return 0, but does not.
g.load_state_dict(f.state_dict())
# Since _extra_state are already loaded, tensors are scaled correctly.
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
print(str(f(sample_data) - g(sample_data))[:100])
# This returns 0.
I propose a fix - override _load_from_state_dict
such that _extra_state is copied before, not after, tensors.
I have also one question - should I add also add the unit test?
@pggPL Yes, please add the unit test.
In the proposed implementation the set_extra_state would be called twice - maybe we should remove the extra state from the dict before calling the parent function?
Unfortunately removing extra_state key from the dictionary will result with error when load_state_dict(..., strict=True)
. I see no way of doing that without breaking other thing.
I modified the test, instead of adding new. This test wasn't passing for me, because there was no implementation of abstract method get_fp8_weights_scratchpad()
in Test_TE_Export
, so I added some dummy method.
Also, this test was never run in the CI, so we should enable that as well by adding the call to this test in qa/L0_pytorch_unittest/test.sh
file
/te-ci pytorch
@timmoon10 could you please take a look at how the test has been modified now?
@pggPL, could you revert back files' permissions?
/te-ci pytorch
@pggPL I think your last commit doesn't have a signoff and so the DCO cribs. Could you fix that?
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch