TransformerEngine
TransformerEngine copied to clipboard
[Feature Request, JAX] Support fused Softmax op when sequence length is larger than 2048
Fail to Jax custom call when training a model with sequence length is larger than 2048. At nowadays, seqlen in most model was bigger than 2048, at least 4096. 2048 seqlen for TE is too small!