Turing.jl
Turing.jl copied to clipboard
How to save a model / fit and load it? Issue with JLD2 for "reconstructing"
(this issue is somewhat related to #2308)
I'm trying to save models on the disk and, in a new session, loading and using them.
Here's an MWE, starting with making and saving a model:
using Turing
using JLD2
@model function mymodel(y)
μ ~ Normal(0, 2)
σ ~ truncated(Normal(0, 3), 0.0, Inf)
for i in 1:length(y)
y[i] ~ Normal(μ, σ)
end
end
fit = mymodel([1, 2, 3, 4, 5])
jldsave("model.jld2"; model=mymodel, fit=fit)
Now, in a new session, if I do the following it errors:
using Turing
using JLD2
loaded = jldopen("model.jld2", "r+")
loaded["model"]
┌ Warning: type Main.#mymodel does not exist in workspace; reconstructing
└ @ JLD2 C:\Users\domma\.julia\packages\JLD2\twZ5D\src\data\reconstructing_datatypes.jl:492
loaded["model"]([1, 2, 3, 4, 5])
ERROR: MethodError: objects of type JLD2.ReconstructedSingleton{Symbol("#mymodel")} are not callable
Stacktrace:
[1] top-level scope
@ c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\1_models_make.jl:154
How to correctly save/load Turing models?
@devmotion IIRC, we can't serialise Turing models due to a DynamicPPL limitation. Is that still the case, and if so, is that fixable?
Models can be serialized, we even have a test for it: https://github.com/TuringLang/DynamicPPL.jl/blob/138bd40acdfc47d7b00e25a2adaf9fec986f9646/test/serialization.jl The serialization issues in DynamicPPL should have been fixed by https://github.com/TuringLang/DynamicPPL.jl/pull/134. I haven't checked the MWE above but I wonder if it's rather a JLD2 than a Turing/DynamicPPL issue.
Thanks for looking into this. Is there another more robust alternative to saving & loading models other than JLD2? I picked JLD2 initially for saving the chains (note that it works for that) following this thread
kind bump in case someone has any good suggestions on how to save & load models
Serialization should work (and is tested) - can you try if Serialization.serialize and Serialization.deserialize works for you?
Unfortunately it doesn't seem like it works:
using Turing
using Serialization
@model function mymodel(y)
μ ~ Normal(0, 2)
σ ~ truncated(Normal(0, 3), 0.0, Inf)
for i in 1:length(y)
y[i] ~ Normal(μ, σ)
end
end
fit = mymodel([1, 2, 3, 4, 5])
Serialization.serialize("model.turing", Dict("model" => mymodel, "fit" => fit))
Restart session:
using Turing
using Serialization
loaded = Serialization.deserialize("model.turing")
loaded["model"]([1, 2, 3, 4, 5])
ERROR: UndefVarError: `#mymodel` not defined
Stacktrace:
[1] deserialize_datatype(s::Serializer{IOStream}, full::Bool)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1399
[2] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:867
[3] deserialize(s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
[4] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:874
[5] deserialize
@ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
[6] deserialize_dict(s::Serializer{IOStream}, T::Type{Dict{String, Any}})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1529
[7] deserialize(s::Serializer{IOStream}, T::Type{Dict{String, Any}})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1536
[8] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:883
[9] deserialize(s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
[10] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:920
[11] deserialize
@ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
[12] deserialize(s::IOStream)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:801
[13] open(f::typeof(deserialize), args::String; kwargs::@Kwargs{})
@ Base .\io.jl:396
[14] open
@ .\io.jl:393 [inlined]
[15] deserialize(filename::String)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:811
[16] top-level scope
@ c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\activate.jl:29
mymodel is a regular Julia function, so it suffers from the same limitations regarding (de)serializations as any other Julia function, whereas fit is an object of type DynamicPPL.Model and behaves differently.
You can (de)serialize mymodel e.g. in the following way:
using Turing
using Serialization
@model function mymodel(y)
μ ~ Normal(0, 2)
σ ~ truncated(Normal(0, 3), 0.0, Inf)
for i in 1:length(y)
y[i] ~ Normal(μ, σ)
end
end
Serialization.serialize("model.turing", methods(mymodel))
using Turing
using Serialization
function mymodel end # this is required
Serialization.deserialize("model.turing")
mymodel([1, 2, 3, 4])
Unfortunately, that didn't do the trick either and a new error crept in that I couldn't make sense when googling what AccessorsImpl was:
julia> loaded = Serialization.deserialize("model.turing")
ERROR: UndefVarError: `AccessorsImpl` not defined
Stacktrace:
[1] deserialize_module(s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:997
[2] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:896
[3] deserialize(s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
[4] deserialize_datatype(s::Serializer{IOStream}, full::Bool)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1398
[5] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:867
[6] deserialize(s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
[7] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:874
[8] deserialize(s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
[9] deserialize_expr(s::Serializer{IOStream}, len::Int64)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1291
[10] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:894
[11] deserialize_fillarray!(A::Vector{Any}, s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1281
[12] deserialize_array(s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1273
[13] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:865
[14] deserialize
@ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
[15] deserialize(s::Serializer{IOStream}, ::Type{Core.CodeInfo})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1133
[16] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:960
[17] deserialize
@ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
[18] deserialize(s::Serializer{IOStream}, ::Type{Method})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1044
[19] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:960
[20] deserialize_fillarray!(A::Vector{Method}, s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1281
[21] deserialize_array(s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1273
[22] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:865
[23] deserialize(s::Serializer{IOStream}, t::DataType)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1501
[24] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:883
[25] deserialize
@ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
[26] deserialize_dict(s::Serializer{IOStream}, T::Type{Dict{String, Any}})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1529
[27] deserialize(s::Serializer{IOStream}, T::Type{Dict{String, Any}})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1536
[28] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:883
[29] deserialize(s::Serializer{IOStream})
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
[30] handle_deserialize(s::Serializer{IOStream}, b::Int32)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:920
[31] deserialize
@ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
[32] deserialize(s::IOStream)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:801
[33] open(f::typeof(deserialize), args::String; kwargs::@Kwargs{})
@ Base .\io.jl:396
[34] open
@ .\io.jl:393 [inlined]
[35] deserialize(filename::String)
@ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:811
[36] top-level scope
@ c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\activate.jl:31
I really appreciate your help though.
The main reason I'm saving the model itself is to be able to refit it later on new data (which surely must be a common use case! in R it is common to save and share and download and re-use big fitted models)
As doing it this way - assuming it is even possible - is clunky and unwieldy, the alternatives I see are:
- Having an update() method (discussed in #2308)
- Is it possible to extract the model object/method from the fitted object? In other words, as far as I understand, a Turing model is often defined as a function (which is hard to serialize), which gets turned into a dynamicPPL object through the
@modelmacro. Can we recover/reconstruct that object from the fitted version?
Thanks again @devmotion
The example above works fine for me, I don't get this error. Did you try a more complicated example?
My setup:
(jl_L5QqA0) pkg> st
Status `/private/var/folders/n6/98_7bm0j0hb57zv3l3tj8sxh0000gn/T/jl_L5QqA0/Project.toml`
⌃ [fce5fe82] Turing v0.33.3
Info Packages marked with ⌃ have new versions available and may be upgradable.
julia> versioninfo()
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (arm64-apple-darwin22.4.0)
CPU: 10 × Apple M2 Pro
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 6 virtual cores)
Environment:
JULIA_PKG_USE_CLI_GIT = true
JULIA_PKG_PRESERVE_TIERED_INSTALLED = true
It does work, the culprit was that I was reloading it using a different Turing version 🤦
Great, cheers, I'll close this!
(but I would still suggest that making that process more convenient would be a nice feature ☺️)
One last shot, I know I'm asking for a lot, but the above solution is not very convenient for programmatic usage. In my case, I define and save a lot of models, and then I load them and use them via a loop. The code is made to work on an arbitrary number of models with arbitrary names:
hence I would like to be able to call the model directly from the dict loaded["model"]([1, 2, 3, 4, 5]) without having to use the original function name, i.e., rather than from the re-defined function mymodel([1, 2, 3, 4]) (because that would require me to re-write bespoke code for all the models)
@penelopeysm can you help add the trick to docs / FAQ?
AccessorsImpl is defined in BangBang.jl. DynamicPPL model macro use it to signal reference of mutation. Ref: https://github.com/TuringLang/DynamicPPL.jl/blob/c9410de91cfdeffeb939022fabdf042c72c71690/src/compiler.jl#L464 and https://github.com/JuliaFolds2/BangBang.jl/blob/7f61170ec6e4b883f5ece892225d61b9e7b04f8e/src/accessors.jl#L1.
the above solution is not very convenient for programmatic usage
The problem boils down to being able to fit a model on data without having to (re)define the original function name, to allow for workflows such as:
(pseudocode)
# 1. Define models
@model m1(y, x)
...
end
fit1 = m1(data)
@model m2(y, x)
...
end
fit2 = m2(data)
# 2. Save models
save(fit1, "fit1")
save(fit2, "fit2")
In a new script
for m in ["fit1", "fit2"]
fit = load(m)
fit(newdata)
predict(fit, ...)
end
Do you think it might be solved by implementing an update() method (#2308)?
Context: an example of use case where this flexibility is IMO a critical feature is when models are fit / sampled from on external machines (high-performance clusters): the output is ideally saved and then downloaded by researchers who can then manipulate these models independently (for reporting, postprocessing, predictions, analysis, etc. etc.)
Out of curiosity, does the update of predict() in https://github.com/TuringLang/DynamicPPL.jl/pull/651 might help with the issue above? (i.e., being able to run predict() on a model stored inside for e.g., a dictionary or from loaded from a file - so that it doesn't require the function definition of the model to be available)
I don't think it would help, but can definitely try later
closing for now, will try to write docs on it soon :)
eing able to run predict() on a model stored inside for e.g., a dictionary or from loaded from a file - so that it doesn't require the function definition of the model to be available
This will not help 😕