Index put loop model regression with ort==1.18
Describe the issue
The error is only raised after 1.18. I tried 1.17.3, and it works fine.
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Loop node. Name:'/Loop' Status Message: Non-zero status code returned while running ScatterND node. Name:'/ScatterND_10' Status Message: invalid indice found, indice = 8
To reproduce
(1) With the uploaded ONNX file test_index_put_loop.zip
onnx_model = onnx.load("test_index_put_loop.onnx")
ort_session = onnxruntime.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
onnxruntime_input = {
k.name: v.numpy(force=True)
for k, v in zip(ort_session.get_inputs(), [y])
}
ort_session.run(None, onnxruntime_input)
(2) From PyTorch
import torch
import onnx
import onnxruntime
@torch.jit.script
def ngram_attention_bias(
sequence_length: int, ngram: int, device: torch.device, dtype: torch.dtype
):
bias = torch.ones(
(ngram, sequence_length), device=device, dtype=dtype
) * float("-inf")
for stream_idx in range(ngram):
for i in range(sequence_length):
bias = bias * 2
bias[stream_idx, i] = 5
bias = bias * 5
bias[0, 0] = 5
for stream_idx in range(ngram):
for i in range(sequence_length):
bias[stream_idx, i] = 5
bias[0, i] = 5
return bias
class ScriptModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.ngram = 2
self.max_target_positions = 512
def forward(self, hidden_states):
seq_length, batch_size = hidden_states.shape[:2]
predict_causal_mask = ngram_attention_bias(
self.max_target_positions,
self.ngram,
hidden_states.device,
hidden_states.dtype,
)
predict_causal_mask = predict_causal_mask[:, :seq_length]
return predict_causal_mask
x = torch.randn(6, 2)
y = torch.randn(4, 1)
torch.onnx.export(
torch.jit.script(ScriptModel()),
x,
"test_index_put_loop.onnx",
input_names=["x"],
dynamic_axes={"x": {0: "seq_length", 1: "batch_size"}},
)
onnx_model = onnx.load("test_index_put_loop.onnx")
ort_session = onnxruntime.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
onnxruntime_input = {
k.name: v.numpy(force=True)
for k, v in zip(ort_session.get_inputs(), [y])
}
ort_session.run(None, onnxruntime_input)
Urgency
Ths is spotted in PyTorch converter test case.
Platform
Linux
OS Version
VERSION="2.0.20240301" MARINER
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
798cea2350a196a67ff7e0621ea125c7f2035f7c
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.
This issue has been automatically closed as 'not planned' because it has been marked as 'stale' for more than 30 days without activity. If you believe this is still an issue, please feel free to reopen it.