TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch] Fixed bug with loading calibrated weights

Open pggPL opened this issue 10 months ago • 11 comments

I encountered the bug related to calibration of scaling factors and loading the model.

When I calibrate the scaling factors with model with weights in bf16, the model parameters are in stored bf16 and scaling factors are stored in parameters which are dumped into "*._extra_state" in state_dict(). Suppose I want to load these weights to model initialized within fp8_model_init=True context. Then in load_state_dict():

  1. Tensors are copied before extra state. When I copy bf16 source tensor to fp8 destination tensor, the scaling factor is updated using old fp8_metadata.
  2. Then _extra_state is copied without any impact on tensors.

In the forward() of many modules we can see the usage of a function:

    @torch.no_grad()
    def reset_fp8_meta_scale_inv(self) -> None:
        """Replace FP8 meta tensor scale-inverse with cached value

        The FP8 meta tensor scale_inv entry corresponding to this
        tensor is replaced with the scale_inv value used to construct
        the tensor.

        """
        if self._fp8_meta is None:
            return
        fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
            forward=self._fp8_meta_forward,
        )
        scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
        scale_inv.view(1).copy_(self._scale_inv.view(1))

for example in layernorm_mlp.py forward:


            if primary_weights_in_fp8:
                # Weights are already in FP8
                fc1_weight.reset_fp8_meta_scale_inv()
                fc2_weight.reset_fp8_meta_scale_inv()
                fc1_weight_fp8 = fc1_weight
                fc2_weight_fp8 = fc2_weight
                fc1_weight_t_fp8 = None
                fc2_weight_t_fp8 = None

which overrides the fp8_meta scale_inv in the module by the value of this parameter in the tensor, which is clearly wrong.

The result of this bug is demonstrated in code below:

import transformer_engine.pytorch as te
import torch
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.common.recipe import Format, DelayedScaling
import torch
import torch.optim as optim


f = te.Linear(16, 16)

#######################################################################################
#         I want to obtain TE module with scaling factors not equal to one.           #
#######################################################################################

optimizer = optim.Adam(f.parameters(), lr=0.001)
sample_data = torch.rand((16, 16)).cuda() * 100000
sample_labels = torch.rand((16, 16)).cuda()

criterion = torch.nn.MSELoss()

with te.fp8_autocast(enabled=True):
    for _ in range(1000):
        optimizer.zero_grad()  # Zerowanie gradientów
        output = f(sample_data)  # Wykonanie przejścia wprzód
        loss = criterion(output, sample_labels)  # Obliczenie straty
        loss.backward()  # Obliczenie gradientów
        optimizer.step()  # Aktualizacja wag


# calibration
with te.fp8_autocast(enabled=False, calibrating=True):
    for _ in range(1000):
        output = f(sample_data)  # Wykonanie przejścia wprzód


#######################################################################################
#    I initialize model g only with fp8_parameters and copy parameters from f         #
#    models should return the same value, but they do not,                            #
#    beacause g has wrong scaling factor                                              #
#######################################################################################

with te.fp8_model_init():
    g = te.Linear(16, 16)

g.load_state_dict(f.state_dict())
# This loads _extra_state after loading the tensor. Tensors are loaded with wrong scale.

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
torch.manual_seed(1234)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    print(str(f(sample_data) - g(sample_data))[:100])
    # Should return 0, but does not.


g.load_state_dict(f.state_dict())
# Since _extra_state are already loaded, tensors are scaled correctly.

with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    print(str(f(sample_data) - g(sample_data))[:100])
    # This returns 0.

I propose a fix - override _load_from_state_dict such that _extra_state is copied before, not after, tensors.

I have also one question - should I add also add the unit test?

pggPL avatar Apr 11 '24 19:04 pggPL

@pggPL Yes, please add the unit test.

In the proposed implementation the set_extra_state would be called twice - maybe we should remove the extra state from the dict before calling the parent function?

ptrendx avatar Apr 15 '24 17:04 ptrendx

Unfortunately removing extra_state key from the dictionary will result with error when load_state_dict(..., strict=True). I see no way of doing that without breaking other thing.

I modified the test, instead of adding new. This test wasn't passing for me, because there was no implementation of abstract method get_fp8_weights_scratchpad() in Test_TE_Export, so I added some dummy method.

pggPL avatar May 03 '24 17:05 pggPL

Also, this test was never run in the CI, so we should enable that as well by adding the call to this test in qa/L0_pytorch_unittest/test.sh file

sudhakarsingh27 avatar May 08 '24 17:05 sudhakarsingh27

/te-ci pytorch

sudhakarsingh27 avatar May 13 '24 21:05 sudhakarsingh27

@timmoon10 could you please take a look at how the test has been modified now?

sudhakarsingh27 avatar May 13 '24 23:05 sudhakarsingh27

@pggPL, could you revert back files' permissions?

sudhakarsingh27 avatar May 15 '24 18:05 sudhakarsingh27

/te-ci pytorch

sudhakarsingh27 avatar May 15 '24 20:05 sudhakarsingh27

@pggPL I think your last commit doesn't have a signoff and so the DCO cribs. Could you fix that?

sudhakarsingh27 avatar May 16 '24 01:05 sudhakarsingh27

/te-ci pytorch

phu0ngng avatar May 16 '24 16:05 phu0ngng

/te-ci pytorch

sudhakarsingh27 avatar May 16 '24 17:05 sudhakarsingh27

/te-ci pytorch

timmoon10 avatar May 17 '24 16:05 timmoon10