Skip to content

Commit

Permalink
Add option to make regions rectangles
Browse files Browse the repository at this point in the history
  • Loading branch information
MMaas3 committed Dec 13, 2023
1 parent d07d123 commit 4c33697
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 2 deletions.
1 change: 1 addition & 0 deletions configs/extra_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_C.PREPROCESS.REGION.REGIONS = []
_C.PREPROCESS.REGION.MERGE_REGIONS = []
_C.PREPROCESS.REGION.REGION_TYPE = []
_c.PREPROCESS.REGION.RECTANGLE_REGIONS = []

_C.PREPROCESS.BASELINE = CN()
_C.PREPROCESS.BASELINE.LINE_WIDTH = 5
Expand Down
1 change: 1 addition & 0 deletions configs/segmentation/region/region_dataset_photo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ PREPROCESS:
REGIONS: ["Photo"]
MERGE_REGIONS: []
REGION_TYPE: ["ImageRegion:Photo"]
RECTANGLE_REGIONS: ["Photo"]

MODEL:
MODE: region
Expand Down
15 changes: 13 additions & 2 deletions page_xml/output_pageXML.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
region_type: Optional[list[str]] = None,
cfg: Optional[CfgNode] = None,
whitelist: Optional[Iterable[str]] = None,
rectangle_regions: Optional[list[str]] = [],
) -> None:
"""
Class for the generation of the pageXML from class predictions on images
Expand All @@ -71,6 +72,7 @@ def __init__(
"""
super().__init__(mode, line_width, regions, merge_regions, region_type)

self.rectangle_regions = rectangle_regions
self.logger = logging.getLogger(get_logger_name())

self.output_dir = None
Expand Down Expand Up @@ -157,6 +159,7 @@ def generate_single_page(

page = PageData(xml_output_path)
page.new_page(image_path.name, str(old_height), str(old_width))

if self.cfg is not None:
page.add_processing_step(get_git_hash(), self.cfg.LAYPA_UUID, self.cfg, self.whitelist)

Expand Down Expand Up @@ -192,8 +195,16 @@ def generate_single_page(
approx_poly = np.round((approx_poly * scaling)).astype(np.int32)

region_coords = ""
for coords in approx_poly.reshape(-1, 2):
region_coords = region_coords + f" {coords[0]},{coords[1]}"
if region in self.rectangle_regions:
# find bounding box
rect = cv2.minAreaRect(approx_poly)
rect = cv2.boxPoints(rect)
for coords in rect:
region_coords = region_coords + f" {round(coords[0])},{round(coords[1])}"
else:
for coords in approx_poly.reshape(-1, 2):
region_coords = region_coords + f" {coords[0]},{coords[1]}"

region_coords = region_coords.strip()

_uuid = uuid.uuid4()
Expand Down
143 changes: 143 additions & 0 deletions test/test_output_pageXML.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,149 @@ def test_region_not_square(self):
self.assertEqual(1, len(image_coords_elements))
self.assertEqual("8,4 4,8 8,12 10,12 14,8 10,4", image_coords_elements[0].attrib.get("points"))

def test_rectangle_region_does_cotains_4_points(self):
output = tempfile.mktemp("_laypa_test")
xml = OutputPageXML(
"region",
output,
5,
["Photo"],
[],
["ImageRegion:Photo"],
None,
[],
["Photo"]

)
background = np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 0, 1, 1, 1],
[1, 1, 0, 0, 0, 0, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
image = np.invert(background == 1) * 1
array = np.array([background, image])
tensor = torch.from_numpy(array)

xml.generate_single_page(tensor, Path("/tmp/test.png"), 20, 20)

page_path = path.join(output, "page", "test.xml")
self.assertTrue(path.exists(page_path), "Page file does not exist")
page = ET.parse(page_path)
namespaces = {"page": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"}
image_coords_elements = page.findall("./page:Page/page:ImageRegion/page:Coords", namespaces=namespaces)
self.assertEqual(1, len(image_coords_elements))
coord_points = image_coords_elements[0].attrib.get("points")
self.assertEqual(4, coord_points.count(","), f"Contains more then 4 points: '{coord_points}'")

def test_rectangle_region_does_create_floating_point_coords(self):
output = tempfile.mktemp("_laypa_test")
xml = OutputPageXML(
"region",
output,
5,
["Photo"],
[],
["ImageRegion:Photo"],
None,
[],
["Photo"]

)
background = np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 0, 1, 1, 1],
[1, 1, 0, 0, 0, 0, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
image = np.invert(background == 1) * 1
array = np.array([background, image])
tensor = torch.from_numpy(array)

xml.generate_single_page(tensor, Path("/tmp/test.png"), 20, 20)

page_path = path.join(output, "page", "test.xml")
self.assertTrue(path.exists(page_path), "Page file does not exist")
page = ET.parse(page_path)
namespaces = {"page": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"}
image_coords_elements = page.findall("./page:Page/page:ImageRegion/page:Coords", namespaces=namespaces)
self.assertEqual(1, len(image_coords_elements))
coord_points = image_coords_elements[0].attrib.get("points")
self.assertEqual(0, coord_points.count("."), f"Probably contains floating points: '{coord_points}'")

def test_only_rectangle_region_one_type(self):
output = tempfile.mktemp("_laypa_test")
xml = OutputPageXML(
"region",
output,
5,
["Photo", "Text"],
[],
["ImageRegion:Photo", "TextRegion:Text"],
None,
[],
["Photo"]

)
background = np.array([[1, 1, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[1, 1, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 0, 0, 1, 1],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 1, 0, 0, 1, 1]])

image = np.array([[0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

text = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 0, 0]])
array = np.array([background, image, text])
tensor = torch.from_numpy(array)

xml.generate_single_page(tensor, Path("/tmp/test.png"), 20, 20)

page_path = path.join(output, "page", "test.xml")
self.assertTrue(path.exists(page_path), "Page file does not exist")
page = ET.parse(page_path)
namespaces = {"page": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"}
image_coords_elements = page.findall("./page:Page/page:ImageRegion/page:Coords", namespaces=namespaces)
self.assertEqual(1, len(image_coords_elements))
image_coord_points = image_coords_elements[0].attrib.get("points")
self.assertEqual(4, image_coord_points.count(","),
f"ImageRegion Contains more then 4 points: '{image_coord_points}'")
text_coords_elements = page.findall("./page:Page/page:TextRegion/page:Coords", namespaces=namespaces)
self.assertEqual(1, len(text_coords_elements))
text_coord_points = text_coords_elements[0].attrib.get("points")
self.assertEqual(6, text_coord_points.count(","), f"TextRegion less than 6 points: '{text_coord_points}'")


if __name__ == "__main__":
unittest.main()

0 comments on commit 4c33697

Please sign in to comment.