dynamax
dynamax copied to clipboard
dynmax glm-hmm using jagged input arrays with differing E and M steps
Summary
I'm interested in learning if/how it would be possible to fit an glm-hmm
(i.e. LogisticRegressionHMM
, CategoricalRegressionHMM
) with a jagged input list (i.e. a list whose elements are lists of different lengths), such that the E step could be run individually for each of the inner constituent lists, whereas the M step would be run to the entire input.
Context
To motivate this question, supposejagged_list
is a list where each element is a session_list
(i.e. jagged_list = [session_list_1, session_list_2, ... session_list_s])
, and each session_list
contains trial samples (i.e. session_list_k = [trial_1, trial_2, ... trial_t]
). Because number of trials varies per session, this is a jagged array. These data represent trials from a single subject across multiple sessions. Trials from previous or future sessions should not be used to learning state probabilities and transitions (E-step), but all trials should be used together for learning weights (M-step).
Note this was previously supported in the SSM library. When the data was structured in this way in SSM, it allowed for the E-step to be run for each session, followed by the M-step across all sessions.
Issues
I think there are two roadblocks that prevent this from being possible.
- Dynamax does not appear to support jagged arrays due to jax implementation
- Dynamax has a procedure for inputs that are batched, however it requires each batch (e.g. session) to have the same number of time steps (e.g. trials).
- The current fit_em method runs the E step and M step for each batch (as opposed to E-step for each batch, M step across all batches)
- However, There are separate methods for E and M steps- I'm just unsure how to properly summarize batch/session iterated E-step outputs (i.e. SuffStats) to pass into a single M-step call across all batches/sessions- any advice here would be greatly. appreciated
Questions
- Is this summary of issues and comparison accurate?
- Is there a way to implement the desired behavior of session-level E-steps and subject-level M-step? (e.g., via padding or using
e_step
andm_step
methods) - Any additional thoughts or suggestions?
Thank you!
+1! I'd also be very interested in having this implemented in Dynamax
+1 Getting this implemented would also help me a lot