TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Support for NVFP4 training

Open nostalgiaa-a opened this issue 1 week ago • 4 comments
trafficstars

I can implement NVFP4-supported linear layer calls with a simple script, but when I use Megatron-LM for NVFP4 training, I found that the TE lacks support for NVFP4Tensors in the function replace_raw_data.
I wonder if you will be providing this support and any related stuffs?

The function is in the file path: TransformerEngine/transformer_engine/pytorch/tensor/utils.py

def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): r"""Change a quantized tensor's data buffer while preserving values

This function modifies only the address space of the underlying
raw data and does not alter any other tensor attributes or values.

This may be used for custom buffer allocations, e.g. packing
multiple parameter tensors together into a single contiguous
buffer for ZeRO-2.

"""
if isinstance(tensor, Float8Tensor):
    old_raw_data = tensor._data
    assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match"
    new_raw_data.detach().copy_(old_raw_data)
    tensor._data = new_raw_data
    del old_raw_data
elif isinstance(tensor, Float8BlockwiseQTensor):
    old_raw_data = tensor._rowwise_data
    assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match"
    new_raw_data.detach().copy_(old_raw_data)
    tensor._rowwise_data = new_raw_data
    del old_raw_data
elif isinstance(tensor, MXFP8Tensor):
    raise NotImplementedError("replace_raw_data for MXFP8Tensor is not supported yet")
else:
    raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet")

The error reports:

replace_raw_data for <class 'transformer_engine.pytorch.tensor.nvfp4_tensor.NVFP4Tensor'> is not supported yet File "/data/xusimin/TE2/Megatron-LM/megatron/core/fp8_utils.py", line 149, in _modify_underlying_storage_impl replace_raw_data(fp8_tensor, new_raw_data) File "/data/xusimin/TE2/Megatron-LM/megatron/core/fp8_utils.py", line 390, in modify_underlying_storage _modify_underlying_storage_impl(tensor, new_raw_data) File "/data/xusimin/TE2/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 743, in init modify_underlying_storage(param, new_param_data) File "/data/xusimin/TE2/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 221, in _allocate_buffers_for_parameters self.ddp_config,

                    param_dtype,

                    grad_dtype,

                    params,

                    data_parallel_group,

                    self.bucket_size,

                    param_to_name,

                    gradient_scaling_factor,

                    param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)],

                    self.ddp_config.nccl_ub,

                    pg_collection,

                )

            )

File "/data/xusimin/TE2/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 312, in init dense_params, self.intra_dp_cp_group, gradient_scaling_factor=gradient_scaling_factor

    )
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/data/xusimin/TE2/Megatron-LM/megatron/training/training.py", line 994, in get_model config=config,

                ddp_config=ddp_config,

                module=model_chunk,

                # Turn off bucketing for model_chunk 2 onwards, since communication for these

                # model chunks is overlapped with compute anyway.

                disable_bucketing=(model_chunk_idx > 0)

                or args.overlap_param_gather_with_optimizer_step,

            )

            for (model_chunk_idx, model_chunk) in enumerate(model)

File "/data/xusimin/TE2/Megatron-LM/megatron/training/training.py", line 1094, in setup_model_and_optimizer model = get_model(model_provider_func, model_type) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/xusimin/TE2/Megatron-LM/megatron/training/training.py", line 666, in pretrain model_provider, model_type, checkpointing_context=checkpointing_context

)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/data/xusimin/TE2/Megatron-LM/pretrain_gpt.py", line 236, in train_valid_test_datasets_provider,

    partial(model_provider, gpt_builder),

    ModelType.encoder_or_decoder,

    forward_step,

    args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},

    extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,

    store=store,

)

ValueError: replace_raw_data for <class 'transformer_engine.pytorch.tensor.nvfp4_tensor.NVFP4Tensor'> is not supported yet

nostalgiaa-a avatar Nov 06 '25 08:11 nostalgiaa-a