KernelAbstractions.jl
KernelAbstractions.jl copied to clipboard
Enzyme gradient always seems to stall
Whatever we try, me and @VasylHafych always run into a stall when trying to autodiff kernels with Enzyme.
The following stalls at evt = dmatmul(...) (example taken from https://github.com/JuliaGPU/KernelAbstractions.jl/blob/master/lib/KernelGradients/test/matmul.jl):
using KernelAbstractions, Enzyme, Test
# From KernelGradients.jl:
function Enzyme.autodiff(kernel::KernelAbstractions.Kernel{<:Any, <:Any, <:Any, Fun}) where Fun
f = kernel.f
function df(ctx, args...)
Enzyme.autodiff_deferred(f::Fun, Enzyme.Const, ctx, args...)
end
similar(kernel, df)
end
@kernel function matmul_kernel!(a, b, c)
i, j = @index(Global, NTuple)
# creating a temporary sum variable for matrix multiplication
tmp_sum = zero(eltype(c))
for k = 1:size(a)[2]
@inbounds tmp_sum += a[i,k] * b[k, j]
end
c[i,j] = tmp_sum
end
ArrayT = Array
a = ArrayT(rand(128, 256))
b = ArrayT(rand(256, 128))
c = ArrayT(zeros(128, 128))
dev = CPU()
matmul = matmul_kernel!(dev, (32, 32))
wait(matmul(a, b, c, ndrange=size(c)))
@test c ≈ a*b
dmatmul = Enzyme.autodiff(matmul)
da = similar(a)
da .= 0
db = similar(b)
db .= 0
dc = similar(c)
dc .= 1
c .= 0
compare_dc = copy(dc)
evt = dmatmul(Duplicated(a, da), Duplicated(b, db), Duplicated(c, dc), ndrange=size(c))
# wait(evt)
# @test da ≈ compare_dc * b'
# @test db ≈ a' * compare_dc
We've tried KernelAbstractions v0.8.0 with the current Enzyme main branch, as well as KA v0.7 with KernelGradients and the last Enzyme release, etc.
Ah sorry you are in intermediate version hell. The latest Enzyme release doesn't work. You either need Enzyme ?0.7 or Enzyme#main.
But it should not hang... It should complain loudly
It definitely just hangs without output, also with Enzyme#main.
Ok I just tagged KernelGradients 0.1.1 that supports Enzyme 0.10, maybe try again?
Still stalls (tried with GPU) with both
(@v1.8) pkg> add Enzyme KernelAbstractions KernelGradients
and
(@v1.8) pkg> add Enzyme#main KernelAbstractions KernelGradients
(@v1.8) st
Status `/tmp/jl_hXzLRO/Project.toml`
[7da242da] Enzyme v0.10.0 `https://github.com/EnzymeAD/Enzyme.jl.git#main`
[63c18a36] KernelAbstractions v0.8.1
[e5faadeb] KernelGradients v0.1.1
(same on Julia v1.7).