Questions regarding `a` and `x` in backward CUDA kernel code
At L288 in selective_scan_bwd_kernel.cuh code , a is defined as follows :
https://github.com/state-spaces/mamba/blob/3b0dde5a20659073af5684e966a81981e614789e/csrc/selective_scan/selective_scan_bwd_kernel.cuh#L288
- What does
astore, and why is it defined in this manner? - Where is hidden state 'x' and is there a way to access it directly in the bwd kernel code? Does
thread_dataafter the Scan op containx?
Thank you very much for your help.
IIRC a stores exp(delta_p * A_val), or maybe the product of such terms up to position p. You should work out mathematically what thread_data[i].y is. It's the second component of the scan output. The first component thread_data[i].x is the product of all exp(delta_j * A_val) for j <= p, where p is the position in the sequence. The second component is the (product of all exp(delta_j * A_val) for j < p) + delta_p * u_p * B_p. This is the hidden state x.
I'm writing all this off the top of my head so the indexing might be slightly wrong or off by 1.
Notice that the fwd kernel stores sum of thread_data[i].y * C_val to the output.