dynamax
dynamax copied to clipboard
Parallel hmm posterior sample
This PR implements a parallel version of HMM posterior sampling using associative scan (see https://github.com/probml/dynamax/issues/341). The scan elements $E_{ij}$ are vectors specifying a sample
z_j ~ p(z_j \mid z_i)
for each possible value of $z_i$. They can be thought of as functions $E : [1,...,n] \to [1,...,n]$ where the associative operator is function composition. This implementation passes the test written for serial sampling (which is commented out for some reason). It starts performing better than serial sampling when the sequence length exceeds a few thousand (I'm a little mystified as to why it takes so long for the crossover to happen).
from dynamax.hidden_markov_model.inference_test import random_hmm_args
from dynamax.hidden_markov_model import hmm_posterior_sample, parallel_hmm_posterior_sample
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import time
num_states = 2
num_iters = 5
timesteps = np.logspace(0,6,10).astype(int)
serial_times, parallel_times = [], []
for num_timesteps in timesteps:
print(num_timesteps)
serial_time, parallel_time = 0, 0
for itr in range(num_iters+1):
args = random_hmm_args(jr.PRNGKey(itr), num_timesteps, 5)
t = time.time()
hmm_posterior_sample(jr.PRNGKey(itr), *args)
print('s', time.time()-t)
if itr > 0: serial_time += time.time()-t
t = time.time()
parallel_hmm_posterior_sample(jr.PRNGKey(itr), *args)
print('p', time.time()-t)
if itr > 0: parallel_time += time.time()-t
serial_times.append(serial_time/num_iters)
parallel_times.append(parallel_time/num_iters)
plt.plot(timesteps, serial_times, label='serial')
plt.plot(timesteps, parallel_times, label='parallel')
plt.legend(loc='upper left')
plt.xscale('log')
plt.yscale('log')
plt.ylabel('Runtime (s)')
plt.xlabel('Sequence length')
plt.gcf().set_size_inches((3,2))
Thanks @calebweinreb! To clarify, I would say that the associative operator takes in two sets of samples, $$z_s \sim p(z_s \mid x_{1:s}, z_{s+1}) $$ and $$z_t \sim p(z_t \mid x_{1:t}, z_{t+1})$$ for all values of $z_{s+1} \in [K]$ and $z_{t+1} \in [K]$.
Then, assuming $t > s$, the associative operator returns a sample $$z_s \sim p(z_s \mid x_{1:t}, z_{t+1})$$ for all $z_{t+1} \in [K]$.
The final message is a sample $z_T \sim p(z_T \mid x_{1:T})$, replicated $K$ times so that it is the same shape as the preceding messages.
The output of associative scan thus yields samples of $z_{1:T} \sim p(z_{1:T} \mid x_{1:T})$. The output shape is (T,K)
, but all columns are identical since they all started with the same final state. Thus, it suffices to take the first column of the output matrix.
This looks really neat @calebweinreb!
One question about the timing results - is that on a cpu or gpu? I remember the behaviour being a bit different for different backends in the context of lgssm inference (for instance results from Adrien).
Hi Scott, thanks for clarifying! I think we landed on a good way of articulating the algorithm over slack. I'll repost here in case others are interested:
- Let's assume an HMM with $K$ hidden states and $T$ time-steps.
- The initial messages $E_{t,t+1}$ are samples from $p(z_t \mid x_{1:t}, z_{t+1})$ for all possible values of $z_{t+1}$
- The initial final message $E_{T}$ is a sample from the last filtering dist, $p(z_T \mid x_{1:T})$, repeated K times so that it's the same shape as the other messages.
- In the first iteration, the associative operator gives you samples from $p(z_t \mid x_{1:t+1}, z_{t+2})$ for all values of $z_{t+2}$ . It does so by sampling $z_{t+1} \sim p(z_{t+1} \mid x_{1:t+1}, z_{t+2})$ then sampling $z_t$ conditioned on $z_{t+1}$.
- This step is repeated recursively in the associative scan. At any intermediate point, the message $E_{i,j}$ stores samples $z_i \sim p(z_i \mid x_{t:j-1}, z_j)$ for each possible value of $z_j$.
- The final output is an array of shape (T,K) where the columns (which are all the same because they share the same final state) each contain the final sampled sequence $z_{1:T}$.
This looks really neat @calebweinreb!
One question about the timing results - is that on a cpu or gpu? I remember the behaviour being a bit different for different backends in the context of lgssm inference (for instance results from Adrien).
I ran the test on a GPU. I assume on a CPU, parallel would always do worse?