AMDMIGraphX
AMDMIGraphX copied to clipboard
Add optional fp16 rmsnorm conversion pass to fix fp16 accuracy
Fixes https://github.com/ROCm/AMDMIGraphX/issues/2556 RMSNorm is used in LLMs like Llama2. Currently the fp16 version can overflow during normalization. This change try to addess it byconvert normalization to fp32.
Codecov Report
Attention: Patch coverage is 97.14286% with 2 lines in your changes missing coverage. Please review.
Project coverage is 91.46%. Comparing base (
a2752d8) to head (53d5a9d). Report is 672 commits behind head on develop.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| src/include/migraphx/rewrite_rmsnorm.hpp | 0.00% | 1 Missing :warning: |
| src/rewrite_rmsnorm.cpp | 97.67% | 1 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## develop #2687 +/- ##
===========================================
+ Coverage 91.41% 91.46% +0.04%
===========================================
Files 465 468 +3
Lines 17478 17548 +70
===========================================
+ Hits 15978 16050 +72
+ Misses 1500 1498 -2
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
🚀 New features to boost your workflow:
- ❄ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
I dont think this is exactly the right approach. I would prefer we update the rewrite_reduce pass to lower reduce_mean to reduce_sum(for all cases). And then we can have a pass to help increase precision by finding patterns like x^2 / n and rewrite them to (x / sqrt(n))^2.
| Test | Batch | Rate new e49865 |
Rate old 239ef7 |
Diff | Compare |
|---|---|---|---|---|---|
| torchvision-resnet50 | 64 | 2,834.97 | 2,802.24 | 1.17% | :white_check_mark: |
| torchvision-resnet50_fp16 | 64 | 6,499.80 | 6,495.21 | 0.07% | :white_check_mark: |
| torchvision-densenet121 | 32 | 2,081.56 | 2,052.94 | 1.39% | :white_check_mark: |
| torchvision-densenet121_fp16 | 32 | 3,661.19 | 3,655.58 | 0.15% | :white_check_mark: |
| cadene-inceptionv4 | 16 | 722.36 | 709.13 | 1.87% | :white_check_mark: |
| cadene-resnext64x4 | 16 | 691.85 | 649.66 | 6.49% | :high_brightness: |
| bert-mrpc-onnx | 8 | 812.17 | 812.10 | 0.01% | :white_check_mark: |
| bert-mrpc-tf | 1 | 387.40 | 387.66 | -0.07% | :white_check_mark: |
| pytorch-examples-wlang-gru | 1 | 302.47 | 303.42 | -0.31% | :white_check_mark: |
| pytorch-examples-wlang-lstm | 1 | 315.50 | 313.10 | 0.77% | :white_check_mark: |
| torchvision-resnet50_1 | 1 | 606.12 | 605.03 | 0.18% | :white_check_mark: |
| cadene-dpn92_1 | 1 | 415.94 | 415.64 | 0.07% | :white_check_mark: |
| cadene-resnext101_1 | 1 | 330.61 | 330.94 | -0.10% | :white_check_mark: |
| onnx-taau-downsample | 1 | 305.61 | 305.08 | 0.18% | :white_check_mark: |
| dlrm-criteoterabyte | 1 | 21.56 | 21.60 | -0.21% | :white_check_mark: |
| dlrm-criteoterabyte_fp16 | 1 | 40.59 | 40.61 | -0.06% | :white_check_mark: |
| agentmodel | 1 | 6,037.04 | 6,009.62 | 0.46% | :white_check_mark: |
| unet_fp16 | 2 | 54.77 | 54.80 | -0.05% | :white_check_mark: |
| resnet50v1_fp16 | 1 | 937.67 | 941.44 | -0.40% | :white_check_mark: |
| bert_base_cased_fp16 | 64 | 924.37 | 924.62 | -0.03% | :white_check_mark: |
| bert_large_uncased_fp16 | 32 | 290.46 | 290.41 | 0.02% | :white_check_mark: |
| bert_large_fp16 | 1 | 171.70 | 171.66 | 0.02% | :white_check_mark: |
| distilgpt2_fp16 | 16 | 1,513.05 | 1,513.78 | -0.05% | :white_check_mark: |
Check results before merge :high_brightness:
@pfultz2 I agree that a general approach for the rewrite could help with more cases. I was just not sure that it should be enabled for all occurences. Also, we need to be careful with it to not intruduce any regressions, since other matcher can depend on that specific structure, so they will need to be extended as well probably.
I dont think this is exactly the right approach. I would prefer we update the
rewrite_reducepass to lowerreduce_meantoreduce_sum(for all cases).
I don't see any rewrite_reduce pass. There are some rewrites scattered throughout the code (e.g. isntance_norm, compile_gen, layernorm.
The idea would be to have a dedicated rewrite_reduce instead of the specific rewrite_rmsnorm?
Also, this math rewrite only a partial solution, since larger numbers still can cause problem. I would suggest having a dedicated rmsnorm fp16-to-fp32 convert optional pass for those cases.
since other matcher can depend on that specific structure, so they will need to be extended as well probably.
This same rewriting is needed for layernorm, Right now we use a custom kernel that does it, but I want to deprecate to use our reduction fusion pipeline though.
Also, this pass can run after prefuse_ops so it shouldn't affect the current layernorm fusion.
I don't see any rewrite_reduce pass.
The pass is added in #2673, which hasnt been merged yet. It runs after prefuse ops so the flash attention can still fuse with a softmax.
Also, this math rewrite only a partial solution, since larger numbers still can cause problem. I would suggest having a dedicated rmsnorm fp16-to-fp32 convert optional pass for those cases.
I think with fast math we can use fp16 but when its disabled we should promote it to fp32.
@pfultz2 Updated to the rewrite_reduce pass.
I think with fast math we can use fp16 but when its disabled we should promote it to fp32.
The fp16->fp32 is still behind a flag. I have to check the fast math part first.
And then we can have a pass to help increase precision by finding patterns like x^2 / n and rewrite them to (x / sqrt(n))^2
Also, this does not have it yet.
@pfultz2 Sorry about reduce_rewrite, I misunderstood that part. I removed that and will move it to a separate PR.
Converted it to draft. It will be updated after https://github.com/ROCm/AMDMIGraphX/issues/2710 added to use this logic instead.