Skip to content

Commit

Permalink
Merge pull request #33 from stefanklut/region-baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut committed Apr 5, 2024
2 parents 275f411 + 107887d commit 8b624a7
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 8 deletions.
21 changes: 20 additions & 1 deletion page_xml/xmlPAGE.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,29 @@ def iter_baseline_coords(self):
if len(split_str_coords) == 0:
continue
if len(split_str_coords) == 1:
split_str_coords = split_str_coords * 2 # double for polyline
split_str_coords = split_str_coords * 2 # double for cv2.polyline
coords = np.array([i.split(",") for i in split_str_coords]).astype(np.int32)
yield coords

def iter_class_baseline_coords(self, element, class_dict):
for class_node in self._iter_element(element):
element_type = self.get_region_type(class_node)
if element_type is None or element_type not in class_dict:
self.logger.warning(f'Element type "{element_type}" undefined in class dict {self.filepath}')
continue
element_class = class_dict[element_type]
for baseline_node in class_node.iterfind("".join([".//", self.base, "Baseline"])):
str_coords = baseline_node.attrib.get("points")
if str_coords is None:
continue
split_str_coords = str_coords.split()
if len(split_str_coords) == 0:
continue
if len(split_str_coords) == 1:
split_str_coords = split_str_coords * 2 # double for cv2.polyline
coords = np.array([i.split(",") for i in split_str_coords]).astype(np.int32)
yield element_class, coords

def iter_text_line_coords(self):
for node in self._iter_element("TextLine"):
coords = self.get_coords(node)
Expand Down
42 changes: 39 additions & 3 deletions page_xml/xml_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import cv2
import numpy as np
from detectron2 import structures
from detectron2.config import configurable
from detectron2.config import CfgNode, configurable

sys.path.append(str(Path(__file__).resolve().parent.joinpath("..")))
from detectron2.config import CfgNode

from page_xml.xml_regions import XMLRegions
from page_xml.xmlPAGE import PageData
from utils.image_utils import save_image_array_to_path
from utils.logging_utils import get_logger_name
from utils.vector_utils import (
point_at_start_or_end_assignment,
Expand Down Expand Up @@ -394,6 +394,24 @@ def build_baseline_sem_seg(self, page: PageData, out_size: tuple[int, int], line
self.logger.warning(f"File {page.filepath} does not contains baseline sem_seg")
return sem_seg

# CLASS BASELINES

def build_class_baseline_sem_seg(self, page: PageData, out_size: tuple[int, int], line_width: int, elements, class_dict):
size = page.get_size()
sem_seg = np.zeros(out_size, np.uint8)
total_overlap = False
for element in elements:
for element_class, baseline_coords in page.iter_class_baseline_coords(element, class_dict):
coords = self._scale_coords(baseline_coords, out_size, size)
sem_seg, overlap = self.draw_line(sem_seg, coords, element_class, thickness=line_width)
total_overlap = total_overlap or overlap

if total_overlap:
self.logger.warning(f"File {page.filepath} contains overlapping class baseline sem_seg")
if not sem_seg.any():
self.logger.warning(f"File {page.filepath} does not contains class baseline sem_seg")
return sem_seg

# TOP BOTTOM

def build_top_bottom_sem_seg(self, page: PageData, out_size: tuple[int, int], line_width: int):
Expand Down Expand Up @@ -554,6 +572,15 @@ def to_sem_seg(
line_width=self.xml_regions.line_width,
)
return sem_seg
elif self.xml_regions.mode == "class_baseline":
sem_seg = self.build_class_baseline_sem_seg(
gt_data,
image_shape,
line_width=self.xml_regions.line_width,
elements=set(self.xml_regions.region_types.values()),
class_dict=self.xml_regions.region_classes,
)
return sem_seg
elif self.xml_regions.mode == "top_bottom":
sem_seg = self.build_top_bottom_sem_seg(
gt_data,
Expand Down Expand Up @@ -714,7 +741,16 @@ def to_pano(
merge_regions=args.merge_regions,
region_type=args.region_type,
)
XMLConverter(xml_regions, args.square_lines)
xml_converter = XMLConverter(xml_regions, args.square_lines)

input_path = Path(args.input)
output_path = Path(args.output)

sem_seg = xml_converter.to_sem_seg(
input_path,
original_image_shape=None,
image_shape=None,
)

# save image
save_image_array_to_path(output_path, sem_seg)
7 changes: 3 additions & 4 deletions page_xml/xml_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
region_types (Optional[list[str]], optional): type of region for each region. Defaults to None.
"""
self.mode = mode
if self.mode == "region":
if mode == "region" or mode == "class_baseline":
assert regions is not None

self._regions = []
Expand All @@ -46,7 +46,7 @@ def __init__(
self.regions = regions
self.region_types = region_type
self.merged_regions = merge_regions
else:
if mode != "region":
assert line_width is not None

self._regions = self._build_regions()
Expand Down Expand Up @@ -105,7 +105,6 @@ def get_parser(cls) -> argparse.ArgumentParser:
"-m",
"--mode",
default="region",
choices=["baseline", "region", "start", "end", "separator", "baseline_separator"],
type=str,
help="Output mode",
)
Expand Down Expand Up @@ -227,7 +226,7 @@ def _build_regions(self) -> list[str]:
list[str]: the names of all the classes currently used
"""
remaining_regions = ["background"]
if self.mode == "region":
if self.mode == "region" or self.mode == "class_baseline":
removed_regions = set()
if self.merged_regions is not None:
for values in self.merged_regions.values():
Expand Down

0 comments on commit 8b624a7

Please sign in to comment.