ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Ability to identify rules that always return AbstractZero

Open oxinabox opened this issue 5 years ago • 2 comments

This related to @non_differentiable I think for operator overloading based AD, if a rule's propagator is always going to return a AbstractZero the correct thing to do quiet different. One wants to accept the overloaded type, but return a non-overloaded type

oxinabox avatar Nov 13 '20 14:11 oxinabox

I think this sounds reasonable, but I'm having trouble saying for sure without a concrete example. Any chance that you could concoct one?

willtebbutt avatar Dec 08 '20 22:12 willtebbutt

Consider size

The code that Nabla would generate right now from our @nondifferentiable size(::AbstractArray) is:

function Base.size(x1::Node{<:AbstractArray{N}}; kwargs...) where N
    (primal_val, pullback) = rrule(size, unbox(x1); kwargs...)
    tape = tape(x1)
    branch = Branch(primal_val, size, (x1,), kwargs.data, tape, length(tape) + 1, pullback)
    push!(tape, branch)
    return branch  # type is <:Node{NTuple{N, Int}}
end
@inline function preprocess(
    ::typeof(size), y::Branch, ȳ, x1::Union{Any, Node{<:Any}}
)
    return pullback(ȳ)  # this will actually just return `NO_FIELDS, DoesNotExist()`
end
@inline function ∇(
    ::typeof(size), ::Type{Arg{N}}, p, ::Any, ::Any, x1::Union{Any, Node{<:Any}};
    kwargs...
) where N
    return p[N + 1]  # skip dself (N==1) as we don't support functors
end

But what we really want to do is:

function Base.size(x1::Node{<:AbstractArray{N}}; kwargs...) where N
    return size(unbox(x1))  # type is NTuple{N, Int}
end

Possibly we want some API like:

cotangent_types(sig_type_tuple)

That defaults to returning Tuple{Any, Any, ...} (or even something smarter?) but that @nondifferentiable overloads to be Tuple{Zero, DoesNotExist, DoesNotExist} so that when generating rules we can decide to just run the primal. OTOH, we could maybe pull this information out of type inference, and use some Tricks.jl trick to do that without it being super expensive, idk.

oxinabox avatar Dec 09 '20 13:12 oxinabox