mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Shape of state transition matrix A ?

Open lxxXuan opened this issue 11 months ago • 1 comments

How can the code be modified to support a shape of A as (M, N, N) instead of (M, N) to enhance the representation of a more versatile state transition matrix?

lxxXuan avatar Feb 26 '24 07:02 lxxXuan

That's not supported in the CUDA code, but you can play around with selective_scan_ref which is in Pytorch (but much slower). Instead of multiplying A with previous hidden states pointwise (i.e. A being diagonal), you can try having A being square and use matrix multiply instead.

tridao avatar Feb 26 '24 07:02 tridao