Turing.jl
Turing.jl copied to clipboard
Test with Tapir
Closes #2247
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 86.86%. Comparing base (
a26ce11) to head (8325529). Report is 1 commits behind head on master.
Additional details and impacted files
@@ Coverage Diff @@
## master #2289 +/- ##
==========================================
+ Coverage 86.79% 86.86% +0.07%
==========================================
Files 24 24
Lines 1598 1599 +1
==========================================
+ Hits 1387 1389 +2
+ Misses 211 210 -1
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
@willtebbutt, could you please take a look at the logs of the failing CI runs and comment on whether these are likely to be Tapir issues, or something we need to address? In one case zero_tangent goes into infinite recursion, in another we expect LogDensityProblemsAD.ADGradient to return a LogDensityProblemsAD.ADGradientWrapper but we seem to get a Tapir.CoDual instead.
Pull Request Test Coverage Report for Build 10679664210
Details
- 5 of 5 (100.0%) changed or added relevant lines in 1 file are covered.
- No unchanged relevant lines lost coverage.
- Overall coverage increased (+0.07%) to 87.085%
| Totals | |
|---|---|
| Change from base Build 10634847007: | 0.07% |
| Covered Lines: | 1389 |
| Relevant Lines: | 1595 |
💛 - Coveralls
In one case zero_tangent goes into infinite recursion,
@willtebbutt can you help take a look at the zero_tangent recursion issue and fix that?
Assuming that we're not considering the \ell field part of the public interface, this looks like an issue in Turing.jl -- it assumes here that the \ell field of the ADGradientWrapper is what is needed, rather than calling Base.parent.
(this is in addition to the recursion issue -- taking a look at that now)
Assuming that we're not considering the \ell field part of the public interface, this looks like an issue in Turing.jl -- it assumes here
@sunxd3 can this be replaced with the new setmodel interface introduced in https://github.com/TuringLang/DynamicPPL.jl/pull/626?
Aye the setmodel methods should be useful there:)
Sorry for letting this go under my radar.
From the look of it, I don't think setmodel is directly useful, but we can modify setvarinfo to mirror the implementation of setmodel.
(model and varinfo are two fields of LogDensityFunction, so it's appropriate to have a setvarinfo that matches it)
Happy to make this modification, should I directly commit to this PR @yebai? It should end up in DynamicPPL, but may not be a good idea to experiment here.
@willtebbutt @mhauru I think the setvarinfo related errors are gone (at least from the look of it)
Thanks @sunxd3.
In addition to the infinite recursion one, there's also this:
MethodError: no method matching translate(::Val{Core.Intrinsics.atomic_pointerset})
Closest candidates are:
translate(::Val{Core.Intrinsics.copysign_float})
@ Tapir ~/.julia/packages/Tapir/vvmqf/src/rrules/builtins.jl:55
translate(::Val{Core.Intrinsics.slt_int})
@ Tapir ~/.julia/packages/Tapir/vvmqf/src/rrules/builtins.jl:64
translate(::Val{Core.Intrinsics.sdiv_int})
@ Tapir ~/.julia/packages/Tapir/vvmqf/src/rrules/builtins.jl:64
...
Stacktrace:
[1] lift_intrinsic(::Core.IntrinsicFunction, ::Core.SSAValue, ::GlobalRef, ::Vararg{Any})
@ Tapir ~/.julia/packages/Tapir/vvmqf/src/interpreter/ir_normalisation.jl:147
[2] intrinsic_to_function
@ ~/.julia/packages/Tapir/vvmqf/src/interpreter/ir_normalisation.jl:136 [inlined]
[3] normalise!(ir::Core.Compiler.IRCode, spnames::Vector{Symbol})
@ Tapir ~/.julia/packages/Tapir/vvmqf/src/interpreter/ir_normalisation.jl:27
[4] build_rrule(interp::Tapir.TapirInterpreter{Tapir.DefaultCtx}, sig_or_mi::Core.MethodInstance; safety_on::Bool, silence_safety_messages::Bool)
@ Tapir ~/.julia/packages/Tapir/vvmqf/src/interpreter/s2s_reverse_mode_ad.jl:769
[5] build_rrule
@ ~/.julia/packages/Tapir/vvmqf/src/interpreter/s2s_reverse_mode_ad.jl:747 [inlined]
[6] (::Tapir.LazyDerivedRule{Tapir.TapirInterpreter{Tapir.DefaultCtx}, Tapir.DerivedRule{MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Tapir.CoDual{typeof(Base._unsafe_copyto!), Tapir.NoFData}, Tapir.CoDual{Matrix{Real}, Matrix{Any}}, Tapir.CoDual{Int64, Tapir.NoFData}, Tapir.CoDual{Matrix{Float64}, Matrix{Float64}}, Tapir.CoDual{Int64, Tapir.NoFData}, Tapir.CoDual{Int64, Tapir.NoFData}}, Tapir.CoDual{Matrix{Real}, Matrix{Any}}}}, MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Tapir.NoRData}, NTuple{6, Tapir.NoRData}}}, Val{false}, Val{6}}})(::Tapir.CoDual{typeof(Base._unsafe_copyto!), Tapir.NoFData}, ::Tapir.CoDual{Matrix{Real}, Matrix{Any}}, ::Tapir.CoDual{Int64, Tapir.NoFData}, ::Tapir.CoDual{Matrix{Float64}, Matrix{Float64}}, ::Tapir.CoDual{Int64, Tapir.NoFData}, ::Tapir.CoDual{Int64, Tapir.NoFData})
@ Tapir ~/.julia/packages/Tapir/vvmqf/src/interpreter/s2s_reverse_mode_ad.jl:1252
[7] RRuleZeroWrapper
@ ~/.julia/packages/Tapir/vvmqf/src/interpreter/s2s_reverse_mode_ad.jl:244 [inlined]
which you can see in the abstractmcmc.jl CI run. I'm guessing this is a Tapir bug, but @willtebbutt if you disagree let me know.
Yes, this is definitely a Tapir.jl limitation -- we're missing a rule for Core.Intrinsics.atomic_pointerset. I need to open an issue about this, and improve some of the error messages in Tapir.jl to make it clear what's going on here.
@yebai @mhauru is it obvious to you whether Tapir.jl will re-derive its rule with a fresh Tapir.TapirContext() at each iteration of e.g. Gibbs(HMC(0.2, 3, :p; adtype=adbackend), PG(10, :x))? I'm seeing a lot of the same rule being derived repeatedly in some of the test/mcmc/Inference.jl tests, which makes me think this might be going on.
Also, @mhauru I've narrowed down the source of the error above involving Core.Intrinsics.atomic_pointerset to the really quite nasty looking Base._unsetindex!. Going to try and get a fix out for this today.
is it obvious to you whether Tapir.jl will re-derive its rule with a fresh Tapir.TapirContext() at each iteration of e.g. Gibbs(HMC(0.2, 3, :p; adtype=adbackend), PG(10, :x))?
Yes, that could happen. @sunxd3 probably knows more.
I am not certain. @willtebbutt will ADgradient(AutoTapir(), ::DynamicPPL.LogDensityFunction) trigger a rule re-deriving?
Yup, that'll do it.
I am not 100% certain about this, but for the old Gibbs, maybe
https://github.com/TuringLang/Turing.jl/blob/803d2f5672b7483c768b9987bcec0dc20257ebda/src/mcmc/gibbs.jl#L241C27-L241C87
will lead to call to ADgradient? @torfjelde
Note, if this is an issue for Tapir.jl, is it also potentially a problem for compiled ReverseDiff.jl?
It is a problem if setmodel is called every Gibbs step. This is conservative for correctness sack, but it would be better to be able to reuse the derived rule (for Tapir) and tape (for ReverseDiff).
v0.2.31 of Tapir.jl should fix the problem around Core.Intrinsics.atomic_pointerset.
This is conservative for correctness sack, but it would be better to be able to reuse the derived rule (for Tapir) and tape (for ReverseDiff).
@sunxd3 we don't need to re-derive the rule for Tapir since it supports dynamic control flows; can you help remove that? For ReverseDiff, it is indeed safer to avoid using cached tape.
Yeah, Will and I just talked about this, I am going to address it. Although it might be annoying because we can't dispatch on individual ADGradientWrapper (ref https://github.com/tpapp/LogDensityProblemsAD.jl/pull/33). But maybe there is way around it.
I am not 100% certain about this, but for the old Gibbs, maybe https://github.com/TuringLang/Turing.jl/blob/803d2f5672b7483c768b9987bcec0dc20257ebda/src/mcmc/gibbs.jl#L241C27-L241C87 will lead to call to ADgradient?
I thought for the old Gibbs, it's incorrectly not calling setmodel, which will lead to incorrect gradient tapes for compiled ones, e.g. ReverseDiff with compiled tape.
E.g.
https://github.com/TuringLang/Turing.jl/blob/803d2f5672b7483c768b9987bcec0dc20257ebda/src/mcmc/hmc.jl#L247-L261
re-uses the hamiltonian from the state, which means that it's re-using the initially consturcted log-density, unless something occurs in gibbs_state here
https://github.com/TuringLang/Turing.jl/blob/803d2f5672b7483c768b9987bcec0dc20257ebda/src/mcmc/gibbs.jl#L235-L236
Looking at gibbs_state:
https://github.com/TuringLang/Turing.jl/blob/803d2f5672b7483c768b9987bcec0dc20257ebda/src/mcmc/gibbs.jl#L111-L114
it doesn't seem to do anything, i.e. we're still using the gradient tape from the initial step.
This was fixed in Turing.Experimental.Gibbs, but I guess has been a dorminant issue in Turing.Inference.Gibbs for a long time now 😕
It is a problem if setmodel is called every Gibbs
I don't think it is? (and this leads to incorrect behavior for compiled ReverseDiff)
Although it might be annoying because we can't dispatch on individual ADGradientWrapper (ref https://github.com/tpapp/LogDensityProblemsAD.jl/pull/33). But maybe there is way around it.
Can we not use the same approach as we used for Turing.Experimental.Gibbs? Pass the ADType to setmodel (and start using setmodel in Turing.Inference.Gibbs too)?
Can we not use the same approach as we used for Turing.Experimental.Gibbs? Pass the ADType to setmodel (and start using setmodel in Turing.Inference.Gibbs too)?
good thought, I think we can
@willtebbutt Apart from zero tangent's circular references, there is only one remaining issue: Is this a known issue?
Got exception outside of a @test
UndefVarError: `P` not defined
Stacktrace:
[1] fcodual_type(::Type{Type{AbstractArray{var"#s128", 1}}})
@ Tapir ~/.julia/packages/Tapir/J4f0f/src/codual.jl:89
[2] make_ad_stmts!(stmt::Expr, line::Tapir.ID, info::Tapir.ADInfo)
@ Tapir ~/.julia/packages/Tapir/J4f0f/src/interpreter/s2s_reverse_mode_ad.jl:526
I'm aware that this can happen (I've seen this kind of thing sporadically over the last year), but I don't properly understand why it happens. Re-running CI in order to get more info from latest version of Tapir.jl on the problem (I've improved the stack traces). Will hunt down today.
From a quick look, I think the Tapir issue might be the same as e.g.
julia> f(::Type{Type{T}}) where {T} = Type{T}
f (generic function with 1 method)
julia> P = Type{Tuple{T}} where T
Type{Tuple{T}} where T
julia> f(P)
ERROR: UndefVarError: `T` not defined
Stacktrace:
[1] f(::Type{Type{Tuple{T}} where T})
@ Main ./REPL[1]:1
[2] top-level scope
@ REPL[8]:1
Ahhh, yes, I believe that's exactly what's going on. I can build new unit tests for Tapir.jl from this, and extend whatever functions aren't working properly. Thanks for the pointer @devmotion .
A version of Tapir.jl which addresses this particular issue should be available in an hour or so.