Support for HMMs with num_states=1
I am trying to fit various HMM classes (LinearRegressionHMM, or GaussianHMM) to my data but it does not let me pass num_states=1. For num_states > 2, everything works as expected. I wanted to know whether no support for num_states=1 is the intended behavior.
It's easy enough to write code for simple linear regression outside dynamax, however, it still makes the comparison with num_states>2 cases error-prone (as one might be using different constants in log-likelihood calculations, etc.).
If it helps, the error occurs while trying to initialize the Dirichlet distribution.
File "/Users/us/project/fitting.py", line 20, in fitEM
params, props = hmm.initialize(key)
File "/Users/us/dynaenv/lib/python3.10/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py", line 649, in initialize
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
File "/Users/us/dynaenv/lib/python3.10/site-packages/dynamax/hidden_markov_model/models/initial.py", line 45, in initialize
initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key)
...
ValueError: Argument `concentration` must have `event_size` at least 2.
Hi @umeshksingla! This is an interesting scenario.
We do make the assumption that there are at least 2 hidden states (I suppose a model with only one hidden state is not really a HMM).
In practice, some behaviours seem to work fine with one hidden state. However, as you have found, we run into an error whenever we try to interact with tfd.Dirichlet distributions as they require at least two states. In order to resolve this we would need to manually check for scenarios with only one state and treat them as a special case.
I am not totally sure that we want to add complexity to handle this, somewhat niche and potentially out of scope, use case however perhaps it is a good idea and if not, at the very least, we should indicate that num_states should be >=2 and improve the error messages here.
For your present purposes you can avoid the call to tfd.Dirichlet during initialization by manually specifying the initial probabilities and transition matrix. This is straightforward as the relevant parameters are constrained to only take a specific value when there is only one state. For instance,
from jax import numpy as jnp
from jax import random as jr
from dynamax.hidden_markov_model import GaussianHMM
hmm = GaussianHMM(num_states=1, emission_dim=1)
initial_probs = jnp.array([1.0])
transition_matrix = jnp.array([[1.0]])
params, props = hmm.initialize(initial_probs=initial_probs, transition_matrix=transition_matrix)
z, x = hmm.sample(params, key = jr.PRNGKey(0), num_timesteps=10)
# z is Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
You shoud be able to use this model for sampling as normal.
However parameter learning (e.g. .fit_em(), .fit_sgd()) will not work as the Dirichlet distribution is called during the fitting process:
hmm = GaussianHMM(num_states=1, emission_dim=1)
initial_probs = jnp.array([1.0])
transition_matrix = jnp.array([[1.0]])
params, props = hmm.initialize(initial_probs=initial_probs, transition_matrix=transition_matrix)
z, x = hmm.sample(params, key=jr.PRNGKey(0), num_timesteps=100)
params_inf, props = hmm.initialize(key=jr.PRNGKey(10) , initial_probs=initial_probs, transition_matrix=transition_matrix)
props.initial.probs.trainable = False
props.transitions.transition_matrix.trainable = False
try:
hmm.fit_em(params_inf, props, emissions=x)
except ValueError as e:
print(f"Error: {e}")
One work-around for this is to make a model with num_states=2 but specify the initial state distribution and transition matrix so that the model will behave as if it has only one state.
Here is an example:
from jax import numpy as jnp
from jax import random as jr
from dynamax.hidden_markov_model import GaussianHMM
hmm = GaussianHMM(num_states=2, emission_dim=1)
initial_probs = jnp.array([1.0, 0.])
transition_matrix = jnp.array([[1.0, 0.], [1.0, 0.0]])
params, props = hmm.initialize(key=jr.PRNGKey(0), initial_probs=initial_probs, transition_matrix=transition_matrix)
params_inf, props = hmm.initialize(key=jr.PRNGKey(100), initial_probs=initial_probs, transition_matrix=transition_matrix)
props.initial.probs.trainable = False
props.transitions.transition_matrix.trainable = False
hmm.fit_em(params_inf, props, emissions=x)
This might get okay parameter results however the logprob calculations aren't fond of this setup and you may get jnp.inf or jnp.nan.