fp8 quantization with FSDP2 error
When I fp8 quantize a model and then shard it using FSDP2, it reports an error:
[rank1]: Traceback (most recent call last):
[rank1]: File "/mnt/teams/algo-teams/shared/code/wanx-inference/generate.py", line 461, in <module>
[rank1]: generate(args)
[rank1]: File "/mnt/teams/algo-teams/shared/code/wanx-inference/generate.py", line 375, in generate
[rank1]: wan_i2v = wan.WanI2V(
[rank1]: ^^^^^^^^^^^
[rank1]: File "/mnt/teams/algo-teams/shared/code/wanx-inference/wan/image2video.py", line 218, in __init__
[rank1]: self.model = shard_dit_fn(self.model, param_dtype=torch.float8_e4m3fn)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/mnt/teams/algo-teams/shared/code/wanx-inference/wan/distributed/fsdp.py", line 112, in shard_dit_model
[rank1]: fully_shard_with_ignore_param(block, mesh=pm.get_dp_with_cp_mesh(), reshard_after_forward=True, mp_policy=mixed_fsdp2, ignored_params=ignored_states_set)
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank1]: updated = func(inp_module, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/mnt/teams/algo-teams/shared/pytorch_distributed_examples/src/tu_pth_dist/fsdp_compat.py", line 200, in fully_shard_with_ignore_param
[rank1]: state._fsdp_param_group = FSDPParamGroup(
[rank1]: ^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 132, in __init__
[rank1]: FSDPParam(
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 239, in __init__
[rank1]: self._init_sharded_param(param, device, shard_placement_fn)
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 368, in _init_sharded_param
[rank1]: chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py", line 124, in _chunk_with_empty
[rank1]: chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 436, in _dispatch__torch_function__
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 455, in _dispatch__torch_dispatch__
[rank1]: raise NotImplementedError(
[rank1]: NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'int'>), kwarg_types={}
I can see that there is no aten.split in https://github.com/pytorch/ao/blob/ab3792e3d91e04f85992a659c1664a6a1a6d733c/torchao/quantization/linear_activation_quantized_tensor.py . Could anyone provide an implementation for it?
I tried to implement the split function by myself,
@implements(aten.split.Tensor)
def _(func, types, args, kwargs):
new_values = func(args[0].original_weight_tensor, *args[1:], **kwargs)
def make_new_tensor(value):
out = LinearActivationQuantizedTensor(
value,
args[0].input_quant_func,
args[0].quant_kwargs,
)
return return_and_correct_aliasing(func, args, kwargs, out)
return list(map(make_new_tensor, new_values))
Another error is reported:
[rank0]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py", line 124, in _chunk_with_empty
[rank0]: chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 436, in _dispatch__torch_function__
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 451, in _dispatch__torch_dispatch__
[rank0]: return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 412, in wrapper
[rank0]: return func(f, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 220, in _
[rank0]: out_tensor = func(tensor.original_weight_tensor, *args[1:], **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
[rank0]: return self._op(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 455, in _dispatch__torch_dispatch__
[rank0]: raise NotImplementedError(
[rank0]: NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>, <class 'int'>), kwarg_types={}
I'm not sure why LinearActivationQuantizedTensor becomes AffineQuantizedTensor. When I look into dtypes/affine_quantized_tensor.py, I can find no where to write a split function.
Any suggestions?
cc @jerryzh168
I also met this problem... Many thanks!
I also met this problem... Many thanks!
My problem is when I employ torchao to quantization Wan2.1 model, it is incompatible with FSDP.
IIUC FSDP is for training, what is the use case to apply it on the post training methods with quantize_ API?
also cc @drisspg