From 4c33697e41ca033d3f74917ca0cea48ff021ee39 Mon Sep 17 00:00:00 2001 From: Martijn Maas Date: Wed, 13 Dec 2023 16:03:55 +0100 Subject: [PATCH] Add option to make regions rectangles --- configs/extra_defaults.py | 1 + .../region/region_dataset_photo.yaml | 1 + page_xml/output_pageXML.py | 15 +- test/test_output_pageXML.py | 143 ++++++++++++++++++ 4 files changed, 158 insertions(+), 2 deletions(-) diff --git a/configs/extra_defaults.py b/configs/extra_defaults.py index 0c43b7d..343eabd 100644 --- a/configs/extra_defaults.py +++ b/configs/extra_defaults.py @@ -36,6 +36,7 @@ _C.PREPROCESS.REGION.REGIONS = [] _C.PREPROCESS.REGION.MERGE_REGIONS = [] _C.PREPROCESS.REGION.REGION_TYPE = [] +_c.PREPROCESS.REGION.RECTANGLE_REGIONS = [] _C.PREPROCESS.BASELINE = CN() _C.PREPROCESS.BASELINE.LINE_WIDTH = 5 diff --git a/configs/segmentation/region/region_dataset_photo.yaml b/configs/segmentation/region/region_dataset_photo.yaml index c6c4fbb..76bb5ff 100644 --- a/configs/segmentation/region/region_dataset_photo.yaml +++ b/configs/segmentation/region/region_dataset_photo.yaml @@ -9,6 +9,7 @@ PREPROCESS: REGIONS: ["Photo"] MERGE_REGIONS: [] REGION_TYPE: ["ImageRegion:Photo"] + RECTANGLE_REGIONS: ["Photo"] MODEL: MODE: region diff --git a/page_xml/output_pageXML.py b/page_xml/output_pageXML.py index c5347c1..6f7f2aa 100644 --- a/page_xml/output_pageXML.py +++ b/page_xml/output_pageXML.py @@ -56,6 +56,7 @@ def __init__( region_type: Optional[list[str]] = None, cfg: Optional[CfgNode] = None, whitelist: Optional[Iterable[str]] = None, + rectangle_regions: Optional[list[str]] = [], ) -> None: """ Class for the generation of the pageXML from class predictions on images @@ -71,6 +72,7 @@ def __init__( """ super().__init__(mode, line_width, regions, merge_regions, region_type) + self.rectangle_regions = rectangle_regions self.logger = logging.getLogger(get_logger_name()) self.output_dir = None @@ -157,6 +159,7 @@ def generate_single_page( page = PageData(xml_output_path) page.new_page(image_path.name, str(old_height), str(old_width)) + if self.cfg is not None: page.add_processing_step(get_git_hash(), self.cfg.LAYPA_UUID, self.cfg, self.whitelist) @@ -192,8 +195,16 @@ def generate_single_page( approx_poly = np.round((approx_poly * scaling)).astype(np.int32) region_coords = "" - for coords in approx_poly.reshape(-1, 2): - region_coords = region_coords + f" {coords[0]},{coords[1]}" + if region in self.rectangle_regions: + # find bounding box + rect = cv2.minAreaRect(approx_poly) + rect = cv2.boxPoints(rect) + for coords in rect: + region_coords = region_coords + f" {round(coords[0])},{round(coords[1])}" + else: + for coords in approx_poly.reshape(-1, 2): + region_coords = region_coords + f" {coords[0]},{coords[1]}" + region_coords = region_coords.strip() _uuid = uuid.uuid4() diff --git a/test/test_output_pageXML.py b/test/test_output_pageXML.py index 7c2f0a5..448d251 100644 --- a/test/test_output_pageXML.py +++ b/test/test_output_pageXML.py @@ -115,6 +115,149 @@ def test_region_not_square(self): self.assertEqual(1, len(image_coords_elements)) self.assertEqual("8,4 4,8 8,12 10,12 14,8 10,4", image_coords_elements[0].attrib.get("points")) + def test_rectangle_region_does_cotains_4_points(self): + output = tempfile.mktemp("_laypa_test") + xml = OutputPageXML( + "region", + output, + 5, + ["Photo"], + [], + ["ImageRegion:Photo"], + None, + [], + ["Photo"] + + ) + background = np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 0, 1, 1, 1], + [1, 1, 0, 0, 0, 0, 0, 0, 1, 1], + [1, 1, 1, 0, 0, 0, 0, 1, 1, 1], + [1, 1, 1, 1, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + image = np.invert(background == 1) * 1 + array = np.array([background, image]) + tensor = torch.from_numpy(array) + + xml.generate_single_page(tensor, Path("/tmp/test.png"), 20, 20) + + page_path = path.join(output, "page", "test.xml") + self.assertTrue(path.exists(page_path), "Page file does not exist") + page = ET.parse(page_path) + namespaces = {"page": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} + image_coords_elements = page.findall("./page:Page/page:ImageRegion/page:Coords", namespaces=namespaces) + self.assertEqual(1, len(image_coords_elements)) + coord_points = image_coords_elements[0].attrib.get("points") + self.assertEqual(4, coord_points.count(","), f"Contains more then 4 points: '{coord_points}'") + + def test_rectangle_region_does_create_floating_point_coords(self): + output = tempfile.mktemp("_laypa_test") + xml = OutputPageXML( + "region", + output, + 5, + ["Photo"], + [], + ["ImageRegion:Photo"], + None, + [], + ["Photo"] + + ) + background = np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 0, 1, 1, 1], + [1, 1, 0, 0, 0, 0, 0, 0, 1, 1], + [1, 1, 1, 0, 0, 0, 0, 1, 1, 1], + [1, 1, 1, 1, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + image = np.invert(background == 1) * 1 + array = np.array([background, image]) + tensor = torch.from_numpy(array) + + xml.generate_single_page(tensor, Path("/tmp/test.png"), 20, 20) + + page_path = path.join(output, "page", "test.xml") + self.assertTrue(path.exists(page_path), "Page file does not exist") + page = ET.parse(page_path) + namespaces = {"page": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} + image_coords_elements = page.findall("./page:Page/page:ImageRegion/page:Coords", namespaces=namespaces) + self.assertEqual(1, len(image_coords_elements)) + coord_points = image_coords_elements[0].attrib.get("points") + self.assertEqual(0, coord_points.count("."), f"Probably contains floating points: '{coord_points}'") + + def test_only_rectangle_region_one_type(self): + output = tempfile.mktemp("_laypa_test") + xml = OutputPageXML( + "region", + output, + 5, + ["Photo", "Text"], + [], + ["ImageRegion:Photo", "TextRegion:Text"], + None, + [], + ["Photo"] + + ) + background = np.array([[1, 1, 0, 0, 1, 1, 1, 1, 1, 1], + [1, 0, 0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 0, 1, 1, 1, 1, 1], + [1, 1, 0, 0, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 0, 0, 1, 1], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 1], + [1, 1, 1, 1, 1, 1, 0, 0, 1, 1]]) + + image = np.array([[0, 0, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) + + text = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0, 0]]) + array = np.array([background, image, text]) + tensor = torch.from_numpy(array) + + xml.generate_single_page(tensor, Path("/tmp/test.png"), 20, 20) + + page_path = path.join(output, "page", "test.xml") + self.assertTrue(path.exists(page_path), "Page file does not exist") + page = ET.parse(page_path) + namespaces = {"page": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} + image_coords_elements = page.findall("./page:Page/page:ImageRegion/page:Coords", namespaces=namespaces) + self.assertEqual(1, len(image_coords_elements)) + image_coord_points = image_coords_elements[0].attrib.get("points") + self.assertEqual(4, image_coord_points.count(","), + f"ImageRegion Contains more then 4 points: '{image_coord_points}'") + text_coords_elements = page.findall("./page:Page/page:TextRegion/page:Coords", namespaces=namespaces) + self.assertEqual(1, len(text_coords_elements)) + text_coord_points = text_coords_elements[0].attrib.get("points") + self.assertEqual(6, text_coord_points.count(","), f"TextRegion less than 6 points: '{text_coord_points}'") + if __name__ == "__main__": unittest.main()