MIOpen icon indicating copy to clipboard operation
MIOpen copied to clipboard

Implement PReLU backward

Open long10024070 opened this issue 1 year ago • 3 comments

  • Added PReLU backward operation and kernels.
  • Added driver test and gtest for PReLU backward operation.
  • New API is guarded by MIOPEN_BETA_API macro.
  • Compared to ROCm pytorch:
float16
op_name dtype size num_param model direction ROCm pytorch MIOpen HIP Improvement
PReLU float16 [512 64 112 112] 1 arcface bwd 99312671 12277100 8.11
PReLU float16 [512 64 56 56] 1 arcface bwd 26715752 3073780 8.79
PReLU float16 [512 128 56 56] 1 arcface bwd 52103300 6140360 8.53
PReLU float16 [512 128 28 28] 1 arcface bwd 13279513 1545730 8.73
PReLU float16 [512 256 28 28] 1 arcface bwd 26771242 3072230 8.79
PReLU float16 [512 256 14 14] 1 arcface bwd 6605507 786651 8.68
PReLU float16 [512 512 14 14] 1 arcface bwd 13283229 1550010 8.71
PReLU float16 [512 512 7 7] 1 arcface bwd 3289272 403211 8.72
PReLU float16 [512 64 112 112] 64 arcface bwd 107275183 46587900 2.31
PReLU float16 [512 64 56 56] 64 arcface bwd 30255041 15273200 2.00
PReLU float16 [512 128 56 56] 128 arcface bwd 56547437 13889100 4.09
PReLU float16 [512 128 28 28] 128 arcface bwd 15626705 3868870 4.10
PReLU float16 [512 256 28 28] 256 arcface bwd 30232264 5028100 6.07
PReLU float16 [512 256 14 14] 256 arcface bwd 7826620 1531590 5.26
PReLU float16 [512 512 14 14] 512 arcface bwd 15628058 2114630 7.50
PReLU float16 [512 512 7 7] 512 arcface bwd 3877272 682737 6.01
float32
op_name dtype size num_param model direction ROCm pytorch MIOpen HIP Improvement
PReLU float32 [512 64 112 112] 1 arcface bwd 103155863 12389000 8.35
PReLU float32 [512 64 56 56] 1 arcface bwd 27476400 3102070 8.95
PReLU float32 [512 128 56 56] 1 arcface bwd 52891035 6194870 8.58
PReLU float32 [512 128 28 28] 1 arcface bwd 14441185 1560250 9.43
PReLU float32 [512 256 28 28] 1 arcface bwd 27334300 3102970 8.90
PReLU float32 [512 256 14 14] 1 arcface bwd 7261653 791718 9.51
PReLU float32 [512 512 14 14] 1 arcface bwd 14419330 1559770 9.42
PReLU float32 [512 512 7 7] 1 arcface bwd 3608816 406560 9.54
PReLU float32 [512 64 112 112] 64 arcface bwd 109970712 46718600 2.36
PReLU float32 [512 64 56 56] 64 arcface bwd 30409798 13333300 2.30
PReLU float32 [512 128 56 56] 128 arcface bwd 57683256 13940500 4.16
PReLU float32 [512 128 28 28] 128 arcface bwd 16256649 3878210 4.27
PReLU float32 [512 256 28 28] 256 arcface bwd 30886236 5124050 6.09
PReLU float32 [512 256 14 14] 256 arcface bwd 8320915 1522140 5.66
PReLU float32 [512 512 14 14] 512 arcface bwd 16164181 2137560 7.70
PReLU float32 [512 512 7 7] 512 arcface bwd 4095063 685281 6.41
bfloat16
op_name dtype size num_param model direction ROCm pytorch MIOpen HIP Improvement
PReLU bfloat16 [512 64 112 112] 1 arcface bwd 99413514 12540100 7.95
PReLU bfloat16 [512 64 56 56] 1 arcface bwd 26839289 3139650 8.63
PReLU bfloat16 [512 128 56 56] 1 arcface bwd 52355515 6275980 8.38
PReLU bfloat16 [512 128 28 28] 1 arcface bwd 13353346 1579090 8.60
PReLU bfloat16 [512 256 28 28] 1 arcface bwd 26856779 3140230 8.64
PReLU bfloat16 [512 256 14 14] 1 arcface bwd 6645672 801566 8.56
PReLU bfloat16 [512 512 14 14] 1 arcface bwd 13434983 1577450 8.66
PReLU bfloat16 [512 512 7 7] 1 arcface bwd 3310021 410969 8.57
PReLU bfloat16 [512 64 112 112] 64 arcface bwd 106599289 46867700 2.28
PReLU bfloat16 [512 64 56 56] 64 arcface bwd 30100583 11790800 2.57
PReLU bfloat16 [512 128 56 56] 128 arcface bwd 55829383 14050900 3.99
PReLU bfloat16 [512 128 28 28] 128 arcface bwd 15548941 3913070 4.03
PReLU bfloat16 [512 256 28 28] 256 arcface bwd 29955610 5126710 5.89
PReLU bfloat16 [512 256 14 14] 256 arcface bwd 7793750 1546550 5.17
PReLU bfloat16 [512 512 14 14] 512 arcface bwd 15556630 2171380 7.27
PReLU bfloat16 [512 512 7 7] 512 arcface bwd 3862986 692056 5.91
  • Average over all cases:
type average
float16 7.13
float32 7.83
bfloat16 7.32

long10024070 avatar Jul 25 '24 08:07 long10024070

@junliume Can you help me checking the failing reason of "Window Build" and "Jenkins - Fp32 Hip Debug gfx90a"?

long10024070 avatar Aug 01 '24 08:08 long10024070

@junliume Can you help me checking the failing reason of "Window Build" and "Jenkins - Fp32 Hip Debug gfx90a"?

I think these are the issues causing the CI to fail.

Window Build windows_build

Jenkins - Fp32 Hip Debug gfx90a gfx90a_failure

sgundabo avatar Aug 01 '24 21:08 sgundabo

@sgundabo Thank you very much.

long10024070 avatar Aug 02 '24 10:08 long10024070

@junliume Github action has passed. @CAHEK7 Would you review this PR?

long10024070 avatar Aug 12 '24 02:08 long10024070

May I ask about the version of Troch and graphics card used for this test?

yanjianglu avatar Jun 04 '25 11:06 yanjianglu

May I ask about the version of Troch and graphics card used for this test?

@yanjianglu I used MI250 GPU and Pytorch from rocm/pytorch. I suppose the image tag should be rocm6.1.2_ubuntu22.04_py3.10_pytorch_release-2.1.2.

long10024070 avatar Jun 04 '25 18:06 long10024070