CUDA.jl icon indicating copy to clipboard operation
CUDA.jl copied to clipboard

Add support for arbitrary group sizes in `gemm_grouped_batched!`

Open lpawela opened this issue 10 months ago • 5 comments

Currently group_size is hardcoded to ones. This adds support for arbitrary group sizes. Also this changes the input types from A::Vector{<:StridedCuMatrix{T}} to A::Vector{<:Vector{<:StridedCuMatrix{T}}}.

lpawela avatar Apr 18 '24 16:04 lpawela

This is a breaking change, so I'll let the original author review.

maleadt avatar May 03 '24 08:05 maleadt

Is it possible to keep the initial function with A::Vector{<:StridedCuMatrix{T}}? You will avoid a breaking change by doing that.

The function with A::Vector{<:Vector{<:StridedCuMatrix{T}}} as input is relevant but the users must check what are the blocks of the same size. It should be an extension of the previous function.

amontoison avatar May 03 '24 16:05 amontoison

I can always restore the older version. But doesn't that just duplicate the behavior of gemm_batched! with some extra steps? In this case the switch would be relatively simple - just change some function names.

lpawela avatar May 03 '24 18:05 lpawela

I can always restore the older version. But doesn't that just duplicate the behavior of gemm_batched! with some extra steps? In this case the switch would be relatively simple - just change some function names.

Why not rely on the multiple dispatch of gemm_batched!?

amontoison avatar May 03 '24 19:05 amontoison

I restored the previous implementation alongside mine. @maleadt is this sufficient or should I just overload gemm_batched!?

lpawela avatar Jun 26 '24 20:06 lpawela

Seems fine to me; I'll let @amontoison give the final OK though.

maleadt avatar Jul 05 '24 10:07 maleadt

LGTM!

amontoison avatar Jul 05 '24 14:07 amontoison

Weirdly, Enzyme.jl tests only seem to fail on this PR, even though I don't think gemm_grouped_batched is used anywhere?

maleadt avatar Jul 08 '24 10:07 maleadt

Strange, on Enzyme.jl w v0.12.22 the test/cuda.jl tests pass for me with this PR

(Enzyme) pkg> st
Project Enzyme v0.12.22
Status `~/lib/Enzyme.jl/Project.toml`
  [fa961155] CEnum v0.5.0
  [052768ef] CUDA v5.4.2 `../CUDA.jl`
  [f151be2c] EnzymeCore v0.7.6
  [61eb1bfa] GPUCompiler v0.26.7
  [929cbde3] LLVM v8.0.0
  [d8793406] ObjectFile v0.4.1
  [21216c6a] Preferences v1.4.3
  [7cc45869] Enzyme_jll v0.0.133+0
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [de0858da] Printf
  [9a3f8284] Random

On the main branch this also passes. The entire test suite also passes on v0.12.22. On the main branch the test suite fails with

┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
Test Summary: | Pass  Total   Time
DiffTest      |   44     44  41.7s
2.3
2.3
Test Summary: | Pass  Total  Time
IO            |    4      4  4.4s
Test Summary: |Time
hmlstm        | None  0.1s
Test Summary:  |Time
No speculation | None  0.3s
ERROR: Package Enzyme errored during testing (received signal: 11)

lpawela avatar Jul 10 '24 13:07 lpawela

@wsmoses Can you look into this CI failure? It's pretty inscrutable to me.

maleadt avatar Jul 12 '24 12:07 maleadt

Oh I think you just got unlucky. That was fixed almost immediately after in Enzyme

On Fri, Jul 12, 2024 at 8:33 AM Tim Besard @.***> wrote:

@wsmoses https://github.com/wsmoses Can you look into this CI failure? It's pretty inscrutable to me.

— Reply to this email directly, view it on GitHub https://github.com/JuliaGPU/CUDA.jl/pull/2334#issuecomment-2225483562, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXEBK6XW33UWCVCPX7TZL7EJPAVCNFSM6AAAAABGNVEMIGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMRVGQ4DGNJWGI . You are receiving this because you were mentioned.Message ID: @.***>

wsmoses avatar Jul 12 '24 12:07 wsmoses