ArviZ.jl
ArviZ.jl copied to clipboard
Error creating InferenceData from Turing Chains with extra info
trafficstars
Currently it seems that if some objects are in the Turing info, we can't map these to the InferenceData info.
using ArviZ, Turing
julia> @model function foo()
x ~ Normal()
end
foo (generic function with 1 method)
julia> chn = sample(foo(),NUTS(),200); # this is fine
julia> chn.info
NamedTuple()
julia> from_mcmcchains(chn)
InferenceData with groups:
> posterior
> sample_stats
julia> chn = sample(foo(),NUTS(),200,;save_state=true) # this will error
julia> chn.info
(model = DynamicPPL.Model{var"#3#4", (), (), (), Tuple{}, Tuple{}}(:foo, var"#3#4"(), NamedTuple(), NamedTuple()), sampler = DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}(NUTS{Turing.Core.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}(-1, 0.65, 10, 1000.0, 0.0), DynamicPPL.Selector(0x00016a8da36513f2, :default, false)), samplerstate = Turing.Inference.HMCState{DynamicPPL.TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, AdvancedHMC.NUTS{AdvancedHMC.MultinomialTS, AdvancedHMC.GeneralisedNoUTurn, AdvancedHMC.Leapfrog{Float64}, Float64}, AdvancedHMC.Hamiltonian{AdvancedHMC.DiagEuclideanMetric{Float64, Vector{Float64}}, Turing.Inference.var"#logπ#54"{DynamicPPL.TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.Model{var"#3#4", (), (), (), Tuple{}, Tuple{}}}, Turing.Inference.var"#∂logπ∂θ#53"{DynamicPPL.TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.Model{var"#3#4", (), (), (), Tuple{}, Tuple{}}}}, AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}, AdvancedHMC.Adaptation.StanHMCAdaptor{AdvancedHMC.Adaptation.WelfordVar{Float64, Vector{Float64}}, AdvancedHMC.Adaptation.NesterovDualAveraging{Float64}}}(DynamicPPL.TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}((x = DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}(Dict(x => 1), [x], UnitRange{Int64}[1:1], [0.007224315165188178], Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0)], Set{DynamicPPL.Selector}[Set([DynamicPPL.Selector(0x00016a8da36513f2, :default, false)])], [0], Dict{String, BitVector}("del" => [0], "trans" => [1])),), Base.RefValue{Float64}(-0.9189646285694758), Base.RefValue{Int64}(0)), 299, NUTS{MultinomialTS,Generalised}(integrator=Leapfrog(ϵ=1.43), max_depth=10), Δ_max=1000.0), Hamiltonian(metric=DiagEuclideanMetric([1.0])), AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}([0.007224315165188178], [-0.5308342150394731], AdvancedHMC.DualValue{Float64, Vector{Float64}}(-0.9189646285694758, [0.007224315165188178]), AdvancedHMC.DualValue{Float64, Vector{Float64}}(-0.14089248192828682, [-0.5308342150394731])), StanHMCAdaptor(
pc=WelfordVar,
ssa=NesterovDualAveraging(γ=0.05, t_0=10.0, κ=0.75, δ=0.65, state.ϵ=1.425166901462951),
init_buffer=75, term_buffer=50, window_size=25,
state=window(76, 50), window_splits()
)))
julia> from_mcmcchains(chn)
ERROR: PyError ($(Expr(:escape, :(ccall(#= /Users/sethaxen/.julia/packages/PyCall/L0fLP/src/pyfncall.jl:43 =# @pysym(:PyObject_Call), PyPtr, (PyPtr, PyPtr, PyPtr), o, pyargsptr, kw))))) <class 'TypeError'>
TypeError("cannot pickle 'PyCall.jlwrap' object")
File "/Users/sethaxen/.julia/conda/3/lib/python3.8/site-packages/arviz/data/inference_data.py", line 1837, in concat
args_groups[group] = deepcopy(group_data) if copy else group_data
File "/Users/sethaxen/.julia/conda/3/lib/python3.8/copy.py", line 153, in deepcopy
y = copier(memo)
File "/Users/sethaxen/.julia/conda/3/lib/python3.8/site-packages/xarray/core/dataset.py", line 1425, in __deepcopy__
return self.copy(deep=True)
File "/Users/sethaxen/.julia/conda/3/lib/python3.8/site-packages/xarray/core/dataset.py", line 1322, in copy
attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs)
File "/Users/sethaxen/.julia/conda/3/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/Users/sethaxen/.julia/conda/3/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/Users/sethaxen/.julia/conda/3/lib/python3.8/copy.py", line 161, in deepcopy
rv = reductor(4)
We should probably filter the info on our end before InferenceData creation so that these errors can't happen.
On recent versions this no longer errors. Instead, it seems the info is just excluded, which isn't what I think we want to do.