TransformerEngine
TransformerEngine copied to clipboard
[Feature Request][PyTorch] Support thd format for fp8 tensors in DotProductAttention
Seems like at the current time packed tensors in thd
format are not supported by transformer_engine.pytorch.attention. DotProductAttention
. That's weird as such mode clearly supported by fused_attn_fwd
from fused_attn cpp_extensions
I see that FusedAttnFunc
was used in FusedAttention
, but implementations for FusedAttnFunc_kvpacked
and FusedAttnFunc_qkvpacked
are not present. I suppose, they could be added in a same way