MIOpen
MIOpen copied to clipboard
Implement MatrixBandPart
-
Added MatrixBandPart.
-
Added driver test and gtest for MatrixBandPart.
-
New API is guarded by MIOPEN_BETA_API macro.
-
Average over all cases:
-
MatrixBandPart
| Type | Forward | Backward |
|---|---|---|
| float16 | 1.78 | 1.76 |
| float32 | 1.72 | 1.69 |
| bfloat16 | 1.71 | 1.68 |
FWD - FP16
| op_name | dtype | size | num_lower | num_upper | max_diag | direction | Rocm kernel avg | MIOpen kernel avg | ROCm / MIOpen |
|---|---|---|---|---|---|---|---|---|---|
| MatrixBandPart | float16 | [2 8 8] | 8 | 7 | 7 | fwd | 7760 | 6594 | 1.176827419 |
| MatrixBandPart | float16 | [2 8 8] | -7 | 7 | 7 | fwd | 13744 | 6505 | 2.11283628 |
| MatrixBandPart | float16 | [4 8 8] | 8 | 7 | 7 | fwd | 13952 | 6434 | 2.16847995 |
| MatrixBandPart | float16 | [4 8 8] | 4 | 7 | 7 | fwd | 13776 | 6666 | 2.066606661 |
| MatrixBandPart | float16 | [4 8 8] | 0 | 7 | 7 | fwd | 8848 | 6363 | 1.390539054 |
| MatrixBandPart | float16 | [4 8 8] | -4 | 7 | 7 | fwd | 8432 | 6719 | 1.254948653 |
| MatrixBandPart | float16 | [4 8 8] | -7 | 7 | 7 | fwd | 8336 | 6807 | 1.224621713 |
| MatrixBandPart | float16 | [8 8 8] | 8 | 7 | 7 | fwd | 8432 | 6452 | 1.306881587 |
| MatrixBandPart | float16 | [8 8 8] | 4 | 7 | 7 | fwd | 8432 | 6115 | 1.378904334 |
| MatrixBandPart | float16 | [8 8 8] | 3 | 7 | 7 | fwd | 8496 | 6737 | 1.261095443 |
| MatrixBandPart | float16 | [16 8 8] | 0 | 7 | 7 | fwd | 8672 | 6097 | 1.422338855 |
| MatrixBandPart | float16 | [16 8 8] | -1 | 7 | 7 | fwd | 8688 | 5990 | 1.450417362 |
| MatrixBandPart | float16 | [16 8 8] | -2 | 7 | 7 | fwd | 8816 | 6328 | 1.393173198 |
| MatrixBandPart | float16 | [16 8 8] | -3 | 7 | 7 | fwd | 8736 | 6346 | 1.376615191 |
| MatrixBandPart | float16 | [16 8 8] | -4 | 7 | 7 | fwd | 8704 | 6008 | 1.44873502 |
| MatrixBandPart | float16 | [32 8 8] | -7 | 7 | 7 | fwd | 8960 | 6417 | 1.396291102 |
| MatrixBandPart | float16 | [64 8 8] | 8 | 7 | 7 | fwd | 9504 | 7039 | 1.350191789 |
| MatrixBandPart | float16 | [64 8 8] | 4 | 7 | 7 | fwd | 9456 | 7092 | 1.333333333 |
| MatrixBandPart | float16 | [128 8 8] | 1 | 7 | 7 | fwd | 9408 | 7163 | 1.313416166 |
| MatrixBandPart | float16 | [256 8 8] | -3 | 7 | 7 | fwd | 9600 | 7217 | 1.330192601 |
| MatrixBandPart | float16 | [256 8 8] | -4 | 7 | 7 | fwd | 9552 | 7128 | 1.34006734 |
| MatrixBandPart | float16 | [512 8 8] | -3 | 7 | 7 | fwd | 9696 | 7252 | 1.33701048 |
| MatrixBandPart | float16 | [1024 8 8] | -3 | 7 | 7 | fwd | 9440 | 7306 | 1.292088694 |
| MatrixBandPart | float16 | [1024 8 8] | -4 | 7 | 7 | fwd | 9488 | 7341 | 1.292466966 |
| MatrixBandPart | float16 | [1024 8 8] | -7 | 7 | 7 | fwd | 9552 | 7484 | 1.276322822 |
FWD - FP32
| op_name | dtype | size | num_lower | num_upper | max_diag | direction | Rocm kernel avg | MIOpen kernel avg | ROCm / MIOpen |
|---|---|---|---|---|---|---|---|---|---|
| MatrixBandPart | float32 | [2 2 2] | 2 | 1 | 1 | fwd | 12144 | 6453 | 1.881915388 |
| MatrixBandPart | float32 | [2 2 2] | 1 | 1 | 1 | fwd | 11968 | 6489 | 1.84435198 |
| MatrixBandPart | float32 | [4 2 2] | 0 | 1 | 1 | fwd | 7296 | 6364 | 1.146448774 |
| MatrixBandPart | float32 | [4 2 2] | -1 | 1 | 1 | fwd | 6688 | 6471 | 1.03353423 |
| MatrixBandPart | float32 | [8 2 2] | 2 | 1 | 1 | fwd | 6720 | 6436 | 1.044126787 |
| MatrixBandPart | float32 | [8 2 2] | 1 | 1 | 1 | fwd | 6800 | 6702 | 1.014622501 |
| MatrixBandPart | float32 | [16 2 2] | 0 | 1 | 1 | fwd | 6784 | 6472 | 1.048207664 |
| MatrixBandPart | float32 | [16 2 2] | -1 | 1 | 1 | fwd | 6800 | 6578 | 1.03374886 |
| MatrixBandPart | float32 | [32 2 2] | 2 | 1 | 1 | fwd | 7184 | 6969 | 1.030850911 |
| MatrixBandPart | float32 | [64 2 2] | -1 | 1 | 1 | fwd | 7344 | 7023 | 1.045706963 |
| MatrixBandPart | float32 | [128 2 2] | 2 | 1 | 1 | fwd | 12016 | 6969 | 1.724207203 |
| MatrixBandPart | float32 | [256 2 2] | -1 | 1 | 1 | fwd | 7664 | 6524 | 1.174739424 |
| MatrixBandPart | float32 | [512 2 2] | 2 | 1 | 1 | fwd | 8128 | 7218 | 1.126073705 |
| MatrixBandPart | float32 | [1024 2 2] | 1 | 1 | 1 | fwd | 8128 | 7235 | 1.123427782 |
| MatrixBandPart | float32 | [1024 2 2] | 0 | 1 | 1 | fwd | 7984 | 7129 | 1.119932669 |
| MatrixBandPart | float32 | [1024 2 2] | -1 | 1 | 1 | fwd | 8000 | 7236 | 1.105583195 |
| MatrixBandPart | float32 | [2 4 4] | 4 | 3 | 3 | fwd | 6736 | 6684 | 1.007779773 |
| MatrixBandPart | float32 | [2 4 4] | 2 | 3 | 3 | fwd | 6752 | 6809 | 0.991628727 |
| MatrixBandPart | float32 | [2 4 4] | 1 | 3 | 3 | fwd | 6720 | 6791 | 0.989544986 |
| MatrixBandPart | float32 | [4 4 4] | -2 | 3 | 3 | fwd | 6832 | 6560 | 1.041463415 |
| MatrixBandPart | float32 | [64 4 4] | 2 | 3 | 3 | fwd | 7632 | 6613 | 1.154090428 |
| MatrixBandPart | float32 | [64 4 4] | 1 | 3 | 3 | fwd | 7616 | 6400 | 1.19 |
| MatrixBandPart | float32 | [256 4 4] | 0 | 3 | 3 | fwd | 8224 | 7307 | 1.1254961 |
| MatrixBandPart | float32 | [256 4 4] | -1 | 3 | 3 | fwd | 8048 | 7289 | 1.10412951 |
| MatrixBandPart | float32 | [512 4 4] | 4 | 3 | 3 | fwd | 8016 | 7129 | 1.124421377 |
| MatrixBandPart | float32 | [1024 4 4] | -1 | 3 | 3 | fwd | 8576 | 7324 | 1.170944839 |
| MatrixBandPart | float32 | [1024 4 4] | -2 | 3 | 3 | fwd | 8560 | 7342 | 1.165894852 |
| MatrixBandPart | float32 | [1024 4 4] | -3 | 3 | 3 | fwd | 8560 | 7253 | 1.180201296 |
FWD - BFP16
| op_name | dtype | size | num_lower | num_upper | max_diag | direction | Rocm kernel avg | MIOpen kernel avg | ROCm / MIOpen |
|---|---|---|---|---|---|---|---|---|---|
| MatrixBandPart | bfloat16 | [2 8 8] | -4 | 7 | 7 | fwd | 7440 | 6630 | 1.122171946 |
| MatrixBandPart | bfloat16 | [2 8 8] | -7 | 7 | 7 | fwd | 7376 | 6648 | 1.109506619 |
| MatrixBandPart | bfloat16 | [4 8 8] | 8 | 7 | 7 | fwd | 7440 | 6915 | 1.075921909 |
| MatrixBandPart | bfloat16 | [4 8 8] | 4 | 7 | 7 | fwd | 7536 | 6702 | 1.124440466 |
| MatrixBandPart | bfloat16 | [4 8 8] | 3 | 7 | 7 | fwd | 7568 | 6648 | 1.138387485 |
| MatrixBandPart | bfloat16 | [4 8 8] | -4 | 7 | 7 | fwd | 7536 | 6542 | 1.151941302 |
| MatrixBandPart | bfloat16 | [4 8 8] | -7 | 7 | 7 | fwd | 7552 | 6844 | 1.103448276 |
| MatrixBandPart | bfloat16 | [8 8 8] | 8 | 7 | 7 | fwd | 7696 | 6151 | 1.25117867 |
| MatrixBandPart | bfloat16 | [32 8 8] | -7 | 7 | 7 | fwd | 7808 | 6222 | 1.254901961 |
| MatrixBandPart | bfloat16 | [64 8 8] | 8 | 7 | 7 | fwd | 8096 | 7128 | 1.135802469 |
| MatrixBandPart | bfloat16 | [64 8 8] | 4 | 7 | 7 | fwd | 8192 | 6613 | 1.238772116 |
| MatrixBandPart | bfloat16 | [256 8 8] | 0 | 7 | 7 | fwd | 8352 | 7075 | 1.1804947 |
| MatrixBandPart | bfloat16 | [256 8 8] | -1 | 7 | 7 | fwd | 8368 | 6933 | 1.206981105 |
| MatrixBandPart | bfloat16 | [512 8 8] | -2 | 7 | 7 | fwd | 8400 | 7271 | 1.155274378 |
| MatrixBandPart | bfloat16 | [512 8 8] | -3 | 7 | 7 | fwd | 8432 | 7217 | 1.168352501 |
| MatrixBandPart | bfloat16 | [512 8 8] | -4 | 7 | 7 | fwd | 8320 | 6933 | 1.200057695 |
| MatrixBandPart | bfloat16 | [512 8 8] | -7 | 7 | 7 | fwd | 8256 | 7253 | 1.138287605 |
| MatrixBandPart | bfloat16 | [1024 8 8] | 8 | 7 | 7 | fwd | 8464 | 7431 | 1.139012246 |
| MatrixBandPart | bfloat16 | [1024 8 8] | 4 | 7 | 7 | fwd | 8672 | 7448 | 1.16433942 |
| MatrixBandPart | bfloat16 | [1024 8 8] | 3 | 7 | 7 | fwd | 8576 | 7199 | 1.191276566 |
BWD - FP16
| op_name | dtype | size | num_lower | num_upper | max_diag | direction | Rocm kernel avg | MIOpen kernel avg | ROCm / MIOpen |
|---|---|---|---|---|---|---|---|---|---|
| MatrixBandPart | float16 | [2 2 2] | 0 | 1 | 1 | bwd | 10416 | 6471 | 1.609643023 |
| MatrixBandPart | float16 | [2 2 2] | -1 | 1 | 1 | bwd | 7776 | 6204 | 1.253384913 |
| MatrixBandPart | float16 | [4 2 2] | 2 | 1 | 1 | bwd | 7680 | 5973 | 1.285786037 |
| MatrixBandPart | float16 | [4 2 2] | 1 | 1 | 1 | bwd | 7712 | 6098 | 1.264676943 |
| MatrixBandPart | float16 | [8 2 2] | -1 | 1 | 1 | bwd | 7776 | 6080 | 1.278947368 |
| MatrixBandPart | float16 | [16 2 2] | 2 | 1 | 1 | bwd | 7824 | 6204 | 1.261121857 |
| MatrixBandPart | float16 | [16 2 2] | 1 | 1 | 1 | bwd | 7824 | 6186 | 1.264791465 |
| MatrixBandPart | float16 | [32 2 2] | -1 | 1 | 1 | bwd | 7792 | 6275 | 1.241752988 |
| MatrixBandPart | float16 | [64 2 2] | 2 | 1 | 1 | bwd | 7920 | 6648 | 1.19133574 |
| MatrixBandPart | float16 | [64 2 2] | -1 | 1 | 1 | bwd | 7840 | 7146 | 1.097117268 |
| MatrixBandPart | float16 | [128 2 2] | 2 | 1 | 1 | bwd | 7824 | 6097 | 1.283254059 |
| MatrixBandPart | float16 | [128 2 2] | 1 | 1 | 1 | bwd | 7744 | 6133 | 1.262677319 |
| MatrixBandPart | float16 | [256 2 2] | 0 | 1 | 1 | bwd | 7984 | 5866 | 1.361063757 |
| MatrixBandPart | float16 | [256 2 2] | -1 | 1 | 1 | bwd | 7984 | 5991 | 1.332665665 |
| MatrixBandPart | float16 | [512 2 2] | 2 | 1 | 1 | bwd | 8144 | 6169 | 1.320149133 |
| MatrixBandPart | float16 | [512 2 2] | -1 | 1 | 1 | bwd | 8096 | 6258 | 1.293704059 |
BWD - FP32
| op_name | dtype | size | num_lower | num_upper | max_diag | direction | Rocm kernel avg | MIOpen kernel avg | ROCm / MIOpen |
|---|---|---|---|---|---|---|---|---|---|
| MatrixBandPart | float32 | [2 2 2] | 2 | 1 | 1 | bwd | 10336 | 6382 | 1.619554998 |
| MatrixBandPart | float32 | [2 2 2] | 1 | 1 | 1 | bwd | 9408 | 6614 | 1.422437254 |
| MatrixBandPart | float32 | [4 2 2] | 0 | 1 | 1 | bwd | 6944 | 6507 | 1.067158445 |
| MatrixBandPart | float32 | [4 2 2] | -1 | 1 | 1 | bwd | 6928 | 6507 | 1.064699554 |
| MatrixBandPart | float32 | [8 2 2] | 2 | 1 | 1 | bwd | 6928 | 6560 | 1.056097561 |
| MatrixBandPart | float32 | [16 2 2] | 0 | 1 | 1 | bwd | 6960 | 6791 | 1.024885878 |
| MatrixBandPart | float32 | [16 2 2] | -1 | 1 | 1 | bwd | 6944 | 6613 | 1.050052926 |
| MatrixBandPart | float32 | [32 2 2] | 2 | 1 | 1 | bwd | 7072 | 6969 | 1.014779739 |
| MatrixBandPart | float32 | [64 2 2] | -1 | 1 | 1 | bwd | 7392 | 7235 | 1.021700069 |
| MatrixBandPart | float32 | [128 2 2] | 2 | 1 | 1 | bwd | 9392 | 6525 | 1.439386973 |
| MatrixBandPart | float32 | [256 2 2] | 0 | 1 | 1 | bwd | 7664 | 6809 | 1.1255691 |
| MatrixBandPart | float32 | [256 2 2] | -1 | 1 | 1 | bwd | 7600 | 6347 | 1.197416102 |
| MatrixBandPart | float32 | [512 2 2] | 2 | 1 | 1 | bwd | 7984 | 7271 | 1.098060789 |
| MatrixBandPart | float32 | [1024 2 2] | 0 | 1 | 1 | bwd | 8064 | 7218 | 1.117206983 |
| MatrixBandPart | float32 | [1024 2 2] | -1 | 1 | 1 | bwd | 8144 | 7307 | 1.114547694 |
BWD - BFP16
| op_name | dtype | size | num_lower | num_upper | max_diag | direction | Rocm kernel avg | MIOpen kernel avg | ROCm / MIOpen |
|---|---|---|---|---|---|---|---|---|---|
| MatrixBandPart | bfloat16 | [2 16 16] | 16 | 15 | 15 | bwd | 7232 | 6417 | 1.127006389 |
| MatrixBandPart | bfloat16 | [2 16 16] | 9 | 15 | 15 | bwd | 7120 | 6524 | 1.091354997 |
| MatrixBandPart | bfloat16 | [4 16 16] | 0 | 15 | 15 | bwd | 7552 | 6328 | 1.193426043 |
| MatrixBandPart | bfloat16 | [4 16 16] | -2 | 15 | 15 | bwd | 7568 | 6595 | 1.147536012 |
| MatrixBandPart | bfloat16 | [4 16 16] | -4 | 15 | 15 | bwd | 7504 | 6257 | 1.199296788 |
| MatrixBandPart | bfloat16 | [8 16 16] | 0 | 15 | 15 | bwd | 7792 | 6435 | 1.210878011 |
| MatrixBandPart | bfloat16 | [8 16 16] | -2 | 15 | 15 | bwd | 7776 | 6168 | 1.260700389 |
| MatrixBandPart | bfloat16 | [16 16 16] | 4 | 15 | 15 | bwd | 8256 | 6879 | 1.200174444 |
| MatrixBandPart | bfloat16 | [16 16 16] | 2 | 15 | 15 | bwd | 8144 | 6773 | 1.202421379 |
| MatrixBandPart | bfloat16 | [32 16 16] | 0 | 15 | 15 | bwd | 8192 | 6986 | 1.172630976 |
| MatrixBandPart | bfloat16 | [64 16 16] | 0 | 15 | 15 | bwd | 8176 | 6844 | 1.194623027 |
| MatrixBandPart | bfloat16 | [128 16 16] | 2 | 15 | 15 | bwd | 8480 | 7360 | 1.152173913 |
| MatrixBandPart | bfloat16 | [256 16 16] | 4 | 15 | 15 | bwd | 8800 | 7324 | 1.201529219 |
| MatrixBandPart | bfloat16 | [256 16 16] | 2 | 15 | 15 | bwd | 8816 | 7093 | 1.242915551 |
| MatrixBandPart | bfloat16 | [512 16 16] | 16 | 15 | 15 | bwd | 8704 | 7484 | 1.163014431 |
| MatrixBandPart | bfloat16 | [512 16 16] | 9 | 15 | 15 | bwd | 8720 | 7626 | 1.143456596 |
| MatrixBandPart | bfloat16 | [512 16 16] | 6 | 15 | 15 | bwd | 9168 | 6986 | 1.312338964 |
| MatrixBandPart | bfloat16 | [1024 16 16] | -9 | 15 | 15 | bwd | 10064 | 9600 | 1.048333333 |
| MatrixBandPart | bfloat16 | [1024 16 16] | -15 | 15 | 15 | bwd | 10080 | 9528 | 1.057934509 |