flash-linear-attention
                                
                                
                                
                                    flash-linear-attention copied to clipboard
                            
                            
                            
                        Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
### Proposal The `chunk` and `fused_chunk` modes have complementary strengths in different scenarios. The interface should be unified so that the user is agnostic to the underlying implementation. The API...
### Feature Request 1. Include a flag (e.g., `use_default_init`) in the model configuration or constructor. When set to `False`, this flag would disable FLA's default initialization logic entirely, allowing users...
### Proposal support context parallelism for all linear attention models ### Rationale One of the major advantages of linear attention is that it enables long sequence modeling. However, for training...
### Feature Request Hello, Thank you for all of your great work. I was wondering if it would be a reasonable inclusion to add even more fused linear activation functions...
### Proposal The current `chunk` mode normally loads 64x64 blocks, do the computation, and then save the resulting hidden state, which could bring I/O burden. In Tri Dao's Mamba2 implementation...
### Proposal fuse shortconv and output norm/gate into kernels, as in Mamba1 and Mamba2 ### Rationale QKV ShortConv will introduce three additional activations, resulting in a non-negligible memory overhead.
### Proposal * We want to add `apply_tp` & `apply_cp` fns for each models as their layer definitions can be varied. Also see comments in https://github.com/fla-org/flame/issues/4