🚀[FEA]: Add FLARE: Fast Low-rank Attention Routing Engine to model zoo
Is this a new feature, an improvement, or a change to existing functionality?
New Feature
How would you describe the priority of this feature request
Low (would be nice)
Please provide a clear description of problem you would like to solve.
I wanted to share our latest paper, titled "FLARE: Fast Low-rank Attention Routing Engine."
https://arxiv.org/abs/2508.12594
FLARE is a novel self-attention mechanism that learns a low-rank attention formulation that can be applied in linear time. FLARE achieves superior accuracy across diverse neural PDE surrogate benchmarks, and scales to unprecedented problem sizes (1 million tokens on a single GPU). Our code is available below.
https://github.com/vpuri3/FLARE.py
FLARE is built entirely from standard fused attention primitives and does not need any custom kernels. You can find the code for FLARE at the link below. I am happy to contribute a PR to implement FLARE.
https://github.com/vpuri3/FLARE.py/blob/master/pdebench/models/flare.py
Describe any alternatives you have considered
No response
Hi @vpuri3 ,
Thanks for opening the issue. First of all, I'm happy to discuss a possible PR with you.
I took a look at your paper, code, and benchmarks. Your model does have some interesting properties with the low rank attention mechanism. If you don't mind a few questions for my own curiosity:
- In Figure 2 of your paper, you showcase performance vs. Transolver with a variety of slice dimension sizes. What motivated those choices? The ablation table in the original transolver paper (Table 9 here https://arxiv.org/pdf/2402.02366) does not give good indications for accuracy above 256 slices.
- In your table of benchmarks (your Table 2), likewise I don't understand some of the numbers. Compare, for example, your transolver numbers vs. (again) their table 9 - there is a bit of a difference between what the two papers report for transolver.
- You mention scaling to 1M mesh points. How did you find accuracy increasing / not-increasing as you scaled? We have problems currently where the meshes are hundreds of millions of points, and I think in the future we'll see billion-point meshes. If accuracy is improving as you add more points, and you want to go bigger, that's a great use case for physicsnemo integration.
(These questions are motivated a lot by our own goals for physicsnemo in this area, delivering high performing models with excellent scaling and accuracy.)
If you'd like to move forward with a pull request, a couple action items:
- You should identify an example use case for the physicsnemo implementation. We have a number of examples already, including a few for transolver that might make sense to tweak for FLARE. Most of our models have at least one example showcasing their usage and performance / accuracy. The DrivAerNet dataset is one we feature often for large-mesh networks, though it's a pretty hefty benchmark. We also have good tools to compare models there. A model PR should ideally come with at least one example using the model.
- Whatever benchmark you pick, be prepared for us to request converge curves / data with your PR implementation that match the results in the paper. We'll test too :).
- Our models require unit testing and stable API testing - take a look at some of the tests in physicsnemo/test/models/ to see what I mean.
- Develop first to the
experimentalmodels folder, and we'll ultimately move your PR when it's API is stable. - Transolver is implemented with a transformer_engine backend instead of pure torch. That seems doable here - if you submit a PR, you're welcome to integrate transformer_engine but it's not mandatory, either.
- What's your timeline?
Hope this helps. Feel free to let us know what you're thinking, or if you have other questions. And thanks again for starting this discussion!
Hi @coreyjadams,
Thanks for getting back! I really appreciate your interest in FLARE. For context, I’d encourage you to take a look at the v2 of our paper, where we’ve added deeper ablations and expanded discussion on low-rankedness and scaling (see Sections 3.2, 5, and related appendices).
Regarding your questions:
1. Motivation for the slice / rank choices in Figure 2
FLARE is designed to expose the trade-off between efficiency (low rank) and accuracy (high rank / full attention). As shown in Figure 5 (right) of the paper, increasing the rank monotonically improves accuracy — up to a point that depends on the problem.
- For example, Elasticity saturates at ranks 16–32, while Darcy requires nearly full attention.
- Figures 2 and 4 (middle) jointly illustrate that higher rank also increases the time-to-solution.
To demonstrate this trade-off, we chose representative ranks = 128, 512, 2048 (the same values used in the 1 M-point DrivAerML experiment). We applied equivalent slice counts for Transolver. While Transolver doesn’t explicitly discuss this trade-off, Table 9 in their paper shows test-loss improvements up to ~256 slices before flattening, so our choices sample some of that operating regime.
2. Benchmark comparisons (Table 2 differences) We re-ran Transolver and other baselines across all benchmarks for consistency. Small mismatches are inevitable due to setup reproducibility (dataset preprocessing, optimizer schedules, etc.), but the overall ranking trends are consistent. The full code is available for inspection on our repo.
Below is a side-by-side summary of the relevant cases:
| Test loss (Rel L₂) ↓ | Elasticity | Darcy | Airfoil | Pipe |
|---|---|---|---|---|
| Transolver (their paper Table 2) | 6.4e-3 | 5.7e-3 | 5.3e-3 | 3.3e-3 |
| Transolver w/o conv (our paper) | 6.40e-3 | 18.6e-3 | 8.24e-3 | 4.87e-3 |
| Transolver w. conv (our paper) | N/A | 5.94e-3 | 5.50e-3 | 3.90e-3 |
| FLARE (ours) | 3.38e-3 | 5.14e-3 | 4.28e-3 | 2.85e-3 |
| Vanilla Transformer (ours) | 5.37e-3 | 4.38e-3 | 6.28e-3 | too slow |
A few observations:
- Our Transolver numbers are quite close to theirs, within expected variation.
- The Transolver architecture relies heavily on convolutional preprocessing (Appendix B.3 of their paper), which limits transferability to unstructured grids.
- Vanilla Transformer performs competitively with low-rank attention methods, providing a useful upper-bound reference.
3. Scaling to 1 M points and accuracy trends
On the “1 M points” test case, we scale the model architecture (number of blocks and attention rank) to a mesh of ≈ 1 M points. Figure 4 shows clear accuracy gains with more blocks and diminishing returns with higher rank.
To your broader question (“does accuracy improve with more points?”): yes, generally. Comparing Table 2 (40 k points) vs. Figure 4 (1 M points), we consistently observe higher accuracy with denser point sampling, up to a limit imposed by data noise and resolution.
One key learning from our 1M-point scaling study was that mixed-precision training (forward in FP16, backward in FP32) is essential for scaling to extremely large problems, since high-performance kernels (e.g., FlashAttention) are typically optimized for low precision. We also found that adapting off-the-shelf surrogate models to low-precision formulations can be nontrivial — many exhibit numerical blow-ups or require careful hyperparameter retuning to remain stable. It would be great to see the community move toward low-precision training setups.
Best,
Vedant