AMDMIGraphX
AMDMIGraphX copied to clipboard
Implement reference op for rotary embedding
Motivation
Create a separate reference implementation for rotary embedding to be used with GQA and Sparse Attention. The reference implementation is accompanied by an op builder, which is to be used instead of the reference op directly, with the idea being to implement rotary embedding via operator composition sometime down the line.
Technical Details
Changelog Category
-
- [ ] Added: New functionality.
-
- [x] Changed: Changes to existing functionality.
-
- [ ] Removed: Functionality or support that has been removed. (Compared to a previous release)
-
- [ ] Optimized: Component performance that has been optimized or improved.
-
- [ ] Resolved Issues: Known issues from a previous version that have been resolved.
-
- [ ] Not Applicable: This PR is not to be included in the changelog.
| Test | Batch | Rate new 79c234 |
Rate old 38fdc6 |
Diff | Compare |
|---|---|---|---|---|---|
| torchvision-resnet50 | 64 | 3,157.45 | 3,173.83 | -0.52% | :white_check_mark: |
| torchvision-resnet50_fp16 | 64 | 6,590.55 | 6,613.50 | -0.35% | :white_check_mark: |
| torchvision-densenet121 | 32 | 2,437.43 | 2,445.12 | -0.31% | :white_check_mark: |
| torchvision-densenet121_fp16 | 32 | 4,116.15 | 4,132.06 | -0.38% | :white_check_mark: |
| torchvision-inceptionv3 | 32 | 1,665.57 | 1,673.47 | -0.47% | :white_check_mark: |
| torchvision-inceptionv3_fp16 | 32 | 2,589.12 | 2,596.39 | -0.28% | :white_check_mark: |
| cadene-inceptionv4 | 16 | 794.28 | 797.27 | -0.38% | :white_check_mark: |
| cadene-resnext64x4 | 16 | 802.11 | 806.10 | -0.49% | :white_check_mark: |
| slim-mobilenet | 64 | 8,197.09 | 8,232.04 | -0.42% | :white_check_mark: |
| slim-nasnetalarge | 64 | 221.72 | 222.86 | -0.51% | :white_check_mark: |
| slim-resnet50v2 | 64 | 3,295.32 | 3,305.41 | -0.31% | :white_check_mark: |
| bert-mrpc-onnx | 8 | 1,132.24 | 1,144.06 | -1.03% | :white_check_mark: |
| bert-mrpc-tf | 1 | 486.43 | 486.42 | 0.00% | :white_check_mark: |
| pytorch-examples-wlang-gru | 1 | 317.25 | 309.90 | 2.37% | :white_check_mark: |
| pytorch-examples-wlang-lstm | 1 | 450.92 | 387.21 | 16.45% | :high_brightness: |
| torchvision-resnet50_1 | 1 | 807.28 | 745.35 | 8.31% | :high_brightness: |
| cadene-dpn92_1 | 1 | 436.44 | 428.94 | 1.75% | :white_check_mark: |
| cadene-resnext101_1 | 1 | 368.60 | 369.63 | -0.28% | :white_check_mark: |
| onnx-taau-downsample | 1 | 398.07 | 399.22 | -0.29% | :white_check_mark: |
| dlrm-criteoterabyte | 1 | 31.92 | 32.03 | -0.36% | :white_check_mark: |
| dlrm-criteoterabyte_fp16 | 1 | 51.00 | 51.11 | -0.21% | :white_check_mark: |
| agentmodel | 1 | 9,826.08 | 9,659.41 | 1.73% | :white_check_mark: |
| unet_fp16 | 2 | 59.03 | 59.19 | -0.28% | :white_check_mark: |
| resnet50v1_fp16 | 1 | 993.33 | 991.39 | 0.20% | :white_check_mark: |
| resnet50v1_int8 | 1 | 992.10 | 971.32 | 2.14% | :white_check_mark: |
| bert_base_cased_fp16 | 64 | 1,099.32 | 1,104.24 | -0.45% | :white_check_mark: |
| bert_large_uncased_fp16 | 32 | 343.81 | 345.64 | -0.53% | :white_check_mark: |
| bert_large_fp16 | 1 | 197.82 | 198.02 | -0.10% | :white_check_mark: |
| distilgpt2_fp16 | 16 | 2,076.27 | 2,085.17 | -0.43% | :white_check_mark: |
| yolov5s | 1 | 587.50 | 588.82 | -0.23% | :white_check_mark: |
| tinyllama | 1 | 43.80 | 43.95 | -0.35% | :white_check_mark: |
| vicuna-fastchat | 1 | 45.04 | 45.27 | -0.51% | :white_check_mark: |
| whisper-tiny-encoder | 1 | 410.03 | 410.98 | -0.23% | :white_check_mark: |
| whisper-tiny-decoder | 1 | 414.22 | 415.37 | -0.28% | :white_check_mark: |
| llama2_7b | 1 | 19.11 | 19.15 | -0.24% | :white_check_mark: |
| qwen1.5-7b | 1 | 23.43 | 23.53 | -0.42% | :white_check_mark: |
| phi3-3.8b | 1 | 26.57 | 26.70 | -0.48% | :white_check_mark: |
| mask-rcnn | 1 | 12.15 | 12.24 | -0.72% | :white_check_mark: |
| llama3-8b | 1 | 21.65 | 21.74 | -0.40% | :white_check_mark: |
| whisper-large-encoder | 1 | 10.17 | 10.22 | -0.51% | :white_check_mark: |
| whisper-large-decoder | 1 | 98.90 | 99.83 | -0.94% | :white_check_mark: |
| mistral-7b | 1 | 23.65 | 23.74 | -0.35% | :white_check_mark: |
| FLUX.1-schnell | 1 | 726.34 | 721.05 | 0.73% | :white_check_mark: |
| nan | nan | nan | nan | nan% | :x: |
This build is not recommended to merge :red_circle:
:x:bert-mrpc-tf: ERROR - check error output
2025-09-24 09:25:13.803844: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 359, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 306, in main
graph = load_tf_graph(model_name)
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 300, in load_tf_graph
graph_def.ParseFromString(f.read())
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/lib/io/file_io.py", line 116, in read
self._preread_check()
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/lib/io/file_io.py", line 77, in _preread_check
self._read_buf = _pywrap_file_io.BufferedInputStream(
tensorflow.python.framework.errors_impl.UnimplementedError: File system scheme '[local]' not implemented (file: '/new-saved-models/tf-misc/bert_mrpc1.pb'):red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output
:red_circle:mask-rcnn: FAILED: MIGraphX is not within tolerance - check verbose output
Codecov Report
:x: Patch coverage is 86.31579% with 13 lines in your changes missing coverage. Please review.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| src/op/builder/rotary_embedding.cpp | 0.00% | 12 Missing :warning: |
| src/include/migraphx/op/rotary_embedding.hpp | 98.80% | 1 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## develop #4315 +/- ##
===========================================
+ Coverage 92.23% 92.24% +0.02%
===========================================
Files 557 562 +5
Lines 25924 26453 +529
===========================================
+ Hits 23909 24401 +492
- Misses 2015 2052 +37
| Files with missing lines | Coverage Δ | |
|---|---|---|
| src/include/migraphx/op/rotary_embedding.hpp | 98.80% <98.80%> (ø) |
|
| src/op/builder/rotary_embedding.cpp | 0.00% <0.00%> (ø) |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.