Yota.jl icon indicating copy to clipboard operation
Yota.jl copied to clipboard

No deriative rule found for struct constructor

Open mcabbott opened this issue 2 years ago • 7 comments

I'm surprised by this error, which if I understand right comes from this line https://github.com/FluxML/Optimisers.jl/blob/master/src/destructure.jl#L31 constructing a struct which isn't in fact used. Is this the desired behaviour, or can all (default?) constructors be handled automatically somehow?

julia> using Yota, Optimisers, ChainRulesCore

julia> function gradient(f, xs...)
         println("Yota gradient!")
         _, g = Yota.grad(f, xs...)
         g[2:end]
       end;

julia> m1 = collect(1:3.0);

julia> gradient(m -> destructure(m)[1][1], m1)[1]
Yota gradient!
ERROR: No deriative rule found for op %30 = Optimisers.Restructure(%2, %19, %28)::Optimisers.Restructure{Vector{Float64}, Int64}, try defining it using 

	ChainRulesCore.rrule(::UnionAll, ::Vector{Float64}, ::Int64, ::Int64) = ...

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
    @ Yota ~/.julia/dev/Yota/src/grad.jl:170
  [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
    @ Yota ~/.julia/dev/Yota/src/grad.jl:211

julia> function ChainRulesCore.rrule(T::Type{<:Optimisers.Restructure}, v, i, j)
         back(tan) = (NoTangent(), tan.model, tan.offsets, tan.length)
         back(z::AbstractZero) = (z,z,z,z)
         T(v,i,j), back
       end

julia> gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
Yota gradient!
true

Xref https://github.com/FluxML/Optimisers.jl/issues/96

mcabbott avatar Jul 07 '22 13:07 mcabbott

It's a bug. Yota thinks Optimisers.Restucture(...) is a primitive and records it to the tape instead of tracing it down to (:new, T, args...) (represented as __new__(T, args...) on the tape). Yota thinks it's a primitive because typeof(Restructure) returns UnionAll, which belongs to module Core, and we don't trace deeper than that. I have an idea how to fix it, will try to implement it tonight.

dfdx avatar Jul 07 '22 14:07 dfdx

Just to note, I was also seeing this in the latest release. But when I cloned the repository in order to add some @show statements for debugging, the problem went away (well, actually it was replaced by a different bug). @mcabbott are you seeing the same behavior in master?

cscherrer avatar Jul 07 '22 15:07 cscherrer

This seems to work on main:

using Optimisers, ChainRulesCore


function gradient(f, xs...)
    println("Yota gradient!")
    _, g = Yota.grad(f, xs...)
    g[2:end]
end;


# a quick hack, not really tested
function ChainRulesCore.rrule(::typeof(convert), ::DataType, x)
    # a more robust implementation would be to do backword conversion:
    #   return x, Δ -> (NoTangent(), NoTangent(), convert(typeof(x), Δ))
    # but it doesn't work for ZeroTangent(), so passing Δ as is
    return x, Δ -> (NoTangent(), NoTangent(), Δ)
end


m1 = collect(1:3.0);
gradient(m -> destructure(m)[1][1], m1)[1]

dfdx avatar Jul 09 '22 21:07 dfdx

I get the same error on latest Yota + Umlat.

The rule for convert silences the error, but doesn't actually make the struct requested. It isn't used above, but if I change the code to something which does use that part, it fails:

julia> gradient((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0])
Yota gradient!
ERROR: MethodError: no method matching _rebuild(::Vector{Float64}, ::Int64, ::ZeroTangent, ::Int64; walk::typeof(Optimisers._Tangent_biwalk), prune::NoTangent)

Closest candidates are:
  _rebuild(::Any, ::Any, ::AbstractVector, ::Any; walk, kw...)
   @ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:82
  _rebuild(::Any, ::Any, ::AbstractVector) got unsupported keyword arguments "walk", "prune"
   @ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:82

Stacktrace:
  [1] (::Optimisers.var"#_flatten_back#18"{Vector{Float64}, Int64, Int64})(::Tangent{Tuple{Vector{Float64}, Int64, Int64}, Tuple{ZeroTangent, NoTangent, NoTangent}})
    @ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:77
  [2] mkcall(fn::Umlaut.Variable, args::Umlaut.Variable; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
  [3] mkcall
    @ ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179 [inlined]
  [4] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
    @ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:164

cf Zygote:

julia> gradient((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0])
(nothing, [1.0, 0.0, 0.0])

mcabbott avatar Jul 10 '22 17:07 mcabbott

It looks like a different error, tracing constructors of UnionAll, which caused the previous error, works correctly this time:

julia> trace((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0]; ctx=GradCtx())
(1.0, Tape{GradCtx}
  inp %1::var"#212#213"
  inp %2::Vector{Float64}
  inp %3::Vector{Float64}
  %5, %6 = [%4] = rrule(YotaRuleConfig(), _flatten, %2)
  %8, %9 = [%7] = rrule(YotaRuleConfig(), indexed_iterate, %5, 1)
  %11, %12 = [%10] = rrule(YotaRuleConfig(), getfield, %8, 1)
  %14, %15 = [%13] = rrule(YotaRuleConfig(), getfield, %8, 2)
  %17, %18 = [%16] = rrule(YotaRuleConfig(), indexed_iterate, %5, 2, %14)
  %20, %21 = [%19] = rrule(YotaRuleConfig(), getfield, %17, 1)
  %23, %24 = [%22] = rrule(YotaRuleConfig(), getfield, %17, 2)
  %26, %27 = [%25] = rrule(YotaRuleConfig(), indexed_iterate, %5, 3, %23)
  %29, %30 = [%28] = rrule(YotaRuleConfig(), getfield, %26, 1)
  %32, %33 = [%31] = rrule(YotaRuleConfig(), apply_type, Optimisers.Restructure, Vector{Float64}, Int64)
  %35, %36 = [%34] = rrule(YotaRuleConfig(), apply_type, Optimisers.Restructure, Vector{Float64}, Int64)
  %38, %39 = [%37] = rrule(YotaRuleConfig(), convert, Vector{Float64}, %2)
  %41, %42 = [%40] = rrule(YotaRuleConfig(), convert, Int64, %20)
  %44, %45 = [%43] = rrule(YotaRuleConfig(), fieldtype, %35, 3)
  %47, %48 = [%46] = rrule(YotaRuleConfig(), convert, %44, %29)
  %50, %51 = [%49] = rrule(YotaRuleConfig(), __new__, %35, %38, %41, %47)    # <-- this is internal constuctor of Restructure
  %53, %54 = [%52] = rrule(YotaRuleConfig(), tuple, %11, %50)
  %56, %57 = [%55] = rrule(YotaRuleConfig(), getindex, %53, 2)
  %59, %60 = [%58] = rrule(YotaRuleConfig(), getproperty, %56, model)
  %62, %63 = [%61] = rrule(YotaRuleConfig(), getproperty, %56, offsets)
  %65, %66 = [%64] = rrule(YotaRuleConfig(), getproperty, %56, length)
  %68, %69 = [%67] = rrule(YotaRuleConfig(), _rebuild, %59, %62, %3, %65)
  %71, %72 = [%70] = rrule(YotaRuleConfig(), getindex, %68, 1)
)

_rebuild() doesn't accept ZeroTangent() as a cotangent value. A real question is whether ZeroTangent() is correct here, and if so, why Zygote doesn't hit the same problem. I will need to understand more about Optimisers internals and the generated graph to answer these questions.

dfdx avatar Jul 10 '22 21:07 dfdx

It's not impossible there are bugs in _rebuild, sorry, it's pretty messy. Will take a look, at some point.

Trying to find a simpler example of what I thought was the original problem, with a struct from here:

julia> using Yota, ChainRulesCore

julia> struct Multiplier{T}  # from test_helpers in ChainRules
           x::T
       end

julia> (m::Multiplier)(y) = m.x * y

julia> function ChainRulesCore.rrule(m::Multiplier, y)
           Multiplier_pullback(dΩ) = (Tangent{typeof(m)}(; x = dΩ * y'), m.x' * dΩ)
           return m(y), Multiplier_pullback
       end

julia> grad(x -> x(3.0), Multiplier(5.0))  # perfect
(15.0, (ZeroTangent(), Tangent{Multiplier{Float64}}(x = 3.0,)))

julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: No deriative rule found for op %3 = Multiplier(%2)::Multiplier{Float64}, try defining it using 

	ChainRulesCore.rrule(::UnionAll, ::Float64) = ...

Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
   @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:170
 [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
   @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:211
...

julia> Yota.trace(x -> Multiplier(x)(3.0), 5.0; ctx=Yota.GradCtx())
(15.0, Tape{Yota.GradCtx}
  inp %1::var"#8#9"
  inp %2::Float64
  %3 = Multiplier(%2)::Multiplier{Float64}
  %5, %6 = [%4] = rrule(Yota.YotaRuleConfig(), %3, 3.0)
)

(jl_lUY4C1) pkg> st Yota
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_lUY4C1/Project.toml`
  [cd998857] Yota v0.7.3

That's the tagged version. On master:

(jl_lpE53K) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_lpE53K/Project.toml`
  [92992a2b] Umlaut v0.2.5 `https://github.com/dfdx/Umlaut.jl.git#main`
  [cd998857] Yota v0.7.4 `https://github.com/dfdx/Yota.jl.git#main`

julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: No deriative rule found for op %9 = convert(%7, %2)::Float64, try defining it using 

	ChainRulesCore.rrule(::typeof(convert), ::DataType, ::Float64) = ...

Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
   @ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:170
 [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
   @ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:211
 [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
   @ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:222
...

julia> Yota.trace(x -> Multiplier(x)(3.0), 5.0; ctx=Yota.GradCtx())
(15.0, Tape{Yota.GradCtx}
  inp %1::var"#14#15"
  inp %2::Float64
  %4, %5 = [%3] = rrule(Yota.YotaRuleConfig(), apply_type, Multiplier, Float64)
  %7, %8 = [%6] = rrule(Yota.YotaRuleConfig(), fieldtype, %4, 1)
  %9 = convert(%7, %2)::Float64
  %11, %12 = [%10] = rrule(Yota.YotaRuleConfig(), __new__, %4, %9)
  %14, %15 = [%13] = rrule(Yota.YotaRuleConfig(), %11, 3.0)
)

julia> function ChainRulesCore.rrule(::typeof(convert), ::DataType, x)
           # Version with re-conversion via ProjectTo? Maybe this is only right for number types...
           # Also, why not convert on the forward pass?
           return x, Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(Δ))
       end

julia> grad(x -> Multiplier(x)(3.0), 5.0)
(15.0, (ZeroTangent(), 3.0))

That's why your rule targets convert, since apply_type is now being applied. But why is convert being called at all?

mcabbott avatar Jul 12 '22 17:07 mcabbott

Sorry for the silence - I've been working on some bug fixes and improvements that may affect this question too. In particular, I added lineinfo to call nodes, and here's what it shows:

struct Multiplier{T}  # from test_helpers in ChainRules
    x::T
end

(m::Multiplier)(y) = m.x * y


function ChainRulesCore.rrule(m::Multiplier, y)
    Multiplier_pullback(dΩ) = (Tangent{typeof(m)}(; x = dΩ * y'), m.x' * dΩ)
    return m(y), Multiplier_pullback
end

mult1(x) = x(3.0)
mult2(x) = Multiplier(x)(3.0)

_, tape = trace(mult2, 5.0; ctx=GradCtx())

Result:

(15.0, Tape{GradCtx}
  inp %1::typeof(mult2)
  inp %2::Float64
  %4, %5 = [%3] = rrule(YotaRuleConfig(), apply_type, Multiplier, Float64)              # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
  %7, %8 = [%6] = rrule(YotaRuleConfig(), apply_type, Multiplier, Float64)              # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
  %9 = convert(Float64, %2)::Float64            # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
  %11, %12 = [%10] = rrule(YotaRuleConfig(), __new__, %7, %9)           # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
  %14, %15 = [%13] = rrule(YotaRuleConfig(), %11, 3.0)          # Main.mult2 at /home/azbs/work/Yota/src/_main3.jl:37
)

So convert(Float64, %2) happens in the object constructor, even though %2 is already Float64. I tried to trick the compiler not to add convert(), but it seems to be just an essential detail of the lowered code. Nevertheless, we can now simplify the rrule for convert to a more strict version:

function ChainRulesCore.rrule(::typeof(convert), ::Type{T}, x::T) where T
    return x, Δ -> (NoTangent(), NoTangent(), Δ)
end

dfdx avatar Aug 08 '22 12:08 dfdx

Both the top example and the Multiplier one work on Yota 0.8 and Julia 1.8, which is great.

On Julia nightly, something seems to go wrong, perhaps of interest (and might be why I saw errors in https://github.com/FluxML/Optimisers.jl/pull/105):

julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: Unexpected expression: $(Expr(:static_parameter, 1))
Full IRCode:

 2 1 ─ %4 = $(Expr(:static_parameter, 1))::Core.Const(Float64)
  │   %1 = Core.apply_type(Main.Multiplier, %4)::Core.Const(Multiplier{Float64})
  │   %2 = (%1)(_2)::Multiplier{Float64}
  └──      return %2
  
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] trace_block!(t::Umlaut.Tracer{Yota.GradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector)
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:333
  [3] trace!(t::Umlaut.Tracer{Yota.GradCtx}, v_fargs::Tuple{UnionAll, Umlaut.Variable})
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:439
  [4] trace_call!(::Umlaut.Tracer{Yota.GradCtx}, ::Type, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:290
  [5] trace_block!(t::Umlaut.Tracer{Yota.GradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector)
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:315
  [6] trace!(t::Umlaut.Tracer{Yota.GradCtx}, v_fargs::Vector{Umlaut.Variable})
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:439
  [7] trace(f::Function, args::Float64; ctx::Yota.GradCtx, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:556
  [8] gradtape(f::Function, args::Float64; ctx::Yota.GradCtx, seed::Int64)
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:291

mcabbott avatar Aug 27 '22 04:08 mcabbott

Apparently, Julia 1.9 changes the way static parameters (i.e. type parameters, {T}) are used in IRCode. You can update Umlaut to 0.4.5 to account for this.

(No changes to the version of Yota itself are needed at the moment)

dfdx avatar Aug 27 '22 13:08 dfdx

Great, thanks. Then I mark this as closed.

mcabbott avatar Aug 27 '22 15:08 mcabbott