Tullio.jl icon indicating copy to clipboard operation
Tullio.jl copied to clipboard

`\sqrt` is not correctly detected for gradient

Open roflmaostc opened this issue 4 years ago • 1 comments

Hey,

just the minor issue here:

julia> using Zygote, Tullio

julia> f(x) = @tullio r = √(x[i]^2 )
f (generic function with 1 method)

julia> h(x) = @tullio r = sqrt(x[i]^2)
h (generic function with 1 method)


julia> x = randn((5));

julia> f(x) ≈ h(x)
true

julia> gradient(f, x)
ERROR: no gradient definition here!
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] (::Tullio.var"#215#216"{Tullio.Eval{var"#ℳ𝒶𝓀ℯ#8"{var"#𝒜𝒸𝓉!#7"}, Nothing}, Tuple{Vector{Float64}}, Float64})(Δ::Float64)
   @ Tullio ~/.julia/packages/Tullio/IHd6P/src/grad/zygote.jl:7
 [3] (::Tullio.var"#74#back#217"{Tullio.var"#215#216"{Tullio.Eval{var"#ℳ𝒶𝓀ℯ#8"{var"#𝒜𝒸𝓉!#7"}, Nothing}, Tuple{Vector{Float64}}, Float64}})(Δ::Float64)
   @ Tullio ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [4] Pullback
   @ ./REPL[9]:1 [inlined]
 [5] (::typeof(∂(f)))(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [6] (::Zygote.var"#41#42"{typeof(∂(f))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [7] gradient(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [8] top-level scope
   @ REPL[11]:1

julia> gradient(h, x)
([-1.0, 1.0, 1.0, 1.0, 1.0],)

Thanks,

Felix

roflmaostc avatar May 20 '21 20:05 roflmaostc

That's an interesting quirk. The rules being used are from DiffRules.jl, which ForwardDiff.jl uses to define methods, and hence evaluates the Symbols. But Tullio just sees what you write, so doesn't find the rule:

julia> (√) === sqrt
true

julia> :√ == :sqrt
false

It would be easy to add special cases of course, but I wonder if there's a good way to standardise them?

mcabbott avatar May 20 '21 21:05 mcabbott