ntg-unity icon indicating copy to clipboard operation
ntg-unity copied to clipboard

OnnxImportException: Unexpected error while parsing layer Unsqueeze__2283:0 of type Unsqueeze when importing models trained with JAX

Open hayden-donnelly opened this issue 1 year ago • 1 comments

Models that produce this error are trained with JAX, then converted into TF SavedModel with the following code:

Conversion with polymorphic shapes (ONNX models derived from this cause Netron to hang, so may the conversion may not be correct):

tf_module = tf.Module()
state_vars = tf.nest.map_structure(tf.Variable, state.params)
tf_module.vars = tf.nest.flatten(state_vars)
predict_fn = jax2tf.convert(
    state.apply_fn, 
    enable_xla=False, 
    polymorphic_shapes=["...", ("b, 256, 256, 1", "b, 1, 1, 1")]
)

input_signature = [
    tf.TensorSpec(shape=(None, image_width, image_height, channels), dtype=tf.float32),
    tf.TensorSpec(shape=(None, 1, 1, 1), dtype=tf.float32)
]
@tf.function(autograph=False, input_signature=input_signature)
def predict(images, noise_variances):
    return predict_fn({'params': state_vars}, (images, noise_variances))
    
tf_module.predict = predict
tf.saved_model.save(tf_module, "./data/temp/saved_model")

Conversion without polymorphic shapes:

tf_module = tf.Module()
state_vars = tf.nest.map_structure(tf.Variable, state.params)
tf_module.vars = tf.nest.flatten(state_vars)
predict_fn = jax2tf.convert(state.apply_fn)

input_signature = [(
    tf.TensorSpec(shape=(1, image_width, image_height, channels), dtype=tf.float32),
    tf.TensorSpec(shape=(1, 1, 1, 1), dtype=tf.float32)
)]
@tf.function(autograph=False, input_signature=input_signature)
def predict(data):
    return predict_fn({'params': state_vars}, data)

tf_module.predict = predict
tf.saved_model.save(tf_module, "./data/temp/saved_model")

Both of these conversions ultimately produce the error.

Once the models are in TF SavedModel format, they are converted to ONNX with the following tf2onnx command:

python3.9 -m tf2onnx.convert --saved-model ./saved_model --output ./pix_onnx_test.onnx --opset 15

Finally, the ONNX model is imported into Unity and generates the following error:

OnnxImportException: Unexpected error while parsing layer Unsqueeze__2283:0 of type Unsqueeze.
Couldn't find attribute axes of type Ints

Json: { "input": [ "StatefulPartitionedCall/jax2tf_apply_/DDIM/ResidualBlock_1/pjit_fn_/AddV2:0", "einsum139953479970976_ba_left_set__2179" ], "output": [ "Unsqueeze__2283:0" ], "name": "Unsqueeze__2283", "opType": "Unsqueeze" }
  at Unity.Barracuda.ONNX.ONNXNodeWrapper.FindAttribute (System.String name, Onnx.AttributeProto+Types+AttributeType type) [0x00010] in C:\Users\Hayden\Desktop\ntg-unity\Library\PackageCache\[email protected]\Barracuda\Runtime\ONNX\ONNXNodeWrapper.cs:281 
  at Unity.Barracuda.ONNX.ONNXNodeWrapper.GetRequiredIntArray (System.String name) [0x00000] in C:\Users\Hayden\Desktop\ntg-unity\Library\PackageCache\[email protected]\Barracuda\Runtime\ONNX\ONNXNodeWrapper.cs:329 
  at Unity.Barracuda.ONNX.ONNXNodeWrapper.get_Axes () [0x00000] in C:\Users\Hayden\Desktop\ntg-unity\Library\PackageCache\[email protected]\Barracuda\Runtime\ONNX\ONNXNodeWrapper.cs:107 
  at Unity.Barracuda.ONNX.ONNXModelConverter+<>c__DisplayClass23_0.<UseStandardImporter>b__5 (Unity.Barracuda.ModelBuilder net, Unity.Barracuda.ONNX.ONNXNodeWrapper node) [0x0005f] in C:\Users\Hayden\Desktop\ntg-unity\Library\PackageCache\[email protected]\Barracuda\Runtime\ONNX\ONNXModelConverter.cs:306 
  at Unity.Barracuda.ONNX.ONNXModelConverter.ConvertOnnxModel (Onnx.ModelProto onnxModel) [0x003a7] in C:\Users\Hayden\Desktop\ntg-unity\Library\PackageCache\[email protected]\Barracuda\Runtime\ONNX\ONNXModelConverter.cs:2798 

Unity.Barracuda.ONNX.ONNXModelConverter.Err (Unity.Barracuda.Model model, System.String layerName, System.String message, System.String extendedMessage, System.String debugMessage) (at Library/PackageCache/[email protected]/Barracuda/Runtime/ONNX/ONNXModelConverter.cs:3298)
Unity.Barracuda.ONNX.ONNXModelConverter.ConvertOnnxModel (Onnx.ModelProto onnxModel) (at Library/PackageCache/[email protected]/Barracuda/Runtime/ONNX/ONNXModelConverter.cs:2807)
Unity.Barracuda.ONNX.ONNXModelConverter.Convert (Google.Protobuf.CodedInputStream inputStream) (at Library/PackageCache/[email protected]/Barracuda/Runtime/ONNX/ONNXModelConverter.cs:155)
Unity.Barracuda.ONNX.ONNXModelConverter.Convert (System.String filePath) (at Library/PackageCache/[email protected]/Barracuda/Runtime/ONNX/ONNXModelConverter.cs:83)
Unity.Barracuda.ONNXModelImporter.OnImportAsset (UnityEditor.AssetImporters.AssetImportContext ctx) (at Library/PackageCache/[email protected]/Barracuda/Editor/ONNXModelImporter.cs:58)
UnityEditor.AssetImporters.ScriptedImporter.GenerateAssetData (UnityEditor.AssetImporters.AssetImportContext ctx) (at <36f62d8e760b48f7af5d32916f997ce1>:0)

hayden-donnelly avatar Sep 05 '23 02:09 hayden-donnelly