mamba
mamba copied to clipboard
Help with _chunk_state_fwd.
trafficstars
Hello @tridao, first of all, congratulations on the great job you did with Mamba 2. Could you please explain the purpose and operation of the function _chunk_state_fwd function and its kernel?
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
You can see the reference implementation: https://github.com/state-spaces/mamba/blob/8ffd905c91d207f5c0cc84fc2a2fb748655094f0/mamba_ssm/ops/triton/ssd_chunk_state.py#L960