TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[JAX] Collective GEMM custom op with `nvte_cublas_gemm` (no comm. overlap)

Open denera opened this issue 1 year ago • 2 comments

Description

Implements both old-style and new FFI-based XLA custom calls in C++, and the corresponding JAX primitive including custom partitioning rules.

Custom partitioning rules for a LHS:([B,] M, K) x RHS:([B,] K, N) = OUT:([B,] M, N) batched mat-mul operation where [B] is the batch dimension:

  • Preserve the partitioning of the [B] dimension for all operands.
  • Always all-gather LHS along the M dimension.
  • Error out if RHS is partitioned in both K and N dimensions.
  • Force the K dimension of LHS to match the partitioning of the K dimension of RHS.
  • If K dimension is partitioned but M dimension is not, jax.lax.psum (all-reduce) the output over the TP mesh resource.
  • If both the M and K dimensions are partitioned, jax.lax.psum_scatter (reduce-scatter) the output over the TP mesh resource.

In practice, the RHS matrix (typically the weight tensor) should be allocated with transposed contracting dimensions ([B,] N, K) for optimal GEMM heuristics in cuBlasLt. This layout is also mandatory for FP8 inputs.

This PR does NOT update fused ops or Flax/Praxis modules to use the new GEMM custom op over the existing XLA pattern matching approach.

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [ ] Code refractor

Changes

  • [x] Added XLA custom calls for nvte_cublas_gemm.
  • [x] Added JAX primitive for the new XLA custom call.
  • [x] Added new serial unit test.
  • [ ] Add distributed unit test.

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [ ] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have made corresponding changes to the documentation
  • [ ] My changes generate no new warnings
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

denera avatar Nov 02 '24 02:11 denera

Why? Normal JAX behavior is to do some gathering.

nouiz avatar Nov 04 '24 16:11 nouiz

It seems that currently the batch size is not handled in the C++ code. Since JAX is using row-major storage for tensor by default, probably the batch dimension should be combined with the m dimension for LHS or the n dimension for RHS?

huanghua1994 avatar Nov 04 '24 22:11 huanghua1994

Closing in favor of #1846

denera avatar Jun 03 '25 20:06 denera