mamba
mamba copied to clipboard
Adding Initial Value Support to Selective Scan Forward Kernel
@tridao Hello!
I am currently working with the selective scan forward kernel, specifically the step h_t = A*h_{t-1} + Bx, where h_0 is currently set to 0. I would like to modify this behavior to allow h_0 to be an initial value (init_value).
Upon reviewing the code, I noticed that the InclusiveScan function from Ktraits::BlockScanT(smem_scan) does not seem to support an initial value option. Here is the relevant line of code: Ktraits::BlockScanT(smem_scan).InclusiveScan()
Could you provide some guidance on how to modify the code to support an initial value for h_0? Any help would be greatly appreciated.