AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Convert OCP FP8 model to FNUZ model inside MIGraphX

Open umangyadav opened this issue 1 year ago • 5 comments

Problem Description

Parsing OCP FP8 Model

This would require MIGraphX to expose E4M3FN data type into the IR. Currently only E4M3FNUZ type is exposed. It is probably not a big work to expose E4M3FN to IR but may require additional testing.

OCP FP8 to FNUZ FP8:

This is simple as this one:

if(ocp_f8_data == 0x80) // if it is negative zero then map to single zero
    hip_f8_data = 0x0;
else if((ocp_f8_data & 0x7f) == 0x7f) // if any of the NaNs
    hip_f8_data = 0x80; // Map to single NaN
else
    hip_f8_data = ocp_f8_data; // Any other numbers, just maintain bit encoding.
return hip_f8_data;

By simply maintaining bit encoding, it automatically halves numbers for FNUZ. This conversion would mean FNUZ would never see max value of 240.0f. Because max OCP FP8 is 448, it would get mapped to 224 in FNUZ. This is assumed to be okay from frameworks team’s experimentations.

QDQ Ops

If the model uses any quantize-linear or dequantize linear ops, then scales and zero points of those operations would have been calculated during training assuming OCP FP8 data type. Conversion of OCP FP8 QDQ op to FNUZ FP8 should be easy. Scale should be multiplied by 2 for QuantizeLinear and divided by 2 for DeQuantizeLinear Zero-points considering compile time constants can be converted into E4M3FNUZ type at compile time.

GEMMs

GEMMs would be doing MAC (multiply-and-accumulate) in FNUZ types, both inputs to the GEMM are assumed to be in FNUZ type. So, because of scaling factor of 2 for both inputs, it would like. image GEMM output would be FP32. FP32 values need to be multiplied by 4. Inside MIGraphX, we can set alpha = 4. As it would be doing D = alpha * A * B + beta * C

Convolutions

Convolutions are also just MAC. Activations would be in FNUZ type. Activations would be coming from Quantize Linear Operation most probably. Weights can be converted into FNUZ type at compile time from OCP FP8. Output of convolution is also FP32. Each output value needs to be multiplied by 4.

Normalization Layers

It is assumed that all the normalization, non-linear layer would happen in FP32. For example, ReLU, GeLU, BatchNorm, LayerNorm etc. MIGraphX would need to make sure this the case during conversion pass. For the SoftMax, it shouldn’t really matter what’s the data type is, I think. Comments?

Parameters:

If any of the parameters are in OCP FP8 then what?
Proposed solution: This may require HIP kernel for conversion or conversion to FNUZ can happen on host. Need to think about it more.

Return value:

What is return value is OCP FP8? This is unlikely case most probably but if that happens probably need HIP kernel to convert back to OCP FP8 before returning or conversion can happen on host. Need to make decision on that.

Work plan for MIGraphX:

  • [ ] Expose E4M3FN type into IR.

  • [ ] Add logic to convert E4M3FN values to E4M3FNUZ type values. This would most probably be just a compile time operation. Logic already exists at https://github.com/ROCm/msft_amd_ai_operators/blob/58fed181d72e5680c2cc5c5b068ff26f993cbc22/operators/conversion/hip/ocp_float8.h#L44

    • Most probably won’t require HIP kernel or GPU implementation for this inside MIGraphX.
    • This should be added as an MIGraphX operation. Or it could be specialization of the “convert” operation if no HIP kernel is required.
    • Need to add tests for this conversion logic.
  • [ ] Add conversion pass that looks for OCP FP8 inside model and add scaling factors for GEMM/convolutions etc. and convert graph to FNUZ FP8.

    • This can handle Gemm/Convs/QDQ at first and assert that none of the other ops are in OCP FP8.
    • Handle parameter and return OCP FP8 values.
    • May require some changes inside lowering to add conversion logic for OCP FP8 parameters to FNUZ parameters on Host.
  • [ ] Add verify tests with OCP FP8

  • [ ] Do end-to-end model testing for accuracy. Add OCP FP8 model to CI.

Operating System

22.04

CPU

Any

GPU

AMD Instinct MI300

Other

No response

ROCm Version

ROCm 6.0.0

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

umangyadav avatar Feb 02 '24 16:02 umangyadav

@pfultz2 @hgaspar @krzysz00 can you guys review and comment if this is correct ?

umangyadav avatar Feb 02 '24 16:02 umangyadav

It is probably not a big work to expose E4M3FN to IR but may require additional testing.

What do you mean by expose? Are we adding this as a new data type?

pfultz2 avatar Feb 02 '24 16:02 pfultz2

What do you mean by expose? Are we adding this as a new data type?

Yes, e.g. here https://github.com/ROCm/AMDMIGraphX/blob/3c0c782617ecd3554a2d17a9145b7bc015592a24/src/include/migraphx/shape.hpp#L65

We already have ref implementation for E4M3FN type and some tests at : https://github.com/ROCm/AMDMIGraphX/blob/develop/test/fp8e4m3fn.cpp

umangyadav avatar Feb 02 '24 17:02 umangyadav

I am thinking we will need a special operator for this, because we want to convert at runtime(for input parameters) and compile time(for literals). We also need a conversion back to FP8 OCP if the return value is FP8 OCP as well.

We would add a pass(say adjust_fp8) that would do the conversions automatically based on the hardware support. It would work similar to quantize_fp16(maybe we could look into making this reusable) by inserting the special conversions back and forth and then remove the nested converts. It would probably need an extra step to insert additional scaling for QDQs.

Adding the HIP kernels shouldn't be too difficult since this is just a pointwise operator. We would just add a function to do the conversion and then have the operator generate such a function when generating the pointwise code. We will probably need to add fp8e4m3fnuz type on the HIP side, but this could be just a storage type(so no need to implement arithmetic operators and other features).

pfultz2 avatar Feb 02 '24 17:02 pfultz2

There are two possibilities for the parameters and return values.

  1. Parameters and/or return value are in OCP FP8. In that case we would require special operator.
  2. Model has input parameters and return values in FP32 and internally it uses QDQ ops to convert to OCP FP8. For example i have Mobilenet FP8 model from harris. It has FP32 inputs and outputs and it then uses QDQ internally.

I see second possibility as more likely the case. Therefore I think we can de-prioritize handling adding special operator. Perhaps @hgaspar can provide more thoughts on this.

umangyadav avatar Feb 02 '24 17:02 umangyadav