oneDNN icon indicating copy to clipboard operation
oneDNN copied to clipboard

cpu: aarch64: extend brgemm conv to sve_128

Open jondea opened this issue 7 months ago • 1 comments

Naively extend to sve_128, with ~2x speedups over gemm:ref in most cases. In practice, this PR will just speed up BWD_D because if oneDNN is built with Arm compute Library (ACL), the FWD_D numbers will be unchanged because ACL is higher in the impl list.

There are almost certainly more optimizations to be made, specifically the heuristics need tweaking, we skip any brg_blocking_t which have oc_block < 16 because they always seem slower even when est_eff suggests otherwise.

Benchmarks when built without ACL (to demonstrate speedups over gemm:ref for FWD_D). Shapes are taken from PyTorch ResNet50 training example

                                        brg conv / gemm:ref (<1 means brg conv is faster)
version                                    build
threads                                       1     8     16    32    64
dir   desc
BWD_D mb128ic1024ih14oc256oh14kh1ph0         .37   .40   .47   .50   .50
      mb128ic1024ih14oc512oh14kh1ph0         .37   .41   .48   .52   .51
      mb128ic128ih28oc128oh28kh3ph1          .55   .60   .69   .67   .65
      mb128ic128ih28oc512oh28kh1ph0          .54   .59   .69   .69   .68
      mb128ic2048ih7oc512oh7kh1ph0           .39   .42   .50   .52   .52
      mb128ic256ih14oc1024oh14kh1ph0         .43   .47   .55   .54   .54
      mb128ic256ih14oc256oh14kh3ph1          .45   .49   .58   .58   .58
      mb128ic256ih56oc128oh56kh1ph0          .50   .54   .62   .60   .57
      mb128ic256ih56oc64oh56kh1ph0           .50   .54   .62   .60   .60
      mb128ic512ih28oc128oh28kh1ph0          .59   .64   .73   .70   .68
      mb128ic512ih28oc256oh28kh1ph0          .60   .66   .77   .74   .71
      mb128ic512ih7oc2048oh7kh1ph0           .49   .54   .63   .62   .61
      mb128ic512ih7oc512oh7kh3ph1            .50   .51   .60   .60   .61
      mb128ic64ih56oc256oh56kh1ph0           .35   .38   .44   .45   .47
      mb128ic64ih56oc64oh56kh3ph1            .51   .55   .63   .60   .57
FWD_D mb128ic1024ih14oc2048oh7kh1sh2ph0      .55   .61   .72   .75   .78
      mb128ic1024ih14oc256oh14kh1ph0         .45   .49   .58   .61   .60
      mb128ic1024ih14oc512oh14kh1ph0         .51   .56   .65   .67   .67
      mb128ic128ih28oc128oh28kh3ph1          .19   .21   .24   .27   .29
      mb128ic128ih28oc512oh28kh1ph0          .65   .70   .79   .77   .76
      mb128ic128ih56oc128oh28kh3sh2ph1       .17   .19   .22   .26   .28
      mb128ic2048ih7oc512oh7kh1ph0           .52   .58   .67   .68   .69
      mb128ic256ih14oc1024oh14kh1ph0         .51   .55   .65   .67   .67
      mb128ic256ih14oc256oh14kh3ph1          .41   .45   .55   .59   .60
      mb128ic256ih28oc256oh14kh3sh2ph1       .44   .49   .59   .64   .64
      mb128ic256ih56oc128oh56kh1ph0          .53   .58   .68   .67   .67
      mb128ic256ih56oc512oh28kh1sh2ph0       .24   .26   .31   .35   .36
      mb128ic256ih56oc64oh56kh1ph0           .34   .37   .43   .44   .46
      mb128ic3ih224oc64oh112kh7sh2ph3        .32   .34   .39   .42   .42
      mb128ic512ih14oc512oh7kh3sh2ph1        .55   .61   .73   .76   .78
      mb128ic512ih28oc1024oh14kh1sh2ph0      .53   .58   .68   .70   .70
      mb128ic512ih28oc128oh28kh1ph0          .55   .60   .70   .70   .69
      mb128ic512ih28oc256oh28kh1ph0          .55   .60   .71   .70   .70
      mb128ic512ih7oc2048oh7kh1ph0           .54   .59   .69   .68   .68
      mb128ic512ih7oc512oh7kh3ph1            .50   .47   .55   .60   .62
      mb128ic64ih56oc256oh56kh1ph0           .52   .56   .65   .64   .63
      mb128ic64ih56oc64oh56kh3ph1            .26   .28   .33   .36   .37

Checklist

General

  • [x] Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit? (except known failures)
  • [x] Have you formatted the code using clang-format?

Performance improvements

  • [x] Have you submitted performance data that demonstrates performance improvements?

jondea avatar May 29 '25 09:05 jondea

Would appreciate a review from @uxlfoundation/onednn-arch, thanks.

Sqvid avatar Jun 13 '25 11:06 Sqvid

Just a heads up this is now giving incorrect results when running u8/s8 brgemm, eg.: ./tests/benchdnn/benchdnn --brgemm --dt=u8:s8:u8 13x192:192x32_n"int8:no_tail:21" (from test_benchdnn_modeC_brgemm_ci_cpu), which gives expected results with sve_256 but not sve_128. On the other hand, bf16 brgemm from #3731 seems to work as expected, and with some extra minor changes the bf16 brgconv sve_128 also worked in all the cases I tested.

michalowski-arm avatar Aug 19 '25 14:08 michalowski-arm

we've run in some issues with bf16 conv, with this change as is (after including the stubs in cpu_conv_list.cpp), following this PR : https://github.com/uxlfoundation/oneDNN/pull/3731 . We can raise followup a patch on top of this. cc: @michalowski-arm

aditew01 avatar Aug 21 '25 10:08 aditew01

Good spot both, I will sort this one way or another and get it in.

jondea avatar Aug 21 '25 15:08 jondea

This is not necessarily a blocker for this to go in. The BF16 conv needs to be enabled explicitly in the cpu_conv_list. We can fix the related issues and push the change.

aditew01 avatar Aug 26 '25 09:08 aditew01

Thanks @michalowski-arm, int8 brgemm now passes, thankfully it was a very simple change

jondea avatar Aug 27 '25 13:08 jondea

@Sqvid I've done some e2e benchmarks and the impact is minimal, but sometimes negative. However, we have work in progress to mitigate it. On balance, it is worth getting this in so that we can build on top of it.

jondea avatar Aug 27 '25 13:08 jondea

@jondea Removed the block. You can merge when you're ready.

Sqvid avatar Aug 27 '25 13:08 Sqvid