Skip to content

Commit

Permalink
Add sem_seg_to_classes_and_confidence and confidence_heatmap methods
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut committed Jan 10, 2024
1 parent 640f0b6 commit 975d7df
Showing 1 changed file with 52 additions and 17 deletions.
69 changes: 52 additions & 17 deletions page_xml/output_pageXML.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,46 @@ def scale_to_range(

return tensor

@staticmethod
def sem_seg_to_classes_and_confidence(
sem_seg: torch.Tensor,
height: Optional[int] = None,
width: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Convert a single prediction into classes and confidence
Args:
sem_seg (torch.Tensor): sem_seg as tensor
Returns:
torch.Tensor, torch.Tensor: classes and confidence
"""
if height is not None and width is not None:
sem_seg = torch.nn.functional.interpolate(
sem_seg[None], size=(height, width), mode="bilinear", align_corners=False
)[0]

sem_seg_normalized = torch.nn.functional.softmax(sem_seg, dim=-3)
confidence, sem_seg_classes = torch.max(sem_seg_normalized, dim=-3)

return sem_seg_classes, confidence

@staticmethod
def confidence_heatmap(confidence: torch.Tensor):
"""
Generate a confidence heatmap
Args:
confidence (torch.Tensor): confidence as tensor
Returns:
np.ndarray: confidence heatmap
"""
confidence_grayscale = (confidence * 255).cpu().numpy().astype(np.uint8)
confidence_heatmap = cv2.applyColorMap(confidence_grayscale, cv2.COLORMAP_PLASMA)[..., ::-1]
return confidence_heatmap

def generate_single_page(
self,
sem_seg: torch.Tensor,
Expand Down Expand Up @@ -196,15 +236,14 @@ def generate_single_page(

if self.xml_regions.mode == "region":
confidence_output_path = self.page_dir.joinpath(image_path.stem + "_confidence.png")
sem_seg_normalized = torch.nn.functional.softmax(sem_seg, dim=-3)
confidence, sem_seg_classes = torch.max(sem_seg_normalized, dim=-3)
sem_seg_classes, confidence = self.sem_seg_to_classes_and_confidence(sem_seg)

# Apply a color map
scaled_confidence = self.scale_to_range(confidence, tensor_min=1 / len(self.regions), tensor_max=1.0)
confidence_grayscale = (scaled_confidence * 255).cpu().numpy().astype(np.uint8)
confidence_colored = cv2.applyColorMap(confidence_grayscale, cv2.COLORMAP_PLASMA)[..., ::-1]

# Apply a color map
confidence_heatmap = self.confidence_heatmap(confidence)
with AtomicFileName(file_path=confidence_output_path) as path:
save_image_array_to_path(str(path), confidence_colored)
save_image_array_to_path(str(path), confidence_heatmap)

sem_seg_classes = sem_seg_classes.cpu().numpy()
mean_confidence = torch.mean(scaled_confidence).cpu().numpy()
Expand Down Expand Up @@ -256,16 +295,13 @@ def generate_single_page(
elif self.xml_regions.mode in ["baseline", "start", "end", "separator"]:
sem_seg_output_path = self.page_dir.joinpath(image_path.stem + ".png")
confidence_output_path = self.page_dir.joinpath(image_path.stem + "_confidence.png")

sem_seg_normalized = torch.nn.functional.softmax(sem_seg, dim=-3)
confidence, sem_seg_classes = torch.max(sem_seg_normalized, dim=-3)
sem_seg_classes, confidence = self.sem_seg_to_classes_and_confidence(sem_seg, old_height, old_width)

# Apply a color map
scaled_confidence = self.scale_to_range(confidence, tensor_min=1 / len(self.xml_regions.regions), tensor_max=1.0)
confidence_grayscale = (scaled_confidence * 255).cpu().numpy().astype(np.uint8)
confidence_colored = cv2.applyColorMap(confidence_grayscale, cv2.COLORMAP_PLASMA)[..., ::-1]
confidence_heatmap = self.confidence_heatmap(scaled_confidence)
with AtomicFileName(file_path=confidence_output_path) as path:
save_image_array_to_path(str(path), confidence_colored)
save_image_array_to_path(str(path), confidence_heatmap)

sem_seg_classes = sem_seg_classes.cpu().numpy()
mean_confidence = torch.mean(scaled_confidence).cpu().numpy()
Expand All @@ -278,15 +314,14 @@ def generate_single_page(
elif self.xml_regions.mode in ["baseline_separator", "top_bottom"]:
sem_seg_output_path = self.page_dir.joinpath(image_path.stem + ".png")
confidence_output_path = self.page_dir.joinpath(image_path.stem + "_confidence.png")
sem_seg_normalized = torch.nn.functional.softmax(sem_seg, dim=-3)
confidence, sem_seg_classes = torch.max(sem_seg_normalized, dim=-3)

sem_seg_classes, confidence = self.sem_seg_to_classes_and_confidence(sem_seg, old_height, old_width)

# Apply a color map
scaled_confidence = self.scale_to_range(confidence, tensor_min=1 / len(self.xml_regions.regions), tensor_max=1.0)
confidence_grayscale = (scaled_confidence * 255).cpu().numpy().astype(np.uint8)
confidence_colored = cv2.applyColorMap(confidence_grayscale, cv2.COLORMAP_PLASMA)[..., ::-1]
confidence_heatmap = self.confidence_heatmap(scaled_confidence)
with AtomicFileName(file_path=confidence_output_path) as path:
save_image_array_to_path(str(path), confidence_colored)
save_image_array_to_path(str(path), confidence_heatmap)

sem_seg_classes = sem_seg_classes.cpu().numpy()
mean_confidence = torch.mean(scaled_confidence).cpu().numpy()
Expand Down

0 comments on commit 975d7df

Please sign in to comment.