Tullio.jl
Tullio.jl copied to clipboard
`\sqrt` is not correctly detected for gradient
Hey,
just the minor issue here:
julia> using Zygote, Tullio
julia> f(x) = @tullio r = √(x[i]^2 )
f (generic function with 1 method)
julia> h(x) = @tullio r = sqrt(x[i]^2)
h (generic function with 1 method)
julia> x = randn((5));
julia> f(x) ≈ h(x)
true
julia> gradient(f, x)
ERROR: no gradient definition here!
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Tullio.var"#215#216"{Tullio.Eval{var"#ℳ𝒶𝓀ℯ#8"{var"#𝒜𝒸𝓉!#7"}, Nothing}, Tuple{Vector{Float64}}, Float64})(Δ::Float64)
@ Tullio ~/.julia/packages/Tullio/IHd6P/src/grad/zygote.jl:7
[3] (::Tullio.var"#74#back#217"{Tullio.var"#215#216"{Tullio.Eval{var"#ℳ𝒶𝓀ℯ#8"{var"#𝒜𝒸𝓉!#7"}, Nothing}, Tuple{Vector{Float64}}, Float64}})(Δ::Float64)
@ Tullio ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] Pullback
@ ./REPL[9]:1 [inlined]
[5] (::typeof(∂(f)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[6] (::Zygote.var"#41#42"{typeof(∂(f))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
[7] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
[8] top-level scope
@ REPL[11]:1
julia> gradient(h, x)
([-1.0, 1.0, 1.0, 1.0, 1.0],)
Thanks,
Felix
That's an interesting quirk. The rules being used are from DiffRules.jl, which ForwardDiff.jl uses to define methods, and hence evaluates the Symbols. But Tullio just sees what you write, so doesn't find the rule:
julia> (√) === sqrt
true
julia> :√ == :sqrt
false
It would be easy to add special cases of course, but I wonder if there's a good way to standardise them?