TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[Feature Request, JAX] Support fused Softmax op when sequence length is larger than 2048

Open MoFHeka opened this issue 5 months ago • 1 comments

Fail to Jax custom call when training a model with sequence length is larger than 2048. At nowadays, seqlen in most model was bigger than 2048, at least 4096. 2048 seqlen for TE is too small!

MoFHeka avatar Jan 26 '24 18:01 MoFHeka