onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

Quick Fix a Bug Caused by #19218

Open centwang opened this issue 1 year ago • 1 comments

#19218 tried to fuse Gather/Slice to Split, but the logic has problem. Scalar value or 1-dim value of indices in Gather node will produce different result, scalar value will produce a result tensor by removing the axis dim, will 1-dim indices value will keep that dim, even when the dim value is 1. For example,

Node |-> Gather(indices=[0], axis=axis) |-> Gather(indices=[1], axis=axis) |-> Slice(index=2, axis=axis) is same as Node |-> Split(axis=axis)

But Node |-> Gather(indices=0, axis=axis) |-> Gather(indices=1, axis=axis) |-> Slice(index=2, axis=axis) is same as Node |-> Split(axis=axis) ||-> Squeeze(axis=axis) ||-> Squeeze(axis=axis) ||->

Previous PR doesn't take such case related to Squeeze/Unsqueeze into account.

Ideally a general solution of fusion should not limit the number of Gather and Slice node number, it's better to check all consumers no matter it's Gather or Slice, if the indices of Gather and start/end of Slice can cover the specific dim of the input tensor, then we can fuse them to a Split, and removing or adding Squeeze/Unsqueeze according to the dim count of the indices tensor in Gather.

@rui-ren, please check if the fix can still be applied to your model. Your UT cases in your PR just cover the transformer logic, but since you didn't execute the produced graph, you actually didn't cover that the transformer may generate an invalid graph in logic. Since this fix is for ORTModule, actually you can add a Python UT case for ORTModule by using a nn.Module including the model script so that we can guarantee the transformer can produce accurate result during execution.

centwang avatar Feb 22 '24 08:02 centwang

Previous PR doesn't take such case related to Squeeze/Unsqueeze into account.

Hi @vincent, thank you for your PR. I was planning to add another PR to catch this squeeze/unsqueeze scenario. In my model,

Node
|-> Gather(indices=[0], axis=axis)
|-> Gather(indices=[1], axis=axis)
image

As such, I only considered the first case and separate Scalar value or 1-dim value of indices in Gather node case in another PR to modify it. Thank you for the optimization.

rui-ren avatar Feb 23 '24 03:02 rui-ren

LGTM, thanks!

rui-ren avatar Feb 29 '24 05:02 rui-ren

BTW, tested this PR with lora and qlora, no perf degradation regarding throughput and accuracy.

rui-ren avatar Feb 29 '24 05:02 rui-ren