NNlib.jl
NNlib.jl copied to clipboard
add Metal extension for batched_mul
Closes #581
PR Checklist
- [x] Tests are added
- [ ] Documentation, if applicable
https://github.com/JuliaGPU/Metal.jl/issues/381
Thanks I hadn't seen that.
Got a wrong answer in this test on CI (tiny arrays though) but didn't investigate further:
https://github.com/FluxML/NNlib.jl/pull/614/files#diff-df0d2a37225f09d22727651479dc1cd59f2b8358f4eb1e2be98c9b04e215be86R31-R34
IIRC the bug might also happen on tiny arrays if it's within a sequence of calls. It's really hard to detect though.