wonnx
wonnx copied to clipboard
Can't run a single linear layer
Describe the bug
I try to export a single linear layer from PyTorch and get one of the following errors.
Error 1:
GpuError(CompileError { node: "Gemm_0", error: InvalidInputShape { input_index: 1, input_shape: Shape { dims: [10, 784], data_type: F32 } } })
Error 2:
IrError(OutputNodeNotFound("onnx::Add_4"))
I viewed the resulting onnx file at netron.app at it appeared to be correct.
To Reproduce
- Run the following script
torch_model = torch.nn.Linear(784, 10)
model_input = torch.zeros((1, 784)) #This results in error 1. Changing shape to (784,) results in error 2
torch.onnx.export(torch_model, # model being run
model_input, # model input (or a tuple for multiple inputs)
"onnx/model.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
- Optionally run onnx-simplifier but it doesn't do anything on such a simple model.
- Run the following rust program
fn main() {
#[cfg(not(target_arch = "wasm32"))]
{
pollster::block_on(run());
}
}
async fn run () {
let model_path = Path::new("onnx/model.onnx");
let _session = wonnx::Session::from_path(model_path).await.unwrap();
}
Expected behavior The model should load successfully.
Desktop PopOS 20.04
Can you post the onnx file here?
The 1 and 2 in the file names correspond to the 2 different errors mentioned above. linear.zip
@pixelspark Any luck replicating?
As for the first problem, long story short: WONNX does not implement Gemm when transA/transB!=0.
The error seems to be the result of this check:
let dim_n = output_shape.dim(1);
//...
if dim_n != input_right_shape.dim(1) || dim_k != input_right_shape.dim(0) {
return Err(CompileError::InvalidInputShape {
input_index: 1,
input_shape: input_right_shape.clone(),
});
}
This check follows from the spec. Note that when transA or transB is non-zero (as it is in your example), the constraints are transposed. WONNX should throw an error earlier if it detects non-zero transA or transB, but for some reason it doesn't in this case (according to Netron, transB==1).
As for the second error: it seems something is just wrong with the ONNX file. Netron also has difficulty finding the type for onnx::Add_4
input for the Add operation (notice in the below screenshot there is no '+' to the right of onnx::Add_4
in the right bar):
data:image/s3,"s3://crabby-images/88010/8801076d9b3f719c427008b809b1253a7d4a1aad" alt="image"
IIUC the name onnx::Add_4
is not valid anyway (it needs to be a valid C90 identifier). Something else to check is whether all nodes are in DAG order in the ONNX file.
@pixelspark Thanks for the explanation but still I'm not sure how to resolve the issue exactly. Do you have any idea why PyTorch might be exporting these incorrectly? Is there a configuration I have wrong in my exporter script for example?
I am having the same issue, any ideas?
@JulienSiems I got the same issue as @Ryul0rd. I think this is due to PyTorch implementation of Linear
, in which it stores weights in transposed manner:
def __init__(self, ...) -> None:
...
self.weight = Parameter(
torch.empty((out_features, in_features), **factory_kwargs)
)
...
and its forward
is implemented like this:
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.weight, self.bias)
which is equivalent to output = input.matmul(weight.t())
. Thus, the exported onnx produces transB=1
which can't be handled yet in current state of WONNX. I ended up implementing my own linear module (which is a slight modification of the original one):
class LinearCustom(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(
# torch.empty((out_features, in_features), **factory_kwargs) <-- change this
torch.empty((in_features, out_features), **factory_kwargs) # <-- to this
)
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor) -> Tensor:
# return F.linear(input, self.weight, self.bias) <-- change this
return torch.matmul(input, self.weight) + self.bias # <-- to this
def extra_repr(self) -> str:
return "in_features={}, out_features={}, bias={}".format(
self.in_features, self.out_features, self.bias is not None
)
I still need to run python3 -m onnxsim model.onnx model-sim.onnx
to make the module naming valid and compatible with WONNX. With this the error is solved.
@pixelspark I'm wondering how much changes required if we implement handling the case of transA/transB!=0
? 🤔