AdvancedVI.jl icon indicating copy to clipboard operation
AdvancedVI.jl copied to clipboard

Add Tapir to AD tests

Open yebai opened this issue 1 year ago • 25 comments

yebai avatar Jun 06 '24 07:06 yebai

Closely related to this, I am wondering whether DifferentiationInterface.jl would be feasible to entirely replace the in-house AD interface. For this my main concerns would be:

  • Does it support GPUs out of the box
  • Does it support structured gradients (no flattening)

Red-Portal avatar Jun 06 '24 07:06 Red-Portal

Does it support GPUs out of the box

yes

Does it support structured gradients (no flattening)

no and it won’t

yebai avatar Jun 06 '24 07:06 yebai

Hi @yebai , I tried adding Tapir, but it seems that there are some issues shared with ReverseDiff's compiled tapes. It seems like Tapir is re-compiling the target function at every call to gradients. I think this is due the fact that we re-define the target function at every step (here), which is why we currently can't use pre-compiled tapes in ReverseDiff. We currently need this to support the STL estimator, which requires stopping gradients and not all AD frameworks provide a way to stop gradients. Not sure how to deal with this.

Red-Portal avatar Jun 13 '24 03:06 Red-Portal

@willtebbutt might be able to help more.

yebai avatar Jun 13 '24 15:06 yebai

@Red-Portal I've just taken a quick look at your link to where we re-define the function each time. I agree that we'll need to abstract that out and re-use it everytime if we want to use Tapir.jl with any success.

Regarding stop gradients -- I think we can probably do this in Tapir.jl, but it would be great if you could open an issue about it so that we can discuss further. I think it's going to involve doing something a little bit strange.

willtebbutt avatar Jun 13 '24 16:06 willtebbutt

Hi all, I gave some thought about it. It would be possible to avoid re-defining the function in the current state of things. But, this will cause issues with subsampling (once we get there in the near future): the Turing model has to be updated at each step. Unless the Turing model recorded on the tapes can be mutated externally (is this possible?), this means we will have to redefine the objective at every step.

Red-Portal avatar Jun 13 '24 21:06 Red-Portal

Okay the following works:

using DifferentiationInterface, ReverseDiff, LinearAlgebra

struct A
    data
end
       
function main()
    rng = Random.default_rng()
    a   = A(ones(10))
    f(x) = dot(a.data,x)
    println(gradient(f, AutoReverseDiff(true), ones(10)))
    a.data[:] = zeros(10)
    println(gradient(f, AutoReverseDiff(true), ones(10)))
end

But Tapir doesn't, but I guess this is due to a different issue? The following code:

using DifferentiationInterface, Tapir, LinearAlgebra

struct A
    data
end
       
function main()
    rng = Random.default_rng()
    a   = A(ones(10))
    f(x) = dot(a.data,x)
    println(gradient(f, AutoTapir(), ones(10)))
    a.data[:] = zeros(10)
    println(gradient(f, AutoTapir(), ones(10)))
end

yields:

julia> main()
[ Info: Compiling rule for Tuple{var"#f#9"{A}, Vector{Float64}} in safe mode. Disable for best performance.
ERROR: MethodError: Cannot `convert` an object of type 
  Core.OpaqueClosure{Tuple{Any},Tuple{Union{Tapir.ZeroRData, Tapir.RData{@NamedTuple{a::Tapir.RData{@NamedTuple{data}}}}},Tapir.NoRData}} to an object of type 
  Core.OpaqueClosure{Tuple{Any},Tuple{Tapir.RData{@NamedTuple{a::Tapir.RData{@NamedTuple{data}}}},Tapir.NoRData}}

Closest candidates are:
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:84

Stacktrace:
 [1] Tapir.DerivedRule{…}(fwds_oc::Function, pb_oc::Function, isva::Val{…}, nargs::Val{…})
   @ Tapir ~/.julia/packages/Tapir/7eB9t/src/interpreter/s2s_reverse_mode_ad.jl:654
 [2] build_rrule(interp::Tapir.TapirInterpreter{…}, sig::Type{…}; safety_on::Bool, silence_safety_messages::Bool)
   @ Tapir ~/.julia/packages/Tapir/7eB9t/src/interpreter/s2s_reverse_mode_ad.jl:808
 [3] build_rrule
   @ ~/.julia/packages/Tapir/7eB9t/src/interpreter/s2s_reverse_mode_ad.jl:741 [inlined]
 [4] prepare_pullback(f::var"#f#9"{A}, backend::AutoTapir, x::Vector{Float64}, dy::Float64)
   @ DifferentiationInterfaceTapirExt ~/.julia/packages/DifferentiationInterface/lN3yP/ext/DifferentiationInterfaceTapirExt/onearg.jl:8
 [5] prepare_gradient(f::var"#f#9"{A}, backend::AutoTapir, x::Vector{Float64})
   @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/lN3yP/src/first_order/gradient.jl:59
 [6] gradient(f::var"#f#9"{A}, backend::AutoTapir, x::Vector{Float64})
   @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/lN3yP/src/first_order/gradient.jl:74
 [7] main()
   @ Main ./REPL[23]:9
 [8] top-level scope
   @ REPL[24]:1
Some type information was truncated. Use `show(err)` to see complete types.

Red-Portal avatar Jun 13 '24 21:06 Red-Portal

To be honest, I would prefer having the forward path immutable as possible, but this is not going to be possible for subsampling unless we redefine the target function in every step.

Red-Portal avatar Jun 13 '24 22:06 Red-Portal

@Red-Portal this is a really interesting failure case (I was aware that this could happen, but hadn't figured out a concrete case in which it would). I'm going to create a unit test out of it and fix it in Tapir.

In the mean time, if you make A parametric, I find that it fixes it locally. Something like:

struct A{T}
    data::T
end

willtebbutt avatar Jun 14 '24 09:06 willtebbutt

@willtebbutt Could we just have an interface for differentiating f(x, data) with respect to x while receiving both x and data for each evaluation? It's such a common use case for any data related stuff, and I think it would massively simplify the use of any precompilation-based AD.

Red-Portal avatar Jun 14 '24 09:06 Red-Portal

So you can definitely do this using Tapir.jl's value_and_gradient!! interface -- the limitation in the use-case above is DifferentiationInterface.jl, rather than Tapir.jl. There's a discussion about this here.

willtebbutt avatar Jun 14 '24 10:06 willtebbutt

@willtebbutt, it might be easier if you could create a working example so @Red-Portal can adapt it.

yebai avatar Jun 14 '24 10:06 yebai

Sure. Something like this:

f(x, data) = dot(data, x)
x = randn(10)
data = randn(10)

rule = Tapir.build_rrule(f, x, data)
Tapir.value_and_gradient!!(rule, f, x, data)

should do the trick.

You should re-use rule each time you run Tapir.value_and_gradient!!.

Note that you don't have to have the same size inputs each time, just the same type -- Tapir.jl is a little less restrictive than ReverseDiff in this regard.

willtebbutt avatar Jun 14 '24 10:06 willtebbutt

Just to clarify, Tapir is source-to-source transformation-based AD. It is like Zygote but addresses many design limitations. So, Tapir works well with data and input-dependent control flows. It even works with global variables as long as they don't mutate.

As @willtebbutt mentioned above, the only assumption is that the argument types stay the same as the ones for building a rule. When argument types change, a new rule should be built by calling Tapir.build_rrule.

yebai avatar Jun 14 '24 11:06 yebai

Okay, this sounds much more promising. I'll come back to this once the rng issue is resolved. In the meantime I'll restructure things such that we don't have to redefine the target function. Thanks for the pointers!

Red-Portal avatar Jun 14 '24 23:06 Red-Portal

Might as well add Enzyme too while you're at it. I don't forsee any issues with the things mentioned above, and it indeed has rng/abstract type/support for non-differentiated data.

wsmoses avatar Jun 18 '24 20:06 wsmoses

@wsmoses Enzyme has been usable for a while since this PR, but just not tested against. The problem is that Enzyme really doesn't play nice with Distributions.jl, so it's pretty much unusable at the moment. Even the most basic models result in immediate segfaults. From my understanding, that is partially why Tapir.jl is being worked on. (Correct? @yebai @willtebbutt )

Red-Portal avatar Jun 18 '24 20:06 Red-Portal

What about distributions.jl does it not play well with? I'm not aware of any outstanding issues on Enzyme.jl that are related?

wsmoses avatar Jun 18 '24 21:06 wsmoses

But yeah if you have any problems please open issues and we'll work fast to get them resolved!

wsmoses avatar Jun 18 '24 21:06 wsmoses

and while I can't speak to why the folks started work on the Taped AD tool, @yebai et al have a project funded for the next three years to add Enzyme to turing so it is intended to be well supported [with substantial speedups already shown, but more work on the ingetation end]. https://www.turing.ac.uk/research/research-projects/development-composable-parallelisable-and-user-friendly-inference-and

Again happy to quickly fix any issues that you see :)

wsmoses avatar Jun 18 '24 21:06 wsmoses

I'll take that as a reminder and try it again. Thanks!

Red-Portal avatar Jun 18 '24 21:06 Red-Portal

Yes, it would be good to test Enzyme, too.

Tapir is a less ambitious project than Enzyme, initially focusing on a rewrite of Zygote. It has strengths and weaknesses. It's entirely written in Julia, with a code base that is much smaller and more hackable, but Tapir's performance will slightly lag behind Enzyme for the foreseeable future.

yebai avatar Jun 19 '24 20:06 yebai

@Red-Portal your examples (both the abstractly-typed data field and using TaskLocalRNGs) should now both work on version 0.2.23 of Tapir.jl (the latest release).

willtebbutt avatar Jun 24 '24 15:06 willtebbutt

Hi @willtebbutt . Thanks for the fixes! Although, it seems there are still some issues left, as can be seen in #68. I currently don't have enough time to dig deeper (distill the issues into smaller MWEs) into this, but let me know if there is anything I can immediately help fix.

Red-Portal avatar Jun 24 '24 22:06 Red-Portal

Update: I made a new PR (#71) because the previous one had a messed up git commit history.

Red-Portal avatar Jun 25 '24 20:06 Red-Portal

Fixed by https://github.com/TuringLang/AdvancedVI.jl/pull/86

yebai avatar Sep 03 '24 07:09 yebai