FBGEMM
FBGEMM copied to clipboard
Make CUTLASS rowwise fp8 faster
Summary: By telling CUTLASS to output in column-major (somehow it's faster) and transposing the inputs so that the end result is the same.
Here are the benchmark results for the sizes I was interested in:
m=4176 n=12288 k=4096 cuBLAS: min=260.414, max=316.191, median=310.302, mean=306.742 +/- 11.8098 FBGEMM, before: min=359.87, max=379.134, median=373.949, mean=372.939 +/- 3.81074 FBGEMM, now: min=268.318, max=320.51, median=317.47, mean=313.336 +/- 11.3555
m=4176 n=4096 k=12288 cuBLAS: min=269.567, max=334.398, median=330.399, mean=325.656 +/- 13.7557 FBGEMM, before: min=327.551, max=365.47, median=361.535, mean=359.815 +/- 4.8227 FBGEMM, now: min=318.815, max=381.853, median=363.486, mean=361.386 +/- 11.5182
m=3456 n=6144 k=4096 cuBLAS: min=112.512, max=133.856, median=131.935, mean=128.974 +/- 6.02215 FBGEMM, before: min=157.632, max=165.151, median=161.407, mean=161.407 +/- 1.22362 FBGEMM, now: min=113.727, max=135.775, median=132.063, mean=129.611 +/- 5.33297
m=3456 n=4096 k=6144 cuBLAS: min=117.023, max=142.847, median=138.687, mean=134.55 +/- 8.34265 FBGEMM, before: min=154.11, max=163.232, median=160.383, mean=159.458 +/- 2.28465 FBGEMM, now: min=119.904, max=143.198, median=140.447, mean=137.979 +/- 5.44675
m=3456 n=4096 k=4096 cuBLAS: min=78.271, max=91.583, median=89.279, mean=86.9895 +/- 4.35227 FBGEMM, before: min=102.655, max=108.255, median=106.079, mean=105.866 +/- 1.09653 FBGEMM, now: min=78.912, max=94.559, median=91.583, mean=88.8953 +/- 4.85437
m=3456 n=12288 k=4096 cuBLAS: min=218.591, max=262.335, median=255.646, mean=252.38 +/- 10.1951 FBGEMM, before: min=302.783, max=319.998, median=313.662, mean=313.112 +/- 3.32206 FBGEMM, now: min=226.654, max=270.59, median=264.734, mean=260.988 +/- 8.79772
m=3456 n=4096 k=12288 cuBLAS: min=249.406, max=297.022, median=285.151, mean=283.93 +/- 6.82518 FBGEMM, before: min=305.982, max=346.558, median=338.015, mean=335.894 +/- 7.76891 FBGEMM, now: min=246.75, max=287.87, median=282.91, mean=280.271 +/- 8.47942
m=4176 n=6144 k=4096 cuBLAS: min=133.151, max=160.224, median=156.543, mean=153.465 +/- 7.28038 FBGEMM, before: min=187.071, max=194.399, median=191.967, mean=191.66 +/- 1.28559 FBGEMM, now: min=135.295, max=163.742, median=158.719, mean=155.969 +/- 7.42937
m=4176 n=4096 k=6144 cuBLAS: min=138.367, max=171.231, median=167.487, mean=164.666 +/- 8.29714 FBGEMM, before: min=165.407, max=187.135, median=183.968, mean=182.002 +/- 4.73605 FBGEMM, now: min=164.638, max=185.471, median=180.542, mean=178.779 +/- 4.95891
https://pxl.cl/566wB
This is the code I used to get the above numbers:
for m, n, k in mnks[:]:
print(f"{m=} {n=} {k=}")
a = torch.randn((m, k), device="cuda").to(torch.float8_e4m3fn)
b = torch.randn((n, k), device="cuda").to(torch.float8_e4m3fn)
scale_a = torch.randn((m,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((n,), device="cuda", dtype=torch.float32)
torch._scaled_mm(a, b.t(), scale_a=scale_a[0], scale_b=scale_b[0], out_dtype=torch.bfloat16, use_fast_accum=True)
torch.ops.fbgemm.f8f8bf16_rowwise(a, b, scale_a, scale_b, use_fast_accum=True)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
for _ in range(1000):
torch._scaled_mm(a, b.t(), scale_a=scale_a[0], scale_b=scale_b[0], out_dtype=torch.bfloat16, use_fast_accum=True)
torch.ops.fbgemm.f8f8bf16_rowwise(a, b, scale_a, scale_b, use_fast_accum=True)
stats = {}
for event in prof.events():
if event.cuda_time > 0:
stats.setdefault(event.key, []).append(event.cuda_time)
for key, times in stats.items():
times.sort()
t = torch.tensor(times)
std, mean = torch.std_mean(t)
print(f"{key[:100]}: min={times[0]:g}, max={times[-1]:g}, median={times[len(times)//2]:g}, mean={mean:g} +/- {std:g}")
I ran the benchmarks on devgpu002.eag5 which has 700W 80GB H100 GPUs.
Differential Revision: D58821928
Deploy Preview for pytorch-fbgemm-docs ready!
| Name | Link |
|---|---|
| Latest commit | d4dfa64c9267b8c8db3f24c9b6bed779c3cbd68c |
| Latest deploy log | https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/66755ee109af390008bc8500 |
| Deploy Preview | https://deploy-preview-2764--pytorch-fbgemm-docs.netlify.app |
| Preview on mobile | Toggle QR Code...Use your smartphone camera to open QR code link. |
To edit notification comments on pull requests, go to your Netlify site configuration.
This pull request was exported from Phabricator. Differential Revision: D58821928