wonnx icon indicating copy to clipboard operation
wonnx copied to clipboard

Can't run a single linear layer

Open Ryul0rd opened this issue 2 years ago • 8 comments

Describe the bug I try to export a single linear layer from PyTorch and get one of the following errors. Error 1: GpuError(CompileError { node: "Gemm_0", error: InvalidInputShape { input_index: 1, input_shape: Shape { dims: [10, 784], data_type: F32 } } }) Error 2: IrError(OutputNodeNotFound("onnx::Add_4"))

I viewed the resulting onnx file at netron.app at it appeared to be correct.

To Reproduce

  1. Run the following script
torch_model = torch.nn.Linear(784, 10)
model_input = torch.zeros((1, 784))    #This results in error 1. Changing shape to (784,) results in error 2
torch.onnx.export(torch_model,           # model being run
                  model_input,                      # model input (or a tuple for multiple inputs)
                  "onnx/model.onnx",           # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,             # the ONNX version to export the model to
                  do_constant_folding=True, # whether to execute constant folding for optimization
                  input_names = ['input'],     # the model's input names
                  output_names = ['output'], # the model's output names
  1. Optionally run onnx-simplifier but it doesn't do anything on such a simple model.
  2. Run the following rust program
fn main() {
    #[cfg(not(target_arch = "wasm32"))]

async fn run () {
    let model_path = Path::new("onnx/model.onnx");
    let _session = wonnx::Session::from_path(model_path).await.unwrap();

Expected behavior The model should load successfully.

Desktop PopOS 20.04

Ryul0rd avatar Sep 10 '22 05:09 Ryul0rd

Can you post the onnx file here?

pixelspark avatar Sep 10 '22 09:09 pixelspark

The 1 and 2 in the file names correspond to the 2 different errors mentioned above. linear.zip

Ryul0rd avatar Sep 10 '22 09:09 Ryul0rd

@pixelspark Any luck replicating?

Ryul0rd avatar Sep 17 '22 18:09 Ryul0rd

As for the first problem, long story short: WONNX does not implement Gemm when transA/transB!=0.

The error seems to be the result of this check:

let dim_n = output_shape.dim(1);
if dim_n != input_right_shape.dim(1) || dim_k != input_right_shape.dim(0) {
                return Err(CompileError::InvalidInputShape {
                    input_index: 1,
                    input_shape: input_right_shape.clone(),

This check follows from the spec. Note that when transA or transB is non-zero (as it is in your example), the constraints are transposed. WONNX should throw an error earlier if it detects non-zero transA or transB, but for some reason it doesn't in this case (according to Netron, transB==1).

pixelspark avatar Sep 18 '22 18:09 pixelspark

As for the second error: it seems something is just wrong with the ONNX file. Netron also has difficulty finding the type for onnx::Add_4 input for the Add operation (notice in the below screenshot there is no '+' to the right of onnx::Add_4 in the right bar):


IIUC the name onnx::Add_4 is not valid anyway (it needs to be a valid C90 identifier). Something else to check is whether all nodes are in DAG order in the ONNX file.

pixelspark avatar Sep 18 '22 18:09 pixelspark

@pixelspark Thanks for the explanation but still I'm not sure how to resolve the issue exactly. Do you have any idea why PyTorch might be exporting these incorrectly? Is there a configuration I have wrong in my exporter script for example?

Ryul0rd avatar Sep 19 '22 06:09 Ryul0rd

I am having the same issue, any ideas?

JulienSiems avatar Jul 18 '23 16:07 JulienSiems

@JulienSiems I got the same issue as @Ryul0rd. I think this is due to PyTorch implementation of Linear, in which it stores weights in transposed manner:

    def __init__(self, ...) -> None:


        self.weight = Parameter(
            torch.empty((out_features, in_features), **factory_kwargs)


and its forward is implemented like this:

    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias) 

which is equivalent to output = input.matmul(weight.t()). Thus, the exported onnx produces transB=1 which can't be handled yet in current state of WONNX. I ended up implementing my own linear module (which is a slight modification of the original one):

class LinearCustom(torch.nn.Module):
    __constants__ = ["in_features", "out_features"]
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(
        in_features: int,
        out_features: int,
        bias: bool = True,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(
            # torch.empty((out_features, in_features), **factory_kwargs)  <-- change this
            torch.empty((in_features, out_features), **factory_kwargs)  # <-- to this
        if bias:
            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
            self.register_parameter("bias", None)

    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        # return F.linear(input, self.weight, self.bias)  <-- change this
        return torch.matmul(input, self.weight) + self.bias  # <-- to this

    def extra_repr(self) -> str:
        return "in_features={}, out_features={}, bias={}".format(
            self.in_features, self.out_features, self.bias is not None

I still need to run python3 -m onnxsim model.onnx model-sim.onnx to make the module naming valid and compatible with WONNX. With this the error is solved.

@pixelspark I'm wondering how much changes required if we implement handling the case of transA/transB!=0? 🤔

ariaghora avatar Aug 24 '23 10:08 ariaghora