mamba
mamba copied to clipboard
using ssm_state and conv_state during training
Currently, from the way the library is written conv_state and ssm_state are only used when generating one step at a time using the InferenceParams. It would be useful to use these during training with the selective scan. An example use case would be when you want to train on documents larger than available memory, while this would cause a stop gradient between chunks of the document it would be useful for avoiding context fragmentation at inference time.
Are there any plans to enable this? If not, how difficult would it be to add this to the scan kernel? For the causal conv I guess you can just keep a cache of the last 3 activations.
Agree it's a useful feature, it's on the roadmap!
@albertfgu Would states be differentiable? I would rather like to not stop the gradient during training unlike @robflynnyh
Not sure, it's a lot of extra work to make it differentiable through the final state. And that's also not necessary for the main use case that we wanted to support (continuing training on the next state). What use case are you thinking of?
I'm thinking of backpropagation through time on multiple chunks (e.g. 4) instead of fitting full sequence in one huge window
I'm thinking of backpropagation through time on multiple chunks (e.g. 4) instead of fitting full sequence in one huge window
Would that take the same amount of activation memory as computing the full sequence?
Hmm yeah, my bad
Are there any convenient ways to set up the initial state for mamba? I wanna use TBPTT to train mamba on longer ctx size, so there is no need to make initial/final states of each chunk differentiable.
Unfortunately not right now. It will hopefully be ready within a few weeks.
Unfortunately not right now. It will hopefully be ready within a few weeks.
Just checking to see if this is still in progress. I was going to start working on this myself, but I would hate to spend a couple weeks figuring it out just for an official implementation to come out at the same time.
Not sure, it's a lot of extra work to make it differentiable through the final state. And that's also not necessary for the main use case that we wanted to support (continuing training on the next state). What use case are you thinking of?
@albertfgu I think differentiable initial states is a critical feature for investigating state initialization. For example, we would be able to perform prefix tuning for mamba (similar to tuning KV-cache in Transformers).