[Quantization + FSDP] Support `quantize_()` for DTensor
While trying out INT8 mixed precision pretraining (#748) with torchtitan, I came across an issue that if the model is FSDP-sharded, quantize_() won't work. The fix would be adding an extra logic to handle DTensor, similar to what FP8 is doing
https://github.com/pytorch/ao/blob/f5703b07acc683653556d04ef970709ba47dba10/torchao/float8/float8_tensor.py#L161-L183
Yeah this came up in some discussions with inference providers like SGLang as well
@jerryzh168 @kwen2501 is this addressed now with quantize + distributed inference composability work?
this is not addressed yet, this is training use case I think, that we can explore in 2025 H1 together with @vkuzo, we do need a guide on how DTensor composes with quantization in both inference and training use cases
@jerryzh168 is there any update for this issue? Cheers!
see some examples in https://github.com/pytorch/ao/blob/main/test/float8/test_fsdp.py
we'll be using quantize_ API everywhere, but maybe not yet for https://github.com/pytorch/ao/blob/137b0795acb3282ce622948b1537e20914186eea/test/float8/test_fsdp.py#L88, cc @vkuzo @danielvegamyhre on the plan to move to quantize_ API there
what is your use case @Andy0422
see some examples in https://github.com/pytorch/ao/blob/main/test/float8/test_fsdp.py
we'll be using
quantize_API everywhere, but maybe not yet forLine 88 in 137b079
convert_to_float8_training( , cc @vkuzo @danielvegamyhre on the plan to move to quantize_ API there what is your use case @Andy0422
My use case is inference. I'd like to use quantize_API, such as Int8DynamicActivationInt8WeightConfig, to quantize the weight first, then move the weight and scale with FSDP, hopefully use int8 allreduce.