tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

[EXAMPLE] In the flash attention example keep the max of all blocks seen in scores_max numerical stability

Open vpj opened this issue 1 month ago β€’ 9 comments

In the flash attention example, keep the max of previous scores_max and max(acc_s) in scores_max for numerical stability

From Flash Attention 2 paper, Algorithm 1

$$m_i^{\text{new}} = \max(m_i, \tilde{m}_{ij})$$

Summary by CodeRabbit

  • Refactor
    • Updated Softmax computation in flash attention example to revise maximum value tracking during per-row operations.

vpj avatar Oct 28 '25 14:10 vpj

πŸ‘‹ Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! πŸš€

github-actions[bot] avatar Oct 28 '25 14:10 github-actions[bot]

Walkthrough

The Softmax computation in the flash attention example is modified to include an explicit max-reduction step. After computing per-row maxima, a parallel loop updates row-wise maximum values to the element-wise maximum of current and previous maximum values before subsequent scaling and exponentiation.

Changes

Cohort / File(s) Summary
Flash Attention Softmax Update
examples/flash_attention/example_mha_fwd_bhsd.py
Added parallel loop in Softmax to compute element-wise maximum between scores_max[i] and scores_max_prev[i] for each row i after initial reduce_max operation, updating the maximum values used in subsequent scaling and normalization steps.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Verify the correctness of the element-wise max-reduction logic and its alignment with flash attention algorithm semantics
  • Confirm numerical stability implications of the updated maximum value computation
  • Validate that the parallel loop implementation is correct and maintains expected behavior

Poem

🐰 A max of maxes, soft and true, Previous values meeting new, Attention scores now more refined, With wisdom of the row aligned! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
βœ… Passed checks (2 passed)
Check name Status Explanation
Description Check βœ… Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check βœ… Passed The PR title "[EXAMPLE] In the flash attention example keep the max of all blocks seen in scores_max numerical stability" directly describes the primary change in the changeset. The raw summary confirms that the main modification is updating scores_max to track the element-wise maximum of current and previous values (scores_max[i] = max(scores_max[i], scores_max_prev[i])), which is exactly what the title conveys with "keep the max of all blocks seen in scores_max." The title also correctly identifies the purpose (numerical stability) and the context (flash attention example). The title is specific enough that a teammate scanning the history would clearly understand this is about improving the numerical stability of the flash attention implementation by preserving the maximum value across blocks.
✨ Finishing touches
  • [ ] πŸ“ Generate docstrings
πŸ§ͺ Generate unit tests (beta)
  • [ ] Create PR with unit tests
  • [ ] Post copyable unit tests in a comment

πŸ“œ Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 60567ba3b26a6940712b10d9575967a1d6fd4dd2 and f18e82983a097245c2e3ad30720791a44c22818e.

πŸ“’ Files selected for processing (1)
  • examples/flash_attention/example_mha_fwd_bhsd.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
tilelang/language/parallel.py (1)
  • Parallel (9-29)
πŸ”‡ Additional comments (1)
examples/flash_attention/example_mha_fwd_bhsd.py (1)

87-90: Critical fix for numerical stability – LGTM!

This correctly implements the running maximum update from the Flash Attention 2 paper (m_i^{new} = max(m_i, m̃_ij)). Without this change, scores_max would only reflect the maximum of the current block rather than the global maximum across all processed blocks, causing numerical instability in the softmax computation.

The implementation is correct:

  • scores_max_prev preserves the global max from previous iterations
  • After reduce_max (line 86), scores_max holds the current block's max
  • This loop updates scores_max to the true global max
  • Subsequent scaling operations (lines 97, 103) now use the correct global maximum

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❀️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

coderabbitai[bot] avatar Oct 28 '25 14:10 coderabbitai[bot]

cc @chengyupku @Rachmanino ?

LeiWang1999 avatar Oct 28 '25 14:10 LeiWang1999

Is there any evidence that can confirm such change is necessary?

Rachmanino avatar Oct 28 '25 14:10 Rachmanino

I think, without it, the current approach can lead to overflow in the accumulator and logsum when a block's max is significantly lower than the previous global max, as it effectively upscales prior contributions by a large factor (exp(m_prev - m_curr) > 1).

vpj avatar Oct 28 '25 14:10 vpj

@codex what's your point of this change?

LeiWang1999 avatar Oct 29 '25 08:10 LeiWang1999

@codex review

LeiWang1999 avatar Oct 29 '25 09:10 LeiWang1999

Codex Review: Didn't find any major issues. Keep it up!

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with πŸ‘.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".