Enzyme.jl
Enzyme.jl copied to clipboard
Add rules for BLAS.dot, BLAS.dotc, and BLAS.dotu
In an attempt to learn Enzyme's rule system and speed up AD of BLAS calls, this PR adds a rule for BLAS.dot (and BLAS.dotc and BLAS.dotu). On an input of length 10,000, this is 6x faster than the fallback in forward mode and 60x faster than the fallback in reverse mode.
Codecov Report
Patch coverage: 2.98% and project coverage change: -0.38 :warning:
Comparison is base (
e96c0c5) 78.88% compared to head (cf9ae74) 78.50%.
:mega: This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more
Additional details and impacted files
@@ Coverage Diff @@
## main #739 +/- ##
==========================================
- Coverage 78.88% 78.50% -0.38%
==========================================
Files 18 19 +1
Lines 8118 8185 +67
==========================================
+ Hits 6404 6426 +22
- Misses 1714 1759 +45
| Impacted Files | Coverage Δ | |
|---|---|---|
| src/Enzyme.jl | 87.06% <ø> (ø) |
|
| src/rules/LinearAlgebra/blas.jl | 2.98% <2.98%> (ø) |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
How does Enzyme treat complex numbers? e.g. if I wanted to also support BLAS.dotc, would the Duplicateds all store complex derivative vectors, and would the Active store a complex derivative?
How does Enzyme treat complex numbers? e.g. if I wanted to also support
BLAS.dotc, would theDuplicateds all store complex derivative vectors, and would theActivestore a complex derivative?
yup
Hi @sethaxen, great to see some work on EnzymeBlas from another approach! I worked on BlasEnzyme last year, but didn't got around merging it back then. Namely I do have tablegen implementations for
handling: asum
handling: axpy
handling: copy
handling: dot
handling: scal
but I am lacking the tests for it. Do you have a couple of tests for dot that I could use? Then I can finish my refactoring work here and re-use your tests once that's done. It would be great if we end up having support at all three levels, bitcode, enzyme proper and julia level, so we can compare those.
Do you have a couple of tests for dot that I could use?
Sure, I just added some tests for dot, dotu, and dotc. Currently the dot tests pass on my machine, but the complex ones segfault on my machine, and it's not clear why. (I'm still a bit confused on how rules with complex numbers are supposed to be defined, see #744). All tests pass on the fallbacks.
It would be great if we end up having support at all three levels, bitcode, enzyme proper and julia level, so we can compare those.
Yeah, it'd be nice if we could make that comparison when this PR is finished before I tackle BLAS rules for the other functions you listed.
I just build Enzyme.jl on top of Enzyme with Blas-Tblgen, removing the julia code that calls the Blas Fallback. It seems like tablegen dot implementation passes the existing Julia tests and also the Enzyme LLVM-IR tests. I need to clean it up a bit because I have multiple blas functions in my PR and not all of them are that well tested, but I hope to merge it soon.
It seems like tablegen dot implementation passes the existing Julia tests and also the Enzyme LLVM-IR tests. I need to clean it up a bit because I have multiple blas functions in my PR and not all of them are that well tested, but I hope to merge it soon.
@ZuseZ4 I'm interested to hear more about this tablegen approach. It seems to produce extremely terse expressions of the rules, which would in some ways be preferable to this approach here. But do you have do define both forward- and reverse-mode rules? And I'm a little confused by the existing rules. e.g. axpy should I believe have a forward-model rule
axpy!(∂a, X, ∂Y)
axpy!(a, ∂X, ∂Y)
and the reverse-mode rule
∂a = dot(X, ∂Y)
axpy!(conj(a), ∂Y, ∂X)
scal!(0, ∂Y)
But the tablegen rule looks like neither of these and even uses asum: https://github.com/EnzymeAD/Enzyme/blob/d679b813efb21bf038d240b8780eaaef3b08b3d8/enzyme/Enzyme/targets/BlasDerivatives.td#L80-L87
Hi @sethaxen. So the general approach is to declare reverse mode rules, fwd. should be able to be handled using the primal function itself. I agree that tablegen is "nicer" than the julia level for now, but this is just due to the fact that I wrote both sides. The Blas rules and the new tablegen.cpp code which parses this Blas rules. The julia rules are more generic and not fine-tuned for blas.
The incorrect reverse blas rules are mostly due to a design decision of my last approach. If say dot uses axpy for the reverse pass, tablegen checks the number and types of axpy arguments based on the axpy rules, which in turn required me to declare and implement all blas functions used for the reverse axpy pass. That recursively caused a typical "big bang" approach where I needed to handle multiple rules at once. Since I didn't had the tests back then I just ended up using placeholder (/incorrect) reverse rules to compile it and thus never merged it. I do have a better solution for it now, I'll push more of it on Monday
Since I didn't had the tests back then I just ended up using placeholder (/incorrect) reverse rules to compile it and thus never merged it. I do have a better solution for it now, I'll push more of it on Monday
Ah, that makes sense. It'll be interesting to see how you handle asum, whose reverse-mode rule isn't computable just with BLAS functions.
@wsmoses the only failing tests are due to #778. Otherwise this should be ready for review.
It seems the BLAS fallback warnings are raised even if the the BLAS fallbacks are not being hit?
Yeah the BLAS fallback injection is done prior to any custom rules, so will always occur. However if a custom rule is hit it will use that implementation rather than the injected fallback.
Ah, that makes sense. It'll be interesting to see how you handle
asum, whose reverse-mode rule isn't computable just with BLAS functions.
That is indeed a bit annoying. We do have Blas-Tablegen and Instruction-Tablegen. The second one could handle this, but we will probably develop it further in the future, while Blas-Tablegen hopefully can stay unchanged in the future. So I might actually copy over a bit of the logic, such that we don't introduce a dependency there. dot is btw. ready to be merged, so I'm now trying to get some more lv. 1 functions ready.
All tests seem to pass for me. Some of the jobs in CI seem to not be picked up for some reason, but as far as I can tell, the same tests pass here that pass on main.
Here's an updated benchmark. The main takeaways are that the forward mode rules give a 2.5-7.3x speed-up vs the versions hit by main, and the reverse mode rules give a 13-60x speed-up when no tape is needed and a 5.7-8.7x speed-up when a tape is needed. In this latter case, the trade-off is that these rules allocate a tape for all entries that contribute to the output, which in a case like the one benchmarked here (where future operations mutate only a single entry in the array) is wasteful.
using BenchmarkTools, Enzyme, LinearAlgebra, Random
Random.seed!(42)
n = 1_000
x = randn(n)
y = randn(n)
∂x = randn(eltype(x), size(x))
∂y = randn(eltype(y), size(y))
incx = incy = 1
# version that triggers tape allocation
function f_overwite!(f, n, x, incx, y, incy)
s = f(n, x, incx, y, incy)
x[1] = 0
y[1] = 0
return s
end
## BLAS.dot
@btime autodiff(
Forward, $(BLAS.dot), $n, $(Duplicated(x, ∂x)), $incx, $(Duplicated(y, ∂y)), $incy,
)
# main: 731.130 ns (0 allocations: 0 bytes)
# here: 100.365 ns (2 allocations: 64 bytes)
# no tape needed
@btime autodiff(
ReverseWithPrimal, $(BLAS.dot), Active, $n, Dx, $incx, Dy, $incy,
) setup=(Dx=Duplicated(x, copy(∂x)); Dy=Duplicated(y, copy(∂y)))
# main: 11.139 μs (2 allocations: 32 bytes)
# here: 192.827 ns (7 allocations: 208 bytes)
# tape needed
@btime autodiff(
ReverseWithPrimal, $(f_overwite!), Active, $(BLAS.dot), $n, Dx, $incx, Dy, $incy,
) setup=(Dx=Duplicated(copy(x), copy(∂x)); Dy=Duplicated(copy(y), copy(∂y)))
# main: 11.624 μs (2 allocations: 32 bytes)
# here: 1.342 μs (9 allocations: 16.08 KiB)
T = ComplexF64
x = randn(T, n)
y = randn(T, n)
∂x = randn(eltype(x), size(x))
∂y = randn(eltype(y), size(y))
incx = incy = 1
## BLAS.dotu
@btime autodiff(
Forward, $(BLAS.dotu), $n, $(Duplicated(x, ∂x)), $incx, $(Duplicated(y, ∂y)), $incy,
)
# main: 1.599 μs (0 allocations: 0 bytes)
# here: 650.117 ns (2 allocations: 64 bytes)
# no tape needed
@btime autodiff(
ReverseWithPrimal, $(BLAS.dotu), Active, $n, Dx, $incx, Dy, $incy,
) setup=(Dx=Duplicated(x, copy(∂x)); Dy=Duplicated(y, copy(∂y)))
# main: 21.493 μs (2 allocations: 48 bytes)
# here: 1.706 μs (8 allocations: 256 bytes)
# tape needed
@btime autodiff(
ReverseWithPrimal, $(f_overwite!), Active, $(BLAS.dotu), $n, Dx, $incx, Dy, $incy,
) setup=(Dx=Duplicated(copy(x), copy(∂x)); Dy=Duplicated(copy(y), copy(∂y)))
# main: 22.231 μs (2 allocations: 48 bytes)
# here: 3.885 μs (10 allocations: 31.75 KiB)
## BLAS.dotc
@btime autodiff(
Forward, $(BLAS.dotc), $n, $(Duplicated(x, ∂x)), $incx, $(Duplicated(y, ∂y)), $incy,
)
# main: 1.584 μs (0 allocations: 0 bytes)
# here: 647.920 ns (2 allocations: 64 bytes)
# no tape needed
@btime autodiff(
ReverseWithPrimal, $(BLAS.dotc), Active, $n, Dx, $incx, Dy, $incy,
) setup=(Dx=Duplicated(x, copy(∂x)); Dy=Duplicated(y, copy(∂y)))
# main: 21.220 μs (2 allocations: 48 bytes)
# here: 758.145 ns (8 allocations: 256 bytes)
# tape needed
@btime autodiff(
ReverseWithPrimal, $(f_overwite!), Active, $(BLAS.dotc), $n, Dx, $incx, Dy, $incy,
) setup=(Dx=Duplicated(copy(x), copy(∂x)); Dy=Duplicated(copy(y), copy(∂y)))
# main: 22.251 μs (2 allocations: 48 bytes)
# here: 2.977 μs (10 allocations: 31.75 KiB)
@ZuseZ4 it would be interesting to see how the tablegen versions compare.