ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
Possible to write rules for methods not collections of methods?
I'm not advocating for anything here. I'm just stating some facts, and wish to ascertain whether a particular design choice is technically feasible or not.
AD Effectively Operates on Individual Methods
First, recall that AD operates on the level of methods -- (in the absence of a generically-typed rule) AD does not know anything about the semantics of a function, it just sees a collection of bits of code.
For example, running Zygote on
my_sum(x::AbstractMatrix) = sum(x)
will produce something equivalent to
function Zygote._pullback(ctx::AContext, ::typeof(my_sum), x::AbstractMatrix)
y, sum_pullback = Zygote._pullback(ctx, sum, x)
function my_sum_pullback(dy)
_, dx = sum_pullback(dy)
return nothing, dx
end
return y, my_sum_pullback
end
If I now add another method
my_sum(x::Diagonal) = sum(diag(x))
Zygote will automatically specialise and produce something like
function Zygote._pullback(ctx::AContext, ::typeof(my_sum), x::Diagonal)
tmp, diag_pullback = Zygote._pullback(ctx, diag, x)
y, sum_pullback = Zygote._pullback(ctx, sum, tmp)
function my_sum_pullback(dy)
_, dtmp = sum_pullback(dy)
_, dx = diag_pullback(dtmp)
return nothing, dx
end
end
While Zygote uses Julia's multiple dispatch system to achieve this behaviour via a single loosely-typed generated function, it produces different outputs depending upon the method of a function hit by the types of the arguments, rather than simply the type of the arguments. It's able to do this because generated functions have access to the IR associated with a particular method.
ChainRules Operates More Generically
This is well understood, but worth pointing out. As implemented in all existing AD systems which support them, our rules apply to all methods of a function to which the rrule applies. So in the my_sum examples above, if I were to define
function ChainRulesCore.rrule(::typeof(my_sum), x::AbstractMatrix)
# some code
end
it will apply to both methods, blocking codegen for the more specialised method.
Would it be possible to make rules apply to methods also?
To take Zygote as a concrete example, would it be technically feasible to make Zygote treat rules as being equivalent to its own codegen-ed code, so that if one defines the rrule above, it is only hit when the my_sum(::AbstractMatrix) method is hit, but leaves codegen to proceed as per usual for my_sum(::Diagonal) method?
Specifically
# hits rrule because my_sum(::AbstractMatrix) is most specialised method applicable
# to Matrix{Float64}.
Zygote.pullback(my_sum, randn(5, 5))
# does not hit rrule because my_sum(::Diagonal) applies to Diagonal{Float64, Vector{Float64}}.
Zygote.pullback(my_sum, Diagonal(randn(5)))
I do not believe it is possible, in the language that is Julia v1.7
I do not believe it is possible, in the language that is Julia v1.7
Is this to say that it would have been possible in 1.6, but is no longer in 1.7, but that it was never possible?
Also, why do you believe it not to be possible? Is it that you're not sure how it can be done, or do you have a particular reason to think that it cannot?
Related: https://github.com/JuliaDiff/ChainRules.jl/issues/237
Is this to say that it would have been possible in 1.6, but is no longer in 1.7, but that it was never possible?
I mean to say that it has never been possible in julia version less than or equal to 1.7. But that I can't say how the language might change in 1.8 yet.
Also, why do you believe it not to be possible? Is it that you're not sure how it can be done, or do you have a particular reason to think that it cannot?
It is that I don't see how it can be done. Thus the "I do not believe".
Possibily something can be done in Zygote, by doing the source code tranform at a different part of the compilation pipeline to where it is done now. Right now it is too early: lowered IR is before types are known, so can't work out what method is being hit anywhere in the first place. and I suspect typed IR is too late: by that stage it isn't working with method, but with MethodInstances (specialized on concrete types)
I mean to say that it has never been possible in julia version less than or equal to 1.7. But that I can't say how the language might change in 1.8 yet.
Cool.
It is that I don't see how it can be done. Thus the "I do not believe".
Also cool.
Right now it is too early: lowered IR is before types are known, so can't work out what method is being hit anywhere in the first place.
Agreed regarding the IR, but Methods have the types in their signature, do they not?
I've just had a play around, and I think something like this might do it. Not sure if I'm really allowed to use which inside a generated function though... possibly I need some backedges.
using ChainRules
# A specialised method. Without this, a rule isn't hit. With this, a rule is hit.
# ChainRules.@non_differentiable sin(::Float64)
@generated function has_a_rule(f, args...)
T = Tuple{f, args...}
# Find the primal method which would be hit by the types provided.
primal_method = which(T)
# Find the rrule that it would hit by the types provided.
rrule_method = which(ChainRules.rrule, T)
# Obtain the signature of the rrule method without reference to the rrule function itself.
rrule_sig = Tuple{rrule_method.sig.parameters[2:end]...}
# Find the method of the original function those arguments would hit.
rrule_method = try
which(rrule_sig)
catch
nothing
end
# Check to see if they're the same method.
use_chain_rule = rrule_method !== nothing && primal_method === rrule_method
return use_chain_rule ? :true : :false
end
has_a_rule(sin, 5.0)
edit: per the example above:
using LinearAlgebra
my_sum(x::AbstractMatrix) = sum(x)
my_sum(x::Diagonal) = sum(diag(x))
ChainRules.@non_differentiable my_sum(::AbstractMatrix)
has_a_rule(my_sum, randn(5, 5)) # returns true
has_a_rule(my_sum, Diagonal(randn(5))) # returns false
Not sure if I'm really allowed to use which inside a generated function though... possibly I need some backedges.
Pretty sure you are not. Zygote basically does, and does put in the backedges. But note how very unreliable it is at picking up new and updated methods after calls.
Not ideal, but at least there's a precedent 😂
In some sense what we have right now is kinda like this but the complement.
If one sticks to a policy of any time you implement a method you either implement a rrule or or @opt_out of an rrule with the same signature,
then you basically get this.
Note: @simeonschaub is essentially trying to do this in https://github.com/FluxML/Zygote.jl/pull/909