flash-attention
flash-attention copied to clipboard
v1 algorithm typo in v2 paper
Hi @tridao,
I believe the v1 algorithm statement in the v2 paper needs the following minor fix in the last equation:
diff --git a/src/background.tex b/src/background.tex
index 0f496e8..157c818 100644
--- a/src/background.tex
+++ b/src/background.tex
@@ -129,7 +129,7 @@ rescale to get the right output at the end:
m^{(2)} &= \max(m^{(1)}, \mathrm{rowmax}(\vS^{(2)})) = m \\
\ell^{(2)} &= e^{m^{(1)} - m^{(2)}} \ell^{(1)} + \mathrm{rowsum}(e^{\vS^{(2)} - m^{(2)}}) = \mathrm{rowsum}(e^{\vS^{(1)} - m}) + \mathrm{rowsum}(e^{\vS^{(2)} - m}) = \ell \\
\tilde{\vP}^{(2)} &= \diag(\ell^{(2)})^{-1} e^{\vS^{(2)} - m^{(2)}} \\
- \vO^{(2)} &= \diag(\ell^{(1)} / \ell^{(2)})^{-1} \vO^{(1)} + \tilde{\vP}^{(2)} \vV^{(2)} = \diag(\ell^{(2)})^{-1} e^{s^{(1)} - m} \vV^{(1)} + \diag(\ell^{(2)})^{-1} e^{s^{(2)} - m} \vV^{(2)} = \vO.
+ \vO^{(2)} &= \diag(\ell^{(1)} / \ell^{(2)}) e^{m^{(1)} - m^{(2)}} \vO^{(1)} + \tilde{\vP}^{(2)} \vV^{(2)} = \diag(\ell^{(2)})^{-1} e^{s^{(1)} - m} \vV^{(1)} + \diag(\ell^{(2)})^{-1} e^{s^{(2)} - m} \vV^{(2)} = \vO.
\end{align*}
We show how \sysnameone uses online softmax to enable tiling
I removed the $-1$ exponent on the diagonal term, and added the $e^{m^{(1)} - m^{(2)}}$ correction term.
Before:
$$\mathbf{O}^{(2)} = \text{diag}\left(l^{(1)}/l^{(2)}\right)^{-1} \mathbf{O}^{(1)} + \mathbf{\tilde{P}}^{(2)}\mathbf{V}^{(2)}$$
After:
$$\mathbf{O}^{(2)} = \text{diag}\left(l^{(1)}/l^{(2)}\right) e^{m^{(1)}-m^{(2)}} \mathbf{O}^{(1)} + \mathbf{\tilde{P}}^{(2)}\mathbf{V}^{(2)}$$
Original rendered algorithm for reference:
I got the LaTeX source from arXiv.