ChainRulesTestUtils.jl
ChainRulesTestUtils.jl copied to clipboard
Add function for checking modules for type piracy, ambiguities, etc in defined rules
Given an example like
struct Foo end
(f::Foo)(x) = x^2
function ChainRulesCore.rrule(::typeof(Foo), x)
Foo_pullback(Δy) = (NO_FIELDS, x' * Δy + Δy * x')
return x^2, Foo_pullback
end
Because typeof(Foo)
is DataType
, this common mistake is highly piratical and will subtilely and completely break AD packages like Zygote for anyone who has this rule in their namespace (it should be ::Foo
in the signature, not typeof(Foo)
).
On Slack, @oxinabox suggested:
We should think about tools to make this more obvious. Maybe something that lists the true things that have been targetted by all rules defined in a package?
I suggested:
Something like
ChainRulesTestUtils.check_all_rules(MyModule)
which as much as possible checks for things like type piracy, ambiguities, etc?
Idk about checking type piracy more broadly (though I have written code for that before and it's on Discourse)
But in particular checking for rules being added to DataType
, Union
and UnionAll
would be good to catch rrule(::typeof(MyType), args...
, when rrule(::Type{<: MyType, args...)
was intended.
I think that would be a very solid first start