ChainRulesTestUtils.jl icon indicating copy to clipboard operation
ChainRulesTestUtils.jl copied to clipboard

Add function for checking modules for type piracy, ambiguities, etc in defined rules

Open sethaxen opened this issue 3 years ago • 1 comments

Given an example like

struct Foo end

(f::Foo)(x) = x^2

function ChainRulesCore.rrule(::typeof(Foo), x)
    Foo_pullback(Δy) = (NO_FIELDS, x' * Δy + Δy * x')
    return x^2, Foo_pullback
end

Because typeof(Foo) is DataType, this common mistake is highly piratical and will subtilely and completely break AD packages like Zygote for anyone who has this rule in their namespace (it should be ::Foo in the signature, not typeof(Foo)).

On Slack, @oxinabox suggested:

We should think about tools to make this more obvious. Maybe something that lists the true things that have been targetted by all rules defined in a package?

I suggested:

Something like ChainRulesTestUtils.check_all_rules(MyModule) which as much as possible checks for things like type piracy, ambiguities, etc?

sethaxen avatar Mar 23 '21 05:03 sethaxen

Idk about checking type piracy more broadly (though I have written code for that before and it's on Discourse)

But in particular checking for rules being added to DataType, Union and UnionAll would be good to catch rrule(::typeof(MyType), args..., when rrule(::Type{<: MyType, args...) was intended.

I think that would be a very solid first start

oxinabox avatar Mar 23 '21 09:03 oxinabox