TransformerEngine
TransformerEngine copied to clipboard
Difficult to understand why "no fused attention kernel is available"
With the JAX API, I am getting this warning:
/usr/local/lib/python3.10/dist-packages/transformer_engine/jax/flax/transformer.py:477: UserWarning: Fused attention is not enabled. Because no fused attention kernel is available, fall back to unfused attention.
Which leads me to this code, which in turn leads me to this function with very complex boolean expressions.
Is there a recommended way to see what exactly is causing this function to return NVTE_Fused_Attn_Backend::NVTE_No_Backend?
A more informative warning for this case would be great.
@zlsh80826 @cyanguwa @denera FYI