Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic shape inference with onnx model #145

Open
jonas-doevenspeck opened this issue Jul 24, 2023 · 2 comments
Open

dynamic shape inference with onnx model #145

jonas-doevenspeck opened this issue Jul 24, 2023 · 2 comments

Comments

@jonas-doevenspeck
Copy link

jonas-doevenspeck commented Jul 24, 2023

After exporting the model to onnx, onnxruntime fails to do inference with inputs that have a different shape than the one used to export the model.

To reproduce:

import onnxruntime
import torch
import numpy as np

model = SwinIR(
    upscale=4,
    in_chans=3,
    img_size=64,
    window_size=8,
    img_range=1.0,
    depths=[6,6,6,6,6,6],
    embed_dim=180,
    num_heads=[6,6,6,6,6,6],
    mlp_ratio=2,
    upsampler="pixelshuffle",
    resi_connection="1conv",
)

model.load_state_dict('001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth')

torch.onnx.export(model, torch.randn(1, 3, 64, 64), 'swin_ir.onnx', dynamic_axes = {"input": [0,2,3],"output": [0,2,3]}, verbose=False, opset_version=17, input_names=["input"], output_names=["output"])

ort_session = onnxruntime.InferenceSession('swin_ir.onnx')

x = np.random.randn(1, 3, 72, 72).astype(np.float32)
ort_inputs = {ort_session.get_inputs()[0].name: x}
ort_outs = ort_session.run(None, ort_inputs)

error message:

[E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Reshape node. Name:'/layers.0/residual_group/blocks.1/attn/Reshape_1' Status Message: [/onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:40](https://vscode-remote+ssh-002dremote-002b10-002e30-002e2-002e76.vscode-resource.vscode-cdn.net/onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:40) onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, onnxruntime::TensorShapeVector&, bool) gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{88,6,64,64}, requested shape:{1,64,6,64,64}

environment: pytorch==1.13.1, onnxruntime==1.15.1, onnx==1.12.0

@pg-alfredlee
Copy link

Hi @jonas-doevenspeck , do you have any luck now on the onnx model inference?

@SherryXieYuchen
Copy link

SherryXieYuchen commented Jun 3, 2024

Hi, I fix it by editing network_swinir.py:
1、replace
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x

by
B = windows.shape[0] / (H * W / window_size / window_size)
x = windows.view(B.numel(), H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B.numel(), H, W, -1)
return x

2、replace
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nWB, window_sizewindow_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))

by
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))

then do the pytorch2onnx, dynamic input should work when inference with onnx model. Hope it helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants