mamba
mamba copied to clipboard
How to get all hidden_states of selective_scan_cuda?
hi, i wonder how to get all hidden_states of selective_scan_cuda, it seems only the last hidden_state can be used, out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
You can't get all hidden states from the fast code. You can materialize the states explicitly if you want them, which can be done directly in Python very easily.
I'm wondering if this final_state refers to the state depicted by the red circle in the figure