mamba icon indicating copy to clipboard operation
mamba copied to clipboard

using ssm_state and conv_state during training

Open robflynnyh opened this issue 7 months ago • 10 comments

Currently, from the way the library is written conv_state and ssm_state are only used when generating one step at a time using the InferenceParams. It would be useful to use these during training with the selective scan. An example use case would be when you want to train on documents larger than available memory, while this would cause a stop gradient between chunks of the document it would be useful for avoiding context fragmentation at inference time.

Are there any plans to enable this? If not, how difficult would it be to add this to the scan kernel? For the causal conv I guess you can just keep a cache of the last 3 activations.

robflynnyh avatar Jan 11 '24 15:01 robflynnyh

Agree it's a useful feature, it's on the roadmap!

albertfgu avatar Jan 11 '24 21:01 albertfgu

@albertfgu Would states be differentiable? I would rather like to not stop the gradient during training unlike @robflynnyh

sentialx avatar Jan 11 '24 23:01 sentialx

Not sure, it's a lot of extra work to make it differentiable through the final state. And that's also not necessary for the main use case that we wanted to support (continuing training on the next state). What use case are you thinking of?

albertfgu avatar Jan 11 '24 23:01 albertfgu

I'm thinking of backpropagation through time on multiple chunks (e.g. 4) instead of fitting full sequence in one huge window

sentialx avatar Jan 12 '24 00:01 sentialx

I'm thinking of backpropagation through time on multiple chunks (e.g. 4) instead of fitting full sequence in one huge window

Would that take the same amount of activation memory as computing the full sequence?

tridao avatar Jan 12 '24 00:01 tridao

Hmm yeah, my bad

sentialx avatar Jan 12 '24 00:01 sentialx

Are there any convenient ways to set up the initial state for mamba? I wanna use TBPTT to train mamba on longer ctx size, so there is no need to make initial/final states of each chunk differentiable.

sustcsonglin avatar Jan 26 '24 20:01 sustcsonglin

Unfortunately not right now. It will hopefully be ready within a few weeks.

albertfgu avatar Jan 26 '24 21:01 albertfgu

Unfortunately not right now. It will hopefully be ready within a few weeks.

Just checking to see if this is still in progress. I was going to start working on this myself, but I would hate to spend a couple weeks figuring it out just for an official implementation to come out at the same time.

ejmejm avatar Mar 03 '24 05:03 ejmejm

Not sure, it's a lot of extra work to make it differentiable through the final state. And that's also not necessary for the main use case that we wanted to support (continuing training on the next state). What use case are you thinking of?

@albertfgu I think differentiable initial states is a critical feature for investigating state initialization. For example, we would be able to perform prefix tuning for mamba (similar to tuning KV-cache in Transformers).

woominsong avatar Mar 19 '24 12:03 woominsong