OMEinsum.jl
                                
                                 OMEinsum.jl copied to clipboard
                                
                                    OMEinsum.jl copied to clipboard
                            
                            
                            
                        Customize Enzyme rules
The following is an example of unsuccessful trail. The return value is not backwarded properly.
julia> using Enzyme, Enzyme.EnzymeRules, OMEinsum
julia> function EnzymeRules.augmented_primal(
                config::EnzymeRules.ConfigWidth{1},
                func::Const{typeof(einsum)}, ::Type{<:Duplicated}, 
                code::Const, xs::Duplicated, size_dict)
           @info("In custom augmented primal rule.")
           # Compute primal
           if EnzymeRules.needs_primal(config)
               primal = func.val(code.val, xs.val, size_dict.val); 
                        shadow=zero(primal)
           else
               primal, shadow = nothing, nothing
           end
           # Save x in tape if x will be overwritten
           @info EnzymeRules.overwritten(config)
           if EnzymeRules.overwritten(config)[3]
               tape = copy(xs.val)
           else
               tape = nothing
           end
           return EnzymeRules.AugmentedReturn(primal, shadow, tape)
       end
julia> function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1},
               func::Const{typeof(einsum)}, dret::Type{<:Duplicated}, tape,
               code::Const,
               xs::Duplicated, size_dict)
   @info """In custom reverse rule: $config.
I was expecting `drect` to be an object rather than a type!!!!"""
   xval = EnzymeRules.overwritten(config)[3] ? tape : xs.val
   for i=1:length(xs.val)
       xs.dval[i] .+= OMEinsum.einsum_grad(OMEinsum.getixs(code.val),
             xval, OMEinsum.getiy(code.val), size_dict.val, conj(dret.dval), i)
   end
   return ()
end
julia> x = randn(3, 3);
julia> gx = zero(x);
julia> autodiff(ReverseWithPrimal, x->sum(einsum(ein"ii->i", x, Dict('i'=>3))),
                                     Duplicated((x,), (gx,)))
[ Info: In custom augmented primal rule.
[ Info: (false, false, false, true)
[ Info: In custom reverse rule: ConfigWidth{1, true, true, (false, 
              false, false, true)}(). 
    I was expecting `drect` to be an object rather than a type!!!!
ERROR: type DataType has no field dval