Turing.jl
Turing.jl copied to clipboard
HMM example model with missing data yields Malformed dims error
If we take the HMM Example from the tutorials section of this project's documentation, everything works as expected. However, if we set one of the values in the generated y data vector to missing, we get a Malformed dims error on the creation of a TArray.
using Turing, MCMCChains
using Distributions
using StatsPlots
using Random
# Define the emission parameter.
y = [ 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, missing, 3.0, 3.0, 2.0, 2.0, 2.0, 1.0, 1.0 ];
N = length(y); K = 3;
# Turing model definition.
@model BayesHmm(y, K) = begin
# Get observation length.
N = length(y)
# State sequence.
s = tzeros(Int, N)
# Emission matrix.
m = Vector(undef, K)
# Transition matrix.
T = Vector{Vector}(undef, K)
# Assign distributions to each element
# of the transition matrix and the
# emission matrix.
for i = 1:K
T[i] ~ Dirichlet(ones(K)/K)
m[i] ~ Normal(i, 0.5)
end
# Observe each point of the input.
s[1] ~ Categorical(K)
y[1] ~ Normal(m[s[1]], 0.1)
for i = 2:N
s[i] ~ Categorical(vec(T[s[i-1]]))
y[i] ~ Normal(m[s[i]], 0.1)
end
end;
g = Gibbs(HMC(0.001, 7, :m, :T), PG(20, :s))
c = sample(BayesHmm(y, 3), g, 100);
This is unfortuanitely not easy to fix at the moment. (I think.)
Here is a slightly hacky solution to the problem:
@model function hmm(yobs, obsidx, K, N, ::Type{Ty} = Float64) where {Ty}
ymis = Vector{Ty}(undef, N-length(obsidx))
# State sequence.
s = tzeros(Int, N)
# Emission matrix.
m ~ arraydist([Normal(i, 0.5) for i in 1:K])
# Transition matrix.
T ~ filldist(Dirichlet(K, 1/K), K)
jo, jm = 1, 1
# Observe each point of the input.
s[1] ~ Categorical(K)
dist = Normal(m[s[1]], 0.1)
if 1 ∈ obsidx
yobs[jo] ~ dist
jo += 1
else
ymis[jm] ~ dist
jm += 1
end
for i = 2:N
s[i] ~ Categorical(T[:,s[i-1]])
dist = Normal(m[s[i]], 0.1)
if i ∈ obsidx
yobs[jo] ~ dist
jo += 1
else
ymis[jm] ~ dist
jm += 1
end
end
end
y = [ 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, missing, 3.0, 3.0, 2.0, 2.0, 2.0, 1.0, 1.0 ];
N, K = length(y), 3;
obsidx = findall(!ismissing, y)
yobs = y[obsidx]
m = hmm(yobs, obsidx, K, N)
g = Gibbs(HMC(0.01, 5, :m, :T, :ymis), PG(10, :s))
c = sample(m, g, 100);