DeepTab icon indicating copy to clipboard operation
DeepTab copied to clipboard

[FEATURE] Potential optimization for the selective_scan_seq

Open nwuestefeld opened this issue 9 months ago • 0 comments

By replacing explicite tensor operations with torch.einsum() in the Zero-Order-Hold transformation, performance and readability can be improved.

Replacing the original Zero-Order-Hold transformation in line 518 of mamba_arch.py

     deltaA = torch.exp(delta.unsqueeze(-1) * A) 
     deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) 
     BX = deltaB * (x.unsqueeze(-1))

with:

deltaA = torch.einsum('bld,dn->bldn', dt, A)
   BX = torch.einsum('bld,bld,bln->bldn', dt, u, B) 

can improve execution time by up to ~40% while requiring the same number of FLOPS. (See attached plot)

Image

Moreover, vectorization of the loop does not further improve execution time.

nwuestefeld avatar Mar 20 '25 12:03 nwuestefeld