ChainRulesTestUtils.jl icon indicating copy to clipboard operation
ChainRulesTestUtils.jl copied to clipboard

inference check failure in test_rrule, @inferred rrule works

Open tpapp opened this issue 3 years ago • 7 comments

I wrapped up an MWE in ImplicitAD.jl for an rrule I defined (see the single test). My issue is that

test_rrule(one_one, 1.0;
                   check_inferred = true,
                   fdm = forward_fdm(5, 1),
                   atol = ϵ, rtol = ϵ)

gives an inference failure, while

@inferred rrule(ChainRulesTestUtils.TestConfig(), one_one, 1.0)

is fine.

tpapp avatar May 03 '22 13:05 tpapp

test_rrule also checks the pullback inferrability, perhaps that's what's giving the error?

mzgubic avatar May 03 '22 13:05 mzgubic

Thanks. Indeed it does, but it seems to be coming from the rrule_via_ad call:

julia> g, pbg = @inferred rrule_via_ad(ChainRulesTestUtils.TestConfig(), one_one_core, 1.0, 2.0)
(162761.18047510285, ChainRulesTestUtils.var"#f_pb#43"{ChainRulesTestUtils.TestConfig, Tuple{Bool, Bool, Bool}, Tuple{typeof(one_one_core), Float64, Float64}, ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}(ChainRulesTestUtils.TestConfig(FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}(SVector{5,Int64}(-2, -1, 0, 1, 2), SVector{5,Float64}(0.08333333333333333, -0.6666666666666666, 0.0, 0.6666666666666666, -0.08333333333333333), (SVector{5,Float64}(-0.08333333333333333, 0.5, -1.5, 0.8333333333333334, 0.25), SVector{5,Float64}(0.08333333333333333, -0.6666666666666666, 0.0, 0.6666666666666666, -0.08333333333333333), SVector{5,Float64}(-0.25, -0.8333333333333334, 1.5, -0.5, 0.08333333333333333)), 10.0, 1.0, Inf, 0.05555555555555555, 1.4999999999999998, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}(SVector{7,Int64}(-3, -2, -1, 0, 1, 2, 3), SVector{7,Float64}(-0.5, 2.0, -2.5, 0.0, 2.5, -2.0, 0.5), (SVector{7,Float64}(0.5, -4.0, 12.5, -20.0, 17.5, -8.0, 1.5), SVector{7,Float64}(-0.5, 2.0, -2.5, 0.0, 2.5, -2.0, 0.5), SVector{7,Float64}(-1.5, 8.0, -17.5, 20.0, -12.5, 4.0, -0.5)), 10.0, 1.0, Inf, 0.5365079365079365, 10.0))), (true, false, false), (ImplicitAD.one_one_core, 1.0, 2.0), ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}())))

julia> @code_warntype pbg(1.0)
MethodInstance for (::ChainRulesTestUtils.var"#f_pb#43"{ChainRulesTestUtils.TestConfig, Tuple{Bool, Bool, Bool}, Tuple{typeof(one_one_core), Float64, Float64}, ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}})(::Float64)
  from (::ChainRulesTestUtils.var"#f_pb#43")(ȳ) in ChainRulesTestUtils at /home/tamas/.julia/packages/ChainRulesTestUtils/vWKSm/src/rule_config.jl:39
Arguments
  #self#::ChainRulesTestUtils.var"#f_pb#43"{ChainRulesTestUtils.TestConfig, Tuple{Bool, Bool, Bool}, Tuple{typeof(one_one_core), Float64, Float64}, ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}
  ȳ::Float64
Body::Tuple
1 ─ %1 = Core.getfield(#self#, :config)::ChainRulesTestUtils.TestConfig
│   %2 = Base.getproperty(%1, :fdm)::Any
│   %3 = Core.getfield(#self#, :call)::Core.Const(ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}()))
│   %4 = Core.getfield(#self#, :primals)::Tuple{typeof(one_one_core), Float64, Float64}
│   %5 = Core.getfield(#self#, :is_ignored)::Tuple{Bool, Bool, Bool}
│   %6 = ChainRulesTestUtils._make_j′vp_call(%2, %3, ȳ, %4, %5)::Tuple
└──      return %6

Given the MWE, I wonder if anyone could please dig into this. I am not familiar with the internals of this package.

tpapp avatar May 03 '22 13:05 tpapp

Wait, I think that just typing the fdm field of TestConfig would do the trick. Testing, then making a PR.

tpapp avatar May 03 '22 13:05 tpapp

Nope, that fixes the Any for %2 above, but _make_j′vp_call is still not inferred. Which is not surprising because of the Any[...] inside that function.

tpapp avatar May 03 '22 13:05 tpapp

Thanks for checking, I'll have time next week to dig into this. For now maybe just set check_inferred=false?

mzgubic avatar May 03 '22 13:05 mzgubic

Thanks, I appreciate it.

tpapp avatar May 03 '22 14:05 tpapp

My current understanding of this issue is:

The inference comes from the pullback, in particular the rrule_via_ad call, more precisely the _make_j′vp_call as you correctly point out. We can't make that call infer easily because the output (tangents to xs) will depend on whether we are ignoring a particular x, the information for which is passed in as a boolean array.

I see two ways around the issue:

  1. We can define an rrule for the one_one_core function, which will be called by the rrule_via_ad with the current TestConfig
  2. We can pass in a different RuleConfig, which unlike the current one (TestConfig) will not call the finite differences methods under the hood to work out the rrule.

I imagine this package wants to be AD-independent? In that case we could still use a particular AD system as a test dependency and use its RuleConfig if we want to test the inference.

mzgubic avatar May 10 '22 11:05 mzgubic