Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running main_test. py alone resulted in an error #17

Open
shushushulian opened this issue Dec 19, 2024 · 0 comments
Open

Running main_test. py alone resulted in an error #17

shushushulian opened this issue Dec 19, 2024 · 0 comments

Comments

@shushushulian
Copy link

Configure the following code according to the same configuration as the test_stngle. sh file, but there will be an error message indicating inconsistent loading model sizes.
`
from gln.common.cmd_args import cmd_args
import argparse
import os
import random
import numpy as np
import torch
from gln.test.model_inference import RetroGLN

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

random.seed(cmd_args.seed)
np.random.seed(cmd_args.seed)
torch.manual_seed(cmd_args.seed)

cmd_opt = argparse.ArgumentParser(description='Argparser for test only')
cmd_opt.add_argument('-model_for_test', default=None, help='model for test')
local_args = argparse.Namespace(
model_for_test="dropbox/schneider50k.ckpt",
)

cmd_args.dropbox = "GLN-master/dropbox"
cmd_args.data_name = 'schneider50k'
model = RetroGLN(cmd_args.dropbox, local_args.model_for_test)
**Error content**
~/Code/Retro/mutilstep/GLN-master/gln/test/model_inference.py in init(self, dropbox, model_dump)
37 model_file = os.path.join(model_dump, 'model.dump')
38 self.gln = GraphPath(self.args)
---> 39 self.gln.load_state_dict(torch.load(model_file))
40 self.gln.cuda()
41 self.gln.eval()

~/.conda/envs/GLN/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1666 if len(error_msgs) > 0:
1667 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1668 self.class.name, "\n\t".join(error_msgs)))
1669 return _IncompatibleKeys(missing_keys, unexpected_keys)
1670

RuntimeError: Error(s) in loading state_dict for GraphPath:
size mismatch for tpl_fwd_predicate.prod_enc.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for tpl_fwd_predicate.tpl_enc.prod_gnn.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for tpl_fwd_predicate.tpl_enc.react_gnn.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for prod_center_predicate.prod_enc.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for prod_center_predicate.prod_center_enc.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for reaction_predicate.prod_enc.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for reaction_predicate.react_enc.react_gnn.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant