mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Questions regarding `a` and `x` in backward CUDA kernel code

Open SudhanshuBokade opened this issue 1 year ago • 1 comments

At L288 in selective_scan_bwd_kernel.cuh code , a is defined as follows :

https://github.com/state-spaces/mamba/blob/3b0dde5a20659073af5684e966a81981e614789e/csrc/selective_scan/selective_scan_bwd_kernel.cuh#L288

  1. What does a store, and why is it defined in this manner?
  2. Where is hidden state 'x' and is there a way to access it directly in the bwd kernel code? Does thread_data after the Scan op contain x?

Thank you very much for your help.

SudhanshuBokade avatar Oct 04 '24 10:10 SudhanshuBokade

IIRC a stores exp(delta_p * A_val), or maybe the product of such terms up to position p. You should work out mathematically what thread_data[i].y is. It's the second component of the scan output. The first component thread_data[i].x is the product of all exp(delta_j * A_val) for j <= p, where p is the position in the sequence. The second component is the (product of all exp(delta_j * A_val) for j < p) + delta_p * u_p * B_p. This is the hidden state x. I'm writing all this off the top of my head so the indexing might be slightly wrong or off by 1.

Notice that the fwd kernel stores sum of thread_data[i].y * C_val to the output.

tridao avatar Oct 04 '24 16:10 tridao