Enzyme.jl
Enzyme.jl copied to clipboard
Support MultiRequest
MPI.jl has a useful MultiRequest API that stores raw MPI.API.MPI_Request under the hood.
using MPI
using Enzyme
MPI.Init()
comm = MPI.COMM_WORLD
function ring(token, comm)
rank = MPI.Comm_rank(comm)
N = MPI.Comm_size(comm)
reqs = MPI.MultiRequest(2)
# reqs = MPI.UnsafeMultiRequest(2)
buf = Ref(token)
if rank != 0
MPI.Irecv!(buf, comm, reqs[2]; source = rank - 1)
MPI.Wait(reqs[2])
end
MPI.Isend(buf, comm, reqs[1]; dest = mod(rank + 1, N))
if rank == 0
MPI.Irecv!(buf, comm, reqs[2]; source = N - 1)
MPI.Wait(reqs[2])
end
MPI.Wait(reqs[1])
return buf[]
end
token = MPI.Comm_rank(comm) == 0 ? 1.0 : NaN
@test ring(token, comm) == 1.0
autodiff(Forward, ring, Duplicated(1.0, 1.0), Const(comm))
Currently fails with:
julia: /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:307: void AdjointGenerator::handleMPI(llvm::CallInst&, llvm::Function*, llvm::StringRef): Assertion `!gutils->isConstantValue(call.getOperand(6))' failed.
[14280] signal (6.-6): Aborted
in expression starting at /home/vchuravy/src/Enzyme/test/integration/MPI/multi_request.jl:98
unknown function (ip: 0x7f16e8a9894c)
gsignal at /usr/lib/libc.so.6 (unknown line)
abort at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x7f16e8a254e2)
handleMPI at /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:307
handleKnownCallDerivatives at /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:2254
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:6405
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:5062
recursivelyHandleSubfunction at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:4984
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:6608
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:5062
EnzymeCreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:661
EnzymeCreateForwardDiff at /home/vchuravy/src/Enzyme/src/api.jl:342
unknown function (ip: 0x7f16e1732298)
_jl_invoke at /cache/build/builder-amdci5-7/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-7/julialang/julia-release-1-dot-10/src/gf.c:3077
macro expansion at /home/vchuravy/src/Enzyme/src/compiler.jl:2687 [inlined]
macro expansion at /home/vchuravy/.julia/packages/LLVM/iza6e/src/base.jl:97 [inlined]
enzyme! at /home/vchuravy/src/Enzyme/src/compiler.jl:2512
A MultiRequest looks like this:
julia> reqs = MPI.MultiRequest(4)
4-element MPI.MultiRequest:
null request
null request
null request
null request
julia> dump(reqs)
MPI.MultiRequest
vals: Array{Int32}((4,)) Int32[738197504, 738197504, 738197504, 738197504]
buffers: Array{Any}((4,))
1: Nothing nothing
2: Nothing nothing
3: Nothing nothing
4: Nothing nothing
The crux is that the raw MPI.API.MPI_Request is just an Int32
julia> MPI.API.MPI_Request
Int32
#2747 adds a realistic test
There is a related issue with a code that doesn't use MultiRequest
MPI.Init()
comm = MPI.COMM_WORLD
function ring_no_mr(token, comm)
rank = MPI.Comm_rank(comm)
N = MPI.Comm_size(comm)
reqs = Vector{MPI.Request}(undef, 2)
fill!(reqs, MPI.REQUEST_NULL)
buf = Ref(token)
if rank != 0
reqs[2] = MPI.Irecv!(buf, comm; source = rank - 1)
MPI.Wait(reqs[2])
end
reqs[1] = MPI.Isend(buf, comm; dest = mod(rank + 1, N))
if rank == 0
reqs[2] = MPI.Irecv!(buf, comm; source = N - 1)
MPI.Wait(reqs[2])
end
MPI.Wait(reqs[1])
return buf[]
end
token = MPI.Comm_rank(comm) == 0 ? 1.0 : NaN
@test ring_no_mr(token, comm) == 1.0
@show autodiff(Forward, ring_no_mr, Duplicated(1.0, 1.0), Const(comm))
Fails due to the fill!(reqs, MPI.REQUEST_NULL)
vchuravy@loki ~/s/E/t/i/MPI (vc/multi_request) [SIGINT]> julia +1.10 --project=. multi_request.jl
ERROR: LoadError: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Failure within method: MethodInstance for ring_no_mr(::Float64, ::MPI.Comm)
Hint: catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.
If you have Cthulu.jl loaded you can also use `code_typed(err; interactive = true)` to interactively introspect the code.
Mismatched activity for: store atomic {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140155490996080 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspace(13)* %81 release, align 8, !dbg !103, !tbaa !240, !alias.scope !35, !noalias !51 const val: {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140155490996080 to {}*) to {} addrspace(10)*)
value=MPI.Request: null request of type MPI.Request
llvalue={} addrspace(10)* addrspacecast ({}* inttoptr (i64 140155490996080 to {}*) to {} addrspace(10)*)
Stacktrace:
[1] setindex!
@ ./array.jl:1021
[2] fill!
@ ./array.jl:395
[3] ring_no_mr
@ ~/src/Enzyme/test/integration/MPI/multi_request.jl:13
Stacktrace:
[1] setindex!
@ ./array.jl:1021 [inlined]
[2] fill!
@ ./array.jl:395 [inlined]
[3] ring_no_mr
@ ~/src/Enzyme/test/integration/MPI/multi_request.jl:13 [inlined]
[4] fwddiffejulia_ring_no_mr_155wrap
@ ~/src/Enzyme/test/integration/MPI/multi_request.jl:0
Thankfully:
function ring_no_mr_shadow_req(token, comm, reqs)
rank = MPI.Comm_rank(comm)
N = MPI.Comm_size(comm)
buf = Ref(token)
if rank != 0
reqs[2] = MPI.Irecv!(buf, comm; source = rank - 1)
MPI.Wait(reqs[2])
end
reqs[1] = MPI.Isend(buf, comm; dest = mod(rank + 1, N))
if rank == 0
reqs[2] = MPI.Irecv!(buf, comm; source = N - 1)
MPI.Wait(reqs[2])
end
MPI.Wait(reqs[1])
return buf[]
end
token = MPI.Comm_rank(comm) == 0 ? 1.0 : NaN
reqs = Vector{MPI.Request}(undef, 2)
@test ring_no_mr_shadow_req(token, comm, reqs) == 1.0
reqs = Vector{MPI.Request}(undef, 2)
dreqs = Vector{MPI.Request}(undef, 2)
@show autodiff(Forward, ring_no_mr_shadow_req, Duplicated(1.0, 1.0), Const(comm), Duplicated(reqs, dreqs))
Does work.