Yota.jl
Yota.jl copied to clipboard
Basic example doesn't work with Pluto notebooks
Whereas it works normally in the REPL
Thanks for reporting this! It turns out, Pluto adds some magic to the call so it's traced differently from a normal REPL. In REPL:
julia> Yota.Umlaut.trace(x -> sum(x .+ 1), [1.0, 2.0, 3.0])
(9.0, Tape{BaseCtx}
inp %1::var"#101#102"
inp %2::Vector{Float64}
%3 = broadcasted(+, %2, 1)::Broadcasted{}
%4 = materialize(%3)::Vector{Float64}
%5 = sum(%4)::Float64
)
In Pluto:
(9.0,
Tape{Umlaut.BaseCtx}
inp %1::var"#3#4"{typeof(+), typeof(sum)}
inp %2::Vector{Float64}
%3 = getfield(%1, :sum)::typeof(sum) # <---
%4 = getfield(%1, :+)::typeof(+) # <---
%5 = broadcasted(%4, %2, 1)::Broadcasted{}
%6 = materialize(%5)::Vector{Float64}
%7 = %3(%6)::Float64)
One workaround is to define the missing rrule
:
using ChainRulesCore
ChainRulesCore.rrule(getfield, x, f::Symbol) = (getfield(x, f), dy -> (NoTangent(), NoTangent(), NoTangent()))
After this the example gives the correct result:
grad(x -> sum(x .+ 1), [1.0, 2.0, 3.0])
# ==> (9.0, (NoTangent(), [1.0, 1.0, 1.0]))