TaylorDiff.jl
TaylorDiff.jl copied to clipboard
Zygote compatibility does not work for Julia 1.10+
I brought up this as an issue in the Zygote.jl
repository, but it might belong here:
Zygote
fails to use rrule
s defined by TaylorDiff
when run with Julia 1.10+
In Julia 1.9.4
:
import Zygote
import TaylorDiff
TaylorDiff.derivative(x -> sum(x .^ 2), [1.0, 2.0, 3.0], [0.0, 1.0, 0.0], :1) # works
Zygote.withgradient([1.0, 2.0, 3.0]) do x
TaylorDiff.derivative(x -> sum(x .^ 2), x, [0.0, 1.0, 0.0], :1)
end # works, returning (val = 4.0, grad = ([0.0, 2.0, 0.0],))
In Julia 1.10+
:
import Zygote
import TaylorDiff
TaylorDiff.derivative(x -> sum(x .^ 2), [1.0, 2.0, 3.0], [0.0, 1.0, 0.0], :1) # works
Zygote.withgradient([1.0, 2.0, 3.0]) do x
TaylorDiff.derivative(x -> sum(x .^ 2), x, [0.0, 1.0, 0.0], :1)
end # doesn't work
The last line gives the following error:
ERROR: Need an adjoint for constructor TaylorDiff.TaylorScalar{Float64, 2}. Gradient is of type TaylorDiff.TaylorScalar{Float64, 2}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] (::Zygote.Jnew{TaylorDiff.TaylorScalar{Float64, 2}, Nothing, false})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:330
[3] (::Zygote.var"#2210#back#313"{Zygote.Jnew{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[4] TaylorScalar
@ ~/.julia/packages/TaylorDiff/zNnz2/src/scalar.jl:17 [inlined]
[5] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[6] TaylorScalar
@ ~/.julia/packages/TaylorDiff/zNnz2/src/scalar.jl:22 [inlined]
[7] macro expansion
@ ~/.julia/packages/TaylorDiff/zNnz2/src/primitive.jl:143 [inlined]
[8] ^
@ ~/.julia/packages/TaylorDiff/zNnz2/src/primitive.jl:128 [inlined]
[9] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[10] literal_pow
@ ./intfuncs.jl:351 [inlined]
[11] (::Zygote.var"#1368#1374")(::Tuple{…}, ȳ₁::TaylorDiff.TaylorScalar{…})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/broadcast.jl:218
[12] #4
@ ./generator.jl:36 [inlined]
[13] iterate(g::Base.Generator, s::Vararg{Any})
@ Base ./generator.jl:47 [inlined]
[14] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{…}}, Base.var"#4#5"{Zygote.var"#1368#1374"}})
@ Base ./array.jl:834
[15] map
@ ./abstractarray.jl:3406 [inlined]
[16] (::Zygote.var"#∇broadcasted#1373"{…})(ȳ::FillArrays.Fill{…})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/broadcast.jl:218
[17] #4117#back
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
[18] #291
@ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
[19] #2169#back
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
[20] broadcasted
@ ./broadcast.jl:1347 [inlined]
[21] #8
@ ./REPL[4]:2 [inlined]
[22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[23] derivative
@ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:37 [inlined]
[24] derivative
@ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:23 [inlined]
[25] #7
@ ./REPL[4]:2 [inlined]
[26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[27] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
[28] withgradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:0
[29] top-level scope
@ REPL[4]:1