TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Additional states

Open yongyanrao opened this issue 1 year ago • 4 comments

We noticed some additional states for each module, e.g.,

transformer.seq_layers.0.layer.self_attention.layernorm_qkv._extra_state
transformer.seq_layers.0.layer.self_attention.proj._extra_state
transformer.seq_layers.0.layer.layernorm_mlp._extra_state

And these states are empty binary strings b''. We are thinking these new states are related to fp8. How should we deal with them? Should we explicitly remove them? Or should we deal with them by some explicit methods?

yongyanrao avatar Oct 05 '23 15:10 yongyanrao

Why do you want to remove them? Those states are handled internally by Transformer Engine if FP8 is used.

ptrendx avatar Oct 05 '23 17:10 ptrendx

I am observing the same behavior during training without FP8, and I believe that these states are causing problems when attempting to load checkpoints into the model, especially when there is no "_extra_state" present in the checkpoint. Is there a method to deactivate or exclude these fields during training without FP8, given that they are all empty?

Teng-xu avatar Oct 05 '23 20:10 Teng-xu

@Teng-xu @yongyanrao These extra states are indeed a part of the additional information needed for FP8 training checkpoint. These can be explicitly removed but the simplest method would be to load the checkpoint using the strict=False flag when using PyTorch's load state dict method.

ksivaman avatar Jan 08 '24 07:01 ksivaman

You can read _extra_state with code like this instead of state.read(). this can show _extra_state.

if isinstance(state, io.BytesIO):
    state.seek(0)
    state = torch.load(state)

zte-tcb avatar Apr 07 '24 08:04 zte-tcb