Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in converting to ONNX model #25

Open
ayazhassan opened this issue Oct 26, 2023 · 11 comments
Open

Error in converting to ONNX model #25

ayazhassan opened this issue Oct 26, 2023 · 11 comments

Comments

@ayazhassan
Copy link

ayazhassan commented Oct 26, 2023

I am getting the following error, while trying to convert the pre-trained model to ONNX model. Can you please look into it and let me know that the pre-trained weights were generated using the current updated model? Conversion code is provided after the error.

Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Traceback (most recent call last):
File "/home/ayaz_khan/SCUNet/onnx.py", line 2, in
import torch.onnx
File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/onnx/init.py", line 57, in
from ._internal.onnxruntime import (
File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/onnx/_internal/onnxruntime.py", line 34, in
import onnx
File "/home/ayaz_khan/SCUNet/onnx.py", line 25, in
convert_to_onnx(model_path, onnx_path)
File "/home/ayaz_khan/SCUNet/onnx.py", line 8, in convert_to_onnx
model.load_state_dict(torch.load(model_path))
File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SCUNet:
Missing key(s) in state_dict: "m_down1.2.weight", "m_down2.2.weight", "m_down3.2.weight".
Unexpected key(s) in state_dict: "m_down1.3.trans_block.ln1.weight", "m_down1.3.trans_block.ln1.bias", "m_down1.3.trans_block.msa.relative_position_params", "m_down1.3.trans_block.msa.embedding_layer.weight", "m_down1.3.trans_block.msa.embedding_layer.bias", "m_down1.3.trans_block.msa.linear.weight", "m_down1.3.trans_block.msa.linear.bias", "m_down1.3.trans_block.ln2.weight", "m_down1.3.trans_block.ln2.bias", "m_down1.3.trans_block.mlp.0.weight", "m_down1.3.trans_block.mlp.0.bias", "m_down1.3.trans_block.mlp.2.weight", "m_down1.3.trans_block.mlp.2.bias", "m_down1.3.conv1_1.weight", "m_down1.3.conv1_1.bias", "m_down1.3.conv1_2.weight", "m_down1.3.conv1_2.bias", "m_down1.3.conv_block.0.weight", "m_down1.3.conv_block.2.weight", "m_down1.4.weight", "m_down1.2.trans_block.ln1.weight", "m_down1.2.trans_block.ln1.bias", "m_down1.2.trans_block.msa.relative_position_params", "m_down1.2.trans_block.msa.embedding_layer.weight", "m_down1.2.trans_block.msa.embedding_layer.bias", "m_down1.2.trans_block.msa.linear.weight", "m_down1.2.trans_block.msa.linear.bias", "m_down1.2.trans_block.ln2.weight", "m_down1.2.trans_block.ln2.bias", "m_down1.2.trans_block.mlp.0.weight", "m_down1.2.trans_block.mlp.0.bias", "m_down1.2.trans_block.mlp.2.weight", "m_down1.2.trans_block.mlp.2.bias", "m_down1.2.conv1_1.weight", "m_down1.2.conv1_1.bias", "m_down1.2.conv1_2.weight", "m_down1.2.conv1_2.bias", "m_down1.2.conv_block.0.weight", "m_down1.2.conv_block.2.weight", "m_down2.3.trans_block.ln1.weight", "m_down2.3.trans_block.ln1.bias", "m_down2.3.trans_block.msa.relative_position_params", "m_down2.3.trans_block.msa.embedding_layer.weight", "m_down2.3.trans_block.msa.embedding_layer.bias", "m_down2.3.trans_block.msa.linear.weight", "m_down2.3.trans_block.msa.linear.bias", "m_down2.3.trans_block.ln2.weight", "m_down2.3.trans_block.ln2.bias", "m_down2.3.trans_block.mlp.0.weight", "m_down2.3.trans_block.mlp.0.bias", "m_down2.3.trans_block.mlp.2.weight", "m_down2.3.trans_block.mlp.2.bias", "m_down2.3.conv1_1.weight", "m_down2.3.conv1_1.bias", "m_down2.3.conv1_2.weight", "m_down2.3.conv1_2.bias", "m_down2.3.conv_block.0.weight", "m_down2.3.conv_block.2.weight", "m_down2.4.weight", "m_down2.2.trans_block.ln1.weight", "m_down2.2.trans_block.ln1.bias", "m_down2.2.trans_block.msa.relative_position_params", "m_down2.2.trans_block.msa.embedding_layer.weight", "m_down2.2.trans_block.msa.embedding_layer.bias", "m_down2.2.trans_block.msa.linear.weight", "m_down2.2.trans_block.msa.linear.bias", "m_down2.2.trans_block.ln2.weight", "m_down2.2.trans_block.ln2.bias", "m_down2.2.trans_block.mlp.0.weight", "m_down2.2.trans_block.mlp.0.bias", "m_down2.2.trans_block.mlp.2.weight", "m_down2.2.trans_block.mlp.2.bias", "m_down2.2.conv1_1.weight", "m_down2.2.conv1_1.bias", "m_down2.2.conv1_2.weight", "m_down2.2.conv1_2.bias", "m_down2.2.conv_block.0.weight", "m_down2.2.conv_block.2.weight", "m_down3.3.trans_block.ln1.weight", "m_down3.3.trans_block.ln1.bias", "m_down3.3.trans_block.msa.relative_position_params", "m_down3.3.trans_block.msa.embedding_layer.weight", "m_down3.3.trans_block.msa.embedding_layer.bias", "m_down3.3.trans_block.msa.linear.weight", "m_down3.3.trans_block.msa.linear.bias", "m_down3.3.trans_block.ln2.weight", "m_down3.3.trans_block.ln2.bias", "m_down3.3.trans_block.mlp.0.weight", "m_down3.3.trans_block.mlp.0.bias", "m_down3.3.trans_block.mlp.2.weight", "m_down3.3.trans_block.mlp.2.bias", "m_down3.3.conv1_1.weight", "m_down3.3.conv1_1.bias", "m_down3.3.conv1_2.weight", "m_down3.3.conv1_2.bias", "m_down3.3.conv_block.0.weight", "m_down3.3.conv_block.2.weight", "m_down3.4.weight", "m_down3.2.trans_block.ln1.weight", "m_down3.2.trans_block.ln1.bias", "m_down3.2.trans_block.msa.relative_position_params", "m_down3.2.trans_block.msa.embedding_layer.weight", "m_down3.2.trans_block.msa.embedding_layer.bias", "m_down3.2.trans_block.msa.linear.weight", "m_down3.2.trans_block.msa.linear.bias", "m_down3.2.trans_block.ln2.weight", "m_down3.2.trans_block.ln2.bias", "m_down3.2.trans_block.mlp.0.weight", "m_down3.2.trans_block.mlp.0.bias", "m_down3.2.trans_block.mlp.2.weight", "m_down3.2.trans_block.mlp.2.bias", "m_down3.2.conv1_1.weight", "m_down3.2.conv1_1.bias", "m_down3.2.conv1_2.weight", "m_down3.2.conv1_2.bias", "m_down3.2.conv_block.0.weight", "m_down3.2.conv_block.2.weight", "m_body.2.trans_block.ln1.weight", "m_body.2.trans_block.ln1.bias", "m_body.2.trans_block.msa.relative_position_params", "m_body.2.trans_block.msa.embedding_layer.weight", "m_body.2.trans_block.msa.embedding_layer.bias", "m_body.2.trans_block.msa.linear.weight", "m_body.2.trans_block.msa.linear.bias", "m_body.2.trans_block.ln2.weight", "m_body.2.trans_block.ln2.bias", "m_body.2.trans_block.mlp.0.weight", "m_body.2.trans_block.mlp.0.bias", "m_body.2.trans_block.mlp.2.weight", "m_body.2.trans_block.mlp.2.bias", "m_body.2.conv1_1.weight", "m_body.2.conv1_1.bias", "m_body.2.conv1_2.weight", "m_body.2.conv1_2.bias", "m_body.2.conv_block.0.weight", "m_body.2.conv_block.2.weight", "m_body.3.trans_block.ln1.weight", "m_body.3.trans_block.ln1.bias", "m_body.3.trans_block.msa.relative_position_params", "m_body.3.trans_block.msa.embedding_layer.weight", "m_body.3.trans_block.msa.embedding_layer.bias", "m_body.3.trans_block.msa.linear.weight", "m_body.3.trans_block.msa.linear.bias", "m_body.3.trans_block.ln2.weight", "m_body.3.trans_block.ln2.bias", "m_body.3.trans_block.mlp.0.weight", "m_body.3.trans_block.mlp.0.bias", "m_body.3.trans_block.mlp.2.weight", "m_body.3.trans_block.mlp.2.bias", "m_body.3.conv1_1.weight", "m_body.3.conv1_1.bias", "m_body.3.conv1_2.weight", "m_body.3.conv1_2.bias", "m_body.3.conv_block.0.weight", "m_body.3.conv_block.2.weight", "m_up3.3.trans_block.ln1.weight", "m_up3.3.trans_block.ln1.bias", "m_up3.3.trans_block.msa.relative_position_params", "m_up3.3.trans_block.msa.embedding_layer.weight", "m_up3.3.trans_block.msa.embedding_layer.bias", "m_up3.3.trans_block.msa.linear.weight", "m_up3.3.trans_block.msa.linear.bias", "m_up3.3.trans_block.ln2.weight", "m_up3.3.trans_block.ln2.bias", "m_up3.3.trans_block.mlp.0.weight", "m_up3.3.trans_block.mlp.0.bias", "m_up3.3.trans_block.mlp.2.weight", "m_up3.3.trans_block.mlp.2.bias", "m_up3.3.conv1_1.weight", "m_up3.3.conv1_1.bias", "m_up3.3.conv1_2.weight", "m_up3.3.conv1_2.bias", "m_up3.3.conv_block.0.weight", "m_up3.3.conv_block.2.weight", "m_up3.4.trans_block.ln1.weight", "m_up3.4.trans_block.ln1.bias", "m_up3.4.trans_block.msa.relative_position_params", "m_up3.4.trans_block.msa.embedding_layer.weight", "m_up3.4.trans_block.msa.embedding_layer.bias", "m_up3.4.trans_block.msa.linear.weight", "m_up3.4.trans_block.msa.linear.bias", "m_up3.4.trans_block.ln2.weight", "m_up3.4.trans_block.ln2.bias", "m_up3.4.trans_block.mlp.0.weight", "m_up3.4.trans_block.mlp.0.bias", "m_up3.4.trans_block.mlp.2.weight", "m_up3.4.trans_block.mlp.2.bias", "m_up3.4.conv1_1.weight", "m_up3.4.conv1_1.bias", "m_up3.4.conv1_2.weight", "m_up3.4.conv1_2.bias", "m_up3.4.conv_block.0.weight", "m_up3.4.conv_block.2.weight", "m_up2.3.trans_block.ln1.weight", "m_up2.3.trans_block.ln1.bias", "m_up2.3.trans_block.msa.relative_position_params", "m_up2.3.trans_block.msa.embedding_layer.weight", "m_up2.3.trans_block.msa.embedding_layer.bias", "m_up2.3.trans_block.msa.linear.weight", "m_up2.3.trans_block.msa.linear.bias", "m_up2.3.trans_block.ln2.weight", "m_up2.3.trans_block.ln2.bias", "m_up2.3.trans_block.mlp.0.weight", "m_up2.3.trans_block.mlp.0.bias", "m_up2.3.trans_block.mlp.2.weight", "m_up2.3.trans_block.mlp.2.bias", "m_up2.3.conv1_1.weight", "m_up2.3.conv1_1.bias", "m_up2.3.conv1_2.weight", "m_up2.3.conv1_2.bias", "m_up2.3.conv_block.0.weight", "m_up2.3.conv_block.2.weight", "m_up2.4.trans_block.ln1.weight", "m_up2.4.trans_block.ln1.bias", "m_up2.4.trans_block.msa.relative_position_params", "m_up2.4.trans_block.msa.embedding_layer.weight", "m_up2.4.trans_block.msa.embedding_layer.bias", "m_up2.4.trans_block.msa.linear.weight", "m_up2.4.trans_block.msa.linear.bias", "m_up2.4.trans_block.ln2.weight", "m_up2.4.trans_block.ln2.bias", "m_up2.4.trans_block.mlp.0.weight", "m_up2.4.trans_block.mlp.0.bias", "m_up2.4.trans_block.mlp.2.weight", "m_up2.4.trans_block.mlp.2.bias", "m_up2.4.conv1_1.weight", "m_up2.4.conv1_1.bias", "m_up2.4.conv1_2.weight", "m_up2.4.conv1_2.bias", "m_up2.4.conv_block.0.weight", "m_up2.4.conv_block.2.weight", "m_up1.3.trans_block.ln1.weight", "m_up1.3.trans_block.ln1.bias", "m_up1.3.trans_block.msa.relative_position_params", "m_up1.3.trans_block.msa.embedding_layer.weight", "m_up1.3.trans_block.msa.embedding_layer.bias", "m_up1.3.trans_block.msa.linear.weight", "m_up1.3.trans_block.msa.linear.bias", "m_up1.3.trans_block.ln2.weight", "m_up1.3.trans_block.ln2.bias", "m_up1.3.trans_block.mlp.0.weight", "m_up1.3.trans_block.mlp.0.bias", "m_up1.3.trans_block.mlp.2.weight", "m_up1.3.trans_block.mlp.2.bias", "m_up1.3.conv1_1.weight", "m_up1.3.conv1_1.bias", "m_up1.3.conv1_2.weight", "m_up1.3.conv1_2.bias", "m_up1.3.conv_block.0.weight", "m_up1.3.conv_block.2.weight", "m_up1.4.trans_block.ln1.weight", "m_up1.4.trans_block.ln1.bias", "m_up1.4.trans_block.msa.relative_position_params", "m_up1.4.trans_block.msa.embedding_layer.weight", "m_up1.4.trans_block.msa.embedding_layer.bias", "m_up1.4.trans_block.msa.linear.weight", "m_up1.4.trans_block.msa.linear.bias", "m_up1.4.trans_block.ln2.weight", "m_up1.4.trans_block.ln2.bias", "m_up1.4.trans_block.mlp.0.weight", "m_up1.4.trans_block.mlp.0.bias", "m_up1.4.trans_block.mlp.2.weight", "m_up1.4.trans_block.mlp.2.bias", "m_up1.4.conv1_1.weight", "m_up1.4.conv1_1.bias", "m_up1.4.conv1_2.weight", "m_up1.4.conv1_2.bias", "m_up1.4.conv_block.0.weight", "m_up1.4.conv_block.2.weight".

import torch
import torch.onnx
from models.network_scunet import SCUNet  # Assuming this is the SCUNet model definition

def convert_to_onnx(model_path, onnx_path, input_shape=(1, 3, 256, 256)):
    # Load the pre-trained PyTorch model
    model = SCUNet()
    model.load_state_dict(torch.load(model_path))
    
    # Set the model to evaluation mode
    model.eval()

    # Define dummy input data
    dummy_input = torch.randn(input_shape)

    # Convert the model to ONNX format
    torch.onnx.export(model, dummy_input, onnx_path, verbose=True)

    print(f"Model converted to ONNX format and saved as {onnx_path}")

# Paths
model_path = "model_zoo/scunet_color_real_psnr.pth"
onnx_path = "./scunet_color_real_gan.onnx"

convert_to_onnx(model_path, onnx_path)

@instant-high
Copy link

I've converted the models toONNX (dynamic axes for input size). Models are working but take only square input images Divisible by 128

@jumpxiu
Copy link

jumpxiu commented Jul 1, 2024

I managed to output the model in onnx form, here is my code:


net = SCUNet()
onnx_path = "onnx_model_name.onnx"
torch.onnx.export(net, torch.randn((2, 3, 64, 64)), onnx_path)

But remember to run the model's classes and methods all through before doing so

@Dlew-DUT
Copy link

I've converted the models toONNX (dynamic axes for input size). Models are working but take only square input images Divisible by 128

Hello, I just checked your homepage but didn't find any content about SCUNet. Can we discuss it?

@Dlew-DUT
Copy link

I am getting the following error, while trying to convert the pre-trained model to ONNX model. Can you please look into it and let me know that the pre-trained weights were generated using the current updated model? Conversion code is provided after the error.

Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Traceback (most recent call last): File "/home/ayaz_khan/SCUNet/onnx.py", line 2, in import torch.onnx File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/onnx/init.py", line 57, in from ._internal.onnxruntime import ( File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/onnx/_internal/onnxruntime.py", line 34, in import onnx File "/home/ayaz_khan/SCUNet/onnx.py", line 25, in convert_to_onnx(model_path, onnx_path) File "/home/ayaz_khan/SCUNet/onnx.py", line 8, in convert_to_onnx model.load_state_dict(torch.load(model_path)) File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SCUNet: Missing key(s) in state_dict: "m_down1.2.weight", "m_down2.2.weight", "m_down3.2.weight". Unexpected key(s) in state_dict: "m_down1.3.trans_block.ln1.weight", "m_down1.3.trans_block.ln1.bias", "m_down1.3.trans_block.msa.relative_position_params", "m_down1.3.trans_block.msa.embedding_layer.weight", "m_down1.3.trans_block.msa.embedding_layer.bias", "m_down1.3.trans_block.msa.linear.weight", "m_down1.3.trans_block.msa.linear.bias", "m_down1.3.trans_block.ln2.weight", "m_down1.3.trans_block.ln2.bias", "m_down1.3.trans_block.mlp.0.weight", "m_down1.3.trans_block.mlp.0.bias", "m_down1.3.trans_block.mlp.2.weight", "m_down1.3.trans_block.mlp.2.bias", "m_down1.3.conv1_1.weight", "m_down1.3.conv1_1.bias", "m_down1.3.conv1_2.weight", "m_down1.3.conv1_2.bias", "m_down1.3.conv_block.0.weight", "m_down1.3.conv_block.2.weight", "m_down1.4.weight", "m_down1.2.trans_block.ln1.weight", "m_down1.2.trans_block.ln1.bias", "m_down1.2.trans_block.msa.relative_position_params", "m_down1.2.trans_block.msa.embedding_layer.weight", "m_down1.2.trans_block.msa.embedding_layer.bias", "m_down1.2.trans_block.msa.linear.weight", "m_down1.2.trans_block.msa.linear.bias", "m_down1.2.trans_block.ln2.weight", "m_down1.2.trans_block.ln2.bias", "m_down1.2.trans_block.mlp.0.weight", "m_down1.2.trans_block.mlp.0.bias", "m_down1.2.trans_block.mlp.2.weight", "m_down1.2.trans_block.mlp.2.bias", "m_down1.2.conv1_1.weight", "m_down1.2.conv1_1.bias", "m_down1.2.conv1_2.weight", "m_down1.2.conv1_2.bias", "m_down1.2.conv_block.0.weight", "m_down1.2.conv_block.2.weight", "m_down2.3.trans_block.ln1.weight", "m_down2.3.trans_block.ln1.bias", "m_down2.3.trans_block.msa.relative_position_params", "m_down2.3.trans_block.msa.embedding_layer.weight", "m_down2.3.trans_block.msa.embedding_layer.bias", "m_down2.3.trans_block.msa.linear.weight", "m_down2.3.trans_block.msa.linear.bias", "m_down2.3.trans_block.ln2.weight", "m_down2.3.trans_block.ln2.bias", "m_down2.3.trans_block.mlp.0.weight", "m_down2.3.trans_block.mlp.0.bias", "m_down2.3.trans_block.mlp.2.weight", "m_down2.3.trans_block.mlp.2.bias", "m_down2.3.conv1_1.weight", "m_down2.3.conv1_1.bias", "m_down2.3.conv1_2.weight", "m_down2.3.conv1_2.bias", "m_down2.3.conv_block.0.weight", "m_down2.3.conv_block.2.weight", "m_down2.4.weight", "m_down2.2.trans_block.ln1.weight", "m_down2.2.trans_block.ln1.bias", "m_down2.2.trans_block.msa.relative_position_params", "m_down2.2.trans_block.msa.embedding_layer.weight", "m_down2.2.trans_block.msa.embedding_layer.bias", "m_down2.2.trans_block.msa.linear.weight", "m_down2.2.trans_block.msa.linear.bias", "m_down2.2.trans_block.ln2.weight", "m_down2.2.trans_block.ln2.bias", "m_down2.2.trans_block.mlp.0.weight", "m_down2.2.trans_block.mlp.0.bias", "m_down2.2.trans_block.mlp.2.weight", "m_down2.2.trans_block.mlp.2.bias", "m_down2.2.conv1_1.weight", "m_down2.2.conv1_1.bias", "m_down2.2.conv1_2.weight", "m_down2.2.conv1_2.bias", "m_down2.2.conv_block.0.weight", "m_down2.2.conv_block.2.weight", "m_down3.3.trans_block.ln1.weight", "m_down3.3.trans_block.ln1.bias", "m_down3.3.trans_block.msa.relative_position_params", "m_down3.3.trans_block.msa.embedding_layer.weight", "m_down3.3.trans_block.msa.embedding_layer.bias", "m_down3.3.trans_block.msa.linear.weight", "m_down3.3.trans_block.msa.linear.bias", "m_down3.3.trans_block.ln2.weight", "m_down3.3.trans_block.ln2.bias", "m_down3.3.trans_block.mlp.0.weight", "m_down3.3.trans_block.mlp.0.bias", "m_down3.3.trans_block.mlp.2.weight", "m_down3.3.trans_block.mlp.2.bias", "m_down3.3.conv1_1.weight", "m_down3.3.conv1_1.bias", "m_down3.3.conv1_2.weight", "m_down3.3.conv1_2.bias", "m_down3.3.conv_block.0.weight", "m_down3.3.conv_block.2.weight", "m_down3.4.weight", "m_down3.2.trans_block.ln1.weight", "m_down3.2.trans_block.ln1.bias", "m_down3.2.trans_block.msa.relative_position_params", "m_down3.2.trans_block.msa.embedding_layer.weight", "m_down3.2.trans_block.msa.embedding_layer.bias", "m_down3.2.trans_block.msa.linear.weight", "m_down3.2.trans_block.msa.linear.bias", "m_down3.2.trans_block.ln2.weight", "m_down3.2.trans_block.ln2.bias", "m_down3.2.trans_block.mlp.0.weight", "m_down3.2.trans_block.mlp.0.bias", "m_down3.2.trans_block.mlp.2.weight", "m_down3.2.trans_block.mlp.2.bias", "m_down3.2.conv1_1.weight", "m_down3.2.conv1_1.bias", "m_down3.2.conv1_2.weight", "m_down3.2.conv1_2.bias", "m_down3.2.conv_block.0.weight", "m_down3.2.conv_block.2.weight", "m_body.2.trans_block.ln1.weight", "m_body.2.trans_block.ln1.bias", "m_body.2.trans_block.msa.relative_position_params", "m_body.2.trans_block.msa.embedding_layer.weight", "m_body.2.trans_block.msa.embedding_layer.bias", "m_body.2.trans_block.msa.linear.weight", "m_body.2.trans_block.msa.linear.bias", "m_body.2.trans_block.ln2.weight", "m_body.2.trans_block.ln2.bias", "m_body.2.trans_block.mlp.0.weight", "m_body.2.trans_block.mlp.0.bias", "m_body.2.trans_block.mlp.2.weight", "m_body.2.trans_block.mlp.2.bias", "m_body.2.conv1_1.weight", "m_body.2.conv1_1.bias", "m_body.2.conv1_2.weight", "m_body.2.conv1_2.bias", "m_body.2.conv_block.0.weight", "m_body.2.conv_block.2.weight", "m_body.3.trans_block.ln1.weight", "m_body.3.trans_block.ln1.bias", "m_body.3.trans_block.msa.relative_position_params", "m_body.3.trans_block.msa.embedding_layer.weight", "m_body.3.trans_block.msa.embedding_layer.bias", "m_body.3.trans_block.msa.linear.weight", "m_body.3.trans_block.msa.linear.bias", "m_body.3.trans_block.ln2.weight", "m_body.3.trans_block.ln2.bias", "m_body.3.trans_block.mlp.0.weight", "m_body.3.trans_block.mlp.0.bias", "m_body.3.trans_block.mlp.2.weight", "m_body.3.trans_block.mlp.2.bias", "m_body.3.conv1_1.weight", "m_body.3.conv1_1.bias", "m_body.3.conv1_2.weight", "m_body.3.conv1_2.bias", "m_body.3.conv_block.0.weight", "m_body.3.conv_block.2.weight", "m_up3.3.trans_block.ln1.weight", "m_up3.3.trans_block.ln1.bias", "m_up3.3.trans_block.msa.relative_position_params", "m_up3.3.trans_block.msa.embedding_layer.weight", "m_up3.3.trans_block.msa.embedding_layer.bias", "m_up3.3.trans_block.msa.linear.weight", "m_up3.3.trans_block.msa.linear.bias", "m_up3.3.trans_block.ln2.weight", "m_up3.3.trans_block.ln2.bias", "m_up3.3.trans_block.mlp.0.weight", "m_up3.3.trans_block.mlp.0.bias", "m_up3.3.trans_block.mlp.2.weight", "m_up3.3.trans_block.mlp.2.bias", "m_up3.3.conv1_1.weight", "m_up3.3.conv1_1.bias", "m_up3.3.conv1_2.weight", "m_up3.3.conv1_2.bias", "m_up3.3.conv_block.0.weight", "m_up3.3.conv_block.2.weight", "m_up3.4.trans_block.ln1.weight", "m_up3.4.trans_block.ln1.bias", "m_up3.4.trans_block.msa.relative_position_params", "m_up3.4.trans_block.msa.embedding_layer.weight", "m_up3.4.trans_block.msa.embedding_layer.bias", "m_up3.4.trans_block.msa.linear.weight", "m_up3.4.trans_block.msa.linear.bias", "m_up3.4.trans_block.ln2.weight", "m_up3.4.trans_block.ln2.bias", "m_up3.4.trans_block.mlp.0.weight", "m_up3.4.trans_block.mlp.0.bias", "m_up3.4.trans_block.mlp.2.weight", "m_up3.4.trans_block.mlp.2.bias", "m_up3.4.conv1_1.weight", "m_up3.4.conv1_1.bias", "m_up3.4.conv1_2.weight", "m_up3.4.conv1_2.bias", "m_up3.4.conv_block.0.weight", "m_up3.4.conv_block.2.weight", "m_up2.3.trans_block.ln1.weight", "m_up2.3.trans_block.ln1.bias", "m_up2.3.trans_block.msa.relative_position_params", "m_up2.3.trans_block.msa.embedding_layer.weight", "m_up2.3.trans_block.msa.embedding_layer.bias", "m_up2.3.trans_block.msa.linear.weight", "m_up2.3.trans_block.msa.linear.bias", "m_up2.3.trans_block.ln2.weight", "m_up2.3.trans_block.ln2.bias", "m_up2.3.trans_block.mlp.0.weight", "m_up2.3.trans_block.mlp.0.bias", "m_up2.3.trans_block.mlp.2.weight", "m_up2.3.trans_block.mlp.2.bias", "m_up2.3.conv1_1.weight", "m_up2.3.conv1_1.bias", "m_up2.3.conv1_2.weight", "m_up2.3.conv1_2.bias", "m_up2.3.conv_block.0.weight", "m_up2.3.conv_block.2.weight", "m_up2.4.trans_block.ln1.weight", "m_up2.4.trans_block.ln1.bias", "m_up2.4.trans_block.msa.relative_position_params", "m_up2.4.trans_block.msa.embedding_layer.weight", "m_up2.4.trans_block.msa.embedding_layer.bias", "m_up2.4.trans_block.msa.linear.weight", "m_up2.4.trans_block.msa.linear.bias", "m_up2.4.trans_block.ln2.weight", "m_up2.4.trans_block.ln2.bias", "m_up2.4.trans_block.mlp.0.weight", "m_up2.4.trans_block.mlp.0.bias", "m_up2.4.trans_block.mlp.2.weight", "m_up2.4.trans_block.mlp.2.bias", "m_up2.4.conv1_1.weight", "m_up2.4.conv1_1.bias", "m_up2.4.conv1_2.weight", "m_up2.4.conv1_2.bias", "m_up2.4.conv_block.0.weight", "m_up2.4.conv_block.2.weight", "m_up1.3.trans_block.ln1.weight", "m_up1.3.trans_block.ln1.bias", "m_up1.3.trans_block.msa.relative_position_params", "m_up1.3.trans_block.msa.embedding_layer.weight", "m_up1.3.trans_block.msa.embedding_layer.bias", "m_up1.3.trans_block.msa.linear.weight", "m_up1.3.trans_block.msa.linear.bias", "m_up1.3.trans_block.ln2.weight", "m_up1.3.trans_block.ln2.bias", "m_up1.3.trans_block.mlp.0.weight", "m_up1.3.trans_block.mlp.0.bias", "m_up1.3.trans_block.mlp.2.weight", "m_up1.3.trans_block.mlp.2.bias", "m_up1.3.conv1_1.weight", "m_up1.3.conv1_1.bias", "m_up1.3.conv1_2.weight", "m_up1.3.conv1_2.bias", "m_up1.3.conv_block.0.weight", "m_up1.3.conv_block.2.weight", "m_up1.4.trans_block.ln1.weight", "m_up1.4.trans_block.ln1.bias", "m_up1.4.trans_block.msa.relative_position_params", "m_up1.4.trans_block.msa.embedding_layer.weight", "m_up1.4.trans_block.msa.embedding_layer.bias", "m_up1.4.trans_block.msa.linear.weight", "m_up1.4.trans_block.msa.linear.bias", "m_up1.4.trans_block.ln2.weight", "m_up1.4.trans_block.ln2.bias", "m_up1.4.trans_block.mlp.0.weight", "m_up1.4.trans_block.mlp.0.bias", "m_up1.4.trans_block.mlp.2.weight", "m_up1.4.trans_block.mlp.2.bias", "m_up1.4.conv1_1.weight", "m_up1.4.conv1_1.bias", "m_up1.4.conv1_2.weight", "m_up1.4.conv1_2.bias", "m_up1.4.conv_block.0.weight", "m_up1.4.conv_block.2.weight".

import torch
import torch.onnx
from models.network_scunet import SCUNet  # Assuming this is the SCUNet model definition

def convert_to_onnx(model_path, onnx_path, input_shape=(1, 3, 256, 256)):
    # Load the pre-trained PyTorch model
    model = SCUNet()
    model.load_state_dict(torch.load(model_path))
    
    # Set the model to evaluation mode
    model.eval()

    # Define dummy input data
    dummy_input = torch.randn(input_shape)

    # Convert the model to ONNX format
    torch.onnx.export(model, dummy_input, onnx_path, verbose=True)

    print(f"Model converted to ONNX format and saved as {onnx_path}")

# Paths
model_path = "model_zoo/scunet_color_real_psnr.pth"
onnx_path = "./scunet_color_real_gan.onnx"

convert_to_onnx(model_path, onnx_path)

hello,Dr. Ayaz H. Khan ,do u have solved this question?

@instant-high
Copy link

I've converted the models toONNX (dynamic axes for input size). Models are working but take only square input images Divisible by 128

Hello, I just checked your homepage but didn't find any content about SCUNet. Can we discuss it?

I only converted it for local use.
I can search for my conversion script and post it here. If I find it

@Dlew-DUT
Copy link

Thank you so much. Your reply has given me hope. Thank you.

I've converted the models toONNX (dynamic axes for input size). Models are working but take only square input images Divisible by 128

Hello, I just checked your homepage but didn't find any content about SCUNet. Can we discuss it?

I only converted it for local use. I can search for my conversion script and post it here. If I find it

Thank you so much. Your reply has given me hope. Thank you.

@instant-high
Copy link

instant-high commented Feb 14, 2025

My conversion code, edit to convert model full or model half:
(input has to be divisible by 128, (128x128), (256x256) ......

Takes some time, be patient...
Also some warnings, but onnx model working

import argparse
import numpy as np
import torch

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='models/scunet_color_real_psnr.pth', help='scunet_color_real_psnrpth, scunet_color_real_gan.pth')
args = parser.parse_args()

n_channels = 3 # color model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu' 

from models.network_scunet import SCUNet as net
model = net(in_nc=n_channels,config=[4,4,4,4,4,4,4],dim=64)

model.load_state_dict(torch.load(args.model), strict=True)
model.eval()
for k, v in model.named_parameters():
    v.requires_grad = False
    
#model = model.half()
model = model.to(device)

onnx_path = 'scunet_color_real_psnr.onnx'

dummy_input = torch.randn(1, 3, 256, 256)
#model half:
#dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float16)

#or

#model:
dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float32)

#only if cuda available
dummy_input = dummy_input.to('cuda')

with torch.no_grad():
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        do_constant_folding=False,
        input_names=['input'],
        output_names=['output'],
        opset_version=12,
        dynamic_axes={'input': {2: 'height', 3: 'width'}, 'output': {2: 'height', 3: 'width'}}
    )

    input("Conversion done")
    
#

@Dlew-DUT
Copy link

My conversion code, edit to convert model full or model half: (input has to be divisible by 128, (128x128), (256x256) ......

Takes some time, be patient... Also some warnings, but onnx model working

import argparse
import numpy as np
import torch

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='models/scunet_color_real_psnr.pth', help='scunet_color_real_psnrpth, scunet_color_real_gan.pth')
args = parser.parse_args()

n_channels = 3 # color model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu' 

from models.network_scunet import SCUNet as net
model = net(in_nc=n_channels,config=[4,4,4,4,4,4,4],dim=64)

model.load_state_dict(torch.load(args.model), strict=True)
model.eval()
for k, v in model.named_parameters():
    v.requires_grad = False
    
#model = model.half()
model = model.to(device)

onnx_path = 'scunet_color_real_psnr.onnx'

dummy_input = torch.randn(1, 3, 256, 256)
#model half:
#dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float16)

#or

#model:
dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float32)

#only if cuda available
dummy_input = dummy_input.to('cuda')

with torch.no_grad():
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        do_constant_folding=False,
        input_names=['input'],
        output_names=['output'],
        opset_version=12,
        dynamic_axes={'input': {2: 'height', 3: 'width'}, 'output': {2: 'height', 3: 'width'}}
    )

    input("Conversion done")
    
#

Thank you very much, let me give it a try. Happy Chinese New Year, and I wish you a prosperous and successful year ahead

@ikmaus
Copy link

ikmaus commented Feb 20, 2025

This is how I proceeded, thank you very much. But my resulting image (originally 128x128) is then divided into 9 parts, which contain the channel information separately. How do I get a normal image again? I use onnx-runtime within my C++ code.

@instant-high
Copy link

instant-high commented Feb 20, 2025

ONNX Model input/output shapes should be
input (1,3,h,w) and output (1,3,h,w) or (batch, channels, height, width)

It's a very long time ago when I coded in C++
So this is pre- and post-processing of your image using python

    def denoise(self, img):
    
        img = img.astype(np.float32)
        img = img.transpose((2, 0, 1))
        img = img /255
        img = np.expand_dims(img, axis=0).astype(np.float32)
                        
        res = self.session.run(None,{'input':img})[0]

        res = (res.squeeze().transpose((1,2,0)) * 255).clip(0, 255).astype(np.uint8)
        
        return res 

Maybe this is the way to do in C++:

#include <opencv2/opencv.hpp>
#include <onnxruntime/core/session/onnxruntime_cxx_api.h>

cv::Mat denoise(Ort::Session& session, const cv::Mat& img) {
    cv::Mat floatImg;
    img.convertTo(floatImg, CV_32F, 1.0 / 255);
    
    // Convert HWC to CHW format
    std::vector<cv::Mat> channels(3);
    cv::split(floatImg, channels);
    for (int i = 0; i < 3; i++) {
        channels[i] = channels[i].reshape(1, 1);
    }
    
    cv::Mat chwImg;
    cv::hconcat(channels, chwImg);
    
    // Expand dimensions (1, 3, H, W)
    std::vector<int64_t> inputShape = {1, 3, img.rows, img.cols};
    std::vector<float> inputTensorValues(chwImg.begin<float>(), chwImg.end<float>());
    
    // Create ONNX Runtime input tensor
    Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, inputTensorValues.data(), inputTensorValues.size(), inputShape.data(), inputShape.size());
    
    // Run inference
    auto outputTensors = session.Run(Ort::RunOptions{nullptr}, {"input"}, &inputTensor, 1, session.GetOutputNames().data(), session.GetOutputNames().size());
    
    // Get output tensor
    float* outputData = outputTensors.front().GetTensorMutableData<float>();
    cv::Mat outputMat(3, img.size(), CV_32F, outputData);
    
    // Convert CHW back to HWC
    std::vector<cv::Mat> outputChannels = {
        cv::Mat(img.rows, img.cols, CV_32F, outputMat.ptr(0)),
        cv::Mat(img.rows, img.cols, CV_32F, outputMat.ptr(1)),
        cv::Mat(img.rows, img.cols, CV_32F, outputMat.ptr(2))
    };
    
    cv::Mat finalImg;
    cv::merge(outputChannels, finalImg);
    
    // Convert to uint8
    finalImg = (finalImg * 255.0f).clip(0, 255);
    finalImg.convertTo(finalImg, CV_8U);
    
    return finalImg;
}

@ikmaus
Copy link

ikmaus commented Feb 20, 2025

Hello, thank you very much! That really seemed to be my problem. To make that tensors from an cv::Mat in the right way. Now it works as expected. I only have some problems with the pixelation at the edges, but that has nothing to do with the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants