inference check failure in test_rrule, @inferred rrule works
I wrapped up an MWE in ImplicitAD.jl for an rrule I defined (see the single test). My issue is that
test_rrule(one_one, 1.0;
check_inferred = true,
fdm = forward_fdm(5, 1),
atol = ϵ, rtol = ϵ)
gives an inference failure, while
@inferred rrule(ChainRulesTestUtils.TestConfig(), one_one, 1.0)
is fine.
test_rrule also checks the pullback inferrability, perhaps that's what's giving the error?
Thanks. Indeed it does, but it seems to be coming from the rrule_via_ad call:
julia> g, pbg = @inferred rrule_via_ad(ChainRulesTestUtils.TestConfig(), one_one_core, 1.0, 2.0)
(162761.18047510285, ChainRulesTestUtils.var"#f_pb#43"{ChainRulesTestUtils.TestConfig, Tuple{Bool, Bool, Bool}, Tuple{typeof(one_one_core), Float64, Float64}, ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}(ChainRulesTestUtils.TestConfig(FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}(SVector{5,Int64}(-2, -1, 0, 1, 2), SVector{5,Float64}(0.08333333333333333, -0.6666666666666666, 0.0, 0.6666666666666666, -0.08333333333333333), (SVector{5,Float64}(-0.08333333333333333, 0.5, -1.5, 0.8333333333333334, 0.25), SVector{5,Float64}(0.08333333333333333, -0.6666666666666666, 0.0, 0.6666666666666666, -0.08333333333333333), SVector{5,Float64}(-0.25, -0.8333333333333334, 1.5, -0.5, 0.08333333333333333)), 10.0, 1.0, Inf, 0.05555555555555555, 1.4999999999999998, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}(SVector{7,Int64}(-3, -2, -1, 0, 1, 2, 3), SVector{7,Float64}(-0.5, 2.0, -2.5, 0.0, 2.5, -2.0, 0.5), (SVector{7,Float64}(0.5, -4.0, 12.5, -20.0, 17.5, -8.0, 1.5), SVector{7,Float64}(-0.5, 2.0, -2.5, 0.0, 2.5, -2.0, 0.5), SVector{7,Float64}(-1.5, 8.0, -17.5, 20.0, -12.5, 4.0, -0.5)), 10.0, 1.0, Inf, 0.5365079365079365, 10.0))), (true, false, false), (ImplicitAD.one_one_core, 1.0, 2.0), ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}())))
julia> @code_warntype pbg(1.0)
MethodInstance for (::ChainRulesTestUtils.var"#f_pb#43"{ChainRulesTestUtils.TestConfig, Tuple{Bool, Bool, Bool}, Tuple{typeof(one_one_core), Float64, Float64}, ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}})(::Float64)
from (::ChainRulesTestUtils.var"#f_pb#43")(ȳ) in ChainRulesTestUtils at /home/tamas/.julia/packages/ChainRulesTestUtils/vWKSm/src/rule_config.jl:39
Arguments
#self#::ChainRulesTestUtils.var"#f_pb#43"{ChainRulesTestUtils.TestConfig, Tuple{Bool, Bool, Bool}, Tuple{typeof(one_one_core), Float64, Float64}, ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}
ȳ::Float64
Body::Tuple
1 ─ %1 = Core.getfield(#self#, :config)::ChainRulesTestUtils.TestConfig
│ %2 = Base.getproperty(%1, :fdm)::Any
│ %3 = Core.getfield(#self#, :call)::Core.Const(ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}()))
│ %4 = Core.getfield(#self#, :primals)::Tuple{typeof(one_one_core), Float64, Float64}
│ %5 = Core.getfield(#self#, :is_ignored)::Tuple{Bool, Bool, Bool}
│ %6 = ChainRulesTestUtils._make_j′vp_call(%2, %3, ȳ, %4, %5)::Tuple
└── return %6
Given the MWE, I wonder if anyone could please dig into this. I am not familiar with the internals of this package.
Wait, I think that just typing the fdm field of TestConfig would do the trick. Testing, then making a PR.
Nope, that fixes the Any for %2 above, but _make_j′vp_call is still not inferred. Which is not surprising because of the Any[...] inside that function.
Thanks for checking, I'll have time next week to dig into this. For now maybe just set check_inferred=false?
Thanks, I appreciate it.
My current understanding of this issue is:
The inference comes from the pullback, in particular the rrule_via_ad call, more precisely the _make_j′vp_call as you correctly point out. We can't make that call infer easily because the output (tangents to xs) will depend on whether we are ignoring a particular x, the information for which is passed in as a boolean array.
I see two ways around the issue:
- We can define an
rrulefor theone_one_corefunction, which will be called by therrule_via_adwith the currentTestConfig - We can pass in a different
RuleConfig, which unlike the current one (TestConfig) will not call the finite differences methods under the hood to work out therrule.
I imagine this package wants to be AD-independent? In that case we could still use a particular AD system as a test dependency and use its RuleConfig if we want to test the inference.