Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Mismatched state_dict keys when loading custom ResNet model into EasyOCR due to inconsistent layer naming #1343

Open
nrosto opened this issue Dec 4, 2024 · 4 comments

Comments

@nrosto
Copy link

nrosto commented Dec 4, 2024

I am experiencing a RuntimeError while trying to load my custom ResNet model into EasyOCR. The error indicates that there are missing keys and unexpected keys in the state_dict during model loading.

RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.FeatureExtraction.ConvNet.0.weight", "module.FeatureExtraction.ConvNet.1.weight", "module.FeatureExtraction.ConvNet.1.bias", "module.FeatureExtraction.ConvNet.1.running_mean", "module.FeatureExtraction.ConvNet.1.running_var", "module.FeatureExtraction.ConvNet.3.weight", "module.FeatureExtraction.ConvNet.4.weight", "module.FeatureExtraction.ConvNet.4.bias", "module.FeatureExtraction.ConvNet.4.running_mean", "module.FeatureExtraction.ConvNet.4.running_var", "module.FeatureExtraction.ConvNet.7.0.conv1.weight", "module.FeatureExtraction.ConvNet.7.0.bn1.weight", "module.FeatureExtraction.ConvNet.7.0.bn1.bias", "module.FeatureExtraction.ConvNet.7.0.bn1.running_mean", "module.FeatureExtraction.ConvNet.7.0.bn1.running_var", "module.FeatureExtraction.ConvNet.7.0.conv2.weight", "module.FeatureExtraction.ConvNet.7.0.bn2.weight", "module.FeatureExtraction.ConvNet.7.0.bn2.bias", "module.FeatureExtraction.ConvNet.7.0.bn2.running_mean", "module.FeatureExtraction.ConvNet.7.0.bn2.running_var", "module.FeatureExtraction.ConvNet.7.0.downsample.0.weight", "module.FeatureExtraction.ConvNet.7.0.downsample.1.weight", "module.FeatureExtraction.ConvNet.7.0.downsample.1.bias", "module.FeatureExtraction.ConvNet.7.0.downsample.1.running_mean", "module.FeatureExtraction.ConvNet.7.0.downsample.1.running_var", "module.FeatureExtraction.ConvNet.8.weight", "module.FeatureExtraction.ConvNet.9.weight", "module.FeatureExtraction.ConvNet.9.bias", "module.FeatureExtraction.ConvNet.9.running_mean", "module.FeatureExtraction.ConvNet.9.running_var", "module.FeatureExtraction.ConvNet.12.0.conv1.weight", "module.FeatureExtraction.ConvNet.12.0.bn1.weight", "module.FeatureExtraction.ConvNet.12.0.bn1.bias", "module.FeatureExtraction.ConvNet.12.0.bn1.running_mean", "module.FeatureExtraction.ConvNet.12.0.bn1.running_var", "module.FeatureExtraction.ConvNet.12.0.conv2.weight", "module.FeatureExtraction.ConvNet.12.0.bn2.weight", "module.FeatureExtraction.ConvNet.12.0.bn2.bias", "module.FeatureExtraction.ConvNet.12.0.bn2.running_mean", "module.FeatureExtraction.ConvNet.12.0.bn2.running_var", "module.FeatureExtraction.ConvNet.12.0.downsample.0.weight", "module.FeatureExtraction.ConvNet.12.0.downsample.1.weight", "module.FeatureExtraction.ConvNet.12.0.downsample.1.bias", "module.FeatureExtraction.ConvNet.12.0.downsample.1.running_mean", "module.FeatureExtraction.ConvNet.12.0.downsample.1.running_var", "module.FeatureExtraction.ConvNet.12.1.conv1.weight", "module.FeatureExtraction.ConvNet.12.1.bn1.weight", "module.FeatureExtraction.ConvNet.12.1.bn1.bias", "module.FeatureExtraction.ConvNet.12.1.bn1.running_mean", "module.FeatureExtraction.ConvNet.12.1.bn1.running_var", "module.FeatureExtraction.ConvNet.12.1.conv2.weight", "module.FeatureExtraction.ConvNet.12.1.bn2.weight", "module.FeatureExtraction.ConvNet.12.1.bn2.bias", "module.FeatureExtraction.ConvNet.12.1.bn2.running_mean", "module.FeatureExtraction.ConvNet.12.1.bn2.running_var", "module.FeatureExtraction.ConvNet.13.weight", "module.FeatureExtraction.ConvNet.14.weight", "module.FeatureExtraction.ConvNet.14.bias", "module.FeatureExtraction.ConvNet.14.running_mean", "module.FeatureExtraction.ConvNet.14.running_var", "module.FeatureExtraction.ConvNet.17.0.conv1.weight", "module.FeatureExtraction.ConvNet.17.0.bn1.weight", "module.FeatureExtraction.ConvNet.17.0.bn1.bias", "module.FeatureExtraction.ConvNet.17.0.bn1.running_mean", "module.FeatureExtraction.ConvNet.17.0.bn1.running_var", "module.FeatureExtraction.ConvNet.17.0.conv2.weight", "module.FeatureExtraction.ConvNet.17.0.bn2.weight", "module.FeatureExtraction.ConvNet.17.0.bn2.bias", "module.FeatureExtraction.ConvNet.17.0.bn2.running_mean", "module.FeatureExtraction.ConvNet.17.0.bn2.running_var", "module.FeatureExtraction.ConvNet.17.0.downsample.0.weight", "module.FeatureExtraction.ConvNet.17.0.downsample.1.weight", "module.FeatureExtraction.ConvNet.17.0.downsample.1.bias", "module.FeatureExtraction.ConvNet.17.0.downsample.1.running_mean", "module.FeatureExtraction.ConvNet.17.0.downsample.1.running_var", "module.FeatureExtraction.ConvNet.17.1.conv1.weight", "module.FeatureExtraction.ConvNet.17.1.bn1.weight", "module.FeatureExtraction.ConvNet.17.1.bn1.bias", "module.FeatureExtraction.ConvNet.17.1.bn1.running_mean", "module.FeatureExtraction.ConvNet.17.1.bn1.running_var", "module.FeatureExtraction.ConvNet.17.1.conv2.weight", "module.FeatureExtraction.ConvNet.17.1.bn2.weight", "module.FeatureExtraction.ConvNet.17.1.bn2.bias", "module.FeatureExtraction.ConvNet.17.1.bn2.running_mean", "module.FeatureExtraction.ConvNet.17.1.bn2.running_var", "module.FeatureExtraction.ConvNet.17.2.conv1.weight", "module.FeatureExtraction.ConvNet.17.2.bn1.weight", "module.FeatureExtraction.ConvNet.17.2.bn1.bias", "module.FeatureExtraction.ConvNet.17.2.bn1.running_mean", "module.FeatureExtraction.ConvNet.17.2.bn1.running_var", "module.FeatureExtraction.ConvNet.17.2.conv2.weight", "module.FeatureExtraction.ConvNet.17.2.bn2.weight", "module.FeatureExtraction.ConvNet.17.2.bn2.bias", "module.FeatureExtraction.ConvNet.17.2.bn2.running_mean", "module.FeatureExtraction.ConvNet.17.2.bn2.running_var", "module.FeatureExtraction.ConvNet.17.3.conv1.weight", "module.FeatureExtraction.ConvNet.17.3.bn1.weight", "module.FeatureExtraction.ConvNet.17.3.bn1.bias", "module.FeatureExtraction.ConvNet.17.3.bn1.running_mean", "module.FeatureExtraction.ConvNet.17.3.bn1.running_var", "module.FeatureExtraction.ConvNet.17.3.conv2.weight", "module.FeatureExtraction.ConvNet.17.3.bn2.weight", "module.FeatureExtraction.ConvNet.17.3.bn2.bias", "module.FeatureExtraction.ConvNet.17.3.bn2.running_mean", "module.FeatureExtraction.ConvNet.17.3.bn2.running_var", "module.FeatureExtraction.ConvNet.17.4.conv1.weight", "module.FeatureExtraction.ConvNet.17.4.bn1.weight", "module.FeatureExtraction.ConvNet.17.4.bn1.bias", "module.FeatureExtraction.ConvNet.17.4.bn1.running_mean", "module.FeatureExtraction.ConvNet.17.4.bn1.running_var", "module.FeatureExtraction.ConvNet.17.4.conv2.weight", "module.FeatureExtraction.ConvNet.17.4.bn2.weight", "module.FeatureExtraction.ConvNet.17.4.bn2.bias", "module.FeatureExtraction.ConvNet.17.4.bn2.running_mean", "module.FeatureExtraction.ConvNet.17.4.bn2.running_var", "module.FeatureExtraction.ConvNet.18.weight", "module.FeatureExtraction.ConvNet.19.weight", "module.FeatureExtraction.ConvNet.19.bias", "module.FeatureExtraction.ConvNet.19.running_mean", "module.FeatureExtraction.ConvNet.19.running_var", "module.FeatureExtraction.ConvNet.21.0.conv1.weight", "module.FeatureExtraction.ConvNet.21.0.bn1.weight", "module.FeatureExtraction.ConvNet.21.0.bn1.bias", "module.FeatureExtraction.ConvNet.21.0.bn1.running_mean", "module.FeatureExtraction.ConvNet.21.0.bn1.running_var", "module.FeatureExtraction.ConvNet.21.0.conv2.weight", "module.FeatureExtraction.ConvNet.21.0.bn2.weight", "module.FeatureExtraction.ConvNet.21.0.bn2.bias", "module.FeatureExtraction.ConvNet.21.0.bn2.running_mean", "module.FeatureExtraction.ConvNet.21.0.bn2.running_var", "module.FeatureExtraction.ConvNet.21.1.conv1.weight", "module.FeatureExtraction.ConvNet.21.1.bn1.weight", "module.FeatureExtraction.ConvNet.21.1.bn1.bias", "module.FeatureExtraction.ConvNet.21.1.bn1.running_mean", "module.FeatureExtraction.ConvNet.21.1.bn1.running_var", "module.FeatureExtraction.ConvNet.21.1.conv2.weight", "module.FeatureExtraction.ConvNet.21.1.bn2.weight", "module.FeatureExtraction.ConvNet.21.1.bn2.bias", "module.FeatureExtraction.ConvNet.21.1.bn2.running_mean", "module.FeatureExtraction.ConvNet.21.1.bn2.running_var", "module.FeatureExtraction.ConvNet.21.2.conv1.weight", "module.FeatureExtraction.ConvNet.21.2.bn1.weight", "module.FeatureExtraction.ConvNet.21.2.bn1.bias", "module.FeatureExtraction.ConvNet.21.2.bn1.running_mean", "module.FeatureExtraction.ConvNet.21.2.bn1.running_var", "module.FeatureExtraction.ConvNet.21.2.conv2.weight", "module.FeatureExtraction.ConvNet.21.2.bn2.weight", "module.FeatureExtraction.ConvNet.21.2.bn2.bias", "module.FeatureExtraction.ConvNet.21.2.bn2.running_mean", "module.FeatureExtraction.ConvNet.21.2.bn2.running_var", "module.FeatureExtraction.ConvNet.22.weight", "module.FeatureExtraction.ConvNet.23.weight", "module.FeatureExtraction.ConvNet.23.bias", "module.FeatureExtraction.ConvNet.23.running_mean", "module.FeatureExtraction.ConvNet.23.running_var", "module.FeatureExtraction.ConvNet.25.weight", "module.FeatureExtraction.ConvNet.26.weight", "module.FeatureExtraction.ConvNet.26.bias", "module.FeatureExtraction.ConvNet.26.running_mean", "module.FeatureExtraction.ConvNet.26.running_var".

Unexpected key(s) in state_dict: "module.FeatureExtraction.ConvNet.conv0_1.weight", "module.FeatureExtraction.ConvNet.bn0_1.weight", "module.FeatureExtraction.ConvNet.bn0_1.bias", "module.FeatureExtraction.ConvNet.bn0_1.running_mean", "module.FeatureExtraction.ConvNet.bn0_1.running_var", "module.FeatureExtraction.ConvNet.bn0_1.num_batches_tracked", "module.FeatureExtraction.ConvNet.conv0_2.weight", "module.FeatureExtraction.ConvNet.bn0_2.weight", "module.FeatureExtraction.ConvNet.bn0_2.bias", "module.FeatureExtraction.ConvNet.bn0_2.running_mean", "module.FeatureExtraction.ConvNet.bn0_2.running_var", "module.FeatureExtraction.ConvNet.bn0_2.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer1.0.conv1.weight", "module.FeatureExtraction.ConvNet.layer1.0.bn1.weight", "module.FeatureExtraction.ConvNet.layer1.0.bn1.bias", "module.FeatureExtraction.ConvNet.layer1.0.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer1.0.bn1.running_var", "module.FeatureExtraction.ConvNet.layer1.0.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer1.0.conv2.weight", "module.FeatureExtraction.ConvNet.layer1.0.bn2.weight", "module.FeatureExtraction.ConvNet.layer1.0.bn2.bias", "module.FeatureExtraction.ConvNet.layer1.0.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer1.0.bn2.running_var", "module.FeatureExtraction.ConvNet.layer1.0.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer1.0.downsample.0.weight", "module.FeatureExtraction.ConvNet.layer1.0.downsample.1.weight", "module.FeatureExtraction.ConvNet.layer1.0.downsample.1.bias", "module.FeatureExtraction.ConvNet.layer1.0.downsample.1.running_mean", "module.FeatureExtraction.ConvNet.layer1.0.downsample.1.running_var", "module.FeatureExtraction.ConvNet.layer1.0.downsample.1.num_batches_tracked", "module.FeatureExtraction.ConvNet.conv1.weight", "module.FeatureExtraction.ConvNet.bn1.weight", "module.FeatureExtraction.ConvNet.bn1.bias", "module.FeatureExtraction.ConvNet.bn1.running_mean", "module.FeatureExtraction.ConvNet.bn1.running_var", "module.FeatureExtraction.ConvNet.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer2.0.conv1.weight", "module.FeatureExtraction.ConvNet.layer2.0.bn1.weight", "module.FeatureExtraction.ConvNet.layer2.0.bn1.bias", "module.FeatureExtraction.ConvNet.layer2.0.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer2.0.bn1.running_var", "module.FeatureExtraction.ConvNet.layer2.0.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer2.0.conv2.weight", "module.FeatureExtraction.ConvNet.layer2.0.bn2.weight", "module.FeatureExtraction.ConvNet.layer2.0.bn2.bias", "module.FeatureExtraction.ConvNet.layer2.0.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer2.0.bn2.running_var", "module.FeatureExtraction.ConvNet.layer2.0.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer2.0.downsample.0.weight", "module.FeatureExtraction.ConvNet.layer2.0.downsample.1.weight", "module.FeatureExtraction.ConvNet.layer2.0.downsample.1.bias", "module.FeatureExtraction.ConvNet.layer2.0.downsample.1.running_mean", "module.FeatureExtraction.ConvNet.layer2.0.downsample.1.running_var", "module.FeatureExtraction.ConvNet.layer2.0.downsample.1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer2.1.conv1.weight", "module.FeatureExtraction.ConvNet.layer2.1.bn1.weight", "module.FeatureExtraction.ConvNet.layer2.1.bn1.bias", "module.FeatureExtraction.ConvNet.layer2.1.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer2.1.bn1.running_var", "module.FeatureExtraction.ConvNet.layer2.1.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer2.1.conv2.weight", "module.FeatureExtraction.ConvNet.layer2.1.bn2.weight", "module.FeatureExtraction.ConvNet.layer2.1.bn2.bias", "module.FeatureExtraction.ConvNet.layer2.1.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer2.1.bn2.running_var", "module.FeatureExtraction.ConvNet.layer2.1.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.conv2.weight", "module.FeatureExtraction.ConvNet.bn2.weight", "module.FeatureExtraction.ConvNet.bn2.bias", "module.FeatureExtraction.ConvNet.bn2.running_mean", "module.FeatureExtraction.ConvNet.bn2.running_var", "module.FeatureExtraction.ConvNet.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.0.conv1.weight", "module.FeatureExtraction.ConvNet.layer3.0.bn1.weight", "module.FeatureExtraction.ConvNet.layer3.0.bn1.bias", "module.FeatureExtraction.ConvNet.layer3.0.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer3.0.bn1.running_var", "module.FeatureExtraction.ConvNet.layer3.0.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.0.conv2.weight", "module.FeatureExtraction.ConvNet.layer3.0.bn2.weight", "module.FeatureExtraction.ConvNet.layer3.0.bn2.bias", "module.FeatureExtraction.ConvNet.layer3.0.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer3.0.bn2.running_var", "module.FeatureExtraction.ConvNet.layer3.0.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.0.downsample.0.weight", "module.FeatureExtraction.ConvNet.layer3.0.downsample.1.weight", "module.FeatureExtraction.ConvNet.layer3.0.downsample.1.bias", "module.FeatureExtraction.ConvNet.layer3.0.downsample.1.running_mean", "module.FeatureExtraction.ConvNet.layer3.0.downsample.1.running_var", "module.FeatureExtraction.ConvNet.layer3.0.downsample.1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.1.conv1.weight", "module.FeatureExtraction.ConvNet.layer3.1.bn1.weight", "module.FeatureExtraction.ConvNet.layer3.1.bn1.bias", "module.FeatureExtraction.ConvNet.layer3.1.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer3.1.bn1.running_var", "module.FeatureExtraction.ConvNet.layer3.1.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.1.conv2.weight", "module.FeatureExtraction.ConvNet.layer3.1.bn2.weight", "module.FeatureExtraction.ConvNet.layer3.1.bn2.bias", "module.FeatureExtraction.ConvNet.layer3.1.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer3.1.bn2.running_var", "module.FeatureExtraction.ConvNet.layer3.1.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.2.conv1.weight", "module.FeatureExtraction.ConvNet.layer3.2.bn1.weight", "module.FeatureExtraction.ConvNet.layer3.2.bn1.bias", "module.FeatureExtraction.ConvNet.layer3.2.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer3.2.bn1.running_var", "module.FeatureExtraction.ConvNet.layer3.2.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.2.conv2.weight", "module.FeatureExtraction.ConvNet.layer3.2.bn2.weight", "module.FeatureExtraction.ConvNet.layer3.2.bn2.bias", "module.FeatureExtraction.ConvNet.layer3.2.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer3.2.bn2.running_var", "module.FeatureExtraction.ConvNet.layer3.2.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.3.conv1.weight", "module.FeatureExtraction.ConvNet.layer3.3.bn1.weight", "module.FeatureExtraction.ConvNet.layer3.3.bn1.bias", "module.FeatureExtraction.ConvNet.layer3.3.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer3.3.bn1.running_var", "module.FeatureExtraction.ConvNet.layer3.3.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.3.conv2.weight", "module.FeatureExtraction.ConvNet.layer3.3.bn2.weight", "module.FeatureExtraction.ConvNet.layer3.3.bn2.bias", "module.FeatureExtraction.ConvNet.layer3.3.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer3.3.bn2.running_var", "module.FeatureExtraction.ConvNet.layer3.3.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.4.conv1.weight", "module.FeatureExtraction.ConvNet.layer3.4.bn1.weight", "module.FeatureExtraction.ConvNet.layer3.4.bn1.bias", "module.FeatureExtraction.ConvNet.layer3.4.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer3.4.bn1.running_var", "module.FeatureExtraction.ConvNet.layer3.4.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer3.4.conv2.weight", "module.FeatureExtraction.ConvNet.layer3.4.bn2.weight", "module.FeatureExtraction.ConvNet.layer3.4.bn2.bias", "module.FeatureExtraction.ConvNet.layer3.4.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer3.4.bn2.running_var", "module.FeatureExtraction.ConvNet.layer3.4.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.conv3.weight", "module.FeatureExtraction.ConvNet.bn3.weight", "module.FeatureExtraction.ConvNet.bn3.bias", "module.FeatureExtraction.ConvNet.bn3.running_mean", "module.FeatureExtraction.ConvNet.bn3.running_var", "module.FeatureExtraction.ConvNet.bn3.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer4.0.conv1.weight", "module.FeatureExtraction.ConvNet.layer4.0.bn1.weight", "module.FeatureExtraction.ConvNet.layer4.0.bn1.bias", "module.FeatureExtraction.ConvNet.layer4.0.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer4.0.bn1.running_var", "module.FeatureExtraction.ConvNet.layer4.0.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer4.0.conv2.weight", "module.FeatureExtraction.ConvNet.layer4.0.bn2.weight", "module.FeatureExtraction.ConvNet.layer4.0.bn2.bias", "module.FeatureExtraction.ConvNet.layer4.0.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer4.0.bn2.running_var", "module.FeatureExtraction.ConvNet.layer4.0.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer4.1.conv1.weight", "module.FeatureExtraction.ConvNet.layer4.1.bn1.weight", "module.FeatureExtraction.ConvNet.layer4.1.bn1.bias", "module.FeatureExtraction.ConvNet.layer4.1.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer4.1.bn1.running_var", "module.FeatureExtraction.ConvNet.layer4.1.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer4.1.conv2.weight", "module.FeatureExtraction.ConvNet.layer4.1.bn2.weight", "module.FeatureExtraction.ConvNet.layer4.1.bn2.bias", "module.FeatureExtraction.ConvNet.layer4.1.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer4.1.bn2.running_var", "module.FeatureExtraction.ConvNet.layer4.1.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer4.2.conv1.weight", "module.FeatureExtraction.ConvNet.layer4.2.bn1.weight", "module.FeatureExtraction.ConvNet.layer4.2.bn1.bias", "module.FeatureExtraction.ConvNet.layer4.2.bn1.running_mean", "module.FeatureExtraction.ConvNet.layer4.2.bn1.running_var", "module.FeatureExtraction.ConvNet.layer4.2.bn1.num_batches_tracked", "module.FeatureExtraction.ConvNet.layer4.2.conv2.weight", "module.FeatureExtraction.ConvNet.layer4.2.bn2.weight", "module.FeatureExtraction.ConvNet.layer4.2.bn2.bias", "module.FeatureExtraction.ConvNet.layer4.2.bn2.running_mean", "module.FeatureExtraction.ConvNet.layer4.2.bn2.running_var", "module.FeatureExtraction.ConvNet.layer4.2.bn2.num_batches_tracked", "module.FeatureExtraction.ConvNet.conv4_1.weight", "module.FeatureExtraction.ConvNet.bn4_1.weight", "module.FeatureExtraction.ConvNet.bn4_1.bias", "module.FeatureExtraction.ConvNet.bn4_1.running_mean", "module.FeatureExtraction.ConvNet.bn4_1.running_var", "module.FeatureExtraction.ConvNet.bn4_1.num_batches_tracked", "module.FeatureExtraction.ConvNet.conv4_2.weight", "module.FeatureExtraction.ConvNet.bn4_2.weight", "module.FeatureExtraction.ConvNet.bn4_2.bias", "module.FeatureExtraction.ConvNet.bn4_2.running_mean", "module.FeatureExtraction.ConvNet.bn4_2.running_var", "module.FeatureExtraction.ConvNet.bn4_2.num_batches_tracked".

My model architecture in custom model:
Transformation: 'None'
FeatureExtraction: 'ResNet'
SequenceModeling: 'BiLSTM'
Prediction: 'CTC'

Also, I use model.py for training and feature_extraction.py

my_custom_model.py:
class ResNet_FeatureExtractor(nn.Module):
""" FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """

def __init__(self, input_channel, output_channel=512):
    super(ResNet_FeatureExtractor, self).__init__()
    self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])

def forward(self, input):
    return self.ConvNet(input)

class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
    super(BasicBlock, self).__init__()
    self.conv1 = self._conv3x3(inplanes, planes)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = self._conv3x3(planes, planes)
    self.bn2 = nn.BatchNorm2d(planes)
    self.relu = nn.ReLU(inplace=True)
    self.downsample = downsample
    self.stride = stride

def _conv3x3(self, in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def forward(self, x):
    residual = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        residual = self.downsample(x)
    out += residual
    out = self.relu(out)

    return out

class ResNet(nn.Module):

def __init__(self, input_channel, output_channel, block, layers):
    super(ResNet, self).__init__()

    self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]

    self.inplanes = int(output_channel / 8)
    self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
                             kernel_size=3, stride=1, padding=1, bias=False)
    self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
    self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
                             kernel_size=3, stride=1, padding=1, bias=False)
    self.bn0_2 = nn.BatchNorm2d(self.inplanes)
    self.relu = nn.ReLU(inplace=True)

    self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
    self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
                           0], kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])

    self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
    self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
                           1], kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])

    self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
    self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
    self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
                           2], kernel_size=3, stride=1, padding=1, bias=False)
    self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])

    self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
    self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
                             3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
    self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
    self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
                             3], kernel_size=2, stride=1, padding=0, bias=False)
    self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])

def _make_layer(self, block, planes, blocks, stride=1):
    downsample = None
    if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(self.inplanes, planes * block.expansion,
                      kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes * block.expansion),
        )

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample))
    self.inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(block(self.inplanes, planes))

    return nn.Sequential(*layers)

def forward(self, x):
    x = self.conv0_1(x)
    x = self.bn0_1(x)
    x = self.relu(x)
    x = self.conv0_2(x)
    x = self.bn0_2(x)
    x = self.relu(x)

    x = self.maxpool1(x)
    x = self.layer1(x)
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)

    x = self.maxpool2(x)
    x = self.layer2(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu(x)

    x = self.maxpool3(x)
    x = self.layer3(x)
    x = self.conv3(x)
    x = self.bn3(x)
    x = self.relu(x)

    x = self.layer4(x)
    x = self.conv4_1(x)
    x = self.bn4_1(x)
    x = self.relu(x)
    x = self.conv4_2(x)
    x = self.bn4_2(x)
    x = self.relu(x)

    return x

class Model(nn.Module):

def __init__(self, input_channel, output_channel, hidden_size, num_class):
    super(Model, self).__init__()

    self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel)
    self.FeatureExtraction_output = output_channel
    self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None,1))

    self.SequenceModeling = nn.Sequential(
        BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
        BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
    self.SequenceModeling_output = hidden_size

    self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)

def forward(self, input, text):
    visual_feature = self.FeatureExtraction(input)
    visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0,3,1,2))
    visual_feature = visual_feature.squeeze(3)

    contextual_feature = self.SequenceModeling(visual_feature)

    prediction = self.Prediction(contextual_feature.contiguous())

    return prediction
@romanvelichkin
Copy link

Are you having problems during loading model or during training?

When you're loading model, you show path to folder with your models. That folder has to contain 'model' folder with model weights: 'my_model_name.pth' and 'user_network' folder with model structure (classes etc) and: 'my_model_name.py' and 'my_model_name.yaml' file describing some info about model: chars etc.

So it looks like this:

  • my_models
    • model
      • my_model_name.pth
    • user_network
      • my_model_name.py
      • my_model_name.yaml

You can try to download default models and look yourself how folder-file structure looks for them.

@nrosto
Copy link
Author

nrosto commented Dec 20, 2024

@romanvelichkin
The issues occur after training, when I try to load the model.
Yes, the structure is exactly as described.

The thing is, with a VGG model of the same architecture, everything works fine.

But when I try to train a ResNet model, it doesn't, and this error appears.

@romanvelichkin
Copy link

romanvelichkin commented Dec 22, 2024

Try to rename layers.

Here is my code, how I do it:

# Remove 'module.' from layers names
my_model_dict = torch.load('my_model.pth')

keys_dict = {}

for key in my_model_dict:
    keys_dict[key] = key.replace('module.', '')

for old_key, new_key in keys_dict.items():
    if old_key in my_model_dict:
        my_model_dict[new_key] = my_model_dict.pop(old_key)

model = CRAFT()  # here you create your model, I trained my model for CRAFT (detection part).
model.load_state_dict(my_model_dict)

@dasantosa
Copy link

Hi! @nrosto I am having the same problem after training. After running the train.py script I want to export that model to onnx and when I try to load my trained model I am getting the error. Did you find a solution?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants