AITemplate
AITemplate copied to clipboard
Adding stream k for gemm universal
Adding support for StreamK for GEMM operations (except for group GEMM).
As workspace calculations are hard to cap for a particular range of shapes (and max shape calcualation doesn't always work) here we use a fixed workspace size for GEMM when StreamK is used (which would cover most of the day-to-day use cases).
It does provide some speedups for smaller shapes in case of gemm_rrr
.
Single-kernel benchmarks:
Time change (LOWER - BETTER):
Op: gemm_rrr Shape: (256, 128, 32) Time change: -4.44% (best avail_sms=64)
Op: gemm_rrr Shape: (256, 256, 128) Time change: -4.46% (best avail_sms=16)
Op: gemm_rrr Shape: (256, 256, 256) Time change: -3.63% (best avail_sms=32)
Op: gemm_rrr Shape: (256, 256, 512) Time change: -3.90% (best avail_sms=4)
Op: gemm_rrr Shape: (256, 256, 1024) Time change: -3.44% (best avail_sms=1)
Op: gemm_rrr Shape: (256, 256, 2048) Time change: +3.59% (best avail_sms=32)
Op: gemm_rrr Shape: (256, 256, 4096) Time change: -2.80% (best avail_sms=1)
Op: gemm_rrr Shape: (256, 256, 8192) Time change: -5.43% (best avail_sms=4)
Op: gemm_rrr Shape: (256, 256, 16384) Time change: -6.02% (best avail_sms=16)
Op: gemm_rrr Shape: (512, 256, 64) Time change: -3.27% (best avail_sms=8)
Op: gemm_rrr Shape: (512, 512, 128) Time change: -5.83% (best avail_sms=8)
Op: gemm_rrr Shape: (512, 512, 256) Time change: -0.82% (best avail_sms=-1)
Op: gemm_rrr Shape: (512, 512, 512) Time change: -0.20% (best avail_sms=2)
Op: gemm_rrr Shape: (512, 512, 1024) Time change: +10.01% (best avail_sms=4)
Op: gemm_rrr Shape: (512, 512, 2048) Time change: -3.62% (best avail_sms=16)
Op: gemm_rrr Shape: (512, 512, 4096) Time change: -6.56% (best avail_sms=1)
Op: gemm_rrr Shape: (512, 512, 8192) Time change: +5.86% (best avail_sms=2)
Op: gemm_rrr Shape: (512, 512, 16384) Time change: +3.35% (best avail_sms=1)
Op: gemm_rrr Shape: (1024, 512, 128) Time change: -5.36% (best avail_sms=8)
Op: gemm_rrr Shape: (1024, 1024, 128) Time change: -2.16% (best avail_sms=2)
Op: gemm_rrr Shape: (1024, 1024, 256) Time change: -4.60% (best avail_sms=32)
Op: gemm_rrr Shape: (1024, 1024, 512) Time change: +9.23% (best avail_sms=2)
Op: gemm_rrr Shape: (1024, 1024, 1024) Time change: -2.59% (best avail_sms=16)
Op: gemm_rrr Shape: (1024, 1024, 2048) Time change: -3.12% (best avail_sms=8)
Op: gemm_rrr Shape: (1024, 1024, 4096) Time change: +4.59% (best avail_sms=2)
Op: gemm_rrr Shape: (1024, 1024, 8192) Time change: +8.05% (best avail_sms=2)
Op: gemm_rrr Shape: (1024, 1024, 16384) Time change: +4.47% (best avail_sms=-1)
Op: gemm_rrr Shape: (2048, 1024, 256) Time change: +2.88% (best avail_sms=16)
Op: gemm_rrr Avg Time change: -0.72%
Op: gemm_rcr_bias Shape: (256, 128, 32) Time change: +5.04% (best avail_sms=64)
Op: gemm_rcr_bias Shape: (256, 256, 128) Time change: +6.39% (best avail_sms=1)
Op: gemm_rcr_bias Shape: (256, 256, 256) Time change: +5.60% (best avail_sms=-1)
Op: gemm_rcr_bias Shape: (256, 256, 512) Time change: +5.45% (best avail_sms=16)
Op: gemm_rcr_bias Shape: (256, 256, 1024) Time change: +4.57% (best avail_sms=64)
Op: gemm_rcr_bias Shape: (256, 256, 2048) Time change: +7.79% (best avail_sms=4)
Op: gemm_rcr_bias Shape: (256, 256, 4096) Time change: +8.03% (best avail_sms=16)
Op: gemm_rcr_bias Shape: (256, 256, 8192) Time change: -1.52% (best avail_sms=32)
Op: gemm_rcr_bias Shape: (256, 256, 16384) Time change: +2.85% (best avail_sms=1)
Op: gemm_rcr_bias Shape: (512, 256, 64) Time change: +4.50% (best avail_sms=1)
Op: gemm_rcr_bias Shape: (512, 512, 128) Time change: +2.96% (best avail_sms=16)
Op: gemm_rcr_bias Shape: (512, 512, 256) Time change: +3.07% (best avail_sms=1)
Op: gemm_rcr_bias Shape: (512, 512, 512) Time change: +4.21% (best avail_sms=4)
Op: gemm_rcr_bias Shape: (512, 512, 1024) Time change: +4.93% (best avail_sms=8)
Op: gemm_rcr_bias Shape: (512, 512, 2048) Time change: +4.67% (best avail_sms=8)
Op: gemm_rcr_bias Shape: (512, 512, 4096) Time change: -1.10% (best avail_sms=8)
Op: gemm_rcr_bias Shape: (512, 512, 8192) Time change: +6.28% (best avail_sms=8)
Op: gemm_rcr_bias Shape: (512, 512, 16384) Time change: +8.10% (best avail_sms=1)
Op: gemm_rcr_bias Shape: (1024, 512, 128) Time change: +4.32% (best avail_sms=16)
Op: gemm_rcr_bias Shape: (1024, 1024, 128) Time change: +2.14% (best avail_sms=8)
Op: gemm_rcr_bias Shape: (1024, 1024, 256) Time change: +3.75% (best avail_sms=1)
Op: gemm_rcr_bias Shape: (1024, 1024, 512) Time change: +6.92% (best avail_sms=1)
Op: gemm_rcr_bias Shape: (1024, 1024, 1024) Time change: +5.67% (best avail_sms=64)
Op: gemm_rcr_bias Shape: (1024, 1024, 2048) Time change: +0.30% (best avail_sms=16)
Op: gemm_rcr_bias Shape: (1024, 1024, 4096) Time change: +9.72% (best avail_sms=64)
Op: gemm_rcr_bias Shape: (1024, 1024, 8192) Time change: +10.09% (best avail_sms=2)
Op: gemm_rcr_bias Shape: (1024, 1024, 16384) Time change: +5.73% (best avail_sms=32)
Op: gemm_rcr_bias Shape: (2048, 1024, 256) Time change: +5.93% (best avail_sms=2)
Op: gemm_rcr_bias Avg Time change: +4.87%