Skip to content

Commit

Permalink
fix predict
Browse files Browse the repository at this point in the history
  • Loading branch information
Guglielmo Camporese committed Jul 8, 2022
1 parent 761f54a commit d4643f3
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,13 @@ def main(args):
_ = model.eval()
device = next(model.parameters()).device
for x, x_path in tqdm(ds, desc='Save predictions'):
H, W = x.shape[-2:]
x = transforms.Resize((256, 256))(x)
x = x.unsqueeze(0).to(device)
logits = model(x).detach().cpu()
preds = F.softmax(logits, 1).argmax(1)[0] * 255 # [h, w]
preds = Image.fromarray(preds.numpy().astype(np.uint8), 'P')
preds = Image.fromarray(preds.numpy().astype(np.uint8), 'L')
preds = preds.resize((W, H))
preds.save(f'{x_path}.png')

else:
Expand Down

0 comments on commit d4643f3

Please sign in to comment.