AdvancedVI.jl
AdvancedVI.jl copied to clipboard
Add Tapir to AD tests
Closely related to this, I am wondering whether DifferentiationInterface.jl would be feasible to entirely replace the in-house AD interface. For this my main concerns would be:
- Does it support GPUs out of the box
- Does it support structured gradients (no flattening)
Does it support GPUs out of the box
yes
Does it support structured gradients (no flattening)
no and it won’t
Hi @yebai , I tried adding Tapir, but it seems that there are some issues shared with ReverseDiff's compiled tapes. It seems like Tapir is re-compiling the target function at every call to gradients. I think this is due the fact that we re-define the target function at every step (here), which is why we currently can't use pre-compiled tapes in ReverseDiff. We currently need this to support the STL estimator, which requires stopping gradients and not all AD frameworks provide a way to stop gradients. Not sure how to deal with this.
@willtebbutt might be able to help more.
@Red-Portal I've just taken a quick look at your link to where we re-define the function each time. I agree that we'll need to abstract that out and re-use it everytime if we want to use Tapir.jl with any success.
Regarding stop gradients -- I think we can probably do this in Tapir.jl, but it would be great if you could open an issue about it so that we can discuss further. I think it's going to involve doing something a little bit strange.
Hi all, I gave some thought about it. It would be possible to avoid re-defining the function in the current state of things. But, this will cause issues with subsampling (once we get there in the near future): the Turing model has to be updated at each step. Unless the Turing model recorded on the tapes can be mutated externally (is this possible?), this means we will have to redefine the objective at every step.
Okay the following works:
using DifferentiationInterface, ReverseDiff, LinearAlgebra
struct A
data
end
function main()
rng = Random.default_rng()
a = A(ones(10))
f(x) = dot(a.data,x)
println(gradient(f, AutoReverseDiff(true), ones(10)))
a.data[:] = zeros(10)
println(gradient(f, AutoReverseDiff(true), ones(10)))
end
But Tapir doesn't, but I guess this is due to a different issue? The following code:
using DifferentiationInterface, Tapir, LinearAlgebra
struct A
data
end
function main()
rng = Random.default_rng()
a = A(ones(10))
f(x) = dot(a.data,x)
println(gradient(f, AutoTapir(), ones(10)))
a.data[:] = zeros(10)
println(gradient(f, AutoTapir(), ones(10)))
end
yields:
julia> main()
[ Info: Compiling rule for Tuple{var"#f#9"{A}, Vector{Float64}} in safe mode. Disable for best performance.
ERROR: MethodError: Cannot `convert` an object of type
Core.OpaqueClosure{Tuple{Any},Tuple{Union{Tapir.ZeroRData, Tapir.RData{@NamedTuple{a::Tapir.RData{@NamedTuple{data}}}}},Tapir.NoRData}} to an object of type
Core.OpaqueClosure{Tuple{Any},Tuple{Tapir.RData{@NamedTuple{a::Tapir.RData{@NamedTuple{data}}}},Tapir.NoRData}}
Closest candidates are:
convert(::Type{T}, ::T) where T
@ Base Base.jl:84
Stacktrace:
[1] Tapir.DerivedRule{…}(fwds_oc::Function, pb_oc::Function, isva::Val{…}, nargs::Val{…})
@ Tapir ~/.julia/packages/Tapir/7eB9t/src/interpreter/s2s_reverse_mode_ad.jl:654
[2] build_rrule(interp::Tapir.TapirInterpreter{…}, sig::Type{…}; safety_on::Bool, silence_safety_messages::Bool)
@ Tapir ~/.julia/packages/Tapir/7eB9t/src/interpreter/s2s_reverse_mode_ad.jl:808
[3] build_rrule
@ ~/.julia/packages/Tapir/7eB9t/src/interpreter/s2s_reverse_mode_ad.jl:741 [inlined]
[4] prepare_pullback(f::var"#f#9"{A}, backend::AutoTapir, x::Vector{Float64}, dy::Float64)
@ DifferentiationInterfaceTapirExt ~/.julia/packages/DifferentiationInterface/lN3yP/ext/DifferentiationInterfaceTapirExt/onearg.jl:8
[5] prepare_gradient(f::var"#f#9"{A}, backend::AutoTapir, x::Vector{Float64})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/lN3yP/src/first_order/gradient.jl:59
[6] gradient(f::var"#f#9"{A}, backend::AutoTapir, x::Vector{Float64})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/lN3yP/src/first_order/gradient.jl:74
[7] main()
@ Main ./REPL[23]:9
[8] top-level scope
@ REPL[24]:1
Some type information was truncated. Use `show(err)` to see complete types.
To be honest, I would prefer having the forward path immutable as possible, but this is not going to be possible for subsampling unless we redefine the target function in every step.
@Red-Portal this is a really interesting failure case (I was aware that this could happen, but hadn't figured out a concrete case in which it would). I'm going to create a unit test out of it and fix it in Tapir.
In the mean time, if you make A parametric, I find that it fixes it locally. Something like:
struct A{T}
data::T
end
@willtebbutt Could we just have an interface for differentiating f(x, data) with respect to x while receiving both x and data for each evaluation? It's such a common use case for any data related stuff, and I think it would massively simplify the use of any precompilation-based AD.
So you can definitely do this using Tapir.jl's value_and_gradient!! interface -- the limitation in the use-case above is DifferentiationInterface.jl, rather than Tapir.jl. There's a discussion about this here.
@willtebbutt, it might be easier if you could create a working example so @Red-Portal can adapt it.
Sure. Something like this:
f(x, data) = dot(data, x)
x = randn(10)
data = randn(10)
rule = Tapir.build_rrule(f, x, data)
Tapir.value_and_gradient!!(rule, f, x, data)
should do the trick.
You should re-use rule each time you run Tapir.value_and_gradient!!.
Note that you don't have to have the same size inputs each time, just the same type -- Tapir.jl is a little less restrictive than ReverseDiff in this regard.
Just to clarify, Tapir is source-to-source transformation-based AD. It is like Zygote but addresses many design limitations. So, Tapir works well with data and input-dependent control flows. It even works with global variables as long as they don't mutate.
As @willtebbutt mentioned above, the only assumption is that the argument types stay the same as the ones for building a rule. When argument types change, a new rule should be built by calling Tapir.build_rrule.
Okay, this sounds much more promising. I'll come back to this once the rng issue is resolved. In the meantime I'll restructure things such that we don't have to redefine the target function. Thanks for the pointers!
Might as well add Enzyme too while you're at it. I don't forsee any issues with the things mentioned above, and it indeed has rng/abstract type/support for non-differentiated data.
@wsmoses Enzyme has been usable for a while since this PR, but just not tested against. The problem is that Enzyme really doesn't play nice with Distributions.jl, so it's pretty much unusable at the moment. Even the most basic models result in immediate segfaults. From my understanding, that is partially why Tapir.jl is being worked on. (Correct? @yebai @willtebbutt )
What about distributions.jl does it not play well with? I'm not aware of any outstanding issues on Enzyme.jl that are related?
But yeah if you have any problems please open issues and we'll work fast to get them resolved!
and while I can't speak to why the folks started work on the Taped AD tool, @yebai et al have a project funded for the next three years to add Enzyme to turing so it is intended to be well supported [with substantial speedups already shown, but more work on the ingetation end]. https://www.turing.ac.uk/research/research-projects/development-composable-parallelisable-and-user-friendly-inference-and
Again happy to quickly fix any issues that you see :)
I'll take that as a reminder and try it again. Thanks!
Yes, it would be good to test Enzyme, too.
Tapir is a less ambitious project than Enzyme, initially focusing on a rewrite of Zygote. It has strengths and weaknesses. It's entirely written in Julia, with a code base that is much smaller and more hackable, but Tapir's performance will slightly lag behind Enzyme for the foreseeable future.
@Red-Portal your examples (both the abstractly-typed data field and using TaskLocalRNGs) should now both work on version 0.2.23 of Tapir.jl (the latest release).
Hi @willtebbutt . Thanks for the fixes! Although, it seems there are still some issues left, as can be seen in #68. I currently don't have enough time to dig deeper (distill the issues into smaller MWEs) into this, but let me know if there is anything I can immediately help fix.
Update: I made a new PR (#71) because the previous one had a messed up git commit history.
Fixed by https://github.com/TuringLang/AdvancedVI.jl/pull/86