SCUNet icon indicating copy to clipboard operation
SCUNet copied to clipboard

Error in converting to ONNX model

Open ayazhassan opened this issue 2 years ago • 11 comments

I am getting the following error, while trying to convert the pre-trained model to ONNX model. Can you please look into it and let me know that the pre-trained weights were generated using the current updated model? Conversion code is provided after the error.

Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Traceback (most recent call last): File "/home/ayaz_khan/SCUNet/onnx.py", line 2, in import torch.onnx File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/onnx/init.py", line 57, in from ._internal.onnxruntime import ( File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/onnx/_internal/onnxruntime.py", line 34, in import onnx File "/home/ayaz_khan/SCUNet/onnx.py", line 25, in convert_to_onnx(model_path, onnx_path) File "/home/ayaz_khan/SCUNet/onnx.py", line 8, in convert_to_onnx model.load_state_dict(torch.load(model_path)) File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SCUNet: Missing key(s) in state_dict: "m_down1.2.weight", "m_down2.2.weight", "m_down3.2.weight". Unexpected key(s) in state_dict: "m_down1.3.trans_block.ln1.weight", "m_down1.3.trans_block.ln1.bias", "m_down1.3.trans_block.msa.relative_position_params", "m_down1.3.trans_block.msa.embedding_layer.weight", "m_down1.3.trans_block.msa.embedding_layer.bias", "m_down1.3.trans_block.msa.linear.weight", "m_down1.3.trans_block.msa.linear.bias", "m_down1.3.trans_block.ln2.weight", "m_down1.3.trans_block.ln2.bias", "m_down1.3.trans_block.mlp.0.weight", "m_down1.3.trans_block.mlp.0.bias", "m_down1.3.trans_block.mlp.2.weight", "m_down1.3.trans_block.mlp.2.bias", "m_down1.3.conv1_1.weight", "m_down1.3.conv1_1.bias", "m_down1.3.conv1_2.weight", "m_down1.3.conv1_2.bias", "m_down1.3.conv_block.0.weight", "m_down1.3.conv_block.2.weight", "m_down1.4.weight", "m_down1.2.trans_block.ln1.weight", "m_down1.2.trans_block.ln1.bias", "m_down1.2.trans_block.msa.relative_position_params", "m_down1.2.trans_block.msa.embedding_layer.weight", "m_down1.2.trans_block.msa.embedding_layer.bias", "m_down1.2.trans_block.msa.linear.weight", "m_down1.2.trans_block.msa.linear.bias", "m_down1.2.trans_block.ln2.weight", "m_down1.2.trans_block.ln2.bias", "m_down1.2.trans_block.mlp.0.weight", "m_down1.2.trans_block.mlp.0.bias", "m_down1.2.trans_block.mlp.2.weight", "m_down1.2.trans_block.mlp.2.bias", "m_down1.2.conv1_1.weight", "m_down1.2.conv1_1.bias", "m_down1.2.conv1_2.weight", "m_down1.2.conv1_2.bias", "m_down1.2.conv_block.0.weight", "m_down1.2.conv_block.2.weight", "m_down2.3.trans_block.ln1.weight", "m_down2.3.trans_block.ln1.bias", "m_down2.3.trans_block.msa.relative_position_params", "m_down2.3.trans_block.msa.embedding_layer.weight", "m_down2.3.trans_block.msa.embedding_layer.bias", "m_down2.3.trans_block.msa.linear.weight", "m_down2.3.trans_block.msa.linear.bias", "m_down2.3.trans_block.ln2.weight", "m_down2.3.trans_block.ln2.bias", "m_down2.3.trans_block.mlp.0.weight", "m_down2.3.trans_block.mlp.0.bias", "m_down2.3.trans_block.mlp.2.weight", "m_down2.3.trans_block.mlp.2.bias", "m_down2.3.conv1_1.weight", "m_down2.3.conv1_1.bias", "m_down2.3.conv1_2.weight", "m_down2.3.conv1_2.bias", "m_down2.3.conv_block.0.weight", "m_down2.3.conv_block.2.weight", "m_down2.4.weight", "m_down2.2.trans_block.ln1.weight", "m_down2.2.trans_block.ln1.bias", "m_down2.2.trans_block.msa.relative_position_params", "m_down2.2.trans_block.msa.embedding_layer.weight", "m_down2.2.trans_block.msa.embedding_layer.bias", "m_down2.2.trans_block.msa.linear.weight", "m_down2.2.trans_block.msa.linear.bias", "m_down2.2.trans_block.ln2.weight", "m_down2.2.trans_block.ln2.bias", "m_down2.2.trans_block.mlp.0.weight", "m_down2.2.trans_block.mlp.0.bias", "m_down2.2.trans_block.mlp.2.weight", "m_down2.2.trans_block.mlp.2.bias", "m_down2.2.conv1_1.weight", "m_down2.2.conv1_1.bias", "m_down2.2.conv1_2.weight", "m_down2.2.conv1_2.bias", "m_down2.2.conv_block.0.weight", "m_down2.2.conv_block.2.weight", "m_down3.3.trans_block.ln1.weight", "m_down3.3.trans_block.ln1.bias", "m_down3.3.trans_block.msa.relative_position_params", "m_down3.3.trans_block.msa.embedding_layer.weight", "m_down3.3.trans_block.msa.embedding_layer.bias", "m_down3.3.trans_block.msa.linear.weight", "m_down3.3.trans_block.msa.linear.bias", "m_down3.3.trans_block.ln2.weight", "m_down3.3.trans_block.ln2.bias", "m_down3.3.trans_block.mlp.0.weight", "m_down3.3.trans_block.mlp.0.bias", "m_down3.3.trans_block.mlp.2.weight", "m_down3.3.trans_block.mlp.2.bias", "m_down3.3.conv1_1.weight", "m_down3.3.conv1_1.bias", "m_down3.3.conv1_2.weight", "m_down3.3.conv1_2.bias", "m_down3.3.conv_block.0.weight", "m_down3.3.conv_block.2.weight", "m_down3.4.weight", "m_down3.2.trans_block.ln1.weight", "m_down3.2.trans_block.ln1.bias", "m_down3.2.trans_block.msa.relative_position_params", "m_down3.2.trans_block.msa.embedding_layer.weight", "m_down3.2.trans_block.msa.embedding_layer.bias", "m_down3.2.trans_block.msa.linear.weight", "m_down3.2.trans_block.msa.linear.bias", "m_down3.2.trans_block.ln2.weight", "m_down3.2.trans_block.ln2.bias", "m_down3.2.trans_block.mlp.0.weight", "m_down3.2.trans_block.mlp.0.bias", "m_down3.2.trans_block.mlp.2.weight", "m_down3.2.trans_block.mlp.2.bias", "m_down3.2.conv1_1.weight", "m_down3.2.conv1_1.bias", "m_down3.2.conv1_2.weight", "m_down3.2.conv1_2.bias", "m_down3.2.conv_block.0.weight", "m_down3.2.conv_block.2.weight", "m_body.2.trans_block.ln1.weight", "m_body.2.trans_block.ln1.bias", "m_body.2.trans_block.msa.relative_position_params", "m_body.2.trans_block.msa.embedding_layer.weight", "m_body.2.trans_block.msa.embedding_layer.bias", "m_body.2.trans_block.msa.linear.weight", "m_body.2.trans_block.msa.linear.bias", "m_body.2.trans_block.ln2.weight", "m_body.2.trans_block.ln2.bias", "m_body.2.trans_block.mlp.0.weight", "m_body.2.trans_block.mlp.0.bias", "m_body.2.trans_block.mlp.2.weight", "m_body.2.trans_block.mlp.2.bias", "m_body.2.conv1_1.weight", "m_body.2.conv1_1.bias", "m_body.2.conv1_2.weight", "m_body.2.conv1_2.bias", "m_body.2.conv_block.0.weight", "m_body.2.conv_block.2.weight", "m_body.3.trans_block.ln1.weight", "m_body.3.trans_block.ln1.bias", "m_body.3.trans_block.msa.relative_position_params", "m_body.3.trans_block.msa.embedding_layer.weight", "m_body.3.trans_block.msa.embedding_layer.bias", "m_body.3.trans_block.msa.linear.weight", "m_body.3.trans_block.msa.linear.bias", "m_body.3.trans_block.ln2.weight", "m_body.3.trans_block.ln2.bias", "m_body.3.trans_block.mlp.0.weight", "m_body.3.trans_block.mlp.0.bias", "m_body.3.trans_block.mlp.2.weight", "m_body.3.trans_block.mlp.2.bias", "m_body.3.conv1_1.weight", "m_body.3.conv1_1.bias", "m_body.3.conv1_2.weight", "m_body.3.conv1_2.bias", "m_body.3.conv_block.0.weight", "m_body.3.conv_block.2.weight", "m_up3.3.trans_block.ln1.weight", "m_up3.3.trans_block.ln1.bias", "m_up3.3.trans_block.msa.relative_position_params", "m_up3.3.trans_block.msa.embedding_layer.weight", "m_up3.3.trans_block.msa.embedding_layer.bias", "m_up3.3.trans_block.msa.linear.weight", "m_up3.3.trans_block.msa.linear.bias", "m_up3.3.trans_block.ln2.weight", "m_up3.3.trans_block.ln2.bias", "m_up3.3.trans_block.mlp.0.weight", "m_up3.3.trans_block.mlp.0.bias", "m_up3.3.trans_block.mlp.2.weight", "m_up3.3.trans_block.mlp.2.bias", "m_up3.3.conv1_1.weight", "m_up3.3.conv1_1.bias", "m_up3.3.conv1_2.weight", "m_up3.3.conv1_2.bias", "m_up3.3.conv_block.0.weight", "m_up3.3.conv_block.2.weight", "m_up3.4.trans_block.ln1.weight", "m_up3.4.trans_block.ln1.bias", "m_up3.4.trans_block.msa.relative_position_params", "m_up3.4.trans_block.msa.embedding_layer.weight", "m_up3.4.trans_block.msa.embedding_layer.bias", "m_up3.4.trans_block.msa.linear.weight", "m_up3.4.trans_block.msa.linear.bias", "m_up3.4.trans_block.ln2.weight", "m_up3.4.trans_block.ln2.bias", "m_up3.4.trans_block.mlp.0.weight", "m_up3.4.trans_block.mlp.0.bias", "m_up3.4.trans_block.mlp.2.weight", "m_up3.4.trans_block.mlp.2.bias", "m_up3.4.conv1_1.weight", "m_up3.4.conv1_1.bias", "m_up3.4.conv1_2.weight", "m_up3.4.conv1_2.bias", "m_up3.4.conv_block.0.weight", "m_up3.4.conv_block.2.weight", "m_up2.3.trans_block.ln1.weight", "m_up2.3.trans_block.ln1.bias", "m_up2.3.trans_block.msa.relative_position_params", "m_up2.3.trans_block.msa.embedding_layer.weight", "m_up2.3.trans_block.msa.embedding_layer.bias", "m_up2.3.trans_block.msa.linear.weight", "m_up2.3.trans_block.msa.linear.bias", "m_up2.3.trans_block.ln2.weight", "m_up2.3.trans_block.ln2.bias", "m_up2.3.trans_block.mlp.0.weight", "m_up2.3.trans_block.mlp.0.bias", "m_up2.3.trans_block.mlp.2.weight", "m_up2.3.trans_block.mlp.2.bias", "m_up2.3.conv1_1.weight", "m_up2.3.conv1_1.bias", "m_up2.3.conv1_2.weight", "m_up2.3.conv1_2.bias", "m_up2.3.conv_block.0.weight", "m_up2.3.conv_block.2.weight", "m_up2.4.trans_block.ln1.weight", "m_up2.4.trans_block.ln1.bias", "m_up2.4.trans_block.msa.relative_position_params", "m_up2.4.trans_block.msa.embedding_layer.weight", "m_up2.4.trans_block.msa.embedding_layer.bias", "m_up2.4.trans_block.msa.linear.weight", "m_up2.4.trans_block.msa.linear.bias", "m_up2.4.trans_block.ln2.weight", "m_up2.4.trans_block.ln2.bias", "m_up2.4.trans_block.mlp.0.weight", "m_up2.4.trans_block.mlp.0.bias", "m_up2.4.trans_block.mlp.2.weight", "m_up2.4.trans_block.mlp.2.bias", "m_up2.4.conv1_1.weight", "m_up2.4.conv1_1.bias", "m_up2.4.conv1_2.weight", "m_up2.4.conv1_2.bias", "m_up2.4.conv_block.0.weight", "m_up2.4.conv_block.2.weight", "m_up1.3.trans_block.ln1.weight", "m_up1.3.trans_block.ln1.bias", "m_up1.3.trans_block.msa.relative_position_params", "m_up1.3.trans_block.msa.embedding_layer.weight", "m_up1.3.trans_block.msa.embedding_layer.bias", "m_up1.3.trans_block.msa.linear.weight", "m_up1.3.trans_block.msa.linear.bias", "m_up1.3.trans_block.ln2.weight", "m_up1.3.trans_block.ln2.bias", "m_up1.3.trans_block.mlp.0.weight", "m_up1.3.trans_block.mlp.0.bias", "m_up1.3.trans_block.mlp.2.weight", "m_up1.3.trans_block.mlp.2.bias", "m_up1.3.conv1_1.weight", "m_up1.3.conv1_1.bias", "m_up1.3.conv1_2.weight", "m_up1.3.conv1_2.bias", "m_up1.3.conv_block.0.weight", "m_up1.3.conv_block.2.weight", "m_up1.4.trans_block.ln1.weight", "m_up1.4.trans_block.ln1.bias", "m_up1.4.trans_block.msa.relative_position_params", "m_up1.4.trans_block.msa.embedding_layer.weight", "m_up1.4.trans_block.msa.embedding_layer.bias", "m_up1.4.trans_block.msa.linear.weight", "m_up1.4.trans_block.msa.linear.bias", "m_up1.4.trans_block.ln2.weight", "m_up1.4.trans_block.ln2.bias", "m_up1.4.trans_block.mlp.0.weight", "m_up1.4.trans_block.mlp.0.bias", "m_up1.4.trans_block.mlp.2.weight", "m_up1.4.trans_block.mlp.2.bias", "m_up1.4.conv1_1.weight", "m_up1.4.conv1_1.bias", "m_up1.4.conv1_2.weight", "m_up1.4.conv1_2.bias", "m_up1.4.conv_block.0.weight", "m_up1.4.conv_block.2.weight".

import torch
import torch.onnx
from models.network_scunet import SCUNet  # Assuming this is the SCUNet model definition

def convert_to_onnx(model_path, onnx_path, input_shape=(1, 3, 256, 256)):
    # Load the pre-trained PyTorch model
    model = SCUNet()
    model.load_state_dict(torch.load(model_path))
    
    # Set the model to evaluation mode
    model.eval()

    # Define dummy input data
    dummy_input = torch.randn(input_shape)

    # Convert the model to ONNX format
    torch.onnx.export(model, dummy_input, onnx_path, verbose=True)

    print(f"Model converted to ONNX format and saved as {onnx_path}")

# Paths
model_path = "model_zoo/scunet_color_real_psnr.pth"
onnx_path = "./scunet_color_real_gan.onnx"

convert_to_onnx(model_path, onnx_path)

ayazhassan avatar Oct 26 '23 09:10 ayazhassan

I've converted the models toONNX (dynamic axes for input size). Models are working but take only square input images Divisible by 128

instant-high avatar Apr 24 '24 15:04 instant-high

I managed to output the model in onnx form, here is my code:


net = SCUNet()
onnx_path = "onnx_model_name.onnx"
torch.onnx.export(net, torch.randn((2, 3, 64, 64)), onnx_path)

But remember to run the model's classes and methods all through before doing so

jumpxiu avatar Jul 01 '24 06:07 jumpxiu

I've converted the models toONNX (dynamic axes for input size). Models are working but take only square input images Divisible by 128

Hello, I just checked your homepage but didn't find any content about SCUNet. Can we discuss it?

Dlew-DUT avatar Feb 14 '25 08:02 Dlew-DUT

I am getting the following error, while trying to convert the pre-trained model to ONNX model. Can you please look into it and let me know that the pre-trained weights were generated using the current updated model? Conversion code is provided after the error.

Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Traceback (most recent call last): File "/home/ayaz_khan/SCUNet/onnx.py", line 2, in import torch.onnx File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/onnx/init.py", line 57, in from ._internal.onnxruntime import ( File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/onnx/_internal/onnxruntime.py", line 34, in import onnx File "/home/ayaz_khan/SCUNet/onnx.py", line 25, in convert_to_onnx(model_path, onnx_path) File "/home/ayaz_khan/SCUNet/onnx.py", line 8, in convert_to_onnx model.load_state_dict(torch.load(model_path)) File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SCUNet: Missing key(s) in state_dict: "m_down1.2.weight", "m_down2.2.weight", "m_down3.2.weight". Unexpected key(s) in state_dict: "m_down1.3.trans_block.ln1.weight", "m_down1.3.trans_block.ln1.bias", "m_down1.3.trans_block.msa.relative_position_params", "m_down1.3.trans_block.msa.embedding_layer.weight", "m_down1.3.trans_block.msa.embedding_layer.bias", "m_down1.3.trans_block.msa.linear.weight", "m_down1.3.trans_block.msa.linear.bias", "m_down1.3.trans_block.ln2.weight", "m_down1.3.trans_block.ln2.bias", "m_down1.3.trans_block.mlp.0.weight", "m_down1.3.trans_block.mlp.0.bias", "m_down1.3.trans_block.mlp.2.weight", "m_down1.3.trans_block.mlp.2.bias", "m_down1.3.conv1_1.weight", "m_down1.3.conv1_1.bias", "m_down1.3.conv1_2.weight", "m_down1.3.conv1_2.bias", "m_down1.3.conv_block.0.weight", "m_down1.3.conv_block.2.weight", "m_down1.4.weight", "m_down1.2.trans_block.ln1.weight", "m_down1.2.trans_block.ln1.bias", "m_down1.2.trans_block.msa.relative_position_params", "m_down1.2.trans_block.msa.embedding_layer.weight", "m_down1.2.trans_block.msa.embedding_layer.bias", "m_down1.2.trans_block.msa.linear.weight", "m_down1.2.trans_block.msa.linear.bias", "m_down1.2.trans_block.ln2.weight", "m_down1.2.trans_block.ln2.bias", "m_down1.2.trans_block.mlp.0.weight", "m_down1.2.trans_block.mlp.0.bias", "m_down1.2.trans_block.mlp.2.weight", "m_down1.2.trans_block.mlp.2.bias", "m_down1.2.conv1_1.weight", "m_down1.2.conv1_1.bias", "m_down1.2.conv1_2.weight", "m_down1.2.conv1_2.bias", "m_down1.2.conv_block.0.weight", "m_down1.2.conv_block.2.weight", "m_down2.3.trans_block.ln1.weight", "m_down2.3.trans_block.ln1.bias", "m_down2.3.trans_block.msa.relative_position_params", "m_down2.3.trans_block.msa.embedding_layer.weight", "m_down2.3.trans_block.msa.embedding_layer.bias", "m_down2.3.trans_block.msa.linear.weight", "m_down2.3.trans_block.msa.linear.bias", "m_down2.3.trans_block.ln2.weight", "m_down2.3.trans_block.ln2.bias", "m_down2.3.trans_block.mlp.0.weight", "m_down2.3.trans_block.mlp.0.bias", "m_down2.3.trans_block.mlp.2.weight", "m_down2.3.trans_block.mlp.2.bias", "m_down2.3.conv1_1.weight", "m_down2.3.conv1_1.bias", "m_down2.3.conv1_2.weight", "m_down2.3.conv1_2.bias", "m_down2.3.conv_block.0.weight", "m_down2.3.conv_block.2.weight", "m_down2.4.weight", "m_down2.2.trans_block.ln1.weight", "m_down2.2.trans_block.ln1.bias", "m_down2.2.trans_block.msa.relative_position_params", "m_down2.2.trans_block.msa.embedding_layer.weight", "m_down2.2.trans_block.msa.embedding_layer.bias", "m_down2.2.trans_block.msa.linear.weight", "m_down2.2.trans_block.msa.linear.bias", "m_down2.2.trans_block.ln2.weight", "m_down2.2.trans_block.ln2.bias", "m_down2.2.trans_block.mlp.0.weight", "m_down2.2.trans_block.mlp.0.bias", "m_down2.2.trans_block.mlp.2.weight", "m_down2.2.trans_block.mlp.2.bias", "m_down2.2.conv1_1.weight", "m_down2.2.conv1_1.bias", "m_down2.2.conv1_2.weight", "m_down2.2.conv1_2.bias", "m_down2.2.conv_block.0.weight", "m_down2.2.conv_block.2.weight", "m_down3.3.trans_block.ln1.weight", "m_down3.3.trans_block.ln1.bias", "m_down3.3.trans_block.msa.relative_position_params", "m_down3.3.trans_block.msa.embedding_layer.weight", "m_down3.3.trans_block.msa.embedding_layer.bias", "m_down3.3.trans_block.msa.linear.weight", "m_down3.3.trans_block.msa.linear.bias", "m_down3.3.trans_block.ln2.weight", "m_down3.3.trans_block.ln2.bias", "m_down3.3.trans_block.mlp.0.weight", "m_down3.3.trans_block.mlp.0.bias", "m_down3.3.trans_block.mlp.2.weight", "m_down3.3.trans_block.mlp.2.bias", "m_down3.3.conv1_1.weight", "m_down3.3.conv1_1.bias", "m_down3.3.conv1_2.weight", "m_down3.3.conv1_2.bias", "m_down3.3.conv_block.0.weight", "m_down3.3.conv_block.2.weight", "m_down3.4.weight", "m_down3.2.trans_block.ln1.weight", "m_down3.2.trans_block.ln1.bias", "m_down3.2.trans_block.msa.relative_position_params", "m_down3.2.trans_block.msa.embedding_layer.weight", "m_down3.2.trans_block.msa.embedding_layer.bias", "m_down3.2.trans_block.msa.linear.weight", "m_down3.2.trans_block.msa.linear.bias", "m_down3.2.trans_block.ln2.weight", "m_down3.2.trans_block.ln2.bias", "m_down3.2.trans_block.mlp.0.weight", "m_down3.2.trans_block.mlp.0.bias", "m_down3.2.trans_block.mlp.2.weight", "m_down3.2.trans_block.mlp.2.bias", "m_down3.2.conv1_1.weight", "m_down3.2.conv1_1.bias", "m_down3.2.conv1_2.weight", "m_down3.2.conv1_2.bias", "m_down3.2.conv_block.0.weight", "m_down3.2.conv_block.2.weight", "m_body.2.trans_block.ln1.weight", "m_body.2.trans_block.ln1.bias", "m_body.2.trans_block.msa.relative_position_params", "m_body.2.trans_block.msa.embedding_layer.weight", "m_body.2.trans_block.msa.embedding_layer.bias", "m_body.2.trans_block.msa.linear.weight", "m_body.2.trans_block.msa.linear.bias", "m_body.2.trans_block.ln2.weight", "m_body.2.trans_block.ln2.bias", "m_body.2.trans_block.mlp.0.weight", "m_body.2.trans_block.mlp.0.bias", "m_body.2.trans_block.mlp.2.weight", "m_body.2.trans_block.mlp.2.bias", "m_body.2.conv1_1.weight", "m_body.2.conv1_1.bias", "m_body.2.conv1_2.weight", "m_body.2.conv1_2.bias", "m_body.2.conv_block.0.weight", "m_body.2.conv_block.2.weight", "m_body.3.trans_block.ln1.weight", "m_body.3.trans_block.ln1.bias", "m_body.3.trans_block.msa.relative_position_params", "m_body.3.trans_block.msa.embedding_layer.weight", "m_body.3.trans_block.msa.embedding_layer.bias", "m_body.3.trans_block.msa.linear.weight", "m_body.3.trans_block.msa.linear.bias", "m_body.3.trans_block.ln2.weight", "m_body.3.trans_block.ln2.bias", "m_body.3.trans_block.mlp.0.weight", "m_body.3.trans_block.mlp.0.bias", "m_body.3.trans_block.mlp.2.weight", "m_body.3.trans_block.mlp.2.bias", "m_body.3.conv1_1.weight", "m_body.3.conv1_1.bias", "m_body.3.conv1_2.weight", "m_body.3.conv1_2.bias", "m_body.3.conv_block.0.weight", "m_body.3.conv_block.2.weight", "m_up3.3.trans_block.ln1.weight", "m_up3.3.trans_block.ln1.bias", "m_up3.3.trans_block.msa.relative_position_params", "m_up3.3.trans_block.msa.embedding_layer.weight", "m_up3.3.trans_block.msa.embedding_layer.bias", "m_up3.3.trans_block.msa.linear.weight", "m_up3.3.trans_block.msa.linear.bias", "m_up3.3.trans_block.ln2.weight", "m_up3.3.trans_block.ln2.bias", "m_up3.3.trans_block.mlp.0.weight", "m_up3.3.trans_block.mlp.0.bias", "m_up3.3.trans_block.mlp.2.weight", "m_up3.3.trans_block.mlp.2.bias", "m_up3.3.conv1_1.weight", "m_up3.3.conv1_1.bias", "m_up3.3.conv1_2.weight", "m_up3.3.conv1_2.bias", "m_up3.3.conv_block.0.weight", "m_up3.3.conv_block.2.weight", "m_up3.4.trans_block.ln1.weight", "m_up3.4.trans_block.ln1.bias", "m_up3.4.trans_block.msa.relative_position_params", "m_up3.4.trans_block.msa.embedding_layer.weight", "m_up3.4.trans_block.msa.embedding_layer.bias", "m_up3.4.trans_block.msa.linear.weight", "m_up3.4.trans_block.msa.linear.bias", "m_up3.4.trans_block.ln2.weight", "m_up3.4.trans_block.ln2.bias", "m_up3.4.trans_block.mlp.0.weight", "m_up3.4.trans_block.mlp.0.bias", "m_up3.4.trans_block.mlp.2.weight", "m_up3.4.trans_block.mlp.2.bias", "m_up3.4.conv1_1.weight", "m_up3.4.conv1_1.bias", "m_up3.4.conv1_2.weight", "m_up3.4.conv1_2.bias", "m_up3.4.conv_block.0.weight", "m_up3.4.conv_block.2.weight", "m_up2.3.trans_block.ln1.weight", "m_up2.3.trans_block.ln1.bias", "m_up2.3.trans_block.msa.relative_position_params", "m_up2.3.trans_block.msa.embedding_layer.weight", "m_up2.3.trans_block.msa.embedding_layer.bias", "m_up2.3.trans_block.msa.linear.weight", "m_up2.3.trans_block.msa.linear.bias", "m_up2.3.trans_block.ln2.weight", "m_up2.3.trans_block.ln2.bias", "m_up2.3.trans_block.mlp.0.weight", "m_up2.3.trans_block.mlp.0.bias", "m_up2.3.trans_block.mlp.2.weight", "m_up2.3.trans_block.mlp.2.bias", "m_up2.3.conv1_1.weight", "m_up2.3.conv1_1.bias", "m_up2.3.conv1_2.weight", "m_up2.3.conv1_2.bias", "m_up2.3.conv_block.0.weight", "m_up2.3.conv_block.2.weight", "m_up2.4.trans_block.ln1.weight", "m_up2.4.trans_block.ln1.bias", "m_up2.4.trans_block.msa.relative_position_params", "m_up2.4.trans_block.msa.embedding_layer.weight", "m_up2.4.trans_block.msa.embedding_layer.bias", "m_up2.4.trans_block.msa.linear.weight", "m_up2.4.trans_block.msa.linear.bias", "m_up2.4.trans_block.ln2.weight", "m_up2.4.trans_block.ln2.bias", "m_up2.4.trans_block.mlp.0.weight", "m_up2.4.trans_block.mlp.0.bias", "m_up2.4.trans_block.mlp.2.weight", "m_up2.4.trans_block.mlp.2.bias", "m_up2.4.conv1_1.weight", "m_up2.4.conv1_1.bias", "m_up2.4.conv1_2.weight", "m_up2.4.conv1_2.bias", "m_up2.4.conv_block.0.weight", "m_up2.4.conv_block.2.weight", "m_up1.3.trans_block.ln1.weight", "m_up1.3.trans_block.ln1.bias", "m_up1.3.trans_block.msa.relative_position_params", "m_up1.3.trans_block.msa.embedding_layer.weight", "m_up1.3.trans_block.msa.embedding_layer.bias", "m_up1.3.trans_block.msa.linear.weight", "m_up1.3.trans_block.msa.linear.bias", "m_up1.3.trans_block.ln2.weight", "m_up1.3.trans_block.ln2.bias", "m_up1.3.trans_block.mlp.0.weight", "m_up1.3.trans_block.mlp.0.bias", "m_up1.3.trans_block.mlp.2.weight", "m_up1.3.trans_block.mlp.2.bias", "m_up1.3.conv1_1.weight", "m_up1.3.conv1_1.bias", "m_up1.3.conv1_2.weight", "m_up1.3.conv1_2.bias", "m_up1.3.conv_block.0.weight", "m_up1.3.conv_block.2.weight", "m_up1.4.trans_block.ln1.weight", "m_up1.4.trans_block.ln1.bias", "m_up1.4.trans_block.msa.relative_position_params", "m_up1.4.trans_block.msa.embedding_layer.weight", "m_up1.4.trans_block.msa.embedding_layer.bias", "m_up1.4.trans_block.msa.linear.weight", "m_up1.4.trans_block.msa.linear.bias", "m_up1.4.trans_block.ln2.weight", "m_up1.4.trans_block.ln2.bias", "m_up1.4.trans_block.mlp.0.weight", "m_up1.4.trans_block.mlp.0.bias", "m_up1.4.trans_block.mlp.2.weight", "m_up1.4.trans_block.mlp.2.bias", "m_up1.4.conv1_1.weight", "m_up1.4.conv1_1.bias", "m_up1.4.conv1_2.weight", "m_up1.4.conv1_2.bias", "m_up1.4.conv_block.0.weight", "m_up1.4.conv_block.2.weight".

import torch
import torch.onnx
from models.network_scunet import SCUNet  # Assuming this is the SCUNet model definition

def convert_to_onnx(model_path, onnx_path, input_shape=(1, 3, 256, 256)):
    # Load the pre-trained PyTorch model
    model = SCUNet()
    model.load_state_dict(torch.load(model_path))
    
    # Set the model to evaluation mode
    model.eval()

    # Define dummy input data
    dummy_input = torch.randn(input_shape)

    # Convert the model to ONNX format
    torch.onnx.export(model, dummy_input, onnx_path, verbose=True)

    print(f"Model converted to ONNX format and saved as {onnx_path}")

# Paths
model_path = "model_zoo/scunet_color_real_psnr.pth"
onnx_path = "./scunet_color_real_gan.onnx"

convert_to_onnx(model_path, onnx_path)

hello,Dr. Ayaz H. Khan ,do u have solved this question?

Dlew-DUT avatar Feb 14 '25 08:02 Dlew-DUT

I've converted the models toONNX (dynamic axes for input size). Models are working but take only square input images Divisible by 128

Hello, I just checked your homepage but didn't find any content about SCUNet. Can we discuss it?

I only converted it for local use. I can search for my conversion script and post it here. If I find it

instant-high avatar Feb 14 '25 09:02 instant-high

Thank you so much. Your reply has given me hope. Thank you.

I've converted the models toONNX (dynamic axes for input size). Models are working but take only square input images Divisible by 128

Hello, I just checked your homepage but didn't find any content about SCUNet. Can we discuss it?

I only converted it for local use. I can search for my conversion script and post it here. If I find it

Thank you so much. Your reply has given me hope. Thank you.

Dlew-DUT avatar Feb 14 '25 09:02 Dlew-DUT

My conversion code, edit to convert model full or model half: (input has to be divisible by 128, (128x128), (256x256) ......

Takes some time, be patient... Also some warnings, but onnx model working

import argparse
import numpy as np
import torch

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='models/scunet_color_real_psnr.pth', help='scunet_color_real_psnrpth, scunet_color_real_gan.pth')
args = parser.parse_args()

n_channels = 3 # color model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu' 

from models.network_scunet import SCUNet as net
model = net(in_nc=n_channels,config=[4,4,4,4,4,4,4],dim=64)

model.load_state_dict(torch.load(args.model), strict=True)
model.eval()
for k, v in model.named_parameters():
    v.requires_grad = False
    
#model = model.half()
model = model.to(device)

onnx_path = 'scunet_color_real_psnr.onnx'

dummy_input = torch.randn(1, 3, 256, 256)
#model half:
#dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float16)

#or

#model:
dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float32)

#only if cuda available
dummy_input = dummy_input.to('cuda')

with torch.no_grad():
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        do_constant_folding=False,
        input_names=['input'],
        output_names=['output'],
        opset_version=12,
        dynamic_axes={'input': {2: 'height', 3: 'width'}, 'output': {2: 'height', 3: 'width'}}
    )

    input("Conversion done")
    
#

instant-high avatar Feb 14 '25 10:02 instant-high

My conversion code, edit to convert model full or model half: (input has to be divisible by 128, (128x128), (256x256) ......

Takes some time, be patient... Also some warnings, but onnx model working

import argparse
import numpy as np
import torch

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='models/scunet_color_real_psnr.pth', help='scunet_color_real_psnrpth, scunet_color_real_gan.pth')
args = parser.parse_args()

n_channels = 3 # color model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu' 

from models.network_scunet import SCUNet as net
model = net(in_nc=n_channels,config=[4,4,4,4,4,4,4],dim=64)

model.load_state_dict(torch.load(args.model), strict=True)
model.eval()
for k, v in model.named_parameters():
    v.requires_grad = False
    
#model = model.half()
model = model.to(device)

onnx_path = 'scunet_color_real_psnr.onnx'

dummy_input = torch.randn(1, 3, 256, 256)
#model half:
#dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float16)

#or

#model:
dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float32)

#only if cuda available
dummy_input = dummy_input.to('cuda')

with torch.no_grad():
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        do_constant_folding=False,
        input_names=['input'],
        output_names=['output'],
        opset_version=12,
        dynamic_axes={'input': {2: 'height', 3: 'width'}, 'output': {2: 'height', 3: 'width'}}
    )

    input("Conversion done")
    
#

Thank you very much, let me give it a try. Happy Chinese New Year, and I wish you a prosperous and successful year ahead

Dlew-DUT avatar Feb 14 '25 11:02 Dlew-DUT

This is how I proceeded, thank you very much. But my resulting image (originally 128x128) is then divided into 9 parts, which contain the channel information separately. How do I get a normal image again? I use onnx-runtime within my C++ code.

ikmaus avatar Feb 20 '25 07:02 ikmaus

ONNX Model input/output shapes should be input (1,3,h,w) and output (1,3,h,w) or (batch, channels, height, width)

It's a very long time ago when I coded in C++ So this is pre- and post-processing of your image using python

    def denoise(self, img):
    
        img = img.astype(np.float32)
        img = img.transpose((2, 0, 1))
        img = img /255
        img = np.expand_dims(img, axis=0).astype(np.float32)
                        
        res = self.session.run(None,{'input':img})[0]

        res = (res.squeeze().transpose((1,2,0)) * 255).clip(0, 255).astype(np.uint8)
        
        return res 

Maybe this is the way to do in C++:

#include <opencv2/opencv.hpp>
#include <onnxruntime/core/session/onnxruntime_cxx_api.h>

cv::Mat denoise(Ort::Session& session, const cv::Mat& img) {
    cv::Mat floatImg;
    img.convertTo(floatImg, CV_32F, 1.0 / 255);
    
    // Convert HWC to CHW format
    std::vector<cv::Mat> channels(3);
    cv::split(floatImg, channels);
    for (int i = 0; i < 3; i++) {
        channels[i] = channels[i].reshape(1, 1);
    }
    
    cv::Mat chwImg;
    cv::hconcat(channels, chwImg);
    
    // Expand dimensions (1, 3, H, W)
    std::vector<int64_t> inputShape = {1, 3, img.rows, img.cols};
    std::vector<float> inputTensorValues(chwImg.begin<float>(), chwImg.end<float>());
    
    // Create ONNX Runtime input tensor
    Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, inputTensorValues.data(), inputTensorValues.size(), inputShape.data(), inputShape.size());
    
    // Run inference
    auto outputTensors = session.Run(Ort::RunOptions{nullptr}, {"input"}, &inputTensor, 1, session.GetOutputNames().data(), session.GetOutputNames().size());
    
    // Get output tensor
    float* outputData = outputTensors.front().GetTensorMutableData<float>();
    cv::Mat outputMat(3, img.size(), CV_32F, outputData);
    
    // Convert CHW back to HWC
    std::vector<cv::Mat> outputChannels = {
        cv::Mat(img.rows, img.cols, CV_32F, outputMat.ptr(0)),
        cv::Mat(img.rows, img.cols, CV_32F, outputMat.ptr(1)),
        cv::Mat(img.rows, img.cols, CV_32F, outputMat.ptr(2))
    };
    
    cv::Mat finalImg;
    cv::merge(outputChannels, finalImg);
    
    // Convert to uint8
    finalImg = (finalImg * 255.0f).clip(0, 255);
    finalImg.convertTo(finalImg, CV_8U);
    
    return finalImg;
}

instant-high avatar Feb 20 '25 11:02 instant-high

Hello, thank you very much! That really seemed to be my problem. To make that tensors from an cv::Mat in the right way. Now it works as expected. I only have some problems with the pixelation at the edges, but that has nothing to do with the model.

ikmaus avatar Feb 20 '25 15:02 ikmaus