DeepTab
DeepTab copied to clipboard
[FEATURE] Potential optimization for the selective_scan_seq
By replacing explicite tensor operations with torch.einsum() in the Zero-Order-Hold transformation, performance and readability can be improved.
Replacing the original Zero-Order-Hold transformation in line 518 of mamba_arch.py
deltaA = torch.exp(delta.unsqueeze(-1) * A)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
BX = deltaB * (x.unsqueeze(-1))
with:
deltaA = torch.einsum('bld,dn->bldn', dt, A)
BX = torch.einsum('bld,bld,bln->bldn', dt, u, B)
can improve execution time by up to ~40% while requiring the same number of FLOPS. (See attached plot)
Moreover, vectorization of the loop does not further improve execution time.