NNlib.jl
NNlib.jl copied to clipboard
Batched multiplication support for ndims > 3
Currently NNlib.batchmul works with arrays with upto 3 dimensions. It would be nice if this could be upgraded to function similar to numpy.matmul or torch.matmul - this would help in a lot of models, especially some of the attention-based ones I'm working on 😅
The first question is what rules it should obey. batched_mul is always matrix-matrix multiplication, and does this:
julia> rand(3,4,7) ⊠ rand(4,5,7) |> size
(3, 5, 7)
julia> rand(3,4,1) ⊠ rand(4,5,7) |> size
(3, 5, 7)
julia> rand(3,4,7) ⊠ rand(4,5,1) |> size
(3, 5, 7)
julia> rand(3,4,7) ⊠ rand(4,5) |> size # same, as size(B,3)==1
(3, 5, 7)
The easy extension would be this one, but should anything else be allowed?
rand(3,4,7,9) ⊠ rand(4,5,7,9) |> size # (3, 5, 7, 9)
I think the 3D cases match what CUBLAS.gemm_strided_batched! handles. But more exotic things could be done with a loop of course.
This is what torch and numpy have to say:
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
which is exactly what you've suggested (of course Julia has the matrices in the first two dimensions). I think this is a straightforward enough way to go about it given other frameworks do it the same way so it is expected behaviour in a sense.
The rule you suggest is, I think, this:
rand(3,4,7,9) ⊠ rand(4,5,7,9) |> size # (3, 5, 7, 9) # trivial reshape
rand(3,4,7,1) ⊠ rand(4,5,7,9) |> size # (3, 5, 7, 9)
rand(3,4,1,9) ⊠ rand(4,5,7,9) |> size # (3, 5, 7, 9)
rand(3,4,1,9) ⊠ rand(4,5,7,1) |> size # (3, 5, 7, 9)
rand(3,4,7,9,11) ⊠ rand(4,5,7,9,11) |> size # (3, 5, 7, 9, 11) # trivial reshape
rand(3,4,1,9,1) ⊠ rand(4,5,7,1,11) |> size # (3, 5, 7, 9, 11)
Notice that there are two levels of difficulty here, all but the last still have a regular stride across batches. Whether that's sufficient for gemm_strided_batched! I don't recall:
julia> rand(3,4,1,9) |> strides # batch stride 12
(1, 3, 12, 12)
julia> rand(3,4,7,1) |> strides # batch stride 12 only
(1, 3, 12, 84)
julia> rand(3,4,7,9,11) |> strides # batch stride 12 only
(1, 3, 12, 84, 756)
julia> rand(3,4,7,1,11) |> strides # irregular batch stride, 12 and 84
(1, 3, 12, 84, 84)
The completely general case is of course not so hard to write as a loop over gemm!, which is all the CPU implementation is anyway. Probably you can hack the broadcast machinery to do the indexing for you.
But on the GPU, at least for ndims=3 the fused strided_batched routine is much quicker than a loop. So ndims>3 cases which can't be written as one call, probably want to be written as a loop over gemm_strided_batched! calls, not over gemm! calls?
Edit:gemm_strided_batched! is here:
https://github.com/JuliaGPU/CUDA.jl/blob/f81cdf7484842889f5adfcaeb60436f5ebfb513a/lib/cublas/wrappers.jl#L1036
But there is also gemm_batched! here:
https://github.com/JuliaGPU/CUDA.jl/blob/f81cdf7484842889f5adfcaeb60436f5ebfb513a/lib/cublas/wrappers.jl#L974
Possibly using that for cases which don't fit the strided_batched is better than inventing something? Some discussion of these options here: https://developer.nvidia.com/blog/cublas-strided-batched-matrix-multiply/