Turing.jl
Turing.jl copied to clipboard
Refactoring AD Tests
Currently, there are many situation in which we run an entire sampler in order to test AD, such as many cases in https://github.com/TuringLang/Turing.jl/blob/master/test/mcmc/abstractmcmc.jl .
This is quite compute intensive. @yebai pointed out that perhaps we should just check that the gradient runs and is correct, rather than thousants of mcmc chains.
We should probably implement some functionality to benchmark various AD backends during this refactoring.
Thinking about this and #2338 a little, it's not super easy to dive into the internals of sample() to check that the gradient is correct, but what I'm thinking we could do is:
- seed (with StableRNGs) and sample for something like 10 iterations per AD backend and check that the results are numerically similar
- once we've established that all AD backends give the same result for the first 10 iters, we choose one blessed AD backend (I guess ForwardDiff) and run the full 1000 samples to check that the sampling result is correct
I don't know how much this will really speed it up, but it could be worth a shot. Thoughts ... let me use the magic ping @TuringLang/maintainers?
Alternatively, I guess we could call logdensity_and_gradient on a bunch of values but I don't feel that that inspires a lot of confidence. Fundamentally we're trying to show that sampling works with AD, not that AD works (imo 'whether AD works' should be tested in the AD packages).
Fundamentally we're trying to show that sampling works with AD
We only need to test sampling works for one AD backend. The other AD backends can be tested via logdensity_and_gradient .
I agree with the general sentiment here. In my view we should
- only test long-run sampling (i.e. to test the correctness of a sampler) once, using whichever AD we prefer.
- publish a collection of test cases which would form a single source of truth for "things Turing.jl would like to be able to differentiate" (which can be updated over time as we find problem cases / find new cases that we care about). This way, AD packages have clear guidance as to the things that Turing.jl wants to be able to differentiate, and can just hook in them in their integration tests. For example, @mhauru recently added a collection of test cases for Bijectors here in Mooncake which has roughly the right structure -- the key point is that there's a single vector of test cases (see the
test_casesvariable) which Mooncake.jl can just iterate through.
It might be that we want to run these test cases in Turing.jl for each backend that we claim supports Turing.jl, in addition to running the tests in Mooncake.jl, but that's something that we could run later on.
Either way, I really think we should decouple "is this sampler correct?" from "does this AD work correctly?".
edit: to phrase the second point differently: in order to have robust AD support for Turing.jl, it is necessary that
- Turing.jl formalises what it needs from AD by publishing an iterable collection of test cases,
- Turing.jl keeps this list up-to-date, and
- ADs check that they can successfully differentiate all of these tests cases as part of their integration testing, and fix any problems.
Ok, so it seems we want to construct a list of LogDensityFunctions, then. (Or a list of models to construct them from.) And I suspect these should go into DynamicPPL.TestUtils.
I'll start playing around with this in a new package, we can decide whether to put it in DPPL proper at a later point in time.
Agreed. We should think about the precise format though. For the sake of both running the tests in downstream AD packages, and for running them in the Turing.jl integration tests, might it make sense to produce things which can be consumed by DifferentiationInterface's testing functionality? See https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterfaceTest/dev/ for context.
Ooh. I like that. But I don't think LogDensityProblemsAD uses DI for all its backends just yet, right? In which case, I think we'd still need to test logdensity_and_gradient itself? (Please correct if wrong.) What we can do is to try to future proof it by making sure we don't stray too far from the DITest interface (if we assume that eventually Turing will be using DI for everything, either if LogDensityProblemsAD switches to using DI for everything, or e.g. if we define our own LogDensityModel interface).
Hmmm maybe you're right? That being said, LogDensityProblemsAD does now support arbitrary ADTypes, meaning that you can use ADTypes to make any backend supported by DifferentiationInterface work for you. e.g.
using ADTypes, Mooncake, Turing
NUTS(; adtype=ADTypes.AutoMooncake(; config=nothing))
so maybe it's fine?
Mostly notes for myself because Julia really lends itself to creating a real maze of function calls, but if anyone's curious.
Mooncake goes via DI so that's all fine. (NOTE: actually running logdensity_and_gradient(AutoMooncake(...) requires DI to be explicitly loaded in the env as the relevant method is in LogDensityProblemsADDifferentiationInterfaceExt. Neither Mooncake nor LogDensityProblemsAD happen to have DI as a hard dep. I've suggested improving the error message here https://github.com/tpapp/LogDensityProblemsAD.jl/issues/41)
However, running ADgradient(AutoForwardDiff(), ldf) where ldf isa DynamicPPL.LogDensityFunction still goes to the custom LogDensityProblemsADForwardDiffExt implementation rather than using DI - this is because of two things:
-
DynamicPPLForwardDiffExt https://github.com/TuringLang/DynamicPPL.jl/blob/f5890a11ef4002026e8edeb355b6e8a50c68d115/ext/DynamicPPLForwardDiffExt.jl#L16-L36 intercepts calls with
AutoForwardDiff()and then in turn callsADgradient(:ForwardDiff, ...) -
LogDensityProblemsADADTypesExt still special-cases ForwardDiff, so even if you pass it AutoForwardDiff it won't use DI due to this function https://github.com/tpapp/LogDensityProblemsAD.jl/blob/e3401f21b5a065df0d5de38b37fad0e6650618f3/ext/LogDensityProblemsADADTypesExt.jl#L43-L50
I managed to force AutoForwardDiff to go via DI by modifying both of these functions (plus some other stuff), and it does give the right results, but given that this has already been discussed in https://github.com/tpapp/LogDensityProblemsAD.jl/pull/29 I don't think it's my job to upstream it.
tldr we can't really shortcircuit LogDensityProblemsAD unless they decide to adopt DI for all backends, so I'll move forward with directly testing logdensity_and_gradient. I'll also throw in some DI tests in preparation for an eventual future where we do go via DI (it's just that those tests will not reflect the current reality of what happens when we run AD on a Turing model).
also the code is here for now https://github.com/penelopeysm/ModelTests.jl
@penelopeysm can you take a look at this so we can finish off #2411
🎉