triton
triton copied to clipboard
3xTF32 precision issues
(cc @lezcano)
I did some experimenting regarding 3xTF32 precision in Triton. I used matmul tutorial code, changed to use 3xTF32 calculation, and compared its precision to the precision of 3xTF32 implementation in CUTLASS (the matmul tutorial code is also changed to match, to the extent possible, generated CUTLASS code regarding size of tiles processed by thread blocks/warps). Below is the script that I used for experimenting, it would print MSE error for Triton and CUTLASS calculations, measured against full F32 calculation precision in PyTorch (that, as an additional reference check, I verified is matching NumPy results produced on CPU). To run the script, PyTorch and nvidia-cutlass packages are to be installed, as well as CUDA SDK.
The script used for testing
import pandas as pd
import torch
import torch.utils.benchmark as benchmark
import triton
import triton.language as tl
import cutlass
dtype = torch.float32
device = "cuda"
loss = torch.nn.MSELoss()
def cutlass_mm(a, b):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
m, n = a.shape[0], b.shape[1]
d = torch.empty((m, n), dtype=a.dtype, device=a.device)
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32
alpha = 1
beta = 0
plan.run(a, b, d, d, alpha, beta, print_module=False)
return d
@triton.jit
def triton_mm_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator, input_precision="tf32x3")
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def triton_mm(a, b):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 32
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
triton_mm_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=8,
num_warps=8,
num_stages=3,
)
return c
torch.manual_seed(1234)
dims = []
triton_3xtf32_loss = []
cutlass_3xtf32_loss = []
for m in range(256, 4096, 128):
n = k = m
a = torch.randn((m, k), dtype=dtype, device=device)
b = torch.randn((k, n), dtype=dtype, device=device)
allow_tf32_saved = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
d_ref = torch.mm(a, b)
torch.backends.cuda.matmul.allow_tf32 = allow_tf32_saved
d_triton_3xtf32 = triton_mm(a, b)
d_cutlass_3xtf32 = cutlass_mm(a, b)
dims.append(m)
triton_3xtf32_loss.append(loss(d_triton_3xtf32, d_ref).item())
cutlass_3xtf32_loss.append(loss(d_cutlass_3xtf32, d_ref).item())
df = pd.DataFrame(
{
"dims": dims,
"Triton 3xTF32 loss": triton_3xtf32_loss,
"CUTLASS 3xTF32 loss": cutlass_3xtf32_loss,
}
)
print(df)
print()
results = []
label = "Triton 3xTF32 vs. CUTLASS 3xTF32 latency"
for m in range(256, 4096, 128):
sub_label = f"m = n = k = {m:5d}"
a = torch.randn((m, k), dtype=dtype, device=device)
b = torch.randn((k, n), dtype=dtype, device=device)
measurement = benchmark.Timer(
stmt="mm(a, b)",
globals={
"mm": triton_mm,
"a": a,
"b": b,
},
label=label,
sub_label=sub_label,
description="Triton 3xTF32",
).blocked_autorange()
results.append(measurement)
measurement = benchmark.Timer(
stmt="mm(a, b)",
globals={
"mm": cutlass_mm,
"a": a,
"b": b,
},
label=label,
sub_label=sub_label,
description="CUTLASS",
).blocked_autorange()
results.append(measurement)
compare = benchmark.Compare(results)
compare.print()
The script will first print the MSE errors vs. reference result, and then the latency (all the runs were on A100):
Test script output for vanilla Triton build
dims Triton 3xTF32 loss CUTLASS 3xTF32 loss
0 256 1.366855e-09 5.235101e-11
1 384 4.742662e-09 8.836381e-11
2 512 1.157405e-08 1.270737e-10
3 640 2.254077e-08 1.644706e-10
4 768 3.873695e-08 2.048905e-10
5 896 6.212847e-08 2.524800e-10
6 1024 9.253924e-08 2.843547e-10
7 1152 1.318329e-07 3.507732e-10
8 1280 1.823635e-07 7.997096e-10
9 1408 2.423697e-07 4.624160e-10
10 1536 3.152084e-07 5.258877e-10
11 1664 3.999571e-07 5.849541e-10
12 1792 5.002328e-07 6.518351e-10
13 1920 6.167757e-07 1.014158e-09
14 2048 7.500014e-07 1.800559e-09
15 2176 8.983116e-07 2.005555e-09
16 2304 1.064476e-06 2.212916e-09
17 2432 1.255128e-06 2.445486e-09
18 2560 1.461378e-06 2.680297e-09
19 2688 1.688605e-06 2.921828e-09
20 2816 1.943802e-06 3.181862e-09
21 2944 2.224484e-06 3.454009e-09
22 3072 2.519756e-06 3.732411e-09
23 3200 2.850649e-06 4.019436e-09
24 3328 3.207230e-06 4.322690e-09
25 3456 3.598114e-06 4.644620e-09
26 3584 4.016068e-06 4.967569e-09
27 3712 4.458372e-06 5.296403e-09
28 3840 4.932218e-06 5.642412e-09
29 3968 5.452913e-06 6.006925e-09
[----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----]
| Triton 3xTF32 | CUTLASS
1 threads: ------------------------------------------
m = n = k = 256 | 525.7 | 1059.2
m = n = k = 384 | 526.6 | 1098.4
m = n = k = 512 | 1047.8 | 1385.1
m = n = k = 640 | 1049.2 | 1606.8
m = n = k = 768 | 1050.2 | 1601.3
m = n = k = 896 | 1565.3 | 1700.0
m = n = k = 1024 | 1571.0 | 1712.2
m = n = k = 1152 | 1572.6 | 1912.5
m = n = k = 1280 | 1573.4 | 1907.2
m = n = k = 1408 | 2092.4 | 2248.6
m = n = k = 1536 | 2094.0 | 2260.1
m = n = k = 1664 | 2095.1 | 2242.8
m = n = k = 1792 | 2612.2 | 2580.8
m = n = k = 1920 | 2615.5 | 2611.6
m = n = k = 2048 | 2617.1 | 2582.8
m = n = k = 2176 | 2618.3 | 2696.4
m = n = k = 2304 | 3136.8 | 2903.1
m = n = k = 2432 | 3139.2 | 2915.2
m = n = k = 2560 | 3144.3 | 2915.3
m = n = k = 2688 | 3649.2 | 3270.4
m = n = k = 2816 | 3660.1 | 3241.2
m = n = k = 2944 | 3661.4 | 3331.5
m = n = k = 3072 | 3664.0 | 3048.8
m = n = k = 3200 | 4180.4 | 3379.3
m = n = k = 3328 | 4182.6 | 3395.0
m = n = k = 3456 | 4184.5 | 3384.3
m = n = k = 3584 | 4690.9 | 3712.1
m = n = k = 3712 | 4707.7 | 3921.7
m = n = k = 3840 | 4706.4 | 3919.7
m = n = k = 3968 | 4708.1 | 3707.0
Times are in microseconds (us).
So Triton code was producing much less precise result than CUTLASS. I tried first with changing TF32->F32 rounding in Triton, applied here, to mach what CUTLASS does. Here are some alternatives to replace this line with:
round_to_zero rounding
"and.b32 $0, $1, 0xffffe000;",
round_to_nearest rounding, i.e. cvt.rn.tf32.f32 for A100
"{\n"
".reg .b32 mantissa_bit;\n"
".reg .b32 sticky_bit;\n"
".reg .b32 round_bit;\n"
".reg .pred flag;\n"
"and.b32 mantissa_bit, $1, 1 << 13;\n"
"setp.ne.b32 flag, mantissa_bit, 0;\n"
"and.b32 sticky_bit, $1, (1 << 12) - 1;\n"
"setp.ne.or.b32 flag, sticky_bit, 0, flag;\n"
"and.b32 round_bit, $1, 1 << 12;\n"
"setp.ne.and.b32 flag, round_bit, 0, flag;\n"
"mov.b32 $0, $1;\n"
"@flag add.u32 $0, $0, 1 << 13;\n"
"and.b32 $0, $0, ~0x1fff;\n"
"}\n",
However, the precision of Triton results did not improved. Then, I verified that individual tl.dot() results match ones that CUTLASS produces (when F32->TF32 rounding matched between the two), and further, by comparing PTX codes generated by Triton and by CUTLASS, I realized that CUTLASS is doing some summing of mma.async results on regular cores. So I came up with following Triton patch to improve the precision:
Patch 1
diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
index f701634d..cdc6a5b6 100644
--- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
@@ -45,6 +45,15 @@ public:
ArrayRef<Value>{value})
.getResult()[0];
};
+ auto zero_like = [&](Value c) -> Value {
+ return rewriter.create<SplatOp>(
+ dotOp->getLoc(), c.getType(),
+ rewriter.create<arith::ConstantOp>(dotOp->getLoc(),
+ rewriter.getF32FloatAttr(0)));
+ };
+ auto add = [&](Value a, Value b) -> Value {
+ return rewriter.create<arith::AddFOp>(dotOp.getLoc(), a, b);
+ };
auto sub = [&](Value a, Value b) -> Value {
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
};
@@ -55,16 +64,22 @@ public:
};
auto aBig = f32ToTF32(dotOp.getA());
- auto aSmall = sub(dotOp.getA(), aBig);
+ auto aSmall = f32ToTF32(sub(dotOp.getA(), aBig));
auto bBig = f32ToTF32(dotOp.getB());
- auto bSmall = sub(dotOp.getB(), bBig);
+ auto bSmall = f32ToTF32(sub(dotOp.getB(), bBig));
+
+ auto zero = zero_like(dotOp.getC());
+
+ auto dot1 = dot(aSmall, bBig, zero);
+ auto dot2 = dot(aBig, bSmall, zero);
+ auto dot3 = dot(aBig, bBig, zero);
- auto dot1 = dot(aSmall, bBig, dotOp.getC());
- auto dot2 = dot(aBig, bSmall, dot1);
- auto dot3 = dot(aBig, bBig, dot2);
+ auto sum1 = add(dot1, dot2);
+ auto sum2 = add(sum1, dot3);
+ auto sum3 = add(sum2, dotOp.getC());
- rewriter.replaceOp(dotOp, dot3);
+ rewriter.replaceOp(dotOp, sum3);
return success();
}
};
Test script output for Triton build with patch 1 above
dims Triton 3xTF32 loss CUTLASS 3xTF32 loss
0 256 9.973545e-12 5.235101e-11
1 384 2.394514e-11 8.836381e-11
2 512 3.826964e-11 1.270737e-10
3 640 5.501000e-11 1.644706e-10
4 768 7.417375e-11 2.048905e-10
5 896 9.785578e-11 2.524800e-10
6 1024 1.070584e-10 2.843547e-10
7 1152 1.521124e-10 3.507732e-10
8 1280 5.776671e-10 7.997096e-10
9 1408 2.180633e-10 4.624160e-10
10 1536 2.567984e-10 5.258877e-10
11 1664 2.960129e-10 5.849541e-10
12 1792 3.404120e-10 6.518351e-10
13 1920 6.804905e-10 1.014158e-09
14 2048 1.444185e-09 1.800559e-09
15 2176 1.624221e-09 2.005555e-09
16 2304 1.812578e-09 2.212916e-09
17 2432 2.018174e-09 2.445486e-09
18 2560 2.231645e-09 2.680297e-09
19 2688 2.451995e-09 2.921828e-09
20 2816 2.689251e-09 3.181862e-09
21 2944 2.937319e-09 3.454009e-09
22 3072 3.194022e-09 3.732411e-09
23 3200 3.460622e-09 4.019436e-09
24 3328 3.738194e-09 4.322690e-09
25 3456 4.038275e-09 4.644620e-09
26 3584 4.337728e-09 4.967569e-09
27 3712 4.644053e-09 5.296403e-09
28 3840 4.968562e-09 5.642412e-09
29 3968 5.309023e-09 6.006925e-09
[----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----]
| Triton 3xTF32 | CUTLASS
1 threads: ------------------------------------------
m = n = k = 256 | 1.3 | 1.1
m = n = k = 384 | 1.3 | 1.1
m = n = k = 512 | 2.5 | 1.4
m = n = k = 640 | 2.5 | 1.4
m = n = k = 768 | 2.6 | 1.4
m = n = k = 896 | 3.8 | 1.7
m = n = k = 1024 | 3.8 | 1.7
m = n = k = 1152 | 3.8 | 1.9
m = n = k = 1280 | 3.9 | 1.9
m = n = k = 1408 | 5.1 | 2.3
m = n = k = 1536 | 5.1 | 2.1
m = n = k = 1664 | 5.1 | 2.0
m = n = k = 1792 | 6.4 | 2.4
m = n = k = 1920 | 6.4 | 2.6
m = n = k = 2048 | 6.4 | 2.6
m = n = k = 2176 | 6.6 | 2.6
m = n = k = 2304 | 7.7 | 2.7
m = n = k = 2432 | 7.7 | 2.9
m = n = k = 2560 | 7.8 | 2.7
m = n = k = 2688 | 9.0 | 3.2
m = n = k = 2816 | 9.0 | 3.3
m = n = k = 2944 | 9.0 | 3.3
m = n = k = 3072 | 9.4 | 3.3
m = n = k = 3200 | 10.3 | 3.4
m = n = k = 3328 | 10.3 | 3.4
m = n = k = 3456 | 10.6 | 3.4
m = n = k = 3584 | 11.6 | 3.7
m = n = k = 3712 | 11.6 | 3.7
m = n = k = 3840 | 11.6 | 3.7
m = n = k = 3968 | 12.1 | 3.7
Times are in milliseconds (ms).
So the precision improved dramatically, but at the cost of performance. But there are alike alternatives possible, here is one:
Patch 2
index f701634d..9071d429 100644
--- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
@@ -45,6 +45,15 @@ public:
ArrayRef<Value>{value})
.getResult()[0];
};
+ auto zero_like = [&](Value c) -> Value {
+ return rewriter.create<SplatOp>(
+ dotOp->getLoc(), c.getType(),
+ rewriter.create<arith::ConstantOp>(dotOp->getLoc(),
+ rewriter.getF32FloatAttr(0)));
+ };
+ auto add = [&](Value a, Value b) -> Value {
+ return rewriter.create<arith::AddFOp>(dotOp.getLoc(), a, b);
+ };
auto sub = [&](Value a, Value b) -> Value {
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
};
@@ -60,11 +69,15 @@ public:
auto bBig = f32ToTF32(dotOp.getB());
auto bSmall = sub(dotOp.getB(), bBig);
- auto dot1 = dot(aSmall, bBig, dotOp.getC());
+ auto zero = zero_like(dotOp.getC());
+
+ auto dot1 = dot(aSmall, bBig, zero);
auto dot2 = dot(aBig, bSmall, dot1);
auto dot3 = dot(aBig, bBig, dot2);
- rewriter.replaceOp(dotOp, dot3);
+ auto sum = add(dot3, dotOp.getC());
+
+ rewriter.replaceOp(dotOp, sum);
return success();
}
};
Test script output for Triton build with patch 2 above
dims Triton 3xTF32 loss CUTLASS 3xTF32 loss
0 256 9.949744e-12 5.235101e-11
1 384 2.407365e-11 8.836381e-11
2 512 3.835959e-11 1.270737e-10
3 640 5.498505e-11 1.644706e-10
4 768 7.436918e-11 2.048905e-10
5 896 9.789199e-11 2.524800e-10
6 1024 1.072674e-10 2.843547e-10
7 1152 1.520337e-10 3.507732e-10
8 1280 5.775638e-10 7.997096e-10
9 1408 2.184144e-10 4.624160e-10
10 1536 2.571353e-10 5.258877e-10
11 1664 2.963491e-10 5.849541e-10
12 1792 3.402902e-10 6.518351e-10
13 1920 6.804675e-10 1.014158e-09
14 2048 1.443346e-09 1.800559e-09
15 2176 1.625424e-09 2.005555e-09
16 2304 1.813113e-09 2.212916e-09
17 2432 2.018629e-09 2.445486e-09
18 2560 2.232485e-09 2.680297e-09
19 2688 2.452671e-09 2.921828e-09
20 2816 2.689190e-09 3.181862e-09
21 2944 2.937780e-09 3.454009e-09
22 3072 3.193837e-09 3.732411e-09
23 3200 3.460724e-09 4.019436e-09
24 3328 3.738940e-09 4.322690e-09
25 3456 4.038074e-09 4.644620e-09
26 3584 4.338085e-09 4.967569e-09
27 3712 4.644735e-09 5.296403e-09
28 3840 4.969717e-09 5.642412e-09
29 3968 5.309353e-09 6.006925e-09
[----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----]
| Triton 3xTF32 | CUTLASS
1 threads: ------------------------------------------
m = n = k = 256 | 701.4 | 1058.7
m = n = k = 384 | 704.7 | 1103.7
m = n = k = 512 | 1392.3 | 1394.9
m = n = k = 640 | 1393.9 | 1387.5
m = n = k = 768 | 1395.9 | 1389.7
m = n = k = 896 | 2077.6 | 1739.9
m = n = k = 1024 | 2088.4 | 1730.4
m = n = k = 1152 | 2100.7 | 1737.3
m = n = k = 1280 | 2094.9 | 1759.5
m = n = k = 1408 | 2790.6 | 2258.8
m = n = k = 1536 | 2786.3 | 2332.9
m = n = k = 1664 | 2788.9 | 2251.7
m = n = k = 1792 | 3470.9 | 2618.0
m = n = k = 1920 | 3479.4 | 2596.3
m = n = k = 2048 | 3480.7 | 2407.4
m = n = k = 2176 | 3498.1 | 2541.2
m = n = k = 2304 | 4177.2 | 2941.1
m = n = k = 2432 | 4177.3 | 2765.4
m = n = k = 2560 | 4180.9 | 2932.8
m = n = k = 2688 | 4864.3 | 3100.1
m = n = k = 2816 | 4871.8 | 3039.4
m = n = k = 2944 | 4873.1 | 3240.2
m = n = k = 3072 | 4875.7 | 3060.8
m = n = k = 3200 | 5580.7 | 3638.5
m = n = k = 3328 | 5573.4 | 3442.0
m = n = k = 3456 | 5572.4 | 3583.3
m = n = k = 3584 | 6259.6 | 3902.5
m = n = k = 3712 | 6263.7 | 3909.3
m = n = k = 3840 | 6263.8 | 3721.8
m = n = k = 3968 | 6268.2 | 3941.8
Times are in microseconds (us).
There are further directions for investigation here, including why CUTLASS is consistently faster for larger shapes than Triton (albeit this is to be tested for other thread block configurations). But my main reason for creating this issue is to ask: having in mind recent discussion about regular TF32 (not 3xTF32) precision, as well as previous discussions on the same topic (including this one, that lists other issues where this was asked about), would it be considered worthwhile to add some optional arguments to tl.dot that would make it possible for users to explicitly specify rounding, for F32->TF32 conversion in case of TF32 precision used, and then somehow also for 3xTF32 case (maybe for F32->TF32 conversion for both "big" and "small" parts of original F32 operands, like CUTLASS does, and then also for summing the three dot results, together with the C operand)?
Some additional notices: It's trivial to extend the test script above to experiment with Triton IEEE precision, and I found results for this case somewhat strange, namely for small inputs the Triton IEEE precision was worse than CUTLASS 3xTF32, but for large inputs the error was 0; so that may be worth looking into too. As well as maybe introducing 4xTF32 option - CUTLASS does it, and it's easy to extend 3xTF32 code in Triton to support it.
Now that I see the order of ops
auto dot1 = dot(aSmall, bBig, dotOp.getC());
auto dot2 = dot(aBig, bSmall, dot1);
auto dot3 = dot(aBig, bBig, dot2);
I realised that this order is not great. We are first adding the correction to C and then the big matmul. I think it might be possible to have the cake and eat it by simply doing first aBig * bBig + getC and then adding the two small bits.
That would be like this, but it doesn't improve the precision:
Patch 3
diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
index f701634d..f08b5e77 100644
--- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
@@ -60,9 +60,9 @@ public:
auto bBig = f32ToTF32(dotOp.getB());
auto bSmall = sub(dotOp.getB(), bBig);
- auto dot1 = dot(aSmall, bBig, dotOp.getC());
- auto dot2 = dot(aBig, bSmall, dot1);
- auto dot3 = dot(aBig, bBig, dot2);
+ auto dot1 = dot(aBig, bBig, dotOp.getC());
+ auto dot2 = dot(aSmall, bBig, dot1);
+ auto dot3 = dot(aBig, bSmall, dot2);
rewriter.replaceOp(dotOp, dot3);
return success();
Test script output for Triton build with patch 3 above
dims Triton 3xTF32 loss CUTLASS 3xTF32 loss
0 256 1.748709e-09 5.235101e-11
1 384 5.589588e-09 8.836381e-11
2 512 1.309468e-08 1.270737e-10
3 640 2.486918e-08 1.644706e-10
4 768 4.208273e-08 2.048905e-10
5 896 6.670907e-08 2.524800e-10
6 1024 9.850878e-08 2.843547e-10
7 1152 1.393552e-07 3.507732e-10
8 1280 1.916842e-07 7.997096e-10
9 1408 2.536156e-07 4.624160e-10
10 1536 3.285114e-07 5.258877e-10
11 1664 4.155119e-07 5.849541e-10
12 1792 5.182922e-07 6.518351e-10
13 1920 6.375814e-07 1.014158e-09
14 2048 7.736186e-07 1.800559e-09
15 2176 9.249484e-07 2.005555e-09
16 2304 1.094300e-06 2.212916e-09
17 2432 1.288391e-06 2.445486e-09
18 2560 1.498034e-06 2.680297e-09
19 2688 1.728995e-06 2.921828e-09
20 2816 1.988098e-06 3.181862e-09
21 2944 2.273059e-06 3.454009e-09
22 3072 2.572578e-06 3.732411e-09
23 3200 2.907947e-06 4.019436e-09
24 3328 3.269294e-06 4.322690e-09
25 3456 3.665231e-06 4.644620e-09
26 3584 4.088140e-06 4.967569e-09
27 3712 4.535651e-06 5.296403e-09
28 3840 5.014916e-06 5.642412e-09
29 3968 5.541221e-06 6.006925e-09
[----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----]
| Triton 3xTF32 | CUTLASS
1 threads: ------------------------------------------
m = n = k = 256 | 438.2 | 1072.9
m = n = k = 384 | 439.4 | 1118.7
m = n = k = 512 | 873.5 | 1408.8
m = n = k = 640 | 874.1 | 1415.5
m = n = k = 768 | 875.4 | 1411.5
m = n = k = 896 | 1306.0 | 1751.2
m = n = k = 1024 | 1308.7 | 1752.6
m = n = k = 1152 | 1309.4 | 1795.5
m = n = k = 1280 | 1312.6 | 2258.4
m = n = k = 1408 | 1742.7 | 2220.1
m = n = k = 1536 | 1745.9 | 2260.9
m = n = k = 1664 | 1745.2 | 2093.1
m = n = k = 1792 | 2172.8 | 2395.6
m = n = k = 1920 | 2177.0 | 2410.4
m = n = k = 2048 | 2179.3 | 2513.4
m = n = k = 2176 | 2183.9 | 2613.5
m = n = k = 2304 | 2610.6 | 2924.5
m = n = k = 2432 | 2612.5 | 2724.9
m = n = k = 2560 | 2613.9 | 2734.8
m = n = k = 2688 | 3041.2 | 3310.8
m = n = k = 2816 | 3045.8 | 3243.0
m = n = k = 2944 | 3047.4 | 3244.8
m = n = k = 3072 | 3049.1 | 3251.0
m = n = k = 3200 | 3478.0 | 3588.4
m = n = k = 3328 | 3480.3 | 3609.2
m = n = k = 3456 | 3482.3 | 3571.3
m = n = k = 3584 | 3912.0 | 3909.3
m = n = k = 3712 | 3913.6 | 3907.6
m = n = k = 3840 | 3915.8 | 3934.3
m = n = k = 3968 | 3918.0 | 3932.7
Times are in microseconds (us).
Closing the issue, for lack of interest.
Sorry, I completely forgot about this one.
We would certainly accept a fix for this, if the perf hit is not too bad, as this mode is advertised as being a reasonably precise mode.
@alexsamardzic Thanks for opening this issue! Out of these suggestions I really like Patch 2; please could you go ahead and create a PR, ping me, and I'll approve it for you?
@alexsamardzic Thanks for opening this issue! Out of these suggestions I really like Patch 2; please could you go ahead and create a PR, ping me, and I'll approve it for you?
Thanks, the PR is here: #4934.
Resolved by #4934.