mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Consider allowing hidden state initialisation via ssm_state input parameter for selective_scan_fn

Open govorunov opened this issue 11 months ago • 5 comments

Please, please, consider adding the ssm_state input parameter for selective_scan_fn to allow hidden state initialisation for the Mamba block. Also please consider making hidden state differentiable as currently at selective_scan_fn we have:

Note that the gradient of the last state is not considered in the backward pass.

This change should potentially open the path for encoder-decoder Mamba architecture and for the encoder-only BERT-like architecture. The architecture analogous to RNNs would be - Mamba encoder goes through the input sequence ignoring output, the last hidden state then used to initialize the decoder with <START> input token and the decoder unrolls the state recursively. For the encoder to work last hidden state has to be differentiable. This also should open a route to encoder-only BERT architecture, classification/embedding problems, etc. For the decoder to work the Mamba block needs to be able to accept a hidden state at initialisation.

Related issues: #233 , #101

PS: Excellent work! Very impressive (especially the CUDA part)!

govorunov avatar Mar 20 '24 01:03 govorunov