mamba
mamba copied to clipboard
Shape of state transition matrix A ?
How can the code be modified to support a shape of A as (M, N, N) instead of (M, N) to enhance the representation of a more versatile state transition matrix?
That's not supported in the CUDA code, but you can play around with selective_scan_ref which is in Pytorch (but much slower). Instead of multiplying A with previous hidden states pointwise (i.e. A being diagonal), you can try having A being square and use matrix multiply instead.