onnxruntime
onnxruntime copied to clipboard
Fix num splits bug
Description
Found a bug with num splits where the heuristic isn't being performed properly due to incorrect passing of sequence length to heuristic function.
Motivation and Context
We were experiencing significant performance issues with long sequence length with flash attention due to this misconfiguration.
Please change all places of get_num_splits_and_buffer_sizes using total sequence length.