From d4643f3a2137e90e60b8eba0d5bb06bb25552841 Mon Sep 17 00:00:00 2001 From: Guglielmo Camporese Date: Fri, 8 Jul 2022 10:32:04 +0200 Subject: [PATCH] fix predict --- main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index a6d2fa1..4b4d69d 100644 --- a/main.py +++ b/main.py @@ -158,10 +158,13 @@ def main(args): _ = model.eval() device = next(model.parameters()).device for x, x_path in tqdm(ds, desc='Save predictions'): + H, W = x.shape[-2:] + x = transforms.Resize((256, 256))(x) x = x.unsqueeze(0).to(device) logits = model(x).detach().cpu() preds = F.softmax(logits, 1).argmax(1)[0] * 255 # [h, w] - preds = Image.fromarray(preds.numpy().astype(np.uint8), 'P') + preds = Image.fromarray(preds.numpy().astype(np.uint8), 'L') + preds = preds.resize((W, H)) preds.save(f'{x_path}.png') else: