Tullio.jl
Tullio.jl copied to clipboard
Use nograd keyword for functions too
julia> using Tullio, Tracker, ForwardDiff
julia> plus10(x) = x + 10; # a function unknown to DiffRules
julia> Tracker.gradient(rand(3)) do x
@tullio res := x[k] * plus10(k) verbose=true
end
┌ Warning: symbolic gradient failed
│ err = "no diffrule found for function plus10(_)."
└ @ Tullio ~/.julia/dev/Tullio/src/macro.jl:1260
ERROR: no gradient definition here!
julia> Tracker.gradient(rand(3)) do x
@tullio res := x[k] * plus10(k) grad=Dual # old solution
end
([11.0, 12.0, 13.0] (tracked),)
julia> Tracker.gradient(rand(3)) do x
@tullio res := x[k] * plus10(k) nograd=plus10 # new solution
end
([11.0, 12.0, 13.0] (tracked),)
Needs a test, but #57 is re-organising them.