Yota.jl icon indicating copy to clipboard operation
Yota.jl copied to clipboard

list of operations that grad does not work with

Open jw3126 opened this issue 3 years ago • 5 comments

Here is a list of some operations that did not work for me. I wonder about the errors that involve ChainRules in their message? For instance, in the sum example, I guess we are tracing too deep into the sum implementation. E.g. there exists a more high level sum rule: https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl#L9

@dfdx ~~Maybe the trace used for gradtape should have an is_primitive that checks if the signature is covered by an rrule?~~ I realize that Yota.is_primitive !== Ghost.is_primitive and there is already such a rule. I think the issue is what should happen when one starts tracing with a call that is already primitive. Not obvious whats the best design. Currently, such a call is entered anyway, this is why e.g. sum([1.0]) fails.

################################################################################
Yota.gradtape(sum, [1.0])
fails
No deriative rule found for op %42 = mapreduce(identity, add_sum, %2)::Float64, try defining it us
ing ChainRules.rrule(::typeof(mapreduce), ::typeof(identity), ::typeof(Base.add_sum), ::Vector{Flo
at64}) = ...
################################################################################
Yota.gradtape(sum, abs2, [1.0])
fails
No deriative rule found for op %30 = mapreduce(%2, add_sum, %3)::Float64, try defining it using Ch
ainRules.rrule(::typeof(mapreduce), ::typeof(abs2), ::typeof(Base.add_sum), ::Vector{Float64}) = .
..
################################################################################
Yota.gradtape(identity, 1.0)
fails
MethodError: no method matching call_signature(::Tape{Yota.GradCtx}, ::Ghost.Input)
Closest candidates are:
  call_signature(::Tape, ::Ghost.Call) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:516
################################################################################
Yota.gradtape(sin, 1.0)
fails
MethodError: Cannot `convert` an object of type Float64 to an object of type Ghost.Variable
Closest candidates are:
  convert(::Type{T}, ::T) where T at essentials.jl:205
  Ghost.Variable(::Any, ::Any) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:22
################################################################################
Yota.gradtape(*, 1.0)
fails
MethodError: no method matching call_signature(::Tape{Yota.GradCtx}, ::Ghost.Input)
Closest candidates are:
  call_signature(::Tape, ::Ghost.Call) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:516
################################################################################
Yota.gradtape(*, 1.0, 2.0)
fails
No deriative rule found for op %4 = mul_float(%2, %3)::Float64, try defining it using ChainRules.r
rule(::Core.IntrinsicFunction, ::Float64, ::Float64) = ...

jw3126 avatar Jul 01 '21 07:07 jw3126

You are absolutely right - there's no way to represent a single primitive f(args...) as a tape, at least as a tape different from the one for args -> f(args...). I see several options here:

  1. Leave it as is, letting people trace the code of primitives even if sometimes it will confuse them.
  2. Forbid tracing the primitives. But what if it is just what somebody wanted to do?
  3. Show a warning explaining that it's probably not what a user wants, but letting them do it anyway.

I lean towards the last option since it's unlikely somebody will trace primitives not from REPL, and warnings in REPL are usually fine. I will let this idea to mature for the next couple of days though.

dfdx avatar Jul 03 '21 20:07 dfdx

I think that the tape args -> f(args...) option sounds natural. One way to think about a tape is that it is a list of all primitive calls that occur. If the entry point was already primitive, then it is just this one primitve call. I also expected that if you have function f(x); g(x) end then trace(f,x) and trace(g,x) would be the same. This again would be consistent with tracing a primitive call returning the tape with just that call. What drawbacks do you see with this?

Also one could add to the list: 1b. Throw an error by default, but that error can be disabled with a keyword allowing tracing into a primitive like currently. Generally I usually favor an error that must be explicitly disabled over a warning.

jw3126 avatar Jul 03 '21 20:07 jw3126

I also expected that if you have function f(x); g(x) end then trace(f,x) and trace(g,x) would be the same.

trace() already works like this:

julia>     g(x) = 2x
g (generic function with 1 method)

julia>     f(x) = g(x)
f (generic function with 1 method)

julia>     trace(f, 1.0)[2]
Tape{Dict{Any, Any}}
  inp %1::typeof(f)
  inp %2::Float64
  %3 = *(2, %2)::Float64


julia>     trace(g, 1.0)[2]
Tape{Dict{Any, Any}}
  inp %1::typeof(g)
  inp %2::Float64
  %3 = *(2, %2)::Float64

grad() behaves similarly with the exception for caching.

One way to think about a tape is that it is a list of all primitive calls that occur.

The first input to a tape is usually an object being called. In case of args -> f(args...) this object is an anonymous function which is fine. In case of a primitive it's unclear what should we put there instead.

The most straightforward way is to wrap the primitive into an anonymous function, but it will break an assumption that tape[V(1)].fn == f which may be useful for introspection and downstream transformations. It will also break on closures/callable structs.

The same applies to skipping the first argument altogether.

Putting the primitive itself as the first input also sounds weird - it will look like a recursive function which it's not.

On the other hand, trying to trace a primitive function doesn't seem to be a big use case, raising an error or warning sounds like a reasonable solution for me, at least until we hit a real case where it's not enough.

dfdx avatar Jul 04 '21 14:07 dfdx

Perhaps another entry for the list: It seems that grad currently deoes not work with LinearAlgebra.Adjoint

julia> A = rand(100, 100)
julia> x = rand(100)
julia> Yota.grad(x -> 0.5 * x' * A * x, x)
ERROR: LoadError: No deriative rule found for op %8 = *(0.5, %7, %2)::Float64, try defining it using 

	ChainRulesCore.rrule(::typeof(*), ::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}}, ::Vector{Float64}) = ...

lassepe avatar Feb 15 '22 10:02 lassepe

Thanks! I posted the corresponding issue in JuliaDiff/ChainRules.jl#589 as other AD systems may benefit from such a rule too.

dfdx avatar Feb 15 '22 22:02 dfdx