axon_onnx icon indicating copy to clipboard operation
axon_onnx copied to clipboard

Importing model with dynamic batch size fails

Open mortont opened this issue 2 years ago • 2 comments

I'm attempting to import a model with a dynamic batch size (modified export of a Whisper encoder) and I'm getting the following error

12:08:45.008 [warning] mel_spectrogram_dynamic_axes_1 has no specified dimension, assuming nil
** (ArgumentError) invalid dimension in axis 0 found in shape. Each dimension must be a positive integer, got nil in shape {nil, 80, 3000}
    (nx 0.4.0) lib/nx/shape.ex:60: Nx.Shape.validate!/3
    (nx 0.4.0) lib/nx.ex:2819: Nx.broadcast/3
    (axon_onnx 0.3.0) lib/axon_onnx/deserialize.ex:1402: anonymous fn/1 in AxonOnnx.Deserialize.recur_nodes/2
    (elixir 1.14.0) lib/map.ex:258: Map.do_map/2
    (elixir 1.14.0) lib/map.ex:252: Map.new_from_map/2
    (axon_onnx 0.3.0) lib/axon_onnx/deserialize.ex:1402: AxonOnnx.Deserialize.recur_nodes/2
    (elixir 1.14.0) lib/enum.ex:2468: Enum."-reduce/3-lists^foldl/2-0-"/3
    iex:5: (file)

I've been able to run it in onnxruntime and it validates correctly with onnx.checker.check_model in python, is there something I'm missing? The input signature of the model is

onnx.load('./tiny.en/encoder.onnx').graph.input
[name: "mel_spectrogram"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "mel_spectrogram_dynamic_axes_1"
      }
      dim {
        dim_value: 80
      }
      dim {
        dim_value: 3000
      }
    }
  }
}
]

mortont avatar Dec 02 '22 17:12 mortont

If you try importing but specifying mel_spectrogram_dynamic_axes_1: 1, does the import work:

AxonOnnx.import("model.onnx", mel_spectrogram_dynamic_axes_1: 1)

For some deserializations to work we need to know shapes up front. If this does not work, can you tell me where to get the model so I can debug?

seanmor5 avatar Dec 07 '22 10:12 seanmor5

If I specify the axis dimension or re-export the onnx model without a dynamic dimension it won't fail in that way, but it does fail on another op (I have re-parametrized the axis labels, but it's the same model from before):

iex(1)> {model, params} = AxonOnnx.import("whisper-onnx/model.onnx", batch: 1, feature_size: 1, encoder_sequence: 1, decoder_sequence: 1)
** (ArgumentError) unable to build model from ONNX graph, expected value /model/encoder/layers.0/self_attn/Concat_output_0 to be constant value, but was :concatenate
    (axon_onnx 0.3.0) lib/axon_onnx/deserialize.ex:2286: AxonOnnx.Deserialize.constant!/4
    (axon_onnx 0.3.0) lib/axon_onnx/deserialize.ex:1201: AxonOnnx.Deserialize.recur_nodes/2
    (elixir 1.14.0) lib/enum.ex:2468: Enum."-reduce/3-lists^foldl/2-0-"/3
    (axon_onnx 0.3.0) lib/axon_onnx/deserialize.ex:44: AxonOnnx.Deserialize.graph_to_axon/2
    (axon_onnx 0.3.0) lib/axon_onnx/deserialize.ex:27: AxonOnnx.Deserialize.to_axon/2
    iex:1: (file)

The model is an encoder-decoder transformer (whisper tiny.en). The onnx model is generated with:

python -m transformers.onnx --model=openai/whisper-tiny.en --feature=speech2seq-lm whisper-onnx/

mortont avatar Dec 08 '22 17:12 mortont