diff --git a/infer.py b/infer.py index 73dd8fc..e7ff8b8 100644 --- a/infer.py +++ b/infer.py @@ -39,7 +39,7 @@ def main(): testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=2) - model = config.model.to(device) + model = config.get_model().to(device) weights = glob.glob(os.path.join(args.weights_root, "*.pth.tar"))