tilelang
tilelang copied to clipboard
[EXAMPLE] In the flash attention example keep the max of all blocks seen in scores_max numerical stability
In the flash attention example, keep the max of previous scores_max and max(acc_s) in scores_max for numerical stability
From Flash Attention 2 paper, Algorithm 1
$$m_i^{\text{new}} = \max(m_i, \tilde{m}_{ij})$$
Summary by CodeRabbit
- Refactor
- Updated Softmax computation in flash attention example to revise maximum value tracking during per-row operations.
π Hi! Thank you for contributing to the TileLang project.
Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.
We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! π
Walkthrough
The Softmax computation in the flash attention example is modified to include an explicit max-reduction step. After computing per-row maxima, a parallel loop updates row-wise maximum values to the element-wise maximum of current and previous maximum values before subsequent scaling and exponentiation.
Changes
| Cohort / File(s) | Summary |
|---|---|
Flash Attention Softmax Update examples/flash_attention/example_mha_fwd_bhsd.py |
Added parallel loop in Softmax to compute element-wise maximum between scores_max[i] and scores_max_prev[i] for each row i after initial reduce_max operation, updating the maximum values used in subsequent scaling and normalization steps. |
Estimated code review effort
π― 3 (Moderate) | β±οΈ ~20 minutes
- Verify the correctness of the element-wise max-reduction logic and its alignment with flash attention algorithm semantics
- Confirm numerical stability implications of the updated maximum value computation
- Validate that the parallel loop implementation is correct and maintains expected behavior
Poem
π° A max of maxes, soft and true, Previous values meeting new, Attention scores now more refined, With wisdom of the row aligned! β¨
Pre-merge checks and finishing touches
β Failed checks (1 warning)
| Check name | Status | Explanation | Resolution |
|---|---|---|---|
| Docstring Coverage | β οΈ Warning | Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. | You can run @coderabbitai generate docstrings to improve docstring coverage. |
β Passed checks (2 passed)
| Check name | Status | Explanation |
|---|---|---|
| Description Check | β Passed | Check skipped - CodeRabbitβs high-level summary is enabled. |
| Title Check | β Passed | The PR title "[EXAMPLE] In the flash attention example keep the max of all blocks seen in scores_max numerical stability" directly describes the primary change in the changeset. The raw summary confirms that the main modification is updating scores_max to track the element-wise maximum of current and previous values (scores_max[i] = max(scores_max[i], scores_max_prev[i])), which is exactly what the title conveys with "keep the max of all blocks seen in scores_max." The title also correctly identifies the purpose (numerical stability) and the context (flash attention example). The title is specific enough that a teammate scanning the history would clearly understand this is about improving the numerical stability of the flash attention implementation by preserving the maximum value across blocks. |
β¨ Finishing touches
- [ ] π Generate docstrings
π§ͺ Generate unit tests (beta)
- [ ] Create PR with unit tests
- [ ] Post copyable unit tests in a comment
π Recent review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π₯ Commits
Reviewing files that changed from the base of the PR and between 60567ba3b26a6940712b10d9575967a1d6fd4dd2 and f18e82983a097245c2e3ad30720791a44c22818e.
π Files selected for processing (1)
examples/flash_attention/example_mha_fwd_bhsd.py(1 hunks)
π§° Additional context used
𧬠Code graph analysis (1)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
tilelang/language/parallel.py (1)
Parallel(9-29)
π Additional comments (1)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
87-90: Critical fix for numerical stability β LGTM!This correctly implements the running maximum update from the Flash Attention 2 paper (m_i^{new} = max(m_i, mΜ_ij)). Without this change,
scores_maxwould only reflect the maximum of the current block rather than the global maximum across all processed blocks, causing numerical instability in the softmax computation.The implementation is correct:
scores_max_prevpreserves the global max from previous iterations- After
reduce_max(line 86),scores_maxholds the current block's max- This loop updates
scores_maxto the true global max- Subsequent scaling operations (lines 97, 103) now use the correct global maximum
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
Comment @coderabbitai help to get the list of available commands and usage tips.
cc @chengyupku @Rachmanino ?
Is there any evidence that can confirm such change is necessary?
I think, without it, the current approach can lead to overflow in the accumulator and logsum when a block's max is significantly lower than the previous global max, as it effectively upscales prior contributions by a large factor (exp(m_prev - m_curr) > 1).
@codex what's your point of this change?
To use Codex here, create an environment for this repo.
@codex review
Codex Review: Didn't find any major issues. Keep it up!
βΉοΈ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with π.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".