Convert OCP FP8 model to FNUZ model inside MIGraphX
Problem Description
Parsing OCP FP8 Model
This would require MIGraphX to expose E4M3FN data type into the IR. Currently only E4M3FNUZ type is exposed. It is probably not a big work to expose E4M3FN to IR but may require additional testing.
OCP FP8 to FNUZ FP8:
This is simple as this one:
if(ocp_f8_data == 0x80) // if it is negative zero then map to single zero
hip_f8_data = 0x0;
else if((ocp_f8_data & 0x7f) == 0x7f) // if any of the NaNs
hip_f8_data = 0x80; // Map to single NaN
else
hip_f8_data = ocp_f8_data; // Any other numbers, just maintain bit encoding.
return hip_f8_data;
By simply maintaining bit encoding, it automatically halves numbers for FNUZ. This conversion would mean FNUZ would never see max value of 240.0f. Because max OCP FP8 is 448, it would get mapped to 224 in FNUZ. This is assumed to be okay from frameworks team’s experimentations.
QDQ Ops
If the model uses any quantize-linear or dequantize linear ops, then scales and zero points of those operations would have been calculated during training assuming OCP FP8 data type. Conversion of OCP FP8 QDQ op to FNUZ FP8 should be easy. Scale should be multiplied by 2 for QuantizeLinear and divided by 2 for DeQuantizeLinear Zero-points considering compile time constants can be converted into E4M3FNUZ type at compile time.
GEMMs
GEMMs would be doing MAC (multiply-and-accumulate) in FNUZ types, both inputs to the GEMM are assumed to be in FNUZ type. So, because of scaling factor of 2 for both inputs, it would like.
GEMM output would be FP32. FP32 values need to be multiplied by 4. Inside MIGraphX, we can set alpha = 4. As it would be doing D = alpha * A * B + beta * C
Convolutions
Convolutions are also just MAC. Activations would be in FNUZ type. Activations would be coming from Quantize Linear Operation most probably. Weights can be converted into FNUZ type at compile time from OCP FP8. Output of convolution is also FP32. Each output value needs to be multiplied by 4.
Normalization Layers
It is assumed that all the normalization, non-linear layer would happen in FP32. For example, ReLU, GeLU, BatchNorm, LayerNorm etc. MIGraphX would need to make sure this the case during conversion pass. For the SoftMax, it shouldn’t really matter what’s the data type is, I think. Comments?
Parameters:
If any of the parameters are in OCP FP8 then what?
Proposed solution: This may require HIP kernel for conversion or conversion to FNUZ can happen on host. Need to think about it more.
Return value:
What is return value is OCP FP8? This is unlikely case most probably but if that happens probably need HIP kernel to convert back to OCP FP8 before returning or conversion can happen on host. Need to make decision on that.
Work plan for MIGraphX:
-
[ ] Expose E4M3FN type into IR.
-
[ ] Add logic to convert E4M3FN values to E4M3FNUZ type values. This would most probably be just a compile time operation. Logic already exists at https://github.com/ROCm/msft_amd_ai_operators/blob/58fed181d72e5680c2cc5c5b068ff26f993cbc22/operators/conversion/hip/ocp_float8.h#L44
- Most probably won’t require HIP kernel or GPU implementation for this inside MIGraphX.
- This should be added as an MIGraphX operation. Or it could be specialization of the “convert” operation if no HIP kernel is required.
- Need to add tests for this conversion logic.
-
[ ] Add conversion pass that looks for OCP FP8 inside model and add scaling factors for GEMM/convolutions etc. and convert graph to FNUZ FP8.
- This can handle Gemm/Convs/QDQ at first and assert that none of the other ops are in OCP FP8.
- Handle parameter and return OCP FP8 values.
- May require some changes inside lowering to add conversion logic for OCP FP8 parameters to FNUZ parameters on Host.
-
[ ] Add verify tests with OCP FP8
-
[ ] Do end-to-end model testing for accuracy. Add OCP FP8 model to CI.
Operating System
22.04
CPU
Any
GPU
AMD Instinct MI300
Other
No response
ROCm Version
ROCm 6.0.0
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
@pfultz2 @hgaspar @krzysz00 can you guys review and comment if this is correct ?
It is probably not a big work to expose E4M3FN to IR but may require additional testing.
What do you mean by expose? Are we adding this as a new data type?
What do you mean by expose? Are we adding this as a new data type?
Yes, e.g. here https://github.com/ROCm/AMDMIGraphX/blob/3c0c782617ecd3554a2d17a9145b7bc015592a24/src/include/migraphx/shape.hpp#L65
We already have ref implementation for E4M3FN type and some tests at : https://github.com/ROCm/AMDMIGraphX/blob/develop/test/fp8e4m3fn.cpp
I am thinking we will need a special operator for this, because we want to convert at runtime(for input parameters) and compile time(for literals). We also need a conversion back to FP8 OCP if the return value is FP8 OCP as well.
We would add a pass(say adjust_fp8) that would do the conversions automatically based on the hardware support. It would work similar to quantize_fp16(maybe we could look into making this reusable) by inserting the special conversions back and forth and then remove the nested converts. It would probably need an extra step to insert additional scaling for QDQs.
Adding the HIP kernels shouldn't be too difficult since this is just a pointwise operator. We would just add a function to do the conversion and then have the operator generate such a function when generating the pointwise code. We will probably need to add fp8e4m3fnuz type on the HIP side, but this could be just a storage type(so no need to implement arithmetic operators and other features).
There are two possibilities for the parameters and return values.
- Parameters and/or return value are in OCP FP8. In that case we would require special operator.
- Model has input parameters and return values in FP32 and internally it uses QDQ ops to convert to OCP FP8. For example i have Mobilenet FP8 model from harris. It has FP32 inputs and outputs and it then uses QDQ internally.
I see second possibility as more likely the case. Therefore I think we can de-prioritize handling adding special operator. Perhaps @hgaspar can provide more thoughts on this.