[QUESTION] Why `Warp`'s tile-based matmul is much slower than torch's one?
I have tried the 1.6.0 version of warp, and tested the performance between tile-based matmul in warp and matmul of torch. I found that the performance of warp seems to be very slow. I am wondering why and in the future it is possible to solve this performance discrepancy?
TILE_M TILE_N TILE_K BLOCK Warp Time Torch Time Relative
64 64 64 256 981.684936 363.559419 2.7002049312879994
64 64 64 512 1121.447108 363.559419 3.084632248243306
64 64 64 1024 1146.522702 363.559419 3.153604726164446
64 64 128 256 1436.224992 363.559419 3.950454635312309
64 64 128 512 1050.912843 363.559419 2.890621967354393
64 64 128 1024 1039.730605 363.559419 2.859864304602159
64 128 64 256 1321.610127 363.559419 3.6351970487663254
64 128 64 512 1231.751565 363.559419 3.388033704058703
64 128 64 1024 1123.240676 363.559419 3.089565604130311
Thanks.
Hi @chaoming0625, there are various improvements on the way to close the performance gap between cuBLAS and cuBLASDx. Could you please share complete details about your benchmark so that we can understand this comparison better?
- GPU
- Memory clock, SM clock
- Data type
- Matrix size
- Benchmark script
Thanks. Here is the information:
- GPU: NVIDIA GeForce RTX 3080 Ti Laptop GPU
- Memory Clock: 16 Gbps (effective)
- SM Clock (Shader/Graphics Clock): 1.365 GHz (base) and up to 1.7 GHz (boost), sm_86
- Data type: float32
- Matrix size: 2048 x 2048, 2048 x 2048
- Benckmark script:
from itertools import product
import torch as tc
import warp as wp
tc.backends.cuda.matmul.allow_tf32 = False # Disable TF32 for matrix multiplications
tc.backends.cudnn.allow_tf32 = False # Disable TF32 for cuDNN operations
wp.init()
wp.clear_kernel_cache()
wp.set_module_options({"fast_math": True, "enable_backward": False})
def create_mlp_kernel(m, n, k):
TILE_M = m
TILE_N = n
TILE_K = k
@wp.kernel
def mlp(x: wp.array2d(dtype=float), weights_wp: wp.array2d(dtype=float), n_k: int, output: wp.array2d(dtype=float)):
i_m, i_n = wp.tid()
sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
for count in range(n_k):
feat = wp.tile_load(x, shape=(TILE_M, TILE_K), offset=(i_m * TILE_M, count * TILE_K))
weight = wp.tile_load(weights_wp, shape=(TILE_K, TILE_N), offset=(count * TILE_K, i_n * TILE_N))
wp.tile_matmul(feat, weight, sum)
wp.tile_store(output, sum, offset=(i_m * TILE_M, i_n * TILE_M))
return mlp
def benchmark_torch(A, B, warm_up, iterations):
# warm-up
for _ in range(warm_up):
tc.matmul(A, B)
timers = {}
tc.cuda.synchronize()
with wp.ScopedTimer("torch", print=False, dict=timers, synchronize=True):
for _ in range(iterations):
tc.matmul(A, B)
tc.cuda.synchronize()
return timers["torch"][0]
def benchmark_warp(A, B, config, warm_up, iterations):
TILE_M = config[0]
TILE_N = config[1]
TILE_K = config[2]
BLOCK_DIM = config[3]
mlp = create_mlp_kernel(TILE_M, TILE_N, TILE_K)
M = A.shape[0]
N = B.shape[1]
K = A.shape[1]
output = wp.zeros((M, N), dtype=float)
# warm-up
for _ in range(warm_up):
wp.launch_tiled(
kernel=mlp, dim=[M // TILE_M, N // TILE_N], inputs=[A, B, K // TILE_K, output], block_dim=BLOCK_DIM
)
# # check output
# if warm_up > 0:
# assert np.allclose(output.numpy(), A.numpy() @ B.numpy(), atol=1e-3, rtol=1e-3)
# benchmark
timers = {}
with wp.ScopedTimer("warp", print=False, dict=timers, synchronize=True):
for _ in range(iterations):
wp.launch_tiled(
kernel=mlp, dim=[M // TILE_M, N // TILE_N], inputs=[A, B, K // TILE_K, output], block_dim=BLOCK_DIM
)
return timers["warp"][0]
# tile_m = [8, 16, 32, 64]
# tile_n = [8, 16, 32, 64]
# tile_k = [8, 16, 64]
tile_m = [64, 128]
tile_n = [64, 128]
tile_k = [64, 128]
block = [256, 512, 1024]
# tile_m = [ 128]
# tile_n = [ 64]
# tile_k = [64, 128]
# block = [256, 512, ]
M = 1024 * 2
N = 1024 * 2
K = 1024 * 2
A = tc.randn(M, K).cuda()
B = tc.randn(K, N).cuda()
print(A.dtype)
iterations = 1000
warm_up = 10
time_torch = benchmark_torch(A, B, warm_up, iterations)
print(f"Torch: {time_torch}")
configs = list(product(tile_m, tile_n, tile_k, block))
wp.config.quiet = True
# header
print(
"{:<{}} {:<{}} {:<{}} {:<{}} {:<{}} {:<{}} {:<{}}".format(
"TILE_M", 12, "TILE_N", 12, "TILE_K", 12, "BLOCK", 12, "Warp Time", 12, "Torch Time", 12, "Relative", 12
)
)
for c in configs:
time_warp = benchmark_warp(wp.from_torch(A), wp.from_torch(B), c, warm_up, iterations)
print(
"{:<{}} {:<{}} {:<{}} {:<{}} {:<{}} {:<{}} {:<{}}".format(
c[0], 12, c[1], 12, c[2], 12, c[3], 12, time_warp, 12, time_torch, 12, time_warp / time_torch, 12
)
)
Thanks! Will come back to this thread when we have an update on performance.
- There was a performance regression introduced at the time you were running the benchmark that was fixed before 1.6.0 was published (42812b58fa592b2a73e6ea238bdbc4853b9a782b).
- I also made a minor update to the benchmark script in https://github.com/NVIDIA/warp/blob/main/warp/examples/benchmarks/benchmark_gemm.py
I wasn't able to run with the settings you provided on my system (RTX 3090). 1024 blocks didn't work, and eventually I get an error having to do with not enough shared memory being available:
tile_m = [64, 128]
tile_n = [64, 128]
tile_k = [64, 128]
block = [128, 256, 512]
M = 2 * 1024
N = 2 * 1024
K = 2 * 1024
Results:
M=2048, N=2048, K=2048
Torch: 0.787639±0.053 ms
TILE_M TILE_N TILE_K BLOCK Time (ms) Std dev (ms) Warp/Torch
-------------------------------------------------------------------------------
64 64 64 128 1.34301 0.026 1.7051
64 64 64 256 1.68695 0.0025 2.14178
64 64 64 512 2.51692 0.11 3.19553
64 64 128 128 2.15439 0.10 2.73526
64 64 128 256 1.95791 0.0023 2.48579
64 64 128 512 2.18375 0.069 2.77253
64 128 64 128 2.05197 0.0021 2.60522
64 128 64 256 1.93148 0.0020 2.45224
64 128 64 512 2.00088 0.079 2.54035
Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure create_gemm_kernel__locals__gemm_b16767b7_cuda_kernel_forward kernel for 131072 bytes, but maximum available is 101376
Failed with TILE_M=64, TILE_N=128, TILE_K=128, BLOCK_DIM=128
Will update this thread when we get more perf improvements added.
Thank you so much. This is exciting!
Updating this to an enhancement request. We should include benchmarks comparing frameworks in ci/cd.