wonnx icon indicating copy to clipboard operation
wonnx copied to clipboard

Err` value: IrError(OutputNodeNotFound("/linear/MatMul_output_0")) on linear model

Open maxwellflitton opened this issue 1 year ago • 8 comments

Describe the bug I've trained a simple linear model in pytorch. I then export it to ONNX. Calling from the ONNX library it works fine. However, when trying to call from wonnx I get the error Err value: IrError(OutputNodeNotFound("/linear/MatMul_output_0")). Looking at the model in neuron everything seems to make sense and in my settings I define the output at the name of 5 as this is the output, I don't know why wonnx is erroring here when onnx works fine.

To Reproduce Train a simple linear model with the following code:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

squarefoot = np.array([1000, 1200, 1500, 1800, 2000, 2200, 2500, 2800, 3000, 3200], dtype=np.float32)
num_floors = np.array([1, 1, 1.5, 1.5, 2, 2, 2.5, 2.5, 3, 3], dtype=np.float32)
house_price = np.array([200000, 230000, 280000, 320000, 350000, 380000, 420000, 470000, 500000, 520000], dtype=np.float32)

squarefoot_mean = squarefoot.mean()
squarefoot_std = squarefoot.std()
num_floors_mean = num_floors.mean()
num_floors_std = num_floors.std()
house_price_mean = house_price.mean()
house_price_std = house_price.std()

# Normalize the data (optional, but recommended for better convergence)
squarefoot = (squarefoot - squarefoot.mean()) / squarefoot.std()
num_floors = (num_floors - num_floors.mean()) / num_floors.std()
house_price = (house_price - house_price.mean()) / house_price.std()

# Convert numpy arrays to PyTorch tensors
squarefoot_tensor = torch.from_numpy(squarefoot)
num_floors_tensor = torch.from_numpy(num_floors)
house_price_tensor = torch.from_numpy(house_price)


# Define the linear regression model
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(2, 1)  # 2 input features, 1 output

    def forward(self, x):
        return self.linear(x)

# Initialize the model
model = LinearRegressionModel()

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
#
# # Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    # Forward pass
    y_pred = model(X)

    # Compute the loss
    loss = criterion(y_pred.squeeze(), house_price_tensor)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print the progress
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

test_squarefoot = torch.tensor([2800, 3200], dtype=torch.float32)
test_num_floors = torch.tensor([2.5, 3], dtype=torch.float32)
test_inputs = torch.stack([test_squarefoot, test_num_floors], dim=1)
test_inputs = torch.tensor([2800, 3], dtype=torch.float32)

# Test the model
with torch.no_grad():
    predicted_prices = model(test_inputs)
    predicted_prices = predicted_prices.squeeze().numpy()
    print("Predicted Prices:", predicted_prices)

I then perform an onnx export with the following code:

# export to ONNX and save file
torch.onnx.export(model, test_inputs, "./linear_test.onnx")

I then load the model in rust with the following code:

use std::collections::HashMap;
use ndarray::{ArrayD, CowArray};
use std::sync::Arc;
use wonnx::Session;
use wonnx::utils::{InputTensor, OutputTensor, tensor};
use wonnx::SessionConfig;

use std::fs::File;
use std::io::{Read, Result};

pub async fn load_model() {
    let mut file = File::open("./linear_test.onnx").unwrap();

    let mut buffer = Vec::new();

    file.read_to_end(&mut buffer).unwrap();
    let config = SessionConfig::new().with_outputs(Some(vec!["5".to_string()]));
    let session = Session::from_bytes_with_config(&buffer, &config).await.unwrap();
    let mut inputs = HashMap::new();
    inputs.insert("onnx::MatMul_0".to_string(), InputTensor::F32(vec![1000.0, 2.0].into()));
    let outputs = session.run(&inputs).await.unwrap();
    println!("file: {:?}", outputs);
}

and I get the error with the following line:

let session = Session::from_bytes_with_config(&buffer, &config).await.unwrap();

Expected behavior Merely to run a simple inference

Screenshots When inspecting the onnx file all the weights seem to match up or am I missing something here?

Screenshot 2023-10-03 at 15 14 16 Screenshot 2023-10-03 at 15 14 29 Screenshot 2023-10-03 at 15 14 40 Screenshot 2023-10-03 at 15 15 06

Desktop

  • OS: MacOs (Ventura V13.4)
  • Chip: Apple M2 Max
  • RAM: 96GB
  • Hard Drive: 3.62TB available of 4TB
  • model: 16-inch 2023

maxwellflitton avatar Oct 03 '23 14:10 maxwellflitton

is there any update on this? Can anyone help? I've tried it again, and still getting the same error

maxwellflitton avatar Jan 13 '24 02:01 maxwellflitton

@maxwellflitton the error indicates an output from an earlier node cannot be found in the ONNX file. Can you share your ONNX file?

pixelspark avatar Jan 13 '24 15:01 pixelspark

I have a similar issue. Here is the onnx file elections.zip (based on https://huggingface.github.io/candle/training/simplified.html)

In my case, it's IrError(OutputNodeNotFound("/ln1/Gemm_output_0")).

notdanilo avatar Mar 12 '24 02:03 notdanilo

I did some debugging. The error is coming from here:

https://github.com/webonnx/wonnx/blob/7880ed8e6d95857e731341bd022b5e2eb8d1bb75/wonnx/src/ir.rs#L24-L26

The shapes are acquired here:

https://github.com/webonnx/wonnx/blob/7880ed8e6d95857e731341bd022b5e2eb8d1bb75/wonnx/src/ir.rs#L198-L208

But I am assuming our onnx files doesn't have definitions for them because they need to be inferred. Is that assumption correct? I don't know the onnx standard.

notdanilo avatar Mar 12 '24 02:03 notdanilo

The onnx file needs to be pre-processed to infer shapes and save them back in the file. @maxwellflitton https://github.com/webonnx/wonnx?tab=readme-ov-file#shape-inference

notdanilo avatar Mar 12 '24 04:03 notdanilo

@notdanilo I am facing the same issue getting SessionError(IrError(OutputNodeNotFound("onnx::Reshape_1988"))) with a model exported using pytorch 1.12

When I try to infer shapes and save them back to a file

nnx prepare my_model.onnx my_model_prepared.onnx --set batch_size=1 --set sequence_length=255 -i I get: Error: Could not infer shapes: unsupported: Split

Any ideas, why this is happening?

Dainerx avatar Mar 12 '24 07:03 Dainerx

@Dainerx

Split shape inference isn't supported in nnx. https://github.com/webonnx/wonnx?tab=readme-ov-file#supported-operators-ref-onnx-ir

Try this:

pip install onnx-simplifier
python -m onnxsim <input.onnx> <output.onnx>

notdanilo avatar Mar 12 '24 07:03 notdanilo

Running into a similar issue with the Equal and Not operations after running the model through onnxsim. Looks like it is unsupported by nnx.

Any pointers on where to get started if I wanted to send a PR? I know where this exists in the wonnx codebase, was just hoping there might be a reference that guided that implementation for the other ops

astnmsn avatar Mar 27 '24 02:03 astnmsn