Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Test with Tapir

Open mhauru opened this issue 1 year ago • 35 comments

Closes #2247

mhauru avatar Jul 17 '24 13:07 mhauru

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 86.86%. Comparing base (a26ce11) to head (8325529). Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2289      +/-   ##
==========================================
+ Coverage   86.79%   86.86%   +0.07%     
==========================================
  Files          24       24              
  Lines        1598     1599       +1     
==========================================
+ Hits         1387     1389       +2     
+ Misses        211      210       -1     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Jul 17 '24 15:07 codecov[bot]

@willtebbutt, could you please take a look at the logs of the failing CI runs and comment on whether these are likely to be Tapir issues, or something we need to address? In one case zero_tangent goes into infinite recursion, in another we expect LogDensityProblemsAD.ADGradient to return a LogDensityProblemsAD.ADGradientWrapper but we seem to get a Tapir.CoDual instead.

mhauru avatar Jul 17 '24 16:07 mhauru

Pull Request Test Coverage Report for Build 10679664210

Details

  • 5 of 5 (100.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.07%) to 87.085%

Totals Coverage Status
Change from base Build 10634847007: 0.07%
Covered Lines: 1389
Relevant Lines: 1595

💛 - Coveralls

coveralls avatar Jul 17 '24 21:07 coveralls

In one case zero_tangent goes into infinite recursion,

@willtebbutt can you help take a look at the zero_tangent recursion issue and fix that?

yebai avatar Jul 23 '24 12:07 yebai

Assuming that we're not considering the \ell field part of the public interface, this looks like an issue in Turing.jl -- it assumes here that the \ell field of the ADGradientWrapper is what is needed, rather than calling Base.parent.

(this is in addition to the recursion issue -- taking a look at that now)

willtebbutt avatar Jul 23 '24 13:07 willtebbutt

Assuming that we're not considering the \ell field part of the public interface, this looks like an issue in Turing.jl -- it assumes here

@sunxd3 can this be replaced with the new setmodel interface introduced in https://github.com/TuringLang/DynamicPPL.jl/pull/626?

yebai avatar Jul 23 '24 15:07 yebai

Aye the setmodel methods should be useful there:)

torfjelde avatar Jul 25 '24 20:07 torfjelde

Sorry for letting this go under my radar.

From the look of it, I don't think setmodel is directly useful, but we can modify setvarinfo to mirror the implementation of setmodel.

(model and varinfo are two fields of LogDensityFunction, so it's appropriate to have a setvarinfo that matches it)

Happy to make this modification, should I directly commit to this PR @yebai? It should end up in DynamicPPL, but may not be a good idea to experiment here.

sunxd3 avatar Jul 31 '24 09:07 sunxd3

@willtebbutt @mhauru I think the setvarinfo related errors are gone (at least from the look of it)

sunxd3 avatar Jul 31 '24 16:07 sunxd3

Thanks @sunxd3.

In addition to the infinite recursion one, there's also this:

MethodError: no method matching translate(::Val{Core.Intrinsics.atomic_pointerset})
  
  Closest candidates are:
    translate(::Val{Core.Intrinsics.copysign_float})
     @ Tapir ~/.julia/packages/Tapir/vvmqf/src/rrules/builtins.jl:55
    translate(::Val{Core.Intrinsics.slt_int})
     @ Tapir ~/.julia/packages/Tapir/vvmqf/src/rrules/builtins.jl:64
    translate(::Val{Core.Intrinsics.sdiv_int})
     @ Tapir ~/.julia/packages/Tapir/vvmqf/src/rrules/builtins.jl:64
    ...
  
  Stacktrace:
    [1] lift_intrinsic(::Core.IntrinsicFunction, ::Core.SSAValue, ::GlobalRef, ::Vararg{Any})
      @ Tapir ~/.julia/packages/Tapir/vvmqf/src/interpreter/ir_normalisation.jl:147
    [2] intrinsic_to_function
      @ ~/.julia/packages/Tapir/vvmqf/src/interpreter/ir_normalisation.jl:136 [inlined]
    [3] normalise!(ir::Core.Compiler.IRCode, spnames::Vector{Symbol})
      @ Tapir ~/.julia/packages/Tapir/vvmqf/src/interpreter/ir_normalisation.jl:27
    [4] build_rrule(interp::Tapir.TapirInterpreter{Tapir.DefaultCtx}, sig_or_mi::Core.MethodInstance; safety_on::Bool, silence_safety_messages::Bool)
      @ Tapir ~/.julia/packages/Tapir/vvmqf/src/interpreter/s2s_reverse_mode_ad.jl:769
    [5] build_rrule
      @ ~/.julia/packages/Tapir/vvmqf/src/interpreter/s2s_reverse_mode_ad.jl:747 [inlined]
    [6] (::Tapir.LazyDerivedRule{Tapir.TapirInterpreter{Tapir.DefaultCtx}, Tapir.DerivedRule{MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Tapir.CoDual{typeof(Base._unsafe_copyto!), Tapir.NoFData}, Tapir.CoDual{Matrix{Real}, Matrix{Any}}, Tapir.CoDual{Int64, Tapir.NoFData}, Tapir.CoDual{Matrix{Float64}, Matrix{Float64}}, Tapir.CoDual{Int64, Tapir.NoFData}, Tapir.CoDual{Int64, Tapir.NoFData}}, Tapir.CoDual{Matrix{Real}, Matrix{Any}}}}, MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Tapir.NoRData}, NTuple{6, Tapir.NoRData}}}, Val{false}, Val{6}}})(::Tapir.CoDual{typeof(Base._unsafe_copyto!), Tapir.NoFData}, ::Tapir.CoDual{Matrix{Real}, Matrix{Any}}, ::Tapir.CoDual{Int64, Tapir.NoFData}, ::Tapir.CoDual{Matrix{Float64}, Matrix{Float64}}, ::Tapir.CoDual{Int64, Tapir.NoFData}, ::Tapir.CoDual{Int64, Tapir.NoFData})
      @ Tapir ~/.julia/packages/Tapir/vvmqf/src/interpreter/s2s_reverse_mode_ad.jl:1252
    [7] RRuleZeroWrapper
      @ ~/.julia/packages/Tapir/vvmqf/src/interpreter/s2s_reverse_mode_ad.jl:244 [inlined]

which you can see in the abstractmcmc.jl CI run. I'm guessing this is a Tapir bug, but @willtebbutt if you disagree let me know.

mhauru avatar Aug 07 '24 13:08 mhauru

Yes, this is definitely a Tapir.jl limitation -- we're missing a rule for Core.Intrinsics.atomic_pointerset. I need to open an issue about this, and improve some of the error messages in Tapir.jl to make it clear what's going on here.

willtebbutt avatar Aug 07 '24 13:08 willtebbutt

@yebai @mhauru is it obvious to you whether Tapir.jl will re-derive its rule with a fresh Tapir.TapirContext() at each iteration of e.g. Gibbs(HMC(0.2, 3, :p; adtype=adbackend), PG(10, :x))? I'm seeing a lot of the same rule being derived repeatedly in some of the test/mcmc/Inference.jl tests, which makes me think this might be going on.

willtebbutt avatar Aug 07 '24 14:08 willtebbutt

Also, @mhauru I've narrowed down the source of the error above involving Core.Intrinsics.atomic_pointerset to the really quite nasty looking Base._unsetindex!. Going to try and get a fix out for this today.

willtebbutt avatar Aug 07 '24 14:08 willtebbutt

is it obvious to you whether Tapir.jl will re-derive its rule with a fresh Tapir.TapirContext() at each iteration of e.g. Gibbs(HMC(0.2, 3, :p; adtype=adbackend), PG(10, :x))?

Yes, that could happen. @sunxd3 probably knows more.

yebai avatar Aug 07 '24 14:08 yebai

I am not certain. @willtebbutt will ADgradient(AutoTapir(), ::DynamicPPL.LogDensityFunction) trigger a rule re-deriving?

sunxd3 avatar Aug 07 '24 16:08 sunxd3

Yup, that'll do it.

willtebbutt avatar Aug 07 '24 16:08 willtebbutt

I am not 100% certain about this, but for the old Gibbs, maybe https://github.com/TuringLang/Turing.jl/blob/803d2f5672b7483c768b9987bcec0dc20257ebda/src/mcmc/gibbs.jl#L241C27-L241C87 will lead to call to ADgradient? @torfjelde

sunxd3 avatar Aug 07 '24 16:08 sunxd3

Note, if this is an issue for Tapir.jl, is it also potentially a problem for compiled ReverseDiff.jl?

willtebbutt avatar Aug 08 '24 08:08 willtebbutt

It is a problem if setmodel is called every Gibbs step. This is conservative for correctness sack, but it would be better to be able to reuse the derived rule (for Tapir) and tape (for ReverseDiff).

sunxd3 avatar Aug 08 '24 09:08 sunxd3

v0.2.31 of Tapir.jl should fix the problem around Core.Intrinsics.atomic_pointerset.

willtebbutt avatar Aug 08 '24 11:08 willtebbutt

This is conservative for correctness sack, but it would be better to be able to reuse the derived rule (for Tapir) and tape (for ReverseDiff).

@sunxd3 we don't need to re-derive the rule for Tapir since it supports dynamic control flows; can you help remove that? For ReverseDiff, it is indeed safer to avoid using cached tape.

yebai avatar Aug 08 '24 15:08 yebai

Yeah, Will and I just talked about this, I am going to address it. Although it might be annoying because we can't dispatch on individual ADGradientWrapper (ref https://github.com/tpapp/LogDensityProblemsAD.jl/pull/33). But maybe there is way around it.

sunxd3 avatar Aug 08 '24 16:08 sunxd3

I am not 100% certain about this, but for the old Gibbs, maybe https://github.com/TuringLang/Turing.jl/blob/803d2f5672b7483c768b9987bcec0dc20257ebda/src/mcmc/gibbs.jl#L241C27-L241C87 will lead to call to ADgradient?

I thought for the old Gibbs, it's incorrectly not calling setmodel, which will lead to incorrect gradient tapes for compiled ones, e.g. ReverseDiff with compiled tape.

E.g.

https://github.com/TuringLang/Turing.jl/blob/803d2f5672b7483c768b9987bcec0dc20257ebda/src/mcmc/hmc.jl#L247-L261

re-uses the hamiltonian from the state, which means that it's re-using the initially consturcted log-density, unless something occurs in gibbs_state here

https://github.com/TuringLang/Turing.jl/blob/803d2f5672b7483c768b9987bcec0dc20257ebda/src/mcmc/gibbs.jl#L235-L236

Looking at gibbs_state:

https://github.com/TuringLang/Turing.jl/blob/803d2f5672b7483c768b9987bcec0dc20257ebda/src/mcmc/gibbs.jl#L111-L114

it doesn't seem to do anything, i.e. we're still using the gradient tape from the initial step.

This was fixed in Turing.Experimental.Gibbs, but I guess has been a dorminant issue in Turing.Inference.Gibbs for a long time now 😕

torfjelde avatar Aug 09 '24 11:08 torfjelde

It is a problem if setmodel is called every Gibbs

I don't think it is? (and this leads to incorrect behavior for compiled ReverseDiff)

Although it might be annoying because we can't dispatch on individual ADGradientWrapper (ref https://github.com/tpapp/LogDensityProblemsAD.jl/pull/33). But maybe there is way around it.

Can we not use the same approach as we used for Turing.Experimental.Gibbs? Pass the ADType to setmodel (and start using setmodel in Turing.Inference.Gibbs too)?

torfjelde avatar Aug 09 '24 11:08 torfjelde

Can we not use the same approach as we used for Turing.Experimental.Gibbs? Pass the ADType to setmodel (and start using setmodel in Turing.Inference.Gibbs too)?

good thought, I think we can

sunxd3 avatar Aug 09 '24 12:08 sunxd3

@willtebbutt Apart from zero tangent's circular references, there is only one remaining issue: Is this a known issue?

Got exception outside of a @test
  UndefVarError: `P` not defined
  Stacktrace:
    [1] fcodual_type(::Type{Type{AbstractArray{var"#s128", 1}}})
      @ Tapir ~/.julia/packages/Tapir/J4f0f/src/codual.jl:89
    [2] make_ad_stmts!(stmt::Expr, line::Tapir.ID, info::Tapir.ADInfo)
      @ Tapir ~/.julia/packages/Tapir/J4f0f/src/interpreter/s2s_reverse_mode_ad.jl:526

yebai avatar Aug 09 '24 15:08 yebai

I'm aware that this can happen (I've seen this kind of thing sporadically over the last year), but I don't properly understand why it happens. Re-running CI in order to get more info from latest version of Tapir.jl on the problem (I've improved the stack traces). Will hunt down today.

willtebbutt avatar Aug 12 '24 10:08 willtebbutt

From a quick look, I think the Tapir issue might be the same as e.g.

julia> f(::Type{Type{T}}) where {T} = Type{T}
f (generic function with 1 method)

julia> P = Type{Tuple{T}} where T
Type{Tuple{T}} where T

julia> f(P)
ERROR: UndefVarError: `T` not defined
Stacktrace:
 [1] f(::Type{Type{Tuple{T}} where T})
   @ Main ./REPL[1]:1
 [2] top-level scope
   @ REPL[8]:1

devmotion avatar Aug 12 '24 12:08 devmotion

Ahhh, yes, I believe that's exactly what's going on. I can build new unit tests for Tapir.jl from this, and extend whatever functions aren't working properly. Thanks for the pointer @devmotion .

willtebbutt avatar Aug 12 '24 12:08 willtebbutt

A version of Tapir.jl which addresses this particular issue should be available in an hour or so.

willtebbutt avatar Aug 12 '24 12:08 willtebbutt