TransformerEngine
TransformerEngine copied to clipboard
[Feature Request, JAX] Support fused Softmax op when sequence length is larger than 2048
Fail to Jax custom call when training a model with sequence length is larger than 2048. At nowadays, seqlen in most model was bigger than 2048, at least 4096. 2048 seqlen for TE is too small!
Thanks for the feedback. We have plan to remove the restriction of 2048.
@zlsh80826 Is this already fixed?
Yes, it is fixed in https://github.com/NVIDIA/TransformerEngine/pull/796. The softmax now supports up to seqlen = 16384.