warp icon indicating copy to clipboard operation
warp copied to clipboard

[QUESTION] Why `Warp`'s tile-based matmul is much slower than torch's one?

Open chaoming0625 opened this issue 11 months ago • 6 comments

I have tried the 1.6.0 version of warp, and tested the performance between tile-based matmul in warp and matmul of torch. I found that the performance of warp seems to be very slow. I am wondering why and in the future it is possible to solve this performance discrepancy?


TILE_M       TILE_N       TILE_K       BLOCK        Warp Time    Torch Time   Relative    
64           64           64           256          981.684936   363.559419   2.7002049312879994
64           64           64           512          1121.447108  363.559419   3.084632248243306
64           64           64           1024         1146.522702  363.559419   3.153604726164446
64           64           128          256          1436.224992  363.559419   3.950454635312309
64           64           128          512          1050.912843  363.559419   2.890621967354393
64           64           128          1024         1039.730605  363.559419   2.859864304602159
64           128          64           256          1321.610127  363.559419   3.6351970487663254
64           128          64           512          1231.751565  363.559419   3.388033704058703
64           128          64           1024         1123.240676  363.559419   3.089565604130311

Thanks.

chaoming0625 avatar Jan 27 '25 03:01 chaoming0625

Hi @chaoming0625, there are various improvements on the way to close the performance gap between cuBLAS and cuBLASDx. Could you please share complete details about your benchmark so that we can understand this comparison better?

  • GPU
  • Memory clock, SM clock
  • Data type
  • Matrix size
  • Benchmark script

shi-eric avatar Jan 27 '25 16:01 shi-eric

Thanks. Here is the information:

  • GPU: NVIDIA GeForce RTX 3080 Ti Laptop GPU
  • Memory Clock: 16 Gbps (effective)
  • SM Clock (Shader/Graphics Clock): 1.365 GHz (base) and up to 1.7 GHz (boost), sm_86
  • Data type: float32
  • Matrix size: 2048 x 2048, 2048 x 2048
  • Benckmark script:

from itertools import product

import torch as tc

import warp as wp

tc.backends.cuda.matmul.allow_tf32 = False  # Disable TF32 for matrix multiplications
tc.backends.cudnn.allow_tf32 = False  # Disable TF32 for cuDNN operations

wp.init()
wp.clear_kernel_cache()
wp.set_module_options({"fast_math": True, "enable_backward": False})


def create_mlp_kernel(m, n, k):
    TILE_M = m
    TILE_N = n
    TILE_K = k

    @wp.kernel
    def mlp(x: wp.array2d(dtype=float), weights_wp: wp.array2d(dtype=float), n_k: int, output: wp.array2d(dtype=float)):
        i_m, i_n = wp.tid()
        sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
        for count in range(n_k):
            feat = wp.tile_load(x, shape=(TILE_M, TILE_K), offset=(i_m * TILE_M, count * TILE_K))
            weight = wp.tile_load(weights_wp, shape=(TILE_K, TILE_N), offset=(count * TILE_K, i_n * TILE_N))
            wp.tile_matmul(feat, weight, sum)

        wp.tile_store(output, sum, offset=(i_m * TILE_M, i_n * TILE_M))

    return mlp


def benchmark_torch(A, B, warm_up, iterations):
    # warm-up
    for _ in range(warm_up):
        tc.matmul(A, B)

    timers = {}
    tc.cuda.synchronize()

    with wp.ScopedTimer("torch", print=False, dict=timers, synchronize=True):
        for _ in range(iterations):
            tc.matmul(A, B)

        tc.cuda.synchronize()

    return timers["torch"][0]


def benchmark_warp(A, B, config, warm_up, iterations):
    TILE_M = config[0]
    TILE_N = config[1]
    TILE_K = config[2]
    BLOCK_DIM = config[3]

    mlp = create_mlp_kernel(TILE_M, TILE_N, TILE_K)

    M = A.shape[0]
    N = B.shape[1]
    K = A.shape[1]

    output = wp.zeros((M, N), dtype=float)

    # warm-up
    for _ in range(warm_up):
        wp.launch_tiled(
            kernel=mlp, dim=[M // TILE_M, N // TILE_N], inputs=[A, B, K // TILE_K, output], block_dim=BLOCK_DIM
        )

    # # check output
    # if warm_up > 0:
    #     assert np.allclose(output.numpy(), A.numpy() @ B.numpy(), atol=1e-3, rtol=1e-3)

    # benchmark
    timers = {}
    with wp.ScopedTimer("warp", print=False, dict=timers, synchronize=True):
        for _ in range(iterations):
            wp.launch_tiled(
                kernel=mlp, dim=[M // TILE_M, N // TILE_N], inputs=[A, B, K // TILE_K, output], block_dim=BLOCK_DIM
            )

    return timers["warp"][0]


# tile_m = [8, 16, 32, 64]
# tile_n = [8, 16, 32, 64]
# tile_k = [8, 16, 64]

tile_m = [64, 128]
tile_n = [64, 128]
tile_k = [64, 128]
block = [256, 512, 1024]

# tile_m = [ 128]
# tile_n = [ 64]
# tile_k = [64, 128]
# block = [256, 512, ]

M = 1024 * 2
N = 1024 * 2
K = 1024 * 2

A = tc.randn(M, K).cuda()
B = tc.randn(K, N).cuda()

print(A.dtype)

iterations = 1000
warm_up = 10

time_torch = benchmark_torch(A, B, warm_up, iterations)
print(f"Torch: {time_torch}")

configs = list(product(tile_m, tile_n, tile_k, block))

wp.config.quiet = True

# header
print(
    "{:<{}} {:<{}} {:<{}} {:<{}} {:<{}} {:<{}} {:<{}}".format(
        "TILE_M", 12, "TILE_N", 12, "TILE_K", 12, "BLOCK", 12, "Warp Time", 12, "Torch Time", 12, "Relative", 12
    )
)
for c in configs:
    time_warp = benchmark_warp(wp.from_torch(A), wp.from_torch(B), c, warm_up, iterations)
    print(
        "{:<{}} {:<{}} {:<{}} {:<{}} {:<{}} {:<{}} {:<{}}".format(
            c[0], 12, c[1], 12, c[2], 12, c[3], 12, time_warp, 12, time_torch, 12, time_warp / time_torch, 12
        )
    )


chaoming0625 avatar Jan 28 '25 08:01 chaoming0625

Thanks! Will come back to this thread when we have an update on performance.

shi-eric avatar Jan 29 '25 00:01 shi-eric

  • There was a performance regression introduced at the time you were running the benchmark that was fixed before 1.6.0 was published (42812b58fa592b2a73e6ea238bdbc4853b9a782b).
  • I also made a minor update to the benchmark script in https://github.com/NVIDIA/warp/blob/main/warp/examples/benchmarks/benchmark_gemm.py

I wasn't able to run with the settings you provided on my system (RTX 3090). 1024 blocks didn't work, and eventually I get an error having to do with not enough shared memory being available:

    tile_m = [64, 128]
    tile_n = [64, 128]
    tile_k = [64, 128]
    block = [128, 256, 512]

    M = 2 * 1024
    N = 2 * 1024
    K = 2 * 1024

Results:

M=2048, N=2048, K=2048
Torch: 0.787639±0.053 ms
TILE_M   TILE_N   TILE_K   BLOCK    Time (ms)  Std dev (ms)   Warp/Torch  
-------------------------------------------------------------------------------
64       64       64       128      1.34301    0.026          1.7051      
64       64       64       256      1.68695    0.0025         2.14178     
64       64       64       512      2.51692    0.11           3.19553     
64       64       128      128      2.15439    0.10           2.73526     
64       64       128      256      1.95791    0.0023         2.48579     
64       64       128      512      2.18375    0.069          2.77253     
64       128      64       128      2.05197    0.0021         2.60522     
64       128      64       256      1.93148    0.0020         2.45224     
64       128      64       512      2.00088    0.079          2.54035     
Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure create_gemm_kernel__locals__gemm_b16767b7_cuda_kernel_forward kernel for 131072 bytes, but maximum available is 101376
Failed with TILE_M=64, TILE_N=128, TILE_K=128, BLOCK_DIM=128

Will update this thread when we get more perf improvements added.

shi-eric avatar Feb 18 '25 18:02 shi-eric

Thank you so much. This is exciting!

chaoming0625 avatar Feb 19 '25 15:02 chaoming0625

Updating this to an enhancement request. We should include benchmarks comparing frameworks in ci/cd.

daedalus5 avatar Oct 06 '25 13:10 daedalus5