Yota.jl
Yota.jl copied to clipboard
list of operations that grad does not work with
Here is a list of some operations that did not work for me. I wonder about the errors that involve ChainRules in their message? For instance, in the sum example, I guess we are tracing too deep into the sum implementation. E.g. there exists a more high level sum rule: https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl#L9
@dfdx ~~Maybe the trace
used for gradtape
should have an is_primitive that checks if the signature is covered by an rrule?~~ I realize that Yota.is_primitive
!== Ghost.is_primitive
and there is already such a rule. I think the issue is
what should happen when one starts tracing with a call that is already primitive. Not obvious whats the best design. Currently, such a call is entered anyway, this is why e.g. sum([1.0])
fails.
################################################################################
Yota.gradtape(sum, [1.0])
fails
No deriative rule found for op %42 = mapreduce(identity, add_sum, %2)::Float64, try defining it us
ing ChainRules.rrule(::typeof(mapreduce), ::typeof(identity), ::typeof(Base.add_sum), ::Vector{Flo
at64}) = ...
################################################################################
Yota.gradtape(sum, abs2, [1.0])
fails
No deriative rule found for op %30 = mapreduce(%2, add_sum, %3)::Float64, try defining it using Ch
ainRules.rrule(::typeof(mapreduce), ::typeof(abs2), ::typeof(Base.add_sum), ::Vector{Float64}) = .
..
################################################################################
Yota.gradtape(identity, 1.0)
fails
MethodError: no method matching call_signature(::Tape{Yota.GradCtx}, ::Ghost.Input)
Closest candidates are:
call_signature(::Tape, ::Ghost.Call) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:516
################################################################################
Yota.gradtape(sin, 1.0)
fails
MethodError: Cannot `convert` an object of type Float64 to an object of type Ghost.Variable
Closest candidates are:
convert(::Type{T}, ::T) where T at essentials.jl:205
Ghost.Variable(::Any, ::Any) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:22
################################################################################
Yota.gradtape(*, 1.0)
fails
MethodError: no method matching call_signature(::Tape{Yota.GradCtx}, ::Ghost.Input)
Closest candidates are:
call_signature(::Tape, ::Ghost.Call) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:516
################################################################################
Yota.gradtape(*, 1.0, 2.0)
fails
No deriative rule found for op %4 = mul_float(%2, %3)::Float64, try defining it using ChainRules.r
rule(::Core.IntrinsicFunction, ::Float64, ::Float64) = ...
You are absolutely right - there's no way to represent a single primitive f(args...)
as a tape, at least as a tape different from the one for args -> f(args...)
. I see several options here:
- Leave it as is, letting people trace the code of primitives even if sometimes it will confuse them.
- Forbid tracing the primitives. But what if it is just what somebody wanted to do?
- Show a warning explaining that it's probably not what a user wants, but letting them do it anyway.
I lean towards the last option since it's unlikely somebody will trace primitives not from REPL, and warnings in REPL are usually fine. I will let this idea to mature for the next couple of days though.
I think that the tape args -> f(args...)
option sounds natural. One way to think about a tape is that it is a list of all primitive calls that occur. If the entry point was already primitive, then it is just this one primitve call.
I also expected that if you have function f(x); g(x) end
then trace(f,x)
and trace(g,x)
would be the same. This again would be consistent with tracing a primitive call returning the tape with just that call. What drawbacks do you see with this?
Also one could add to the list: 1b. Throw an error by default, but that error can be disabled with a keyword allowing tracing into a primitive like currently. Generally I usually favor an error that must be explicitly disabled over a warning.
I also expected that if you have function f(x); g(x) end then trace(f,x) and trace(g,x) would be the same.
trace()
already works like this:
julia> g(x) = 2x
g (generic function with 1 method)
julia> f(x) = g(x)
f (generic function with 1 method)
julia> trace(f, 1.0)[2]
Tape{Dict{Any, Any}}
inp %1::typeof(f)
inp %2::Float64
%3 = *(2, %2)::Float64
julia> trace(g, 1.0)[2]
Tape{Dict{Any, Any}}
inp %1::typeof(g)
inp %2::Float64
%3 = *(2, %2)::Float64
grad()
behaves similarly with the exception for caching.
One way to think about a tape is that it is a list of all primitive calls that occur.
The first input to a tape is usually an object being called. In case of args -> f(args...)
this object is an anonymous function which is fine. In case of a primitive it's unclear what should we put there instead.
The most straightforward way is to wrap the primitive into an anonymous function, but it will break an assumption that tape[V(1)].fn == f
which may be useful for introspection and downstream transformations. It will also break on closures/callable structs.
The same applies to skipping the first argument altogether.
Putting the primitive itself as the first input also sounds weird - it will look like a recursive function which it's not.
On the other hand, trying to trace a primitive function doesn't seem to be a big use case, raising an error or warning sounds like a reasonable solution for me, at least until we hit a real case where it's not enough.
Perhaps another entry for the list: It seems that grad currently deoes not work with LinearAlgebra.Adjoint
julia> A = rand(100, 100)
julia> x = rand(100)
julia> Yota.grad(x -> 0.5 * x' * A * x, x)
ERROR: LoadError: No deriative rule found for op %8 = *(0.5, %7, %2)::Float64, try defining it using
ChainRulesCore.rrule(::typeof(*), ::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}}, ::Vector{Float64}) = ...
Thanks! I posted the corresponding issue in JuliaDiff/ChainRules.jl#589 as other AD systems may benefit from such a rule too.