llm-foundry
llm-foundry copied to clipboard
add transformer engine fp8 attention
This PR adds the TransformerEngine fp8 attention implemention. https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py.
To enable it, set the below fields in the yaml config.
precision: amp_fp8
model:
attn_config:
attn_type: te_multihead_attention
kv_n_heads: 8
fc_type: te