TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[Feature Request][PyTorch] Support thd format for fp8 tensors in DotProductAttention

Open alexdremov opened this issue 8 months ago • 0 comments

Seems like at the current time packed tensors in thd format are not supported by transformer_engine.pytorch.attention. DotProductAttention. That's weird as such mode clearly supported by fused_attn_fwd from fused_attn cpp_extensions

I see that FusedAttnFunc was used in FusedAttention, but implementations for FusedAttnFunc_kvpacked and FusedAttnFunc_qkvpacked are not present. I suppose, they could be added in a same way

alexdremov avatar Jun 13 '24 14:06 alexdremov