Issue with ReverseDiff and undef on Julia v1.7
Issue
Julia v1.7's implementation of copy! is not compatible with undef entries, which causes quite a few issues for us when we're using BangBang.jl. Ref: https://github.com/JuliaFolds2/BangBang.jl/issues/21 https://github.com/JuliaFolds2/BangBang.jl/pull/22 and their references.
After https://github.com/JuliaFolds2/BangBang.jl/pull/22, a lot of these issues were addressed, but it seems it still doesn't quite do it when we're working with AD backends which uses types for the tracing, e.g. ReverseDiff.jl.
In particular, something like
@model function demo(::Type{TV}=Vector{Float64}) where {TV}
x = TV(undef, 1)
x[1] ~ Normal()
end
will fail when used with ReverseDiff.jl, due hitting the NoBang version, and thus hitting https://github.com/JuliaFolds2/BangBang.jl/issues/21 (even after https://github.com/JuliaFolds2/BangBang.jl/issues/22). The problem comes down to
eltype(x)::ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}
typeof(__varinfo__[@varname(x[1])])::ReverseDiff.TrackedReal{Float64, Float64, Nothing}
i.e. eltype(x) !== typeof(__varinfo__[@varname(x[1])]), and so the method instances https://github.com/JuliaFolds2/BangBang.jl/blob/1e4455451378d150a0359e56d5e7ed75b74ddd6a/src/base.jl#L531-L539 is not hit when using ReverseDiff.jl.
In contrast, this is not an issue for, say, ForwardDiff.jl or non-diff stuff, since here we always hit the mutating version, i.e. the check above is hit.
Things to do
- [x] For now, I'll just disable this particular failing test (which occurs on https://github.com/TuringLang/Turing.jl/pull/2223), as it is holding up PRs.
- [ ]
setindex!, i.e. the mutating version, is indeed valid in this case, so BangBang.jl is incorrectly reporting it as not. I'll raise a corresponding issue over there, but it's also somewhat unclear to me if we'll ever be able to fully cover all the scenarios correctly, or if we'll have to play this "catch up"-game indefinitively. I.e. may be worth considering implementing a slightly more stringent version ofBangBang.AccessorsImpl.prefermutationwhich always usessetindex!in favour ofBangBang.NoBang.setindexwhenever we see an array.
This was the debug script I was using:
using DynamicPPL, ReverseDiff, LogDensityProblems, LogDensityProblemsAD, Distributions
s_global = nothing
s_global_i = nothing
@model function demo_assume_index_observe(
x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
) where {TV}
# `assume` with indexing and `observe`
s = TV(undef, length(x))
global s_global
s_global = s
for i in eachindex(s)
if haskey(__varinfo__, @varname(s[i]))
global s_global_i
s_global_i = __varinfo__[@varname(s[i])]
@info "s[$i] varinfo" s_global_i
@info "s[$i] check" BangBang.possible(BangBang._setindex!, s, s_global_i, 1)
end
@info "s[$i]" isassigned(s, i) typeof(s) eltype(s)
s[i] ~ InverseGamma(2, 3)
end
m = TV(undef, length(x))
for i in eachindex(m)
@info "m[$i]" isassigned(m, i) eltype(m)
m[i] ~ Normal(0, sqrt(s[i]))
end
x ~ MvNormal(m, Diagonal(s))
return (; s=s, m=m, x=x, logp=DynamicPPL.getlogp(__varinfo__))
end
# (✓) WORKS!
model = demo_assume_index_observe()
f = ADgradient(AutoReverseDiff(), DynamicPPL.LogDensityFunction(model))
LogDensityProblems.logdensity_and_gradient(f, f.ℓ.varinfo[:])
# (×) BREAKS!
# NOTE: This requires the set up from `test/mcmc/abstractmcmc.jl`.
model = demo_assume_index_observe()
adtype = AutoForwardDiff()
sampler = initialize_nuts(model)
sampler_ext = DynamicPPL.Sampler(externalsampler(sampler; adtype, unconstrained=true), model)
sample(model, sampler_ext, 2; n_adapts=0, discard_initial=0)
### Error message:
# [1] getindex
# @ ./array.jl:861 [inlined]
# [2] copy!(dst::Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, src::Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}})
# @ Base ./abstractarray.jl:874
# [3] _setindex
# @ ~/.julia/packages/BangBang/2pcna/src/NoBang/base.jl:133 [inlined]
# [4] may
# @ ~/.julia/packages/BangBang/2pcna/src/core.jl:11 [inlined]
# [5] setindex!!
# @ ~/.julia/packages/BangBang/2pcna/src/base.jl:478 [inlined]
# [6] set
# @ ~/.julia/packages/BangBang/2pcna/src/accessors.jl:35 [inlined]
Note the very strange behavior of working if I try to set up the only the gradient computation by hand, while when I try to sample using the same adtype, we hit the issue! Seems to be something wrt. how the different variables are created.
We dropped support for Julia 1.7.