ChainRulesTestUtils.jl
ChainRulesTestUtils.jl copied to clipboard
Test Linearity
Pushforwards and pullbacks are requred to be linear operators.
Which formally is defined as φ(𝛼x) = 𝛼φ(x) and φ(a+b) = φ(a) + φ(b) for a,b,x differentials/vector space elements, and 𝛼 being a scalar.
A important collory of that is φ(0) = 0 where two 0s are the additive identity for the relivent vector spaces.
Which we can test with iszero(φ(zero(d)), since we always have access to a differential d, and all differentials must define zero and iszero.
Similarly we can test φ(𝛼x) = 𝛼φ(x) via φ(2d) == 2φ(d) since all differnetial define multiplication with a scalar. and we have one provided d.
We can't easily test φ(a+b) = φ(a) + φ(b) though, as we only have one differential.
We can test φ(d+d) = φ(d) + φ(d) but that is less interesting.
If we are generating them using rand_tangent we can easily generate another.
Though perhaps we should be insisting that things define the addition of there natural differentials with the structural differentials that rand_tangent will most often provide (though sometimes rand_tangent's structural differential will have extra nondifferentiable fields so we can't trust it, unless we insist the user does actually overload it).
Or we could avoid direct additon and do (primal + rand_tangent(d)) + d) though this is getting off-topic..
Possibly just skip that in initial implementation of this
The linear pullback can be extracted from the rrule via (f,primal...) -> (ds...) -> last(rrule(f, primal...))(ds...) (or equiv: (f,primal) -> last(rrule(f, primal))) The linear pushforward can be extracted from frulevia(f,primal) -> (ds...) -> frule(ds..., f, primal...)`
We can't easily test φ(a+b) = φ(a) + φ(b) though, as we only have one differential.
Could we just introduce two scalars p and q s.t. p + q = 1, and check that
φ(d) = φ(p * d) + φ(q * d)
I think that would cover the cases we care about.
Fair, it isn't comprehensive but yeah i think:
6φ(d) = φ(2d) + φ(4d)
would catch a lot.
Just like iszero(φ(zero(d))