TransformerEngine
TransformerEngine copied to clipboard
`Float8Quantizer::create_tensor` calculates `scale_inv` instead of creating an empty buffer
https://github.com/NVIDIA/TransformerEngine/blob/b39397c541292f336c5964dd1661d80c08dc4c78/transformer_engine/pytorch/csrc/extensions/quantizer.cpp#L112
This brings some overheads. For example, in fused_multi_quantize, the reciprocal kernels (along with the launch overheads) take most of the overall time.
There was an optimization that updates FP8 scale-inverse in kernels with FP8 output https://github.com/NVIDIA/TransformerEngine/pull/1083, why did we change it?
cc @timmoon10
Logically, scale is part of the recipe and scale_inv is part of the data:
- You can change
scaleat any time, for whatever reason, to any value. It is only relevant at the exact moment you are doing an FP8 cast. - Changing
scale_invdestroys FP8 data, just as if you overwrote the FP8 data itself. Also, consider thatscale_invgeneralizes to a tensor in MXFP8 and other block-scaling formats.
Before https://github.com/NVIDIA/TransformerEngine/pull/1083, we needed messy logic for the scale update after the forward pass: update scale_inv together with scale but also make a copy so that we don't destroy cached tensors for the backward pass. This added the CPU overhead of an extra kernel launch (cudaMemcpy) and allocating temporary PyTorch tensors.
After https://github.com/NVIDIA/TransformerEngine/pull/1083, we can decouple scale and scale_inv: populate scale_inv when performing the FP8 cast and update scale after the forward and backward passes. In most cases, this adds very little overhead to existing kernels (existing FP8 kernels output to amax, so just output to scale_inv as well). There are a few kernels that don't output scale_inv (e.g. cuBLAS) so we need a separate scale-inverse kernel, but that is no worse than the extra cudaMemcpy in the old implementation.
In this case, I think the problem is that the API and behavior of Quantizer::create_tensor are wonky. The rowwise_data arg doesn't make sense in a generic base class function, since its interpretation depends on the concrete class (and also you need multiple tensors to represent MXFP8). If the user wants to pass in a raw UINT8 buffer, they should also be required to pass in the scale_inv since there's no reason to expect that it matches the quantizer's scale value. Float8Quantizer::create_tensor should be changed so it allocates an uninitialized UINT8 buffer and an uninitialized FP32 scale-inv, and then it'll be the responsibility of the FP8 cast kernel to populate the scale-inv with the correct value.
@timmoon10 I totally understand what #1083 was doing. My question is why we changed away from it in TE 2.0, for example, calculating scale_inv in Float8Quantizer::create_tensor instead of in the kernels that output fp8 data.
Float8Quantizer::create_tensor should be changed so it allocates an uninitialized UINT8 buffer and an uninitialized FP32 scale-inv, and then it'll be the responsibility of the FP8 cast kernel to populate the scale-inv with the correct value.
Totally agree.
I'd say that this is just a bug - we should remove this at::reciprocal. The rowwise data I think was added as a way to pass preallocated buffer (which we need e.g. for UserBuffer), we will need to extend that to maybe pass a dictionary of buffers (since e.g. to have UB support for MXFP8 we may need to do the same for the scale_inv @timmoon10?