Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Crash on `OMEinsum.einsum!`

Open mofeing opened this issue 9 months ago • 20 comments

I'm running into a huge error when trying to run the code below. I'm not sure if I'm missing a rule or is a Enzyme bug.

error.txt

using Enzyme
using OMEinsum

function f(ssa1::Array{Float64, 2}, ssa2::Array{Float64, 2}, ssa3::Array{Float64, 2}, ssa4::Array{Float64, 2}, ssa5::Array{Float64, 2}, ssa6::Array{Float64, 2}, ssa7::Array{Float64, 0})
	einsum!(((:A, :C), (:C, :B)), (:A, :B), (ssa1, ssa2), ssa5, true, false, (OMEinsum.get_size_dict)(((:A, :C), (:C, :B)), (ssa1, ssa2)))
	einsum!(((:A, :C), (:C, :B)), (:A, :B), (ssa3, ssa4), ssa6, true, false, (OMEinsum.get_size_dict)(((:A, :C), (:C, :B)), (ssa3, ssa4)))
	einsum!(((:A, :B), (:A, :B)), (), (ssa5, ssa6), ssa7, true, false, (OMEinsum.get_size_dict)(((:A, :B), (:A, :B)), (ssa5, ssa6)))
	return only(ssa7)
end

x = [rand(2,2) for _ in 1:4]
tmp = [rand(2,2), rand(2,2)]
y = fill(0.0)
∇ = [zero.(x)..., zero.(tmp)..., zero(y)]

f(x..., tmp..., y)

autodiff(Reverse, f, Active, Duplicated.([x..., tmp..., y], ∇)...)

mofeing avatar May 06 '24 22:05 mofeing

hey @mofeing this generates a ton of code and is going to be rather hard to debug as is. Can you reproduce this in a more minimal example?

wsmoses avatar May 07 '24 17:05 wsmoses

I can also replicate this with TensorOperations. Both of this "einsum" packages should be calling BLAS for these simple examples.

using Enzyme
using TensorOperations

function f(ssa1::Array{Float64, 2}, ssa2::Array{Float64, 2}, ssa3::Array{Float64, 2}, ssa4::Array{Float64, 2})
	ssa5 = tensorcontract((1,3), ssa1, (1,2), :N, ssa2, (2,3), :N)
	ssa6 = tensorcontract((1,3), ssa3, (1,2), :N, ssa4, (2,3), :N)
	ssa7 = tensorcontract((), ssa5, (1,2), :N, ssa6, (1,2), :N)
	return only(ssa7)
end

x = [rand(2,2) for _ in 1:4]
∇ = zero.(x)

f(x...)

autodiff(Reverse, f, Active, Duplicated.(x, ∇)...)

mofeing avatar May 07 '24 17:05 mofeing

Going to still need this simplier, it makes hundreds of thousands of lines of IR -- which I can look through, but it'd be much more time efficient otherwise.

wsmoses avatar May 07 '24 17:05 wsmoses

hey @mofeing this generates a ton of code and is going to be rather hard to debug as is. Can you reproduce this in a more minimal example?

The smallest I can do is just taking 2 vectors and doing a dot product.

using Enzyme
using OMEinsum

function f(ssa1::Array{Float64, 1}, ssa2::Array{Float64, 1}, ssa3::Array{Float64, 0})
	einsum!(((:A,), (:A,)), (), (ssa1, ssa2), ssa3, true, false, (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2)))
	return only(ssa3)
end

x = [rand(2) for _ in 1:2]
y = fill(0.0)
∇ = [zero.(x)..., zero(y)]

f(x..., y)

autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)

EDIT: Fixed some typos.

mofeing avatar May 07 '24 17:05 mofeing

would you be able to inline the definitions/macros/etc from einsum? (and possibly simplify)

wsmoses avatar May 07 '24 17:05 wsmoses

mmm I managed to skip one layer of OMEinsum and simplify this case to the following:

using Enzyme
using OMEinsum

rule = OMEinsum.SimpleBinaryRule{('j',), ('j',), ()}()

function h(ssa1, ssa2, ssa3)
    OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
end

x = [rand(2), rand(2)]
y = fill(0.0)
∇ = map(zero, [x...,y])

h(x..., y)

autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)

And now I'm getting this error:

Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler [~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59](https://file+.vscode-resource.vscode-cdn.net/Users/mofeing/Developer/k-local-gradient-descent/notebooks/~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59)
Enzyme Mutability Error: Cannot add one in place to immutable value fill(0.0)

Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] add_one_in_place
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5329 [inlined]
 [3] augmented_julia_h_9637wrap
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0
 [4] macro expansion
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5656 [inlined]
 [5] enzyme_call
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5334 [inlined]
 [6] AugmentedForwardThunk
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5223 [inlined]
 [7] autodiff
   @ ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:235 [inlined]
 [8] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(h), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Array{Float64, 0}})
   @ Enzyme ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:303
 [9] top-level scope
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2

mofeing avatar May 07 '24 18:05 mofeing

So that is a different issue, any luck minimizing the previous error?

wsmoses avatar May 07 '24 19:05 wsmoses

but you should probably replace fill(0) with like zeros(2) to workaround the latter one for now if need be.

wsmoses avatar May 07 '24 19:05 wsmoses

Ah, no. This is the only simplification I could do. It's all or nothing 🥲.

Also, I don't call fill(0.0) inside h... I did a small modification (now it's doing a element-wise multiplication of 2 vectors) and it's strange:

using Enzyme
using OMEinsum

rule = OMEinsum.SimpleBinaryRule{('l',), ('l',), ('l',)}()

function h(ssa1, ssa2, ssa3)
    OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
end

x = [rand(2), rand(2)]
y = zeros(2)
∇ = map(zero, [x...,y])

h(x..., y)

autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)

which returns

Enzyme Mutability Error: Cannot add one in place to immutable value [0.0, 0.0]

Stacktrace:
 [1] error
   @ ./error.jl:35
 [2] add_one_in_place
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5329 [inlined]
 [3] augmented_julia_h_13355wrap
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0
 [4] macro expansion
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5656 [inlined]
 [5] enzyme_call
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5334 [inlined]
 [6] AugmentedForwardThunk
   @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5223 [inlined]
 [7] autodiff
   @ ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:235 [inlined]
 [8] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(h), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}})
   @ Enzyme ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:303
 [9] top-level scope
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2

I think I'm gonna need to add manually some rules for Enzyme and TensorOperations.

mofeing avatar May 07 '24 20:05 mofeing

So the last error implies you are returning an array not a scalar

On Tue, May 7, 2024 at 4:15 PM Sergio Sánchez Ramírez < @.***> wrote:

Ah, no. This is the only simplification I could do. It's all or nothing 🥲.

Also, I don't call fill(0.0) inside h... I did a small modification (now it's doing a element-wise multiplication of 2 vectors) and it's strange:

using Enzymeusing OMEinsum

rule = OMEinsum.SimpleBinaryRule{('l',), ('l',), ('l',)}() function h(ssa1, ssa2, ssa3) OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)end

x = [rand(2), rand(2)] y = zeros(2) ∇ = map(zero, [x...,y]) h(x..., y) autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)

which returns

Enzyme Mutability Error: Cannot add one in place to immutable value [0.0, 0.0]

Stacktrace: [1] error @ ./error.jl:35 [2] add_one_in_place @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5329 [inlined] [3] augmented_julia_h_13355wrap @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0 [4] macro expansion @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5656 [inlined] [5] enzyme_call @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5334 [inlined] [6] AugmentedForwardThunk @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5223 [inlined] [7] autodiff @ ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:235 [inlined] [8] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(h), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}) @ Enzyme ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:303 [9] top-level scope @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2

I think I'm gonna need to add manually some rules for Enzyme and TensorOperations.

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1416#issuecomment-2099224257, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXCNGEJSHEFHJNJM77TZBEY45AVCNFSM6AAAAABHJ3YSESVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOJZGIZDIMRVG4 . You are receiving this because you commented.Message ID: <EnzymeAD/Enzyme. @.***>

wsmoses avatar May 07 '24 20:05 wsmoses

Yes, in this last example is returning an array, and in the previous example a scalar.

But in both cases I get Enzyme Mutability Error: Cannot add one in place to immutable value.

mofeing avatar May 07 '24 20:05 mofeing

You cannot return an array when the return is marked active, you must return a scalar, so you should do only(...) for the last program

wsmoses avatar May 07 '24 20:05 wsmoses

Ah okay, I just found the limitation in autodiff(::Reverse)

If I add a only to h,

function h(ssa1, ssa2, ssa3)
    OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
    return only(ssa3)
end

∇ = map(zero, [ssa1,ssa2,ssa3])
autodiff(Reverse, h, Active, Duplicated.([ssa1,ssa2,ssa3], ∇)...)

Then, I get:

julia> ∇
3-element Vector{Array{Float64}}:
 [0.26640132425362695, 0.906387409016582]
 [0.6451817492001096, 0.9773755209533468]
 fill(0.0)

mofeing avatar May 07 '24 20:05 mofeing

So I guess the problem is in between einsum! and binary_einsum!, which is the code in... https://github.com/under-Peter/OMEinsum.jl/blob/327cf355c746e9f646c5beee74dcd2c11aa90240/src/einsum.jl#L99-L117

mofeing avatar May 07 '24 20:05 mofeing

The mutability error is not a significant issue (it usually means you returned an array rather than scalar like here). The other issue with the long trace is the one I need a MWE for to fix

wsmoses avatar May 07 '24 20:05 wsmoses

@mofeing I merged a jll bump which fixes some phi node issues. Can you seee if this persists?

If not, we should close.

wsmoses avatar May 11 '24 17:05 wsmoses

nothing seems to have changed 🥲

but I think I might me on the way of having a MWE (this issue might not be the only source of error). I noticed that the einsum! function uses @debug and I can get Enzyme to run indefinitely by making it print a variable.

this works correctly

using OMEinsum
using Enzyme

x = [rand(2) for _ in 1:2]
y = zeros()
ssa1, ssa2, ssa3 = x..., y

function u(ssa1, ssa2, ssa3)
    rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()

    @debug "asdf"

    OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
    return only(ssa3)
end

this runs indefinitely

using OMEinsum
using Enzyme

x = [rand(2) for _ in 1:2]
y = zeros()
ssa1, ssa2, ssa3 = x..., y

function u(ssa1, ssa2, ssa3)
    rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()

    @debug "asdf" rule

    OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
    return only(ssa3)
end

mofeing avatar May 11 '24 19:05 mofeing

that's a useful issue to minimize/understand, but separate from the one you found. If you can minimize either individual issue, we can fix that issue.

wsmoses avatar May 12 '24 20:05 wsmoses

I've updated to the latest Enzyme (v0.12.6) and the error seems a lil bit different (No more big explosions!). For the function f below,

function f(ssa1::Array{Float64, 1}, ssa2::Array{Float64, 1}, ssa3::Array{Float64, 0})
	einsum!(((:A,), (:A,)), (), (ssa1, ssa2), ssa3, true, false, (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2)))
	return only(ssa3)
end

the error is the following:

Error on `autodiff(Reverse, f, ...)`
julia> ssa1, ssa2, ssa3 = rand(2), rand(2), zeros()

julia> ∇ = map(zero, [ssa1, ssa2, ssa3])

julia> autodiff(Reverse, f, Active, Duplicated.([ssa1, ssa2, ssa3], ∇)...)
Enzyme execution failed.
Mismatched activity for:   store {} addrspace(10)* %.fca.0.0.1.0.0.extract5, {} addrspace(10)* addrspace(10)* %.fca.0.0.1.0.0.gep6, align 8, !dbg !72, !noalias !80 const val:   %.fca.0.0.1.0.0.extract5 = extractvalue [2 x [1 x {} addrspace(10)*]] %0, 0, 0, !dbg !72
Type tree: {[-1]:Pointer}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
 [1] ntuple
   @ ./ntuple.jl:49
 [2] copy
   @ ./broadcast.jl:1118
 [3] materialize
   @ ./broadcast.jl:903
 [4] einsum!
   @ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:100
 [5] einsum!
   @ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:0


Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:1325
  [2] ntuple
    @ ./ntuple.jl:49 [inlined]
  [3] copy
    @ ./broadcast.jl:1118 [inlined]
  [4] materialize
    @ ./broadcast.jl:903 [inlined]
  [5] einsum!
    @ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:100 [inlined]
  [6] einsum!
    @ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:0 [inlined]
  [7] augmented_julia_einsum__4300_inner_1wrap
    @ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:0
  [8] macro expansion
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5719 [inlined]
  [9] enzyme_call
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5385 [inlined]
 [10] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5275 [inlined]
 [11] runtime_generic_augfwd(activity::Type{Val{(false, false, false, true, true, false, false, false)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true, false, true, true, true)}, RT::Val{@NamedTuple{1, 2, 3}}, f::typeof(einsum!), df::Nothing, primal_1::Tuple{Tuple{Symbol}, Tuple{Symbol}}, shadow_1_1::Nothing, primal_2::Tuple{}, shadow_2_1::Nothing, primal_3::Tuple{Vector{Float64}, Vector{Float64}}, shadow_3_1::Tuple{Vector{Float64}, Vector{Float64}}, primal_4::Array{Float64, 0}, shadow_4_1::Array{Float64, 0}, primal_5::Bool, shadow_5_1::Nothing, primal_6::Bool, shadow_6_1::Nothing, primal_7::Dict{Symbol, Int64}, shadow_7_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/rules/jitrules.jl:179
 [12] f
    @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:5 [inlined]
 [13] diffejulia_f_2479wrap
    @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0
 [14] macro expansion
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5719 [inlined]
 [15] enzyme_call
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5385 [inlined]
 [16] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5264 [inlined]
 [17] autodiff
    @ ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:291 [inlined]
 [18] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(f), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Array{Float64, 0}})
    @ Enzyme ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:303
 [19] top-level scope
    @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2

I've taken the source code of the OMEinsum.einsum! method that I'm using and removed the @debug expressions and a if branch that is not taken (only the else part is run). The result is the u method below:

function u(ssa1, ssa2, ssa3)
    LT = Symbol
    ixs = ((:A,), (:A,))
    iy = ()

    iyv = OMEinsum._collect(LT,iy)
    ix1v,ix2v = OMEinsum._collect.(Ref(LT), ixs)

    size_dict = (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2))

    c1, c2, cy, s1, s2, s3, i1, i2, iyb = OMEinsum.analyze_binary(ix1v, ix2v, iyv, size_dict)
    rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()

    xs1 = OMEinsum.simplifyto(ix1v, c1, ssa1, size_dict)
    xs2 = OMEinsum.simplifyto(ix2v, c2, ssa2, size_dict)
    x1_ = OMEinsum.safe_reshape(xs1, s1)
    x2_ = OMEinsum.safe_reshape(xs2, s2)

    OMEinsum.binary_einsum!(rule, x1_, x2_, ssa3, true, false)
    return only(ssa3)
end

This u method fails similarly (but not equally, look for example the Type tree) as f:

Error on `autodiff(Reverse, u, ...)`
julia> autodiff(Reverse, u, Active, Duplicated.([ssa1, ssa2, ssa3], ∇)...)
Enzyme execution failed.
Mismatched activity for:   %unbox24.fca.4.load.pn = phi {} addrspace(10)* [ %unbox24.fca.4.load, %L60 ], [ %unbox31.unpack61, %L64 ] const val:   %unbox24.fca.4.load = load {} addrspace(10)*, {} addrspace(10)** %unbox24.fca.4.gep, align 8, !dbg !136
Type tree: {}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
 [1] u
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:12


Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:1325
  [2] iterate
    @ ./range.jl:897 [inlined]
  [3] copyto!
    @ ./abstractarray.jl:942 [inlined]
  [4] _collect
    @ ./array.jl:696 [inlined]
  [5] collect
    @ ./array.jl:694 [inlined]
  [6] #91
    @ ./none:0 [inlined]
  [7] iterate
    @ ./generator.jl:47 [inlined]
  [8] collect
    @ ./array.jl:834 [inlined]
  [9] get_size_dict!
    @ ~/.julia/packages/OMEinsum/zZBsQ/src/interfaces.jl:61 [inlined]
 [10] get_size_dict
    @ ~/.julia/packages/OMEinsum/zZBsQ/src/interfaces.jl:100 [inlined]
 [11] get_size_dict
    @ ~/.julia/packages/OMEinsum/zZBsQ/src/interfaces.jl:99 [inlined]
 [12] u
    @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:9 [inlined]
 [13] diffejulia_u_5824wrap
    @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0
 [14] macro expansion
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5719 [inlined]
 [15] enzyme_call
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5385 [inlined]
 [16] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5264 [inlined]
 [17] autodiff
    @ ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:291 [inlined]
 [18] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(u), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Array{Float64, 0}})
    @ Enzyme ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:303
 [19] top-level scope
    @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:26

By setting OMEinsum.analyze_binary to be inactive, u works but f continues to give the same error:

julia> EnzymeRules.inactive(::typeof(OMEinsum.analyze_binary), args...) = nothing

julia> ∇ = map(zero, [ssa1,ssa2,ssa3])

julia> autodiff(Reverse, u, Active, Duplicated.([ssa1,ssa2,ssa3], ∇)...)
┌ Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59

((nothing, nothing, nothing),)

julia> ∇
3-element Vector{Array{Float64}}:
 [0.056936369463587044, 0.29215353698534263]
 [0.4591629765253664, 0.672870606539631]
 fill(0.0)

mofeing avatar May 14 '24 15:05 mofeing

This is not working either...

function custom_einsum!(ixs, iy, @nospecialize(xs::NTuple{2, Any}), @nospecialize(y), sx, sy, size_dict::Dict{LT}) where LT
    iyv = OMEinsum._collect(LT,iy)
    ix1v, ix2v = OMEinsum._collect.(Ref(LT), ixs)

    x1, x2 = xs
    c1, c2, cy, s1, s2, s3, i1, i2, iyb = OMEinsum.analyze_binary(ix1v, ix2v, iyv, size_dict)
    rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()
    xs1 = OMEinsum.simplifyto(ix1v, c1, x1, size_dict)
    xs2 = OMEinsum.simplifyto(ix2v, c2, x2, size_dict)
    x1_ = OMEinsum.safe_reshape(xs1, s1)
    x2_ = OMEinsum.safe_reshape(xs2, s2)

    # if cy != iyv
    #     y_ = similar(y, (s3...,))
    #     y_ = reshape(OMEinsum.binary_einsum!(rule, x1_, x2_, y_, true, false), [size_dict[x] for x in cy]...)
    #     return custom_einsum!((cy,), iyv, (y_,), y, sx, sy, size_dict)
    # else
        OMEinsum.binary_einsum!(rule, x1_, x2_, OMEinsum.safe_reshape(y, s3), sx, sy)
        return y
    # end
end

function custom_f(ssa1::Array{Float64, 1}, ssa2::Array{Float64, 1}, ssa3::Array{Float64, 0})
	custom_einsum!(((:A,), (:A,)), (), (ssa1, ssa2), ssa3, true, false, (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2)))
	return only(ssa3)
end

x = [rand(2) for _ in 1:2]
y = zeros()

custom_f(x..., y)

∇ = map(zero, [ssa1, ssa2, ssa3])
autodiff(Reverse, custom_f, Active, Duplicated.([ssa1, ssa2, ssa3], ∇)...)

Could it be that the problem is around the argument passing of einsum!? Or maybe the u example works because there is some kind of constant folding/propagation happening before LLVM?

mofeing avatar May 14 '24 15:05 mofeing

@mofeing with a bunch of fixes now landed, how does this work presently?

wsmoses avatar May 24 '24 04:05 wsmoses