onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

SplitToSequence cannot support float16 as input/output

Open xiaowuhu opened this issue 1 year ago • 8 comments

Describe the issue

SplitToSequence operator cannot support float16 as input, although it was said 'yes' in ONNX doc. It impact fp16 converter. Because the output is a sequence, so user have to convert the element in the output sequence to fp16 one by one, the performance is bad.

To reproduce

  1. prepare a [5,5] input array with dtype = np.float16
  2. call op.SplitToSequence(input, dim=0, num_outputs=2)
  3. the expected output is sequence([2,5], [3,5]).

It works on float32.

Urgency

No response

Platform

Windows

OS Version

11

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

xiaowuhu avatar May 18 '23 08:05 xiaowuhu

The error persists in 1.15.1: onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Type Error: Type (seq(tensor(float16))) of output arg (output0) of node () does not match expected type (seq(tensor(float))).

justinchuby avatar Jul 05 '23 18:07 justinchuby

Summary

ORT raises [ONNXRuntimeError] : 1 : FAIL : Type Error: Type (seq(tensor(float16))) of output arg (_val_1) of node (_0x79a4c20_n23) does not match expected type (seq(tensor(float))). when executing test ops_test.TestOutputConsistencyFullGraphCPU.test_output_match_opinfo__chunk_cpu_float16 in ONNX Script TorchLib.

To recreate this report, use

CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__chunk_cpu_float16

To reproduce

import onnx
import onnxruntime as ort
import numpy as np
from numpy import array, float16, float32, float64, int32, int64

onnx_model_text = """
<
   ir_version: 8,
   opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18],
   producer_name: "pytorch",
   producer_version: "2.1.0"
>
torch_jit (float16[5,5,5] input_0) => (seq(float16[5,unk__4,5]) _val_1) {
   _val_1 = pkg.onnxscript.torch_lib.aten_chunk <chunks = 5, dim = 1> (input_0)
}
<
  domain: "pkg.onnxscript.torch_lib",
  opset_import: ["" : 18]
>
aten_chunk <chunks>(self) => (return_val)
{
   neg_1 = Constant <value_ints = [-1]> ()
   self_shape = Shape (self)
   dim = Constant <value_int: int = @dim> ()
   dim_size = Gather <axis = 0> (self_shape, dim)
   chunks = Constant <value_int: int = @chunks> ()
   chunks_cast = CastLike (chunks, dim_size)
   num_per_chunk = Div (dim_size, chunks_cast)
   chunks_0 = Constant <value_int: int = @chunks> ()
   chunks_0_cast = CastLike (chunks_0, dim_size)
   tmp = Mod (dim_size, chunks_0_cast)
   int64_0 = Constant <value = int64 int64_0 {0}> ()
   int64_0_cast = CastLike (int64_0, tmp)
   tmp_1 = Greater (tmp, int64_0_cast)
   tmp_2 = Cast <to = 7> (tmp_1)
   num_per_chunk_3 = Add (tmp_2, num_per_chunk)
   num_chunk = Div (dim_size, num_per_chunk_3)
   tmp_4 = Reshape (num_chunk, neg_1)
   list_split = Expand (num_per_chunk_3, tmp_4)
   remainder = Mod (dim_size, num_per_chunk_3)
   int64_0_5 = Constant <value = int64 int64_0_5 {0}> ()
   int64_0_5_cast = CastLike (int64_0_5, remainder)
   cond = Greater (remainder, int64_0_5_cast)
   list_split_9 = If (cond) <then_branch = thenGraph_19 () => ( list_split_7) {
      tmp_6 = Reshape (remainder, neg_1)
      list_split_7 = Concat <axis = 0> (list_split, tmp_6)
   }, else_branch = elseGraph_19 () => ( list_split_8) {
      list_split_8 = Identity (list_split)
   }>
   return_val = SplitToSequence <axis: int = @dim> (self, list_split_9)
}
"""

ort_inputs = {'input_0': array([[[-2.047  ,  8.09   , -7.1    , -2.293  , -7.355  ],
        [-5.668  , -4.676  ,  0.3516 ,  1.371  , -8.875  ],
        [ 7.137  , -8.44   ,  7.523  ,  7.367  , -4.43   ],
        [ 3.016  ,  1.125  ,  8.81   , -3.312  ,  4.14   ],
        [ 0.545  ,  1.213  ,  4.375  , -3.797  , -5.562  ]],

       [[ 5.92   , -5.33   , -6.47   , -5.68   ,  6.785  ],
        [ 4.297  , -6.977  , -0.06152, -8.65   ,  0.2373 ],
        [-7.82   , -7.242  ,  7.375  , -2.152  , -0.835  ],
        [ 0.2812 ,  0.413  , -4.586  , -5.43   ,  5.035  ],
        [ 6.39   , -1.934  , -8.14   , -2.996  ,  7.656  ]],

       [[-2.629  ,  8.664  ,  4.797  , -0.5625 ,  6.484  ],
        [-3.621  ,  8.28   ,  4.05   , -3.357  ,  6.75   ],
        [ 0.03516,  1.907  , -4.586  , -2.268  , -5.51   ],
        [-1.354  , -2.021  , -5.555  , -7.188  ,  0.12305],
        [ 7.53   ,  4.86   , -1.169  ,  4.043  ,  6.062  ]],

       [[ 5.047  , -1.415  , -2.479  ,  2.11   , -4.05   ],
        [-0.03516, -8.164  ,  7.902  ,  6.44   , -4.746  ],
        [ 3.568  , -6.977  , -2.426  ,  5.75   ,  1.494  ],
        [-4.254  ,  0.3076 , -4.395  ,  1.397  ,  7.367  ],
        [ 6.133  ,  2.127  ,  0.747  ,  6.99   , -3.93   ]],

       [[-3.016  , -4.113  ,  5.035  ,  8.37   , -0.10547],
        [-3.332  , -2.012  ,  0.2373 ,  5.44   ,  4.88   ],
        [ 2.98   ,  7.004  ,  8.414  ,  2.9    ,  5.617  ],
        [ 4.438  ,  8.914  ,  3.05   ,  4.15   ,  2.021  ],
        [-7.34   ,  1.969  ,  3.375  ,  3.305  ,  2.479  ]]],
      dtype=float16)}

session_options = ort.SessionOptions()
session_options.graph_optimization_level = (
    ort.GraphOptimizationLevel.ORT_DISABLE_ALL
)
onnx_model = onnx.parser.parse_model(onnx_model_text)

session = ort.InferenceSession(
    onnx_model.SerializeToString(), session_options, providers=("CPUExecutionProvider",)
)
ort_outputs = session.run(None, ort_inputs)

Full error stack

Traceback (most recent call last):
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test_common.py", line 533, in _capture_graph_and_evaluate_torch_script_evaluator
    return _safe_ort_session_run(onnx_model.SerializeToString(), ort_inputs)
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test_common.py", line 349, in _safe_ort_session_run
    raise return_dict["error"]
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Type Error: Type (seq(tensor(float16))) of output arg (_val_1) of node (_0x79a4c20_n23) does not match expected type (seq(tensor(float))).

justinchuby avatar Aug 03 '23 05:08 justinchuby

https://github.com/microsoft/onnxruntime/pull/17117 is trying to fix this.

centwang avatar Aug 11 '23 09:08 centwang

Hi, I am facing the same issue for models using torch.repeat_interleave.

Edit: the issue is actually "fixed" upstream (no more SplitToSequence) in the export in pytorch 2.1, thanks to https://github.com/pytorch/pytorch/pull/100575

fxmarty avatar Oct 16 '23 11:10 fxmarty

Fixed upstream. Closing issue. Thanks all.

MaanavD avatar Mar 27 '24 21:03 MaanavD

@MaanavD we prob should keep any ort issues tagged with "dynamo" open

justinchuby avatar Mar 27 '24 21:03 justinchuby

@MaanavD we prob should keep any ort issues tagged with "dynamo" open

Was the issue fixed? if so, it is safe to close. Otherwise, we need to figure out whether this is an ORT or Torchlib issue and work to fix it

thiagocrepaldi avatar May 01 '24 15:05 thiagocrepaldi

This is an ORT issue. We can run the repro script again to validate with the latest version.

justinchuby avatar May 01 '24 17:05 justinchuby