ChainRulesTestUtils.jl
ChainRulesTestUtils.jl copied to clipboard
Tests are slow: use vjvp ?
Compare Zygote's gradtest
julia> @btime gradtest(x -> sum(abs2, x), randn(4, 3, 2))
17.977 μs (236 allocations: 5.86 KiB)
and CRTUs test_rrule
@btime test_rrule(Zygote.ZygoteRuleConfig(), x -> sum(abs2, x), randn(4, 3, 2); rrule_f=rrule_via_ad)
1.490 ms (6884 allocations: 480.10 KiB)
it's nearly 100x slower. Do we understand why? I didn't have time to look into it so just making an issue.
ChainRules tests take a pretty long time to run so this might be worth improving.
A big part of it is that FiniteDifferences.jl is slower, but more accurate what Zygote's gradtest does.
gradtest is equivelent to central_fdm(3, 1; adapt=0); where as we use central_fdm(5, 1; adapt=1)
so we query 2 more additional points near the point, and we have 1 step of adaption to determine the optimal step-size.
We could try tinkering with that, getting it faster that way might be possible, and might not break too many tests, though we might need to also relax some of the atol/rtol.
Another thing to consider is that we're currently generating entire vjps, rather than vjvps (vector-Jacobian-vector products), which would be sufficient.
The thing that we really need to test for reverse-mode is that
< J' ȳ, ẋ> ≈ <ȳ, J ẋ>
rrules are good at computing < J' ȳ, ẋ>, while finite differencing is good at computing <ȳ, J ẋ>. We're currently using finite differencing to approximate < J' ȳ, ẋ>, which means we make O(length(ẋ)) more call to FiniteDifferences than we ought to.
The only extra thing we should need to make this work is the ability to compute inner products between the output of rrules, which probably isn't much more difficult than comparing between the output of FiniteDifferences and rrule anyway.
I've been doing this for a while in TemporalGPs.jl and it seems to work really well. I wrote loads of code to hack around ChainRulesTestUtils not being up to what I needed prior to ADIA -- didn't want to contribute it back at the time because I wasn't completely sure whether this was the right way to go about things, but I'm now convinced that it is.
Yeah, and dot i.e. inner product is something all tangent types should overload.
A problem maybe is if it fails that won't tell you where you failed, will it?
Yeah, and dot i.e. inner product is something all tangent types should overload.
Indeed.
A problem maybe is if it fails that won't tell you where you failed, will it?
Yeah, this is a problem. It does make life a bit trickier when it comes to debugging. I've typically found that you want to retain the ability to (slowly) compute the vjp for debugging puposes. Fortunately, you don't really need to have things like the ability to check for equality lying around to do this -- you'll always be doing it by eye.
One of the real benefits of doing things this way is that you can test AD at scale, For example, whereas with the current way of doing things really requires small problem sizes, the inner product approach can handle any problem size in which you're happy to make a small handful of function evaluations. The advantage is less that it's better to test on big problems, and more that its convenient to be able to test any old problem you have lying around regardless its size.
Had a quick stab at doing this (I'll not be pushing this further myself in the immediate future, just wanted to see what it might look like) https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/208
On the surface of it, it doesn't look like we have any substantial practical impediments to doing this, but I've not dug into the details.