Turing.jl
Turing.jl copied to clipboard
Feature request: allow user to pass gradient tape into sample() function
When sampling a Turing model using ReverseDiff and memoization in a distributed setting (MCMCDistributed()), the gradient tape is compiled n
times if there are n
worker processes.
We can avoid the extra compilations by allowing user to compile the tape once and passing it into the sample
function.
Note that you have to be careful to not share the tape though: https://github.com/TuringLang/Turing.jl/issues/1412#issuecomment-698160681
IMO it would be great to make the setup more modular and less "magic". I'm not sure though if one should be able to specify the tape as part of the sample
calls. It seems it is only relevant for a very specific application, so maybe it should be bundled more clearly with the ReverseDiff AD choices. If AD settings would be more explicit (see https://github.com/TuringLang/Turing.jl/issues/1402), maybe one could think about having an ReverseDiffAD
singleton for ReverseDiff without caching and then allow users to cache the tape once by calling ReverseDiffCachedAD(model, sampler)
that precomputes a tape for a specific model and sampler, which could then be passed around explicitly and be reused (although still not clear to me how exactly since as mentioned above you need a separate one for each cache).
Since this issue seems to (only?) affect the cutoff when distributed sampling pays off, it might be good to take the time and think about such design choices before implementing a workaround.