ao icon indicating copy to clipboard operation
ao copied to clipboard

fp8 quantization with FSDP2 error

Open happynear opened this issue 9 months ago • 5 comments

When I fp8 quantize a model and then shard it using FSDP2, it reports an error:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/generate.py", line 461, in <module>
[rank1]:     generate(args)
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/generate.py", line 375, in generate
[rank1]:     wan_i2v = wan.WanI2V(
[rank1]:               ^^^^^^^^^^^
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/wan/image2video.py", line 218, in __init__
[rank1]:     self.model = shard_dit_fn(self.model, param_dtype=torch.float8_e4m3fn)
[rank1]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/wan/distributed/fsdp.py", line 112, in shard_dit_model
[rank1]:     fully_shard_with_ignore_param(block, mesh=pm.get_dp_with_cp_mesh(), reshard_after_forward=True, mp_policy=mixed_fsdp2, ignored_params=ignored_states_set)
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank1]:     updated = func(inp_module, *args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/teams/algo-teams/shared/pytorch_distributed_examples/src/tu_pth_dist/fsdp_compat.py", line 200, in fully_shard_with_ignore_param
[rank1]:     state._fsdp_param_group = FSDPParamGroup(
[rank1]:                               ^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 132, in __init__
[rank1]:     FSDPParam(
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 239, in __init__
[rank1]:     self._init_sharded_param(param, device, shard_placement_fn)
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 368, in _init_sharded_param
[rank1]:     chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py", line 124, in _chunk_with_empty
[rank1]:     chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank1]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 436, in _dispatch__torch_function__
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 455, in _dispatch__torch_dispatch__
[rank1]:     raise NotImplementedError(
[rank1]: NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'int'>), kwarg_types={}

I can see that there is no aten.split in https://github.com/pytorch/ao/blob/ab3792e3d91e04f85992a659c1664a6a1a6d733c/torchao/quantization/linear_activation_quantized_tensor.py . Could anyone provide an implementation for it?

happynear avatar Mar 20 '25 04:03 happynear

I tried to implement the split function by myself,

@implements(aten.split.Tensor)
def _(func, types, args, kwargs):
    new_values = func(args[0].original_weight_tensor, *args[1:], **kwargs)

    def make_new_tensor(value):
        out = LinearActivationQuantizedTensor(
                    value,
                    args[0].input_quant_func,
                    args[0].quant_kwargs,
                )
        return return_and_correct_aliasing(func, args, kwargs, out)

    return list(map(make_new_tensor, new_values))

Another error is reported:

[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py", line 124, in _chunk_with_empty
[rank0]:     chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 436, in _dispatch__torch_function__
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 451, in _dispatch__torch_dispatch__
[rank0]:     return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 412, in wrapper
[rank0]:     return func(f, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 220, in _
[rank0]:     out_tensor = func(tensor.original_weight_tensor, *args[1:], **kwargs)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 455, in _dispatch__torch_dispatch__
[rank0]:     raise NotImplementedError(
[rank0]: NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>, <class 'int'>), kwarg_types={}

I'm not sure why LinearActivationQuantizedTensor becomes AffineQuantizedTensor. When I look into dtypes/affine_quantized_tensor.py, I can find no where to write a split function. Any suggestions?

happynear avatar Mar 20 '25 05:03 happynear

cc @jerryzh168

vkuzo avatar Mar 20 '25 12:03 vkuzo

I also met this problem... Many thanks!

Andy0422 avatar Apr 22 '25 10:04 Andy0422

I also met this problem... Many thanks!

My problem is when I employ torchao to quantization Wan2.1 model, it is incompatible with FSDP.

Andy0422 avatar Apr 22 '25 10:04 Andy0422

IIUC FSDP is for training, what is the use case to apply it on the post training methods with quantize_ API?

also cc @drisspg

jerryzh168 avatar Apr 29 '25 20:04 jerryzh168