Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Support MultiRequest

Open vchuravy opened this issue 1 month ago • 2 comments

MPI.jl has a useful MultiRequest API that stores raw MPI.API.MPI_Request under the hood.

using MPI
using Enzyme


MPI.Init()
comm = MPI.COMM_WORLD


function ring(token, comm)
    rank = MPI.Comm_rank(comm)
    N = MPI.Comm_size(comm)
    reqs = MPI.MultiRequest(2)
    # reqs = MPI.UnsafeMultiRequest(2)

    buf = Ref(token)

    if rank != 0
        MPI.Irecv!(buf, comm, reqs[2]; source = rank - 1)  
        MPI.Wait(reqs[2])
    end

    MPI.Isend(buf, comm, reqs[1]; dest = mod(rank + 1, N))

    if rank == 0
        MPI.Irecv!(buf, comm, reqs[2]; source = N - 1)
        MPI.Wait(reqs[2])
    end 
    MPI.Wait(reqs[1])

    return buf[]
end

token = MPI.Comm_rank(comm) == 0 ? 1.0 : NaN
@test ring(token, comm) == 1.0

autodiff(Forward, ring, Duplicated(1.0, 1.0), Const(comm))

Currently fails with:

julia: /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:307: void AdjointGenerator::handleMPI(llvm::CallInst&, llvm::Function*, llvm::StringRef): Assertion `!gutils->isConstantValue(call.getOperand(6))' failed.

[14280] signal (6.-6): Aborted
in expression starting at /home/vchuravy/src/Enzyme/test/integration/MPI/multi_request.jl:98
unknown function (ip: 0x7f16e8a9894c)
gsignal at /usr/lib/libc.so.6 (unknown line)
abort at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x7f16e8a254e2)
handleMPI at /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:307
handleKnownCallDerivatives at /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:2254
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:6405
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:5062
recursivelyHandleSubfunction at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:4984
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:6608
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:5062
EnzymeCreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:661
EnzymeCreateForwardDiff at /home/vchuravy/src/Enzyme/src/api.jl:342
unknown function (ip: 0x7f16e1732298)
_jl_invoke at /cache/build/builder-amdci5-7/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-7/julialang/julia-release-1-dot-10/src/gf.c:3077
macro expansion at /home/vchuravy/src/Enzyme/src/compiler.jl:2687 [inlined]
macro expansion at /home/vchuravy/.julia/packages/LLVM/iza6e/src/base.jl:97 [inlined]
enzyme! at /home/vchuravy/src/Enzyme/src/compiler.jl:2512

A MultiRequest looks like this:

julia> reqs = MPI.MultiRequest(4)
4-element MPI.MultiRequest:
 null request
 null request
 null request
 null request

julia> dump(reqs)
MPI.MultiRequest
  vals: Array{Int32}((4,)) Int32[738197504, 738197504, 738197504, 738197504]
  buffers: Array{Any}((4,))
    1: Nothing nothing
    2: Nothing nothing
    3: Nothing nothing
    4: Nothing nothing

The crux is that the raw MPI.API.MPI_Request is just an Int32

julia> MPI.API.MPI_Request
Int32

#2747 adds a realistic test

vchuravy avatar Nov 07 '25 18:11 vchuravy

There is a related issue with a code that doesn't use MultiRequest


MPI.Init()
comm = MPI.COMM_WORLD

function ring_no_mr(token, comm)
    rank = MPI.Comm_rank(comm)
    N = MPI.Comm_size(comm)
    reqs = Vector{MPI.Request}(undef, 2)
    fill!(reqs, MPI.REQUEST_NULL)

    buf = Ref(token)

    if rank != 0
        reqs[2] = MPI.Irecv!(buf, comm; source = rank - 1)  
        MPI.Wait(reqs[2])
    end

    reqs[1] = MPI.Isend(buf, comm; dest = mod(rank + 1, N))

    if rank == 0
        reqs[2] = MPI.Irecv!(buf, comm; source = N - 1)
        MPI.Wait(reqs[2])
    end 
    MPI.Wait(reqs[1])

    return buf[]
end

token = MPI.Comm_rank(comm) == 0 ? 1.0 : NaN
@test ring_no_mr(token, comm) == 1.0

@show autodiff(Forward, ring_no_mr, Duplicated(1.0, 1.0), Const(comm))

Fails due to the fill!(reqs, MPI.REQUEST_NULL)

vchuravy@loki ~/s/E/t/i/MPI (vc/multi_request) [SIGINT]> julia +1.10 --project=. multi_request.jl
ERROR: LoadError: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
 Failure within method: MethodInstance for ring_no_mr(::Float64, ::MPI.Comm)
Hint: catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.
If you have Cthulu.jl loaded you can also use `code_typed(err; interactive = true)` to interactively introspect the code.
Mismatched activity for:   store atomic {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140155490996080 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspace(13)* %81 release, align 8, !dbg !103, !tbaa !240, !alias.scope !35, !noalias !51 const val: {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140155490996080 to {}*) to {} addrspace(10)*)
 value=MPI.Request: null request of type MPI.Request
 llvalue={} addrspace(10)* addrspacecast ({}* inttoptr (i64 140155490996080 to {}*) to {} addrspace(10)*)

Stacktrace:
 [1] setindex!
   @ ./array.jl:1021
 [2] fill!
   @ ./array.jl:395
 [3] ring_no_mr
   @ ~/src/Enzyme/test/integration/MPI/multi_request.jl:13

Stacktrace:
  [1] setindex!
    @ ./array.jl:1021 [inlined]
  [2] fill!
    @ ./array.jl:395 [inlined]
  [3] ring_no_mr
    @ ~/src/Enzyme/test/integration/MPI/multi_request.jl:13 [inlined]
  [4] fwddiffejulia_ring_no_mr_155wrap
    @ ~/src/Enzyme/test/integration/MPI/multi_request.jl:0

vchuravy avatar Nov 07 '25 18:11 vchuravy

Thankfully:

function ring_no_mr_shadow_req(token, comm, reqs)
    rank = MPI.Comm_rank(comm)
    N = MPI.Comm_size(comm)

    buf = Ref(token)

    if rank != 0
        reqs[2] = MPI.Irecv!(buf, comm; source = rank - 1)  
        MPI.Wait(reqs[2])
    end

    reqs[1] = MPI.Isend(buf, comm; dest = mod(rank + 1, N))

    if rank == 0
        reqs[2] = MPI.Irecv!(buf, comm; source = N - 1)
        MPI.Wait(reqs[2])
    end 
    MPI.Wait(reqs[1])

    return buf[]
end

token = MPI.Comm_rank(comm) == 0 ? 1.0 : NaN
reqs = Vector{MPI.Request}(undef, 2)
@test ring_no_mr_shadow_req(token, comm, reqs) == 1.0

reqs = Vector{MPI.Request}(undef, 2)
dreqs = Vector{MPI.Request}(undef, 2)
@show autodiff(Forward, ring_no_mr_shadow_req, Duplicated(1.0, 1.0), Const(comm), Duplicated(reqs, dreqs))

Does work.

vchuravy avatar Nov 07 '25 18:11 vchuravy