mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] [CUDA] Blas tests failing on B200

Open awni opened this issue 1 month ago • 2 comments

python python/tests/test_blas.py -v

A bunch of failures:

test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) ... 
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 2, 1), shape_b=(1, 1, 1), transpose='nn') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 2, 1), shape_b=(1, 1, 1), transpose='nt') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 2), shape_b=(1, 1, 1), transpose='tn') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 2), shape_b=(1, 1, 1), transpose='tt') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 1), shape_b=(1, 1, 2), transpose='nn') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 1), shape_b=(1, 2, 1), transpose='nt') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 1), shape_b=(1, 1, 2), transpose='tn') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(1, 1, 1), shape_b=(1, 2, 1), transpose='tt') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(3, 23, 3), shape_b=(3, 3, 457), transpose='nn') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(3, 23, 3), shape_b=(3, 457, 3), transpose='nt') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(3, 3, 23), shape_b=(3, 3, 457), transpose='tn') ... FAIL
  test_matmul_shapes (__main__.TestBlas.test_matmul_shapes) (dtype='float32', shape_a=(3, 3, 23), shape_b=(3, 457, 3), transpose='tt') ... FAIL
test_matmul_unaligned (__main__.TestBlas.test_matmul_unaligned) ... 
  test_matmul_unaligned (__main__.TestBlas.test_matmul_unaligned) (dtype='float16', shape_a=(129, 129), shape_b=(129, 129)) ... FAIL
  test_matmul_unaligned (__main__.TestBlas.test_matmul_unaligned) (dtype='float16', shape_a=(130, 130), shape_b=(130, 130)) ... FAIL
test_matrix_vector (__main__.TestBlas.test_matrix_vector) ... 
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(1, 1), shape_vec=(1, 1), mat_first=False, dtype='float32') ... FAIL
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(1, 1), shape_vec=(1, 1), mat_first=True, dtype='float32') ... FAIL
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(1, 2), shape_vec=(1, 1), mat_first=False, dtype='float32') ... FAIL
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(2, 1), shape_vec=(1, 1), mat_first=True, dtype='float32') ... FAIL
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(1, 3), shape_vec=(1, 1), mat_first=False, dtype='float32') ... FAIL
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(3, 1), shape_vec=(1, 1), mat_first=True, dtype='float32') ... FAIL
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(2, 1), shape_vec=(1, 2), mat_first=False, dtype='float32') ... FAIL
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(1, 2), shape_vec=(2, 1), mat_first=True, dtype='float32') ... FAIL
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(2, 2), shape_vec=(1, 2), mat_first=False, dtype='float32') ... FAIL
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(2, 2), shape_vec=(2, 1), mat_first=True, dtype='float32') ... FAIL
  test_matrix_vector (__main__.TestBlas.test_matrix_vector) (shape_mat=(2, 3), shape_vec=(1, 2), mat_first=False, dtype='float32') ... FAIL

awni avatar Nov 10 '25 23:11 awni

This seems to be related to graph exec update. If you disable cuda graphs the tests run fine MLX_USE_CUDA_GRAPHS=0. Also if you use a fresh graphExec for each graph the tests also pass.

awni avatar Nov 11 '25 00:11 awni

A standalone C++ script that reproduces the error. Instead of getting the desired output, on B200, I get:

y[0..3] = {3.7793, 3.79102, 3.80078, 3.81055}
Graph exec update success!
y[0..3] = {0, 0, 0, 0}
// Build:
//
// nvcc -O2 -std=c++17 gemv.cpp -lcublasLt -lcublas
//
// Run:
// ./a.out
//
// Should print:
// y[0..3] = {3.7793, 3.79102, 3.80078, 3.81055}
// y[0..3] = {3.90039, 3.90039, 3.90039, 3.90039}

#include <cstdlib>
#include <cuda_runtime.h>
#include <cublasLt.h>
#include <cstdio>
#include <vector>

#define CHECK_CUDA(x) do { cudaError_t e=(x); if(e!=cudaSuccess){ \
  fprintf(stderr,"CUDA error %s:%d: %s\n",__FILE__,__LINE__,cudaGetErrorString(e)); std::exit(1);} } while(0)
#define CHECK_CUBLAS(x) do { cublasStatus_t s=(x); if(s!=CUBLAS_STATUS_SUCCESS){ \
  fprintf(stderr,"cuBLAS error %s:%d: %d\n",__FILE__,__LINE__,int(s)); std::exit(1);} } while(0)

void gemv(int M, int K) {
  const float alpha = 1.0f, beta = 0.0f;

  // Host init (row-major A, contiguous x, y)
  std::vector<__half> hA(M*K), hx(K), hy(M, 0.f);
  for (int i = 0; i < M*K; ++i) hA[i] = __float2half((i % 13) * 0.01f);  // simple data
  for (int i = 0; i < K;   ++i) hx[i] = __float2half(1.0f);

  __half *dA=nullptr, *dx=nullptr, *dy=nullptr;
  CHECK_CUDA(cudaMalloc(&dA, M*K*sizeof(__half)));
  CHECK_CUDA(cudaMalloc(&dx, K*sizeof(__half)));
  CHECK_CUDA(cudaMalloc(&dy, M*sizeof(__half)));
  CHECK_CUDA(cudaMemcpy(dA, hA.data(), M*K*sizeof(__half), cudaMemcpyHostToDevice));
  CHECK_CUDA(cudaMemcpy(dx, hx.data(), K*sizeof(__half),   cudaMemcpyHostToDevice));
  CHECK_CUDA(cudaMemset(dy, 0, M*sizeof(__half)));

  // cuBLASLt setup
  cublasLtHandle_t lt;
  CHECK_CUBLAS(cublasLtCreate(&lt));

  cublasLtMatmulDesc_t opDesc;
  CHECK_CUBLAS(cublasLtMatmulDescCreate(&opDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
  // No transposes: A[MxK] * B[KxN]
  cublasOperation_t op = CUBLAS_OP_N;
  CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(opDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op, sizeof(op)));
  CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(opDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op, sizeof(op)));

  // Matrix layouts (row-major for all)
  cublasLtMatrixLayout_t Adesc, Bdesc, Cdesc, Ddesc;
  CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_16F, /*rows=*/M, /*cols=*/K, /*ld=*/K));
  CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_16F, /*rows=*/K, /*cols=*/1, /*ld=*/1));
  CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16F, /*rows=*/M, /*cols=*/1, /*ld=*/1));
  CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, CUDA_R_16F, /*rows=*/M, /*cols=*/1, /*ld=*/1));

  cublasLtOrder_t row = CUBLASLT_ORDER_ROW;
  CHECK_CUBLAS(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row, sizeof(row)));
  CHECK_CUBLAS(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row, sizeof(row)));
  CHECK_CUBLAS(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row, sizeof(row)));
  CHECK_CUBLAS(cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row, sizeof(row)));

  // Heuristic algo (no workspace for brevity)
  cublasLtMatmulPreference_t pref;
  CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&pref));
  size_t wsSize = 0;
  CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &wsSize, sizeof(wsSize)));
  cublasLtMatmulHeuristicResult_t heur;
  int returned = 0;
  CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic(lt, opDesc, Adesc, Bdesc, Cdesc, Ddesc, pref, 1, &heur, &returned));
  if (returned == 0) { fprintf(stderr, "No cuBLASLt heuristic found.\n"); return; }

  // Stream & CUDA Graph capture
  cudaStream_t stream;
  CHECK_CUDA(cudaStreamCreate(&stream));
  CHECK_CUDA(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));

  // Matmul: y = A * x
  CHECK_CUBLAS(cublasLtMatmul(
      lt, opDesc,
      &alpha,
      dA, Adesc,
      dx, Bdesc,
      &beta,
      dy, Cdesc,
      dy, Ddesc,
      &heur.algo,
      /*workspace=*/nullptr, /*workspaceSize=*/0,
      stream));

  cudaGraph_t graph;
  CHECK_CUDA(cudaStreamEndCapture(stream, &graph));

  static cudaGraphExec_t graphExec = nullptr;
  if (graphExec == nullptr) {
    cudaGraphDebugDotPrint(graph, "graph64.dot", 0);
    CHECK_CUDA(cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0));
  } else {
    cudaGraphDebugDotPrint(graph, "graph65.dot", 0);
    cudaGraphExecUpdateResult update_result;
    cudaGraphExecUpdateResultInfo info;
    cudaGraphExecUpdate(graphExec, graph, &info);
    update_result = info.result;
    if (update_result != cudaGraphExecUpdateSuccess) {
      fprintf(stderr, "Graph exec update failed\n");
      return;
    } else {
      printf("Graph exec update success!\n");
    }
  }

  // Launch captured graph (can be replayed many times)
  CHECK_CUDA(cudaGraphLaunch(graphExec, stream));
  CHECK_CUDA(cudaStreamSynchronize(stream));

  // Fetch result
  CHECK_CUDA(cudaMemcpy(hy.data(), dy, M*sizeof(__half), cudaMemcpyDeviceToHost));
  auto h2f=[](__half h){ return __half2float(h); };
  printf("y[0..3] = {%g, %g, %g, %g}\n", h2f(hy[0]), h2f(hy[1]), h2f(hy[2]), h2f(hy[3]));
}

int main() {
  gemv(64, 64);
  gemv(65, 64);
  return 0;
}

awni avatar Nov 11 '25 14:11 awni

Closed in #2813

awni avatar Nov 22 '25 23:11 awni