Yota.jl icon indicating copy to clipboard operation
Yota.jl copied to clipboard

Basic example doesn't work with Pluto notebooks

Open VarLad opened this issue 2 years ago • 1 comments

image

Whereas it works normally in the REPL

VarLad avatar Jun 27 '22 22:06 VarLad

Thanks for reporting this! It turns out, Pluto adds some magic to the call so it's traced differently from a normal REPL. In REPL:

julia> Yota.Umlaut.trace(x -> sum(x .+ 1), [1.0, 2.0, 3.0])
(9.0, Tape{BaseCtx}
  inp %1::var"#101#102"
  inp %2::Vector{Float64}
  %3 = broadcasted(+, %2, 1)::Broadcasted{}
  %4 = materialize(%3)::Vector{Float64}
  %5 = sum(%4)::Float64
)

In Pluto:

(9.0,
Tape{Umlaut.BaseCtx}
  inp %1::var"#3#4"{typeof(+), typeof(sum)}
  inp %2::Vector{Float64}
  %3 = getfield(%1, :sum)::typeof(sum)                # <---
  %4 = getfield(%1, :+)::typeof(+)                         # <---
  %5 = broadcasted(%4, %2, 1)::Broadcasted{}
  %6 = materialize(%5)::Vector{Float64}
  %7 = %3(%6)::Float64)

One workaround is to define the missing rrule:

using ChainRulesCore

ChainRulesCore.rrule(getfield, x, f::Symbol) = (getfield(x, f), dy -> (NoTangent(), NoTangent(), NoTangent()))

After this the example gives the correct result:

grad(x -> sum(x .+ 1), [1.0, 2.0, 3.0])
# ==> (9.0, (NoTangent(), [1.0, 1.0, 1.0]))

dfdx avatar Jun 29 '22 21:06 dfdx