diff --git a/eval.py b/eval.py index 84f1b78..f1318c5 100644 --- a/eval.py +++ b/eval.py @@ -124,7 +124,11 @@ def create_gt_visualization(image_filename, sem_seg_filename): def create_pred_visualization(image_filename): image = load_image(image_filename) logger.info(f"Predict: {image_filename}") - outputs = predictor(image) + if cfg.INPUT.FORMAT == "BGR": + input_image = image[..., ::-1] + else: + input_image = image + outputs = predictor(input_image) pred = torch.argmax(outputs[0]["sem_seg"], dim=-3).to("cpu") # outputs["panoptic_seg"] = (outputs["panoptic_seg"][0].to("cpu"), # outputs["panoptic_seg"][1])