Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

HMM example model with missing data yields Malformed dims error

Open finf281 opened this issue 5 years ago • 1 comments

If we take the HMM Example from the tutorials section of this project's documentation, everything works as expected. However, if we set one of the values in the generated y data vector to missing, we get a Malformed dims error on the creation of a TArray.

using Turing, MCMCChains
using Distributions
using StatsPlots
using Random

# Define the emission parameter.
y = [ 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, missing, 3.0, 3.0, 2.0, 2.0, 2.0, 1.0, 1.0 ];
N = length(y);  K = 3;

# Turing model definition.
@model BayesHmm(y, K) = begin
    # Get observation length.
    N = length(y)

    # State sequence.
    s = tzeros(Int, N)

    # Emission matrix.
    m = Vector(undef, K)

    # Transition matrix.
    T = Vector{Vector}(undef, K)

    # Assign distributions to each element
    # of the transition matrix and the
    # emission matrix.
    for i = 1:K
        T[i] ~ Dirichlet(ones(K)/K)
        m[i] ~ Normal(i, 0.5)
    end

    # Observe each point of the input.
    s[1] ~ Categorical(K)
    y[1] ~ Normal(m[s[1]], 0.1)

    for i = 2:N
        s[i] ~ Categorical(vec(T[s[i-1]]))
        y[i] ~ Normal(m[s[i]], 0.1)
    end
end;

g = Gibbs(HMC(0.001, 7, :m, :T), PG(20, :s))
c = sample(BayesHmm(y, 3), g, 100);

finf281 avatar Jul 21 '20 21:07 finf281

This is unfortuanitely not easy to fix at the moment. (I think.)

Here is a slightly hacky solution to the problem:

@model function hmm(yobs, obsidx, K, N, ::Type{Ty} = Float64) where {Ty}

        ymis = Vector{Ty}(undef, N-length(obsidx))

        # State sequence.
        s = tzeros(Int, N)
    
        # Emission matrix.
        m ~ arraydist([Normal(i, 0.5) for i in 1:K])
    
        # Transition matrix.
        T ~ filldist(Dirichlet(K, 1/K), K)
    
        jo, jm = 1, 1

        # Observe each point of the input.
        s[1] ~ Categorical(K)
        dist = Normal(m[s[1]], 0.1)
        if 1 ∈ obsidx
            yobs[jo] ~ dist
            jo += 1
        else
            ymis[jm] ~ dist
            jm += 1
        end
    
        for i = 2:N
            s[i] ~ Categorical(T[:,s[i-1]])
            dist = Normal(m[s[i]], 0.1)
            if i ∈ obsidx
                yobs[jo] ~ dist
                jo += 1
            else
                ymis[jm] ~ dist
                jm += 1
            end
        end
end

y = [ 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, missing, 3.0, 3.0, 2.0, 2.0, 2.0, 1.0, 1.0 ];
N, K = length(y), 3;

obsidx = findall(!ismissing, y)
yobs = y[obsidx]

m = hmm(yobs, obsidx, K, N)
g = Gibbs(HMC(0.01, 5, :m, :T, :ymis), PG(10, :s))
c = sample(m, g, 100);

trappmartin avatar Oct 22 '20 10:10 trappmartin