megablocks icon indicating copy to clipboard operation
megablocks copied to clipboard

Grouped GEMM execution not possible with HW

Open cassanof opened this issue 10 months ago • 2 comments
trafficstars

When running the grouped gemm implementation and expert parallelism, i am faced with the following error:

[rank5]:   File "/env/lib/python3.11/site-packages/megablocks-0.8.0.dev0-py3.11-linux-x86_64.egg/megablocks/layers/glu.py", line 255, in forward
[rank5]:     x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
[rank5]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/env/lib/python3.11/site-packages/grouped_gemm-0.1.6-py3.11-linux-x86_64.egg/grouped_gemm/ops.py", line 33, in gmm
[rank5]:     return GroupedGemm.apply(a, b, batch_sizes, trans_b)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/env/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
[rank5]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/env/lib/python3.11/site-packages/grouped_gemm-0.1.6-py3.11-linux-x86_64.egg/grouped_gemm/ops.py", line 11, in forward
[rank5]:     return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/env/lib/python3.11/site-packages/grouped_gemm-0.1.6-py3.11-linux-x86_64.egg/grouped_gemm/backend.py", line 27, in gmm
[rank5]:     backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
[rank5]: RuntimeError: Grouped GEMM execution not possible with HW

this only happens when you combine the two. using either alone works fine. setup here is 8xh100.

cassanof avatar Jan 03 '25 10:01 cassanof

Hm... both work for me using the LLMFoundry integration.

I would start tracing back from here: https://github.com/tgale96/grouped_gemm/blob/ebeae0bb3ded459886309b2a30410deb16937af4/csrc/grouped_gemm.cu#L250-L253 It's probably helpful to start by also logging shapes, cuda version, etc and share

mvpatel2000 avatar Jan 03 '25 17:01 mvpatel2000

hmm, found the issue. for some reason if you forward all_zeros the kernel cries. i forward all zeros the first batch to do some custom tracing.

cassanof avatar Jan 04 '25 05:01 cassanof