diff --git a/page_xml/xmlPAGE.py b/page_xml/xmlPAGE.py index 179d560..f7ddd15 100644 --- a/page_xml/xmlPAGE.py +++ b/page_xml/xmlPAGE.py @@ -182,10 +182,29 @@ def iter_baseline_coords(self): if len(split_str_coords) == 0: continue if len(split_str_coords) == 1: - split_str_coords = split_str_coords * 2 # double for polyline + split_str_coords = split_str_coords * 2 # double for cv2.polyline coords = np.array([i.split(",") for i in split_str_coords]).astype(np.int32) yield coords + def iter_class_baseline_coords(self, element, class_dict): + for class_node in self._iter_element(element): + element_type = self.get_region_type(class_node) + if element_type is None or element_type not in class_dict: + self.logger.warning(f'Element type "{element_type}" undefined in class dict {self.filepath}') + continue + element_class = class_dict[element_type] + for baseline_node in class_node.iterfind("".join([".//", self.base, "Baseline"])): + str_coords = baseline_node.attrib.get("points") + if str_coords is None: + continue + split_str_coords = str_coords.split() + if len(split_str_coords) == 0: + continue + if len(split_str_coords) == 1: + split_str_coords = split_str_coords * 2 # double for cv2.polyline + coords = np.array([i.split(",") for i in split_str_coords]).astype(np.int32) + yield element_class, coords + def iter_text_line_coords(self): for node in self._iter_element("TextLine"): coords = self.get_coords(node) diff --git a/page_xml/xml_converter.py b/page_xml/xml_converter.py index bbe7dd2..66b63fa 100644 --- a/page_xml/xml_converter.py +++ b/page_xml/xml_converter.py @@ -7,13 +7,13 @@ import cv2 import numpy as np from detectron2 import structures -from detectron2.config import configurable +from detectron2.config import CfgNode, configurable sys.path.append(str(Path(__file__).resolve().parent.joinpath(".."))) -from detectron2.config import CfgNode from page_xml.xml_regions import XMLRegions from page_xml.xmlPAGE import PageData +from utils.image_utils import save_image_array_to_path from utils.logging_utils import get_logger_name from utils.vector_utils import ( point_at_start_or_end_assignment, @@ -394,6 +394,24 @@ def build_baseline_sem_seg(self, page: PageData, out_size: tuple[int, int], line self.logger.warning(f"File {page.filepath} does not contains baseline sem_seg") return sem_seg + # CLASS BASELINES + + def build_class_baseline_sem_seg(self, page: PageData, out_size: tuple[int, int], line_width: int, elements, class_dict): + size = page.get_size() + sem_seg = np.zeros(out_size, np.uint8) + total_overlap = False + for element in elements: + for element_class, baseline_coords in page.iter_class_baseline_coords(element, class_dict): + coords = self._scale_coords(baseline_coords, out_size, size) + sem_seg, overlap = self.draw_line(sem_seg, coords, element_class, thickness=line_width) + total_overlap = total_overlap or overlap + + if total_overlap: + self.logger.warning(f"File {page.filepath} contains overlapping class baseline sem_seg") + if not sem_seg.any(): + self.logger.warning(f"File {page.filepath} does not contains class baseline sem_seg") + return sem_seg + # TOP BOTTOM def build_top_bottom_sem_seg(self, page: PageData, out_size: tuple[int, int], line_width: int): @@ -554,6 +572,15 @@ def to_sem_seg( line_width=self.xml_regions.line_width, ) return sem_seg + elif self.xml_regions.mode == "class_baseline": + sem_seg = self.build_class_baseline_sem_seg( + gt_data, + image_shape, + line_width=self.xml_regions.line_width, + elements=set(self.xml_regions.region_types.values()), + class_dict=self.xml_regions.region_classes, + ) + return sem_seg elif self.xml_regions.mode == "top_bottom": sem_seg = self.build_top_bottom_sem_seg( gt_data, @@ -714,7 +741,16 @@ def to_pano( merge_regions=args.merge_regions, region_type=args.region_type, ) - XMLConverter(xml_regions, args.square_lines) + xml_converter = XMLConverter(xml_regions, args.square_lines) input_path = Path(args.input) output_path = Path(args.output) + + sem_seg = xml_converter.to_sem_seg( + input_path, + original_image_shape=None, + image_shape=None, + ) + + # save image + save_image_array_to_path(output_path, sem_seg) diff --git a/page_xml/xml_regions.py b/page_xml/xml_regions.py index 351cb60..e13f329 100644 --- a/page_xml/xml_regions.py +++ b/page_xml/xml_regions.py @@ -29,7 +29,7 @@ def __init__( region_types (Optional[list[str]], optional): type of region for each region. Defaults to None. """ self.mode = mode - if self.mode == "region": + if mode == "region" or mode == "class_baseline": assert regions is not None self._regions = [] @@ -46,7 +46,7 @@ def __init__( self.regions = regions self.region_types = region_type self.merged_regions = merge_regions - else: + if mode != "region": assert line_width is not None self._regions = self._build_regions() @@ -105,7 +105,6 @@ def get_parser(cls) -> argparse.ArgumentParser: "-m", "--mode", default="region", - choices=["baseline", "region", "start", "end", "separator", "baseline_separator"], type=str, help="Output mode", ) @@ -227,7 +226,7 @@ def _build_regions(self) -> list[str]: list[str]: the names of all the classes currently used """ remaining_regions = ["background"] - if self.mode == "region": + if self.mode == "region" or self.mode == "class_baseline": removed_regions = set() if self.merged_regions is not None: for values in self.merged_regions.values():