Implement PReLU backward
- Added PReLU backward operation and kernels.
- Added driver test and gtest for PReLU backward operation.
- New API is guarded by MIOPEN_BETA_API macro.
- Compared to ROCm pytorch:
float16
| op_name | dtype | size | num_param | model | direction | ROCm pytorch | MIOpen HIP | Improvement |
|---|---|---|---|---|---|---|---|---|
| PReLU | float16 | [512 64 112 112] | 1 | arcface | bwd | 99312671 | 12277100 | 8.11 |
| PReLU | float16 | [512 64 56 56] | 1 | arcface | bwd | 26715752 | 3073780 | 8.79 |
| PReLU | float16 | [512 128 56 56] | 1 | arcface | bwd | 52103300 | 6140360 | 8.53 |
| PReLU | float16 | [512 128 28 28] | 1 | arcface | bwd | 13279513 | 1545730 | 8.73 |
| PReLU | float16 | [512 256 28 28] | 1 | arcface | bwd | 26771242 | 3072230 | 8.79 |
| PReLU | float16 | [512 256 14 14] | 1 | arcface | bwd | 6605507 | 786651 | 8.68 |
| PReLU | float16 | [512 512 14 14] | 1 | arcface | bwd | 13283229 | 1550010 | 8.71 |
| PReLU | float16 | [512 512 7 7] | 1 | arcface | bwd | 3289272 | 403211 | 8.72 |
| PReLU | float16 | [512 64 112 112] | 64 | arcface | bwd | 107275183 | 46587900 | 2.31 |
| PReLU | float16 | [512 64 56 56] | 64 | arcface | bwd | 30255041 | 15273200 | 2.00 |
| PReLU | float16 | [512 128 56 56] | 128 | arcface | bwd | 56547437 | 13889100 | 4.09 |
| PReLU | float16 | [512 128 28 28] | 128 | arcface | bwd | 15626705 | 3868870 | 4.10 |
| PReLU | float16 | [512 256 28 28] | 256 | arcface | bwd | 30232264 | 5028100 | 6.07 |
| PReLU | float16 | [512 256 14 14] | 256 | arcface | bwd | 7826620 | 1531590 | 5.26 |
| PReLU | float16 | [512 512 14 14] | 512 | arcface | bwd | 15628058 | 2114630 | 7.50 |
| PReLU | float16 | [512 512 7 7] | 512 | arcface | bwd | 3877272 | 682737 | 6.01 |
float32
| op_name | dtype | size | num_param | model | direction | ROCm pytorch | MIOpen HIP | Improvement |
|---|---|---|---|---|---|---|---|---|
| PReLU | float32 | [512 64 112 112] | 1 | arcface | bwd | 103155863 | 12389000 | 8.35 |
| PReLU | float32 | [512 64 56 56] | 1 | arcface | bwd | 27476400 | 3102070 | 8.95 |
| PReLU | float32 | [512 128 56 56] | 1 | arcface | bwd | 52891035 | 6194870 | 8.58 |
| PReLU | float32 | [512 128 28 28] | 1 | arcface | bwd | 14441185 | 1560250 | 9.43 |
| PReLU | float32 | [512 256 28 28] | 1 | arcface | bwd | 27334300 | 3102970 | 8.90 |
| PReLU | float32 | [512 256 14 14] | 1 | arcface | bwd | 7261653 | 791718 | 9.51 |
| PReLU | float32 | [512 512 14 14] | 1 | arcface | bwd | 14419330 | 1559770 | 9.42 |
| PReLU | float32 | [512 512 7 7] | 1 | arcface | bwd | 3608816 | 406560 | 9.54 |
| PReLU | float32 | [512 64 112 112] | 64 | arcface | bwd | 109970712 | 46718600 | 2.36 |
| PReLU | float32 | [512 64 56 56] | 64 | arcface | bwd | 30409798 | 13333300 | 2.30 |
| PReLU | float32 | [512 128 56 56] | 128 | arcface | bwd | 57683256 | 13940500 | 4.16 |
| PReLU | float32 | [512 128 28 28] | 128 | arcface | bwd | 16256649 | 3878210 | 4.27 |
| PReLU | float32 | [512 256 28 28] | 256 | arcface | bwd | 30886236 | 5124050 | 6.09 |
| PReLU | float32 | [512 256 14 14] | 256 | arcface | bwd | 8320915 | 1522140 | 5.66 |
| PReLU | float32 | [512 512 14 14] | 512 | arcface | bwd | 16164181 | 2137560 | 7.70 |
| PReLU | float32 | [512 512 7 7] | 512 | arcface | bwd | 4095063 | 685281 | 6.41 |
bfloat16
| op_name | dtype | size | num_param | model | direction | ROCm pytorch | MIOpen HIP | Improvement |
|---|---|---|---|---|---|---|---|---|
| PReLU | bfloat16 | [512 64 112 112] | 1 | arcface | bwd | 99413514 | 12540100 | 7.95 |
| PReLU | bfloat16 | [512 64 56 56] | 1 | arcface | bwd | 26839289 | 3139650 | 8.63 |
| PReLU | bfloat16 | [512 128 56 56] | 1 | arcface | bwd | 52355515 | 6275980 | 8.38 |
| PReLU | bfloat16 | [512 128 28 28] | 1 | arcface | bwd | 13353346 | 1579090 | 8.60 |
| PReLU | bfloat16 | [512 256 28 28] | 1 | arcface | bwd | 26856779 | 3140230 | 8.64 |
| PReLU | bfloat16 | [512 256 14 14] | 1 | arcface | bwd | 6645672 | 801566 | 8.56 |
| PReLU | bfloat16 | [512 512 14 14] | 1 | arcface | bwd | 13434983 | 1577450 | 8.66 |
| PReLU | bfloat16 | [512 512 7 7] | 1 | arcface | bwd | 3310021 | 410969 | 8.57 |
| PReLU | bfloat16 | [512 64 112 112] | 64 | arcface | bwd | 106599289 | 46867700 | 2.28 |
| PReLU | bfloat16 | [512 64 56 56] | 64 | arcface | bwd | 30100583 | 11790800 | 2.57 |
| PReLU | bfloat16 | [512 128 56 56] | 128 | arcface | bwd | 55829383 | 14050900 | 3.99 |
| PReLU | bfloat16 | [512 128 28 28] | 128 | arcface | bwd | 15548941 | 3913070 | 4.03 |
| PReLU | bfloat16 | [512 256 28 28] | 256 | arcface | bwd | 29955610 | 5126710 | 5.89 |
| PReLU | bfloat16 | [512 256 14 14] | 256 | arcface | bwd | 7793750 | 1546550 | 5.17 |
| PReLU | bfloat16 | [512 512 14 14] | 512 | arcface | bwd | 15556630 | 2171380 | 7.27 |
| PReLU | bfloat16 | [512 512 7 7] | 512 | arcface | bwd | 3862986 | 692056 | 5.91 |
- Average over all cases:
| type | average |
|---|---|
| float16 | 7.13 |
| float32 | 7.83 |
| bfloat16 | 7.32 |
@junliume Can you help me checking the failing reason of "Window Build" and "Jenkins - Fp32 Hip Debug gfx90a"?
@junliume Can you help me checking the failing reason of "Window Build" and "Jenkins - Fp32 Hip Debug gfx90a"?
I think these are the issues causing the CI to fail.
Window Build
Jenkins - Fp32 Hip Debug gfx90a
@sgundabo Thank you very much.
@junliume Github action has passed. @CAHEK7 Would you review this PR?
May I ask about the version of Troch and graphics card used for this test?
May I ask about the version of Troch and graphics card used for this test?
@yanjianglu I used MI250 GPU and Pytorch from rocm/pytorch. I suppose the image tag should be rocm6.1.2_ubuntu22.04_py3.10_pytorch_release-2.1.2.