AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Add optional fp16 rmsnorm conversion pass to fix fp16 accuracy

Open attila-dusnoki-htec opened this issue 1 year ago • 10 comments

Fixes https://github.com/ROCm/AMDMIGraphX/issues/2556 RMSNorm is used in LLMs like Llama2. Currently the fp16 version can overflow during normalization. This change try to addess it byconvert normalization to fp32.

attila-dusnoki-htec avatar Jan 25 '24 16:01 attila-dusnoki-htec

Codecov Report

Attention: Patch coverage is 97.14286% with 2 lines in your changes missing coverage. Please review.

Project coverage is 91.46%. Comparing base (a2752d8) to head (53d5a9d). Report is 672 commits behind head on develop.

Files with missing lines Patch % Lines
src/include/migraphx/rewrite_rmsnorm.hpp 0.00% 1 Missing :warning:
src/rewrite_rmsnorm.cpp 97.67% 1 Missing :warning:
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #2687      +/-   ##
===========================================
+ Coverage    91.41%   91.46%   +0.04%     
===========================================
  Files          465      468       +3     
  Lines        17478    17548      +70     
===========================================
+ Hits         15978    16050      +72     
+ Misses        1500     1498       -2     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Jan 25 '24 17:01 codecov[bot]

I dont think this is exactly the right approach. I would prefer we update the rewrite_reduce pass to lower reduce_mean to reduce_sum(for all cases). And then we can have a pass to help increase precision by finding patterns like x^2 / n and rewrite them to (x / sqrt(n))^2.

pfultz2 avatar Jan 25 '24 18:01 pfultz2

Test Batch Rate new
e49865
Rate old
239ef7
Diff Compare
torchvision-resnet50 64 2,834.97 2,802.24 1.17% :white_check_mark:
torchvision-resnet50_fp16 64 6,499.80 6,495.21 0.07% :white_check_mark:
torchvision-densenet121 32 2,081.56 2,052.94 1.39% :white_check_mark:
torchvision-densenet121_fp16 32 3,661.19 3,655.58 0.15% :white_check_mark:
cadene-inceptionv4 16 722.36 709.13 1.87% :white_check_mark:
cadene-resnext64x4 16 691.85 649.66 6.49% :high_brightness:
bert-mrpc-onnx 8 812.17 812.10 0.01% :white_check_mark:
bert-mrpc-tf 1 387.40 387.66 -0.07% :white_check_mark:
pytorch-examples-wlang-gru 1 302.47 303.42 -0.31% :white_check_mark:
pytorch-examples-wlang-lstm 1 315.50 313.10 0.77% :white_check_mark:
torchvision-resnet50_1 1 606.12 605.03 0.18% :white_check_mark:
cadene-dpn92_1 1 415.94 415.64 0.07% :white_check_mark:
cadene-resnext101_1 1 330.61 330.94 -0.10% :white_check_mark:
onnx-taau-downsample 1 305.61 305.08 0.18% :white_check_mark:
dlrm-criteoterabyte 1 21.56 21.60 -0.21% :white_check_mark:
dlrm-criteoterabyte_fp16 1 40.59 40.61 -0.06% :white_check_mark:
agentmodel 1 6,037.04 6,009.62 0.46% :white_check_mark:
unet_fp16 2 54.77 54.80 -0.05% :white_check_mark:
resnet50v1_fp16 1 937.67 941.44 -0.40% :white_check_mark:
bert_base_cased_fp16 64 924.37 924.62 -0.03% :white_check_mark:
bert_large_uncased_fp16 32 290.46 290.41 0.02% :white_check_mark:
bert_large_fp16 1 171.70 171.66 0.02% :white_check_mark:
distilgpt2_fp16 16 1,513.05 1,513.78 -0.05% :white_check_mark:

Check results before merge :high_brightness:

migraphx-bot avatar Jan 25 '24 18:01 migraphx-bot


     :white_check_mark: bert-mrpc-onnx: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert-mrpc-tf: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance
     :white_check_mark: torchvision-resnet50_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-dpn92_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-resnext101_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance
     :white_check_mark: agentmodel: PASSED: MIGraphX meets tolerance
     :white_check_mark: unet: PASSED: MIGraphX meets tolerance
     :white_check_mark: resnet50v1: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert_base_cased_fp16: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert_large_uncased_fp16: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert_large: PASSED: MIGraphX meets tolerance
     :white_check_mark: yolov5s: PASSED: MIGraphX meets tolerance
     :white_check_mark: tinyllama: PASSED: MIGraphX meets tolerance
     :white_check_mark: vicuna-fastchat: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-encoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-decoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: distilgpt2_fp16: PASSED: MIGraphX meets tolerance

migraphx-bot avatar Jan 25 '24 18:01 migraphx-bot

@pfultz2 I agree that a general approach for the rewrite could help with more cases. I was just not sure that it should be enabled for all occurences. Also, we need to be careful with it to not intruduce any regressions, since other matcher can depend on that specific structure, so they will need to be extended as well probably.

I dont think this is exactly the right approach. I would prefer we update the rewrite_reduce pass to lower reduce_mean to reduce_sum(for all cases).

I don't see any rewrite_reduce pass. There are some rewrites scattered throughout the code (e.g. isntance_norm, compile_gen, layernorm. The idea would be to have a dedicated rewrite_reduce instead of the specific rewrite_rmsnorm?


Also, this math rewrite only a partial solution, since larger numbers still can cause problem. I would suggest having a dedicated rmsnorm fp16-to-fp32 convert optional pass for those cases.

attila-dusnoki-htec avatar Jan 26 '24 12:01 attila-dusnoki-htec

since other matcher can depend on that specific structure, so they will need to be extended as well probably.

This same rewriting is needed for layernorm, Right now we use a custom kernel that does it, but I want to deprecate to use our reduction fusion pipeline though.

Also, this pass can run after prefuse_ops so it shouldn't affect the current layernorm fusion.

I don't see any rewrite_reduce pass.

The pass is added in #2673, which hasnt been merged yet. It runs after prefuse ops so the flash attention can still fuse with a softmax.

pfultz2 avatar Jan 26 '24 15:01 pfultz2

Also, this math rewrite only a partial solution, since larger numbers still can cause problem. I would suggest having a dedicated rmsnorm fp16-to-fp32 convert optional pass for those cases.

I think with fast math we can use fp16 but when its disabled we should promote it to fp32.

pfultz2 avatar Jan 27 '24 00:01 pfultz2

@pfultz2 Updated to the rewrite_reduce pass.

I think with fast math we can use fp16 but when its disabled we should promote it to fp32.

The fp16->fp32 is still behind a flag. I have to check the fast math part first.

And then we can have a pass to help increase precision by finding patterns like x^2 / n and rewrite them to (x / sqrt(n))^2

Also, this does not have it yet.

attila-dusnoki-htec avatar Jan 29 '24 17:01 attila-dusnoki-htec

@pfultz2 Sorry about reduce_rewrite, I misunderstood that part. I removed that and will move it to a separate PR.

attila-dusnoki-htec avatar Jan 30 '24 08:01 attila-dusnoki-htec

Converted it to draft. It will be updated after https://github.com/ROCm/AMDMIGraphX/issues/2710 added to use this logic instead.

attila-dusnoki-htec avatar Feb 01 '24 09:02 attila-dusnoki-htec