dynamax icon indicating copy to clipboard operation
dynamax copied to clipboard

dynmax glm-hmm using jagged input arrays with differing E and M steps

Open jess-breda opened this issue 9 months ago • 1 comments

Summary

I'm interested in learning if/how it would be possible to fit an glm-hmm (i.e. LogisticRegressionHMM, CategoricalRegressionHMM) with a jagged input list (i.e. a list whose elements are lists of different lengths), such that the E step could be run individually for each of the inner constituent lists, whereas the M step would be run to the entire input.

Context

To motivate this question, supposejagged_list is a list where each element is a session_list (i.e. jagged_list = [session_list_1, session_list_2, ... session_list_s]), and each session_list contains trial samples (i.e. session_list_k = [trial_1, trial_2, ... trial_t]). Because number of trials varies per session, this is a jagged array. These data represent trials from a single subject across multiple sessions. Trials from previous or future sessions should not be used to learning state probabilities and transitions (E-step), but all trials should be used together for learning weights (M-step).

Note this was previously supported in the SSM library. When the data was structured in this way in SSM, it allowed for the E-step to be run for each session, followed by the M-step across all sessions.

Issues

I think there are two roadblocks that prevent this from being possible.

  1. Dynamax does not appear to support jagged arrays due to jax implementation
  • Dynamax has a procedure for inputs that are batched, however it requires each batch (e.g. session) to have the same number of time steps (e.g. trials).
  1. The current fit_em method runs the E step and M step for each batch (as opposed to E-step for each batch, M step across all batches)
  • However, There are separate methods for E and M steps- I'm just unsure how to properly summarize batch/session iterated E-step outputs (i.e. SuffStats) to pass into a single M-step call across all batches/sessions- any advice here would be greatly. appreciated

Questions

  1. Is this summary of issues and comparison accurate?
  2. Is there a way to implement the desired behavior of session-level E-steps and subject-level M-step? (e.g., via padding or using e_step and m_step methods)
  3. Any additional thoughts or suggestions?

Thank you!

jess-breda avatar May 14 '24 21:05 jess-breda

+1! I'd also be very interested in having this implemented in Dynamax

atlaie avatar Jun 04 '24 09:06 atlaie

+1 Getting this implemented would also help me a lot

conormcgrory avatar Nov 21 '24 01:11 conormcgrory