dynamax
dynamax copied to clipboard
pass inputs into the LDS model
Hello,
I have a very basic question: how to pass N X T X D inputs ("X") into the LDS model (N trials, T time steps and D dimensional inputs)?
In the linear_gaussian_ssm model.py file, the inputs is Optional[Float[Array, "ntime input_dim"]], so there's no dimension for trials (N)?
I tried to do things as in the Kalman filter/ smoother example. But the problem is that I also need to include d latent trajectoreis into the model (i.e. the state dimension should be D + d, if I encode the covariates into the emission matrix).
Not sure how to do it correctly...
Hi @weigcdsb, I'm not sure I totally understand your use case, would you be able to explain it in some more detail and we'll see if I can help 😄.
In general, it should be possible to use jax.vmap
to map filtering/smoothing over additional dimensions (as described here), however this might be be suitable for all scenarios.
@gileshd, thanks for replying & sorry for confusions.
Just use the notations in the comment of your models.py file:
$$p(y_t \mid z_t) = \mathcal{N}(y_t \mid H_t z_t + D_t u_t + d_t, R_t)$$
, where $p(z_t \mid z_{t-1}, u_t) = \mathcal{N}(z_t \mid F_t z_{t-1} + B_t u_t + b_t, Q_t)$ and $p(z_1) = \mathcal{N}(z_1 \mid m, S)$, for $t=1,\ldots,T$. Here, $u_t$ is an input of size input_dim
(assume input_dim=D
, defaults to 0). If there are $N$ observations, then emission_dim = N
. So the total inputs (stack all $u_t$ together) should have dimension $N\times D\times T$.
My question is how can I pass the input $u_t$ into the LDS model? In the linear_gaussian_ssm models.py file, the comment says inputs: Optional[Float[Array, "ntime input_dim"]]=None
, which means the dimension should be $T \times D$. So there's no option for multiple emissions, say $N>1$ (as we cannot pass 3D-array to the model)?
Hope this clarifies my question.
Correct. The input vector u_t at each time step must be a D-dimensional vector. So inputs
has shape (T,D)
(or None
). You can always flatten your 3d inputs outside of dynamax.