cutlass
cutlass copied to clipboard
[QST] set different stages has different accuracy
I run TF32 gemm example, set different stages(1 of 4) has different accurate. why?
Which TF32 example do you use? It is a ampere new feature, it is supposed to use multi-stage mainloop which requires >= 3 stages.
I was run example 14_ampere_tf32_tensorop_gemm, but data init method is custom. when stages is 4, I set K=128, 256, 512, 1024, the result from gpu have higher accurate compare with cpu(double) when K value increases. But theoretically, the accuracy should be lower and lower when K increase。
I used cuda11.3 in A100. @hwu36
tf32 is impossible to have higher accuracy than fp64. It is likely that your code has some bugs.
You can take a look at https://github.com/NVIDIA/cutlass/tree/master/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm. It compares accuracy of tf32x1 (what you are testing), tf32x3, fp32 with the fp64 baseline.
I just did gemm(MxK KxN) TF32 relative error compare with FP64. when K increase, relative errors are decrease. It is unreasonable.
I made below change to example 27 to calculate the accuracy of tf32 vs fp64
diff --git a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_t
ensorop_gemm.cu
index 06559637..fe1e7bb7 100644
--- a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu
+++ b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu
@@ -115,18 +115,20 @@ struct Result {
// Methods
//
static void print_csv_header() {
- std::cout << "M,N,K,Runtime(ms),GFLOPS,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64" << std::endl;
+// std::cout << "M,N,K,Runtime(ms),GFLOPS,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64" << std::endl;
+ std::cout << "M,N,K,1xTF32_vs_FP64" << std::endl;
}
void print_csv_row() {
std::cout << m << ","
<< n << ","
<< k << ","
- << runtime_ms << ","
- << gflops << ","
- << l2_norm_3xtf32_vs_fp64 << ","
+// << runtime_ms << ","
+// << gflops << ","
+// << l2_norm_3xtf32_vs_fp64 << ","
<< l2_norm_1xtf32_vs_fp64 << ","
- << l2_norm_fp32_vs_fp64 << std::endl;
+// << l2_norm_fp32_vs_fp64
+ << std::endl;
}
};
@@ -150,7 +152,7 @@ struct Options {
Options():
help(false),
- problem_size({3456, 4096, 4096}),
+ problem_size({1024, 1024, 1024}),
iterations(20),
seed(1),
alpha(1),
@@ -442,6 +444,13 @@ bool run(Options &options) {
// Split K dimension into 1 partitions
int split_k_slices = 1;
+ // Result structure
+ Result result;
+ result.m = problem_size.m();
+ result.n = problem_size.n();
+ result.k = problem_size.k();
+
+#if 0
////////////////////////////////////////////////////////////////////////////////
/// 3. Run 3xTF32 kernel within a profiling loop
////////////////////////////////////////////////////////////////////////////////
@@ -472,8 +481,6 @@ bool run(Options &options) {
status_3xtf32 = gemm_op_3xTF32.initialize(arguments_3xtf32, workspace_3xtf32.get());
CUTLASS_CHECK(status_3xtf32);
- // Result structure
- Result result;
//
// Construct events
@@ -533,9 +540,6 @@ bool run(Options &options) {
}
// Compute average runtime and GFLOPs.
- result.m = problem_size.m();
- result.n = problem_size.n();
- result.k = problem_size.k();
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
@@ -545,7 +549,7 @@ bool run(Options &options) {
}
tensor_d_3xTF32.sync_host();
-
+#endif
////////////////////////////////////////////////////////////////////////////////
/// 4. Run TF32 kernel without profiling loop
////////////////////////////////////////////////////////////////////////////////
@@ -607,7 +611,7 @@ bool run(Options &options) {
////////////////////////////////////////////////////////////////////////////////
// Run reference kernel (F32)
////////////////////////////////////////////////////////////////////////////////
-
+#if 0
// Create instantiation for device reference gemm kernel
Gemm_F32 gemm_f32;
@@ -625,17 +629,17 @@ bool run(Options &options) {
// Copy output data from CUTLASS and reference kernel to host for comparison
tensor_d_F32.sync_host();
-
+#endif
////////////////////////////////////////////////////////////////////////////////
/////// Compute l2 norms
////////////////////////////////////////////////////////////////////////////////
// l2 norm 3xTF32 vs F64
- cutlass::HostTensor<double, LayoutOutput> tensor_d_3xTF32_in_F64(problem_size.mn());
- cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view());
-
- result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
- tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view());
+// cutlass::HostTensor<double, LayoutOutput> tensor_d_3xTF32_in_F64(problem_size.mn());
+// cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view());
+//
+// result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
+// tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view());
// l2 norm 1xTF32 vs F64
cutlass::HostTensor<double, LayoutOutput> tensor_d_1xTF32_in_F64(problem_size.mn());
@@ -645,11 +649,11 @@ bool run(Options &options) {
tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view());
// l2 norm F32 vs F64
- cutlass::HostTensor<double, LayoutOutput> tensor_d_F32_in_F64(problem_size.mn());
- cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view());
-
- result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
- tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view());
+// cutlass::HostTensor<double, LayoutOutput> tensor_d_F32_in_F64(problem_size.mn());
+// cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view());
+//
+// result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
+// tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view());
results.push_back(result);
@@ -665,9 +669,10 @@ bool run(Options &options) {
std::cout << "Normalized L2 norm of" << std::endl;
std::cout.precision(8);
std::cout << std::scientific
- << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl
+// << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl
<< " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl
- << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl;
+// << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl
+;
return true;
}
And here is the accuracy
[haichengw@ipp1-0234 build]$ ./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm --benchmark
Gemm problem size: 1024 x 1024 x 4
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.62205812e-04
Gemm problem size: 1024 x 1024 x 8
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.61951569e-04
Gemm problem size: 1024 x 1024 x 16
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.60477998e-04
Gemm problem size: 1024 x 1024 x 32
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.60553667e-04
Gemm problem size: 1024 x 1024 x 64
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.60396326e-04
Gemm problem size: 1024 x 1024 x 128
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.61174139e-04
Gemm problem size: 1024 x 1024 x 256
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.60776329e-04
Gemm problem size: 1024 x 1024 x 512
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.61082619e-04
Gemm problem size: 1024 x 1024 x 1024
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.61014961e-04
Gemm problem size: 1024 x 1024 x 2048
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.61070991e-04
Gemm problem size: 1024 x 1024 x 4096
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.60942982e-04
Gemm problem size: 1024 x 1024 x 8192
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.62191428e-04
Gemm problem size: 1024 x 1024 x 16384
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.65854988e-04
Gemm problem size: 1024 x 1024 x 32768
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 2.80166157e-04
Gemm problem size: 1024 x 1024 x 65536
Runtime: 0.0000 ms
GFLOPs: 0.00
Normalized L2 norm of
- 1xTF32 error with FP64 reference : 3.31786530e-04
CSV results
M,N,K,1xTF32_vs_FP64
1024,1024,4,2.62205812e-04,
1024,1024,8,2.61951569e-04,
1024,1024,16,2.60477998e-04,
1024,1024,32,2.60553667e-04,
1024,1024,64,2.60396326e-04,
1024,1024,128,2.61174139e-04,
1024,1024,256,2.60776329e-04,
1024,1024,512,2.61082619e-04,
1024,1024,1024,2.61014961e-04,
1024,1024,2048,2.61070991e-04,
1024,1024,4096,2.60942982e-04,
1024,1024,8192,2.62191428e-04,
1024,1024,16384,2.65854988e-04,
1024,1024,32768,2.80166157e-04,
1024,1024,65536,3.31786530e-04,
relative error is quite stable.
@yuxgis
I just did gemm(MxK KxN) TF32 relative error compare with FP64. when K increase, relative errors are decrease. It is unreasonable.
Are you sure you're not run 3xTF32? It has the characteristic of improving accuracy when K increases.
This issue has been labeled inactive-30d
due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d
if there is no activity in the next 60 days.
@yuxgis did you figure out your issue?