mlx
mlx copied to clipboard
[BUG] [CUDA] Blas tests failing on B200
python python/tests/test_blas.py -v
A bunch of failures:
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) ...
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 2, 1), shape_b=(1, 1, 1), transpose='nn') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 2, 1), shape_b=(1, 1, 1), transpose='nt') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 2), shape_b=(1, 1, 1), transpose='tn') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 2), shape_b=(1, 1, 1), transpose='tt') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 1), shape_b=(1, 1, 2), transpose='nn') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 1), shape_b=(1, 2, 1), transpose='nt') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 1), shape_b=(1, 1, 2), transpose='tn') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 1), shape_b=(1, 2, 1), transpose='tt') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(3, 23, 3), shape_b=(3, 3, 457), transpose='nn') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(3, 23, 3), shape_b=(3, 457, 3), transpose='nt') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(3, 3, 23), shape_b=(3, 3, 457), transpose='tn') ... FAIL
test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(3, 3, 23), shape_b=(3, 457, 3), transpose='tt') ... FAIL
test_matmul_unaligned (__main__.TestBlas.test_matmul_unaligned) ...
test_matmul_unaligned (__main__.TestBlas.test_matmul_unaligned) (dtype='float16', shape_a=(129, 129), shape_b=(129, 129)) ... FAIL
test_matmul_unaligned (__main__.TestBlas.test_matmul_unaligned) (dtype='float16', shape_a=(130, 130), shape_b=(130, 130)) ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) ...
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(1, 1), shape_vec=(1, 1), mat_first=False, dtype='float32') ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(1, 1), shape_vec=(1, 1), mat_first=True, dtype='float32') ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(1, 2), shape_vec=(1, 1), mat_first=False, dtype='float32') ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(2, 1), shape_vec=(1, 1), mat_first=True, dtype='float32') ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(1, 3), shape_vec=(1, 1), mat_first=False, dtype='float32') ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(3, 1), shape_vec=(1, 1), mat_first=True, dtype='float32') ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(2, 1), shape_vec=(1, 2), mat_first=False, dtype='float32') ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(1, 2), shape_vec=(2, 1), mat_first=True, dtype='float32') ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(2, 2), shape_vec=(1, 2), mat_first=False, dtype='float32') ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(2, 2), shape_vec=(2, 1), mat_first=True, dtype='float32') ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(2, 3), shape_vec=(1, 2), mat_first=False, dtype='float32') ... FAIL
This seems to be related to graph exec update. If you disable cuda graphs the tests run fine MLX_USE_CUDA_GRAPHS=0. Also if you use a fresh graphExec for each graph the tests also pass.
A standalone C++ script that reproduces the error. Instead of getting the desired output, on B200, I get:
y[0..3] = {3.7793, 3.79102, 3.80078, 3.81055}
Graph exec update success!
y[0..3] = {0, 0, 0, 0}
// Build:
//
// nvcc -O2 -std=c++17 gemv.cpp -lcublasLt -lcublas
//
// Run:
// ./a.out
//
// Should print:
// y[0..3] = {3.7793, 3.79102, 3.80078, 3.81055}
// y[0..3] = {3.90039, 3.90039, 3.90039, 3.90039}
#include <cstdlib>
#include <cuda_runtime.h>
#include <cublasLt.h>
#include <cstdio>
#include <vector>
#define CHECK_CUDA(x) do { cudaError_t e=(x); if(e!=cudaSuccess){ \
fprintf(stderr,"CUDA error %s:%d: %s\n",__FILE__,__LINE__,cudaGetErrorString(e)); std::exit(1);} } while(0)
#define CHECK_CUBLAS(x) do { cublasStatus_t s=(x); if(s!=CUBLAS_STATUS_SUCCESS){ \
fprintf(stderr,"cuBLAS error %s:%d: %d\n",__FILE__,__LINE__,int(s)); std::exit(1);} } while(0)
void gemv(int M, int K) {
const float alpha = 1.0f, beta = 0.0f;
// Host init (row-major A, contiguous x, y)
std::vector<__half> hA(M*K), hx(K), hy(M, 0.f);
for (int i = 0; i < M*K; ++i) hA[i] = __float2half((i % 13) * 0.01f); // simple data
for (int i = 0; i < K; ++i) hx[i] = __float2half(1.0f);
__half *dA=nullptr, *dx=nullptr, *dy=nullptr;
CHECK_CUDA(cudaMalloc(&dA, M*K*sizeof(__half)));
CHECK_CUDA(cudaMalloc(&dx, K*sizeof(__half)));
CHECK_CUDA(cudaMalloc(&dy, M*sizeof(__half)));
CHECK_CUDA(cudaMemcpy(dA, hA.data(), M*K*sizeof(__half), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(dx, hx.data(), K*sizeof(__half), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemset(dy, 0, M*sizeof(__half)));
// cuBLASLt setup
cublasLtHandle_t lt;
CHECK_CUBLAS(cublasLtCreate(<));
cublasLtMatmulDesc_t opDesc;
CHECK_CUBLAS(cublasLtMatmulDescCreate(&opDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
// No transposes: A[MxK] * B[KxN]
cublasOperation_t op = CUBLAS_OP_N;
CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(opDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op, sizeof(op)));
CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(opDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op, sizeof(op)));
// Matrix layouts (row-major for all)
cublasLtMatrixLayout_t Adesc, Bdesc, Cdesc, Ddesc;
CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_16F, /*rows=*/M, /*cols=*/K, /*ld=*/K));
CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_16F, /*rows=*/K, /*cols=*/1, /*ld=*/1));
CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16F, /*rows=*/M, /*cols=*/1, /*ld=*/1));
CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, CUDA_R_16F, /*rows=*/M, /*cols=*/1, /*ld=*/1));
cublasLtOrder_t row = CUBLASLT_ORDER_ROW;
CHECK_CUBLAS(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row, sizeof(row)));
CHECK_CUBLAS(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row, sizeof(row)));
CHECK_CUBLAS(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row, sizeof(row)));
CHECK_CUBLAS(cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row, sizeof(row)));
// Heuristic algo (no workspace for brevity)
cublasLtMatmulPreference_t pref;
CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&pref));
size_t wsSize = 0;
CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &wsSize, sizeof(wsSize)));
cublasLtMatmulHeuristicResult_t heur;
int returned = 0;
CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic(lt, opDesc, Adesc, Bdesc, Cdesc, Ddesc, pref, 1, &heur, &returned));
if (returned == 0) { fprintf(stderr, "No cuBLASLt heuristic found.\n"); return; }
// Stream & CUDA Graph capture
cudaStream_t stream;
CHECK_CUDA(cudaStreamCreate(&stream));
CHECK_CUDA(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
// Matmul: y = A * x
CHECK_CUBLAS(cublasLtMatmul(
lt, opDesc,
&alpha,
dA, Adesc,
dx, Bdesc,
&beta,
dy, Cdesc,
dy, Ddesc,
&heur.algo,
/*workspace=*/nullptr, /*workspaceSize=*/0,
stream));
cudaGraph_t graph;
CHECK_CUDA(cudaStreamEndCapture(stream, &graph));
static cudaGraphExec_t graphExec = nullptr;
if (graphExec == nullptr) {
cudaGraphDebugDotPrint(graph, "graph64.dot", 0);
CHECK_CUDA(cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0));
} else {
cudaGraphDebugDotPrint(graph, "graph65.dot", 0);
cudaGraphExecUpdateResult update_result;
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graphExec, graph, &info);
update_result = info.result;
if (update_result != cudaGraphExecUpdateSuccess) {
fprintf(stderr, "Graph exec update failed\n");
return;
} else {
printf("Graph exec update success!\n");
}
}
// Launch captured graph (can be replayed many times)
CHECK_CUDA(cudaGraphLaunch(graphExec, stream));
CHECK_CUDA(cudaStreamSynchronize(stream));
// Fetch result
CHECK_CUDA(cudaMemcpy(hy.data(), dy, M*sizeof(__half), cudaMemcpyDeviceToHost));
auto h2f=[](__half h){ return __half2float(h); };
printf("y[0..3] = {%g, %g, %g, %g}\n", h2f(hy[0]), h2f(hy[1]), h2f(hy[2]), h2f(hy[3]));
}
int main() {
gemv(64, 64);
gemv(65, 64);
return 0;
}
Closed in #2813