dynamax icon indicating copy to clipboard operation
dynamax copied to clipboard

add support for HMM inference with sparse transition matrices

Open murphyk opened this issue 3 years ago • 6 comments

See https://github.com/probml/JSL/blob/main/jsl/hmm/sparse_lib.py for an older implementation.

murphyk avatar Nov 14 '22 06:11 murphyk

+1. Lots of models have constrained transition matrices (e.g. left-to-right HMMs and change-point models).

slinderman avatar Nov 14 '22 07:11 slinderman

🙋 I can start looking at this !

emdupre avatar Nov 28 '22 21:11 emdupre

I've started just by using the Gaussian HMM example from the docs.

import jax.numpy as jnp
import jax.random as jr
from dynamax.hidden_markov_model import GaussianHMM

num_states = 6
num_emissions = 35

# Construct the HMM
hmm = GaussianHMM(num_states, num_emissions)

# Specify parameters of the HMM
initial_probs = jnp.ones(num_states) / num_states
transition_matrix = 0.8 * jnp.eye(num_states) + jnp.diag(
    jnp.tile(0.2, num_states - 1), k=1)

emission_means = jnp.column_stack([
    jnp.cos(jnp.linspace(0, 2 * jnp.pi, num_states + 1))[:-1],
    jnp.sin(jnp.linspace(0, 2 * jnp.pi, num_states + 1))[:-1],
    jnp.zeros((num_states, num_emissions - 2)),
    ])
emission_covs = jnp.tile(
    0.1**2 * jnp.eye(num_emissions),
    (num_states, 1, 1))

# Initialize the parameters struct with known values    
params, _ = hmm.initialize(
    initial_probs=initial_probs,
    transition_matrix=transition_matrix,
    emission_means=emission_means,
    emission_covariances=emission_covs)
true_states, emissions = hmm.sample(params, jr.PRNGKey(42), 100)

posterior = hmm.smoother(params, emissions)

This MWE runs without any errors, suggesting that sparse transition matrices aren't the issue per se (since this one is):

DeviceArray([[0.8, 0.2, 0. , 0. , 0. , 0. ],
             [0. , 0.8, 0.2, 0. , 0. , 0. ],
             [0. , 0. , 0.8, 0.2, 0. , 0. ],
             [0. , 0. , 0. , 0.8, 0.2, 0. ],
             [0. , 0. , 0. , 0. , 0.8, 0.2],
             [0. , 0. , 0. , 0. , 0. , 0.8]], dtype=float32)

I can create an error by modifying the transition matrix to add a new, "dummy absorbing" state like so:

# Set up transition matrix with final dummy-absorbing state
transition_matrix = 0.8 * jnp.eye(num_states + 1) + jnp.diag(
    jnp.tile(0.2, num_states), k=1)
transition_matrix = transition_matrix.at[-1, -1].set(1)
transition_matrix
DeviceArray([[0.8, 0.2, 0. , 0. , 0. , 0. , 0. ],
             [0. , 0.8, 0.2, 0. , 0. , 0. , 0. ],
             [0. , 0. , 0.8, 0.2, 0. , 0. , 0. ],
             [0. , 0. , 0. , 0.8, 0.2, 0. , 0. ],
             [0. , 0. , 0. , 0. , 0.8, 0.2, 0. ],
             [0. , 0. , 0. , 0. , 0. , 0.8, 0.2],
             [0. , 0. , 0. , 0. , 0. , 0. , 1. ]], dtype=float32)

But that might be a bit too far afield from this original issue to discuss here !

Please let me know if I misunderstood the original posting, too 🙏

emdupre avatar Nov 29 '22 00:11 emdupre

What I had in mind is to exploit sparsity to speedup the K^2 computation at each step of forwards-backwards. Currently we just use alpha(t) = A*alpha(t-1), and ignore structure in A. My idea was to use https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html to exploit the sparsity in A. (The linked JSL code says it works with jax.experiemental.sparse, but there are no demos or units tests, so I am not sure that is true... Besides, JSL is deprecated, we only want to support dynamax.)

An alternative approach is to implement the algorithms in the paper below, which only work for certain banded transition matrices, which can arise from discretizing an underlying continuous system.

@inproceedings{Felzenszwalb03, title = {{Fast Algorithms for Large State Space HMMs with Applications to Web Usage Analysis}}, booktitle = nips, year = 2003, author = "P. Felzenszwalb and D. Huttenlocher and J. Kleinberg" }

murphyk avatar Nov 29 '22 03:11 murphyk

Ah, there's a slight confusion then. One issue is a bug that the current message passing code is returning NaNs in some cases, as @emdupre showed above. Another is a feature request to support the experimental sparse library for faster matrix-vector multiplies.

@emdupre, why don't you create a new issue for the bug in the current message passing code and reference this one.

slinderman avatar Nov 29 '22 04:11 slinderman

Thank you both, and sorry for the noise ! I've opened #290 for that discussion 🙇

emdupre avatar Nov 30 '22 00:11 emdupre