Is the version constraint on `flash-attn` real?
setup.py requires flash-attn <= 2.0.4, which is kind of a weird constraint -- is it a real one or an artifact left over from something historical? It makes it tricky to have TE and flash-attn installed in the same environment (we are using some of the latest flash-attn features), especially if you want to interact with flash-attn directly rather than through the TE abstractions.
Same, I also have this issue!
#506 increases the scope of flash-attn versions supported in TE
One thing to note is that flash attention in 2.1 changed the behavior of the cross attention + causal mask: https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag. On the TE side, in order to keep behavior consistent with the previous versions and across backends we will need to disable FA for that case when FA v2.1+ is installed. I believe FA's new behavior is better in some usecases though, so we will probably introduce some opt-in mechanism for it.