mlx icon indicating copy to clipboard operation
mlx copied to clipboard

SIMD vector matrix multiplication?

Open jafioti opened this issue 2 years ago • 3 comments

I was reading through the matmul kernels, and I noticed the beginning of the vecmat kernel looked like this:

static METAL_FUNC void run(
      const device T* mat,
      const device T* in_vec,
      device T* out_vec, 
      const constant int& in_vec_size [[buffer(3)]],
      const constant int& out_vec_size [[buffer(4)]],
      threadgroup T* tgp_memory [[threadgroup(0)]],
      uint3 tid [[threadgroup_position_in_grid]],
      uint3 lid [[thread_position_in_threadgroup]],
      uint simd_gid [[simdgroup_index_in_threadgroup]],
      uint simd_lid [[thread_index_in_simdgroup]]) {

    // Appease compiler 
    (void)simd_gid;
    (void)simd_lid;

and noticed simdgroup instructions were not used anywhere in the kernel. Are there plans to use simd here? Can anyone point me to a good vector matrix multiplication metal kernel that uses simdgroup ops?

jafioti avatar Dec 15 '23 01:12 jafioti

Can anyone point me to a good vector matrix multiplication metal kernel that uses simdgroup ops?

Check this https://github.com/tinygrad/tinygrad/blob/master/extra/gemm/metal_matmul.py

cyrusmsk avatar Dec 25 '23 23:12 cyrusmsk

Can anyone point me to a good vector matrix multiplication metal kernel that uses simdgroup ops?

Here's a good gemm kernel by @philipturner: https://github.com/philipturner/metal-flash-attention/blob/main/Sources/GEMM.metal

altaic avatar Dec 26 '23 04:12 altaic

Matrix-vector kernels (GEMV) are trivial to write. I wrote a 4-bit quantized matrix times 16-bit vector kernel in a few hours and revealed it on LLaMA.cpp while getting a lot of attention. Such kernels are bound by memory bandwidth, and would not benefit from SIMD-group matmul or SIMD-group async copy. SIMD broadcast is not better than reading something from L1D.

Only SIMD-group reductions (highly recommended) to sum elementwise across threads spanning the same accumulator. A technique to reduce overhead which can be non-trivial and worthwhile to optimize. In many applications, these little overheads are what bring the entire application down when nobody is paying critical attention.

MFA is about matrix-matrix (GEMM) kernels, something much more challenging and a critical part of many AI training applications. Also used in AI inference for convolutions or MHA. This took a few weeks of full-time work to get fully optimized, orders of magnitude more than GEMV.

philipturner avatar Dec 26 '23 04:12 philipturner