onnxruntime
onnxruntime copied to clipboard
Quick Fix a Bug Caused by #19218
#19218 tried to fuse Gather/Slice to Split, but the logic has problem. Scalar value or 1-dim value of indices in Gather node will produce different result, scalar value will produce a result tensor by removing the axis dim, will 1-dim indices value will keep that dim, even when the dim value is 1. For example,
Node |-> Gather(indices=[0], axis=axis) |-> Gather(indices=[1], axis=axis) |-> Slice(index=2, axis=axis) is same as Node |-> Split(axis=axis)
But Node |-> Gather(indices=0, axis=axis) |-> Gather(indices=1, axis=axis) |-> Slice(index=2, axis=axis) is same as Node |-> Split(axis=axis) ||-> Squeeze(axis=axis) ||-> Squeeze(axis=axis) ||->
Previous PR doesn't take such case related to Squeeze/Unsqueeze into account.
Ideally a general solution of fusion should not limit the number of Gather and Slice node number, it's better to check all consumers no matter it's Gather or Slice, if the indices of Gather and start/end of Slice can cover the specific dim of the input tensor, then we can fuse them to a Split, and removing or adding Squeeze/Unsqueeze according to the dim count of the indices tensor in Gather.
@rui-ren, please check if the fix can still be applied to your model. Your UT cases in your PR just cover the transformer logic, but since you didn't execute the produced graph, you actually didn't cover that the transformer may generate an invalid graph in logic. Since this fix is for ORTModule, actually you can add a Python UT case for ORTModule by using a nn.Module including the model script so that we can guarantee the transformer can produce accurate result during execution.
Previous PR doesn't take such case related to Squeeze/Unsqueeze into account.
Hi @vincent, thank you for your PR. I was planning to add another PR to catch this squeeze/unsqueeze
scenario. In my model,
Node
|-> Gather(indices=[0], axis=axis)
|-> Gather(indices=[1], axis=axis)
As such, I only considered the first case and separate Scalar value or 1-dim value of indices in Gather node
case in another PR to modify it. Thank you for the optimization.
LGTM, thanks!
BTW, tested this PR with lora
and qlora
, no perf degradation regarding throughput and accuracy.