Enzyme.jl
Enzyme.jl copied to clipboard
Crash on `OMEinsum.einsum!`
I'm running into a huge error when trying to run the code below. I'm not sure if I'm missing a rule or is a Enzyme bug.
using Enzyme
using OMEinsum
function f(ssa1::Array{Float64, 2}, ssa2::Array{Float64, 2}, ssa3::Array{Float64, 2}, ssa4::Array{Float64, 2}, ssa5::Array{Float64, 2}, ssa6::Array{Float64, 2}, ssa7::Array{Float64, 0})
einsum!(((:A, :C), (:C, :B)), (:A, :B), (ssa1, ssa2), ssa5, true, false, (OMEinsum.get_size_dict)(((:A, :C), (:C, :B)), (ssa1, ssa2)))
einsum!(((:A, :C), (:C, :B)), (:A, :B), (ssa3, ssa4), ssa6, true, false, (OMEinsum.get_size_dict)(((:A, :C), (:C, :B)), (ssa3, ssa4)))
einsum!(((:A, :B), (:A, :B)), (), (ssa5, ssa6), ssa7, true, false, (OMEinsum.get_size_dict)(((:A, :B), (:A, :B)), (ssa5, ssa6)))
return only(ssa7)
end
x = [rand(2,2) for _ in 1:4]
tmp = [rand(2,2), rand(2,2)]
y = fill(0.0)
∇ = [zero.(x)..., zero.(tmp)..., zero(y)]
f(x..., tmp..., y)
autodiff(Reverse, f, Active, Duplicated.([x..., tmp..., y], ∇)...)
hey @mofeing this generates a ton of code and is going to be rather hard to debug as is. Can you reproduce this in a more minimal example?
I can also replicate this with TensorOperations
. Both of this "einsum" packages should be calling BLAS for these simple examples.
using Enzyme
using TensorOperations
function f(ssa1::Array{Float64, 2}, ssa2::Array{Float64, 2}, ssa3::Array{Float64, 2}, ssa4::Array{Float64, 2})
ssa5 = tensorcontract((1,3), ssa1, (1,2), :N, ssa2, (2,3), :N)
ssa6 = tensorcontract((1,3), ssa3, (1,2), :N, ssa4, (2,3), :N)
ssa7 = tensorcontract((), ssa5, (1,2), :N, ssa6, (1,2), :N)
return only(ssa7)
end
x = [rand(2,2) for _ in 1:4]
∇ = zero.(x)
f(x...)
autodiff(Reverse, f, Active, Duplicated.(x, ∇)...)
Going to still need this simplier, it makes hundreds of thousands of lines of IR -- which I can look through, but it'd be much more time efficient otherwise.
hey @mofeing this generates a ton of code and is going to be rather hard to debug as is. Can you reproduce this in a more minimal example?
The smallest I can do is just taking 2 vectors and doing a dot product.
using Enzyme
using OMEinsum
function f(ssa1::Array{Float64, 1}, ssa2::Array{Float64, 1}, ssa3::Array{Float64, 0})
einsum!(((:A,), (:A,)), (), (ssa1, ssa2), ssa3, true, false, (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2)))
return only(ssa3)
end
x = [rand(2) for _ in 1:2]
y = fill(0.0)
∇ = [zero.(x)..., zero(y)]
f(x..., y)
autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)
EDIT: Fixed some typos.
would you be able to inline the definitions/macros/etc from einsum? (and possibly simplify)
mmm I managed to skip one layer of OMEinsum and simplify this case to the following:
using Enzyme
using OMEinsum
rule = OMEinsum.SimpleBinaryRule{('j',), ('j',), ()}()
function h(ssa1, ssa2, ssa3)
OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
end
x = [rand(2), rand(2)]
y = fill(0.0)
∇ = map(zero, [x...,y])
h(x..., y)
autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)
And now I'm getting this error:
Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler [~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59](https://file+.vscode-resource.vscode-cdn.net/Users/mofeing/Developer/k-local-gradient-descent/notebooks/~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59)
Enzyme Mutability Error: Cannot add one in place to immutable value fill(0.0)
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] add_one_in_place
@ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5329 [inlined]
[3] augmented_julia_h_9637wrap
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0
[4] macro expansion
@ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5656 [inlined]
[5] enzyme_call
@ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5334 [inlined]
[6] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5223 [inlined]
[7] autodiff
@ ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:235 [inlined]
[8] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(h), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Array{Float64, 0}})
@ Enzyme ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:303
[9] top-level scope
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2
So that is a different issue, any luck minimizing the previous error?
but you should probably replace fill(0) with like zeros(2) to workaround the latter one for now if need be.
Ah, no. This is the only simplification I could do. It's all or nothing 🥲.
Also, I don't call fill(0.0)
inside h
... I did a small modification (now it's doing a element-wise multiplication of 2 vectors) and it's strange:
using Enzyme
using OMEinsum
rule = OMEinsum.SimpleBinaryRule{('l',), ('l',), ('l',)}()
function h(ssa1, ssa2, ssa3)
OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
end
x = [rand(2), rand(2)]
y = zeros(2)
∇ = map(zero, [x...,y])
h(x..., y)
autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)
which returns
Enzyme Mutability Error: Cannot add one in place to immutable value [0.0, 0.0]
Stacktrace:
[1] error
@ ./error.jl:35
[2] add_one_in_place
@ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5329 [inlined]
[3] augmented_julia_h_13355wrap
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0
[4] macro expansion
@ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5656 [inlined]
[5] enzyme_call
@ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5334 [inlined]
[6] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5223 [inlined]
[7] autodiff
@ ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:235 [inlined]
[8] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(h), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}})
@ Enzyme ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:303
[9] top-level scope
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2
I think I'm gonna need to add manually some rules for Enzyme and TensorOperations.
So the last error implies you are returning an array not a scalar
On Tue, May 7, 2024 at 4:15 PM Sergio Sánchez Ramírez < @.***> wrote:
Ah, no. This is the only simplification I could do. It's all or nothing 🥲.
Also, I don't call fill(0.0) inside h... I did a small modification (now it's doing a element-wise multiplication of 2 vectors) and it's strange:
using Enzymeusing OMEinsum
rule = OMEinsum.SimpleBinaryRule{('l',), ('l',), ('l',)}() function h(ssa1, ssa2, ssa3) OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)end
x = [rand(2), rand(2)] y = zeros(2) ∇ = map(zero, [x...,y]) h(x..., y) autodiff(Reverse, f, Active, Duplicated.([x..., y], ∇)...)
which returns
Enzyme Mutability Error: Cannot add one in place to immutable value [0.0, 0.0]
Stacktrace: [1] error @ ./error.jl:35 [2] add_one_in_place @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5329 [inlined] [3] augmented_julia_h_13355wrap @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0 [4] macro expansion @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5656 [inlined] [5] enzyme_call @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5334 [inlined] [6] AugmentedForwardThunk @ ~/.julia/packages/Enzyme/iGAtf/src/compiler.jl:5223 [inlined] [7] autodiff @ ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:235 [inlined] [8] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(h), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}) @ Enzyme ~/.julia/packages/Enzyme/iGAtf/src/Enzyme.jl:303 [9] top-level scope @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2
I think I'm gonna need to add manually some rules for Enzyme and TensorOperations.
— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1416#issuecomment-2099224257, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXCNGEJSHEFHJNJM77TZBEY45AVCNFSM6AAAAABHJ3YSESVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOJZGIZDIMRVG4 . You are receiving this because you commented.Message ID: <EnzymeAD/Enzyme. @.***>
Yes, in this last example is returning an array, and in the previous example a scalar.
But in both cases I get Enzyme Mutability Error: Cannot add one in place to immutable value
.
You cannot return an array when the return is marked active, you must return a scalar, so you should do only(...) for the last program
Ah okay, I just found the limitation in autodiff(::Reverse)
If I add a only
to h
,
function h(ssa1, ssa2, ssa3)
OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
return only(ssa3)
end
∇ = map(zero, [ssa1,ssa2,ssa3])
autodiff(Reverse, h, Active, Duplicated.([ssa1,ssa2,ssa3], ∇)...)
Then, I get:
julia> ∇
3-element Vector{Array{Float64}}:
[0.26640132425362695, 0.906387409016582]
[0.6451817492001096, 0.9773755209533468]
fill(0.0)
So I guess the problem is in between einsum!
and binary_einsum!
, which is the code in...
https://github.com/under-Peter/OMEinsum.jl/blob/327cf355c746e9f646c5beee74dcd2c11aa90240/src/einsum.jl#L99-L117
The mutability error is not a significant issue (it usually means you returned an array rather than scalar like here). The other issue with the long trace is the one I need a MWE for to fix
@mofeing I merged a jll bump which fixes some phi node issues. Can you seee if this persists?
If not, we should close.
nothing seems to have changed 🥲
but I think I might me on the way of having a MWE (this issue might not be the only source of error). I noticed that the einsum!
function uses @debug
and I can get Enzyme to run indefinitely by making it print a variable.
this works correctly
using OMEinsum
using Enzyme
x = [rand(2) for _ in 1:2]
y = zeros()
ssa1, ssa2, ssa3 = x..., y
function u(ssa1, ssa2, ssa3)
rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()
@debug "asdf"
OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
return only(ssa3)
end
this runs indefinitely
using OMEinsum
using Enzyme
x = [rand(2) for _ in 1:2]
y = zeros()
ssa1, ssa2, ssa3 = x..., y
function u(ssa1, ssa2, ssa3)
rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()
@debug "asdf" rule
OMEinsum.binary_einsum!(rule, ssa1, ssa2, ssa3, true, false)
return only(ssa3)
end
that's a useful issue to minimize/understand, but separate from the one you found. If you can minimize either individual issue, we can fix that issue.
I've updated to the latest Enzyme (v0.12.6) and the error seems a lil bit different (No more big explosions!). For the function f
below,
function f(ssa1::Array{Float64, 1}, ssa2::Array{Float64, 1}, ssa3::Array{Float64, 0})
einsum!(((:A,), (:A,)), (), (ssa1, ssa2), ssa3, true, false, (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2)))
return only(ssa3)
end
the error is the following:
Error on `autodiff(Reverse, f, ...)`
julia> ssa1, ssa2, ssa3 = rand(2), rand(2), zeros()
julia> ∇ = map(zero, [ssa1, ssa2, ssa3])
julia> autodiff(Reverse, f, Active, Duplicated.([ssa1, ssa2, ssa3], ∇)...)
Enzyme execution failed.
Mismatched activity for: store {} addrspace(10)* %.fca.0.0.1.0.0.extract5, {} addrspace(10)* addrspace(10)* %.fca.0.0.1.0.0.gep6, align 8, !dbg !72, !noalias !80 const val: %.fca.0.0.1.0.0.extract5 = extractvalue [2 x [1 x {} addrspace(10)*]] %0, 0, 0, !dbg !72
Type tree: {[-1]:Pointer}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now
Stacktrace:
[1] ntuple
@ ./ntuple.jl:49
[2] copy
@ ./broadcast.jl:1118
[3] materialize
@ ./broadcast.jl:903
[4] einsum!
@ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:100
[5] einsum!
@ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:0
Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:1325
[2] ntuple
@ ./ntuple.jl:49 [inlined]
[3] copy
@ ./broadcast.jl:1118 [inlined]
[4] materialize
@ ./broadcast.jl:903 [inlined]
[5] einsum!
@ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:100 [inlined]
[6] einsum!
@ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:0 [inlined]
[7] augmented_julia_einsum__4300_inner_1wrap
@ ~/.julia/packages/OMEinsum/zZBsQ/src/einsum.jl:0
[8] macro expansion
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5719 [inlined]
[9] enzyme_call
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5385 [inlined]
[10] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5275 [inlined]
[11] runtime_generic_augfwd(activity::Type{Val{(false, false, false, true, true, false, false, false)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true, false, true, true, true)}, RT::Val{@NamedTuple{1, 2, 3}}, f::typeof(einsum!), df::Nothing, primal_1::Tuple{Tuple{Symbol}, Tuple{Symbol}}, shadow_1_1::Nothing, primal_2::Tuple{}, shadow_2_1::Nothing, primal_3::Tuple{Vector{Float64}, Vector{Float64}}, shadow_3_1::Tuple{Vector{Float64}, Vector{Float64}}, primal_4::Array{Float64, 0}, shadow_4_1::Array{Float64, 0}, primal_5::Bool, shadow_5_1::Nothing, primal_6::Bool, shadow_6_1::Nothing, primal_7::Dict{Symbol, Int64}, shadow_7_1::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/rules/jitrules.jl:179
[12] f
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:5 [inlined]
[13] diffejulia_f_2479wrap
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0
[14] macro expansion
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5719 [inlined]
[15] enzyme_call
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5385 [inlined]
[16] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5264 [inlined]
[17] autodiff
@ ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:291 [inlined]
[18] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(f), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Array{Float64, 0}})
@ Enzyme ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:303
[19] top-level scope
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:2
I've taken the source code of the OMEinsum.einsum!
method that I'm using and removed the @debug
expressions and a if
branch that is not taken (only the else
part is run). The result is the u
method below:
function u(ssa1, ssa2, ssa3)
LT = Symbol
ixs = ((:A,), (:A,))
iy = ()
iyv = OMEinsum._collect(LT,iy)
ix1v,ix2v = OMEinsum._collect.(Ref(LT), ixs)
size_dict = (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2))
c1, c2, cy, s1, s2, s3, i1, i2, iyb = OMEinsum.analyze_binary(ix1v, ix2v, iyv, size_dict)
rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()
xs1 = OMEinsum.simplifyto(ix1v, c1, ssa1, size_dict)
xs2 = OMEinsum.simplifyto(ix2v, c2, ssa2, size_dict)
x1_ = OMEinsum.safe_reshape(xs1, s1)
x2_ = OMEinsum.safe_reshape(xs2, s2)
OMEinsum.binary_einsum!(rule, x1_, x2_, ssa3, true, false)
return only(ssa3)
end
This u
method fails similarly (but not equally, look for example the Type tree
) as f
:
Error on `autodiff(Reverse, u, ...)`
julia> autodiff(Reverse, u, Active, Duplicated.([ssa1, ssa2, ssa3], ∇)...)
Enzyme execution failed.
Mismatched activity for: %unbox24.fca.4.load.pn = phi {} addrspace(10)* [ %unbox24.fca.4.load, %L60 ], [ %unbox31.unpack61, %L64 ] const val: %unbox24.fca.4.load = load {} addrspace(10)*, {} addrspace(10)** %unbox24.fca.4.gep, align 8, !dbg !136
Type tree: {}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now
Stacktrace:
[1] u
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:12
Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:1325
[2] iterate
@ ./range.jl:897 [inlined]
[3] copyto!
@ ./abstractarray.jl:942 [inlined]
[4] _collect
@ ./array.jl:696 [inlined]
[5] collect
@ ./array.jl:694 [inlined]
[6] #91
@ ./none:0 [inlined]
[7] iterate
@ ./generator.jl:47 [inlined]
[8] collect
@ ./array.jl:834 [inlined]
[9] get_size_dict!
@ ~/.julia/packages/OMEinsum/zZBsQ/src/interfaces.jl:61 [inlined]
[10] get_size_dict
@ ~/.julia/packages/OMEinsum/zZBsQ/src/interfaces.jl:100 [inlined]
[11] get_size_dict
@ ~/.julia/packages/OMEinsum/zZBsQ/src/interfaces.jl:99 [inlined]
[12] u
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:9 [inlined]
[13] diffejulia_u_5824wrap
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:0
[14] macro expansion
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5719 [inlined]
[15] enzyme_call
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5385 [inlined]
[16] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5264 [inlined]
[17] autodiff
@ ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:291 [inlined]
[18] autodiff(::ReverseMode{false, FFIABI, false}, ::typeof(u), ::Type{Active}, ::Duplicated{Vector{Float64}}, ::Duplicated{Vector{Float64}}, ::Duplicated{Array{Float64, 0}})
@ Enzyme ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:303
[19] top-level scope
@ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum.ipynb:26
By setting OMEinsum.analyze_binary
to be inactive, u
works but f
continues to give the same error:
julia> EnzymeRules.inactive(::typeof(OMEinsum.analyze_binary), args...) = nothing
julia> ∇ = map(zero, [ssa1,ssa2,ssa3])
julia> autodiff(Reverse, u, Active, Duplicated.([ssa1,ssa2,ssa3], ∇)...)
┌ Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59
((nothing, nothing, nothing),)
julia> ∇
3-element Vector{Array{Float64}}:
[0.056936369463587044, 0.29215353698534263]
[0.4591629765253664, 0.672870606539631]
fill(0.0)
This is not working either...
function custom_einsum!(ixs, iy, @nospecialize(xs::NTuple{2, Any}), @nospecialize(y), sx, sy, size_dict::Dict{LT}) where LT
iyv = OMEinsum._collect(LT,iy)
ix1v, ix2v = OMEinsum._collect.(Ref(LT), ixs)
x1, x2 = xs
c1, c2, cy, s1, s2, s3, i1, i2, iyb = OMEinsum.analyze_binary(ix1v, ix2v, iyv, size_dict)
rule = OMEinsum.SimpleBinaryRule{(i1...,), (i2...,), (iyb...,)}()
xs1 = OMEinsum.simplifyto(ix1v, c1, x1, size_dict)
xs2 = OMEinsum.simplifyto(ix2v, c2, x2, size_dict)
x1_ = OMEinsum.safe_reshape(xs1, s1)
x2_ = OMEinsum.safe_reshape(xs2, s2)
# if cy != iyv
# y_ = similar(y, (s3...,))
# y_ = reshape(OMEinsum.binary_einsum!(rule, x1_, x2_, y_, true, false), [size_dict[x] for x in cy]...)
# return custom_einsum!((cy,), iyv, (y_,), y, sx, sy, size_dict)
# else
OMEinsum.binary_einsum!(rule, x1_, x2_, OMEinsum.safe_reshape(y, s3), sx, sy)
return y
# end
end
function custom_f(ssa1::Array{Float64, 1}, ssa2::Array{Float64, 1}, ssa3::Array{Float64, 0})
custom_einsum!(((:A,), (:A,)), (), (ssa1, ssa2), ssa3, true, false, (OMEinsum.get_size_dict)(((:A,), (:A,)), (ssa1, ssa2)))
return only(ssa3)
end
x = [rand(2) for _ in 1:2]
y = zeros()
custom_f(x..., y)
∇ = map(zero, [ssa1, ssa2, ssa3])
autodiff(Reverse, custom_f, Active, Duplicated.([ssa1, ssa2, ssa3], ∇)...)
Could it be that the problem is around the argument passing of einsum!
? Or maybe the u
example works because there is some kind of constant folding/propagation happening before LLVM?
@mofeing with a bunch of fixes now landed, how does this work presently?