TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[Feature Request, JAX] Support fused Softmax op when sequence length is larger than 2048

Open MoFHeka opened this issue 1 year ago • 1 comments

Fail to Jax custom call when training a model with sequence length is larger than 2048. At nowadays, seqlen in most model was bigger than 2048, at least 4096. 2048 seqlen for TE is too small!

MoFHeka avatar Jan 26 '24 18:01 MoFHeka

Thanks for the feedback. We have plan to remove the restriction of 2048.

zlsh80826 avatar Jan 29 '24 14:01 zlsh80826

@zlsh80826 Is this already fixed?

ptrendx avatar May 16 '24 18:05 ptrendx

Yes, it is fixed in https://github.com/NVIDIA/TransformerEngine/pull/796. The softmax now supports up to seqlen = 16384.

zlsh80826 avatar May 17 '24 01:05 zlsh80826