flash-linear-attention
                                
                                
                                
                                    flash-linear-attention copied to clipboard
                            
                            
                            
                        Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
Hi I catched the 3B version of the model from the hugging face hub and then when I try to use loss.backward (after model.train()) using the transformer library, I got...
I'm currently trying to use just the operators defined in `fla.ops`; however, because of the `__init__.py` script for the main package, it's not possible to do this without importing things...
I've been trying to use the linear attention kernels in a model which I am compiling, however the triton kernel does not seem to work with torch compile. Specifically, when...
This pull request aims at enhance fla support for RWKV6, both speed and perfermance on bf16. Also , enable fla on Intel cards. ## FLA ChunkRWKV6 Optimized Implementation This repository...
I have compared the speeds of GLA, Attention, and Flash Attention, as shown in the table below, and found that GLA has little to no advantage in terms of speed....
Thanks for the incredibly clean repository! I am Sayak from the [Diffusers](https://github.com/huggingface/diffusers) team at Hugging Face. My question is probably very naive, so I apologize for that in advance. I...
I added a classification head to the pretrained Transformer++ model from https://huggingface.co/fla-hub/transformer-1.3B-100B/tree/main and finetuned on SST-2 dataset. However, the validation loss remained constant since the begginning. Here's my code for...
Great work! It appears that both GLA and RetNet are optimized only for causal cases. Is there an optimized linear attention for non-causal scenarios?
The current FLA RWKV6 implementation has significant precision issues in pure bf16 mode. Below are the results from my experiments: CUDA bf16 (fp32 internal): y: 0.0016603531206355376 gr: 0.0017877683404764239 gk: 0.0017853925508536652...