ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
Ability to identify rules that always return AbstractZero
This related to @non_differentiable
I think for operator overloading based AD,
if a rule's propagator is always going to return a AbstractZero the correct thing to do quiet different.
One wants to accept the overloaded type, but return a non-overloaded type
I think this sounds reasonable, but I'm having trouble saying for sure without a concrete example. Any chance that you could concoct one?
Consider size
The code that Nabla would generate right now from our
@nondifferentiable size(::AbstractArray)
is:
function Base.size(x1::Node{<:AbstractArray{N}}; kwargs...) where N
(primal_val, pullback) = rrule(size, unbox(x1); kwargs...)
tape = tape(x1)
branch = Branch(primal_val, size, (x1,), kwargs.data, tape, length(tape) + 1, pullback)
push!(tape, branch)
return branch # type is <:Node{NTuple{N, Int}}
end
@inline function preprocess(
::typeof(size), y::Branch, ȳ, x1::Union{Any, Node{<:Any}}
)
return pullback(ȳ) # this will actually just return `NO_FIELDS, DoesNotExist()`
end
@inline function ∇(
::typeof(size), ::Type{Arg{N}}, p, ::Any, ::Any, x1::Union{Any, Node{<:Any}};
kwargs...
) where N
return p[N + 1] # skip dself (N==1) as we don't support functors
end
But what we really want to do is:
function Base.size(x1::Node{<:AbstractArray{N}}; kwargs...) where N
return size(unbox(x1)) # type is NTuple{N, Int}
end
Possibly we want some API like:
cotangent_types(sig_type_tuple)
That defaults to returning Tuple{Any, Any, ...} (or even something smarter?)
but that @nondifferentiable overloads to be Tuple{Zero, DoesNotExist, DoesNotExist}
so that when generating rules we can decide to just run the primal.
OTOH, we could maybe pull this information out of type inference, and use some Tricks.jl trick to do that without it being super expensive, idk.