Skip to content

Commit

Permalink
Merge pull request #20 from stefanklut/square-lines
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut authored Feb 1, 2024
2 parents 7f55d8c + 4e17ed7 commit 22a9bfd
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 36 deletions.
1 change: 1 addition & 0 deletions configs/extra_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

_C.PREPROCESS.BASELINE = CN()
_C.PREPROCESS.BASELINE.LINE_WIDTH = 5
_C.PREPROCESS.BASELINE.SQUARE_LINES = False

_C.PREPROCESS.RESIZE = CN()
_C.PREPROCESS.RESIZE.USE = False
Expand Down
2 changes: 1 addition & 1 deletion core/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def preprocess_datasets(
merge_regions=cfg.PREPROCESS.REGION.MERGE_REGIONS,
region_type=cfg.PREPROCESS.REGION.REGION_TYPE,
)
xml_converter = XMLConverter(xml_regions)
xml_converter = XMLConverter(xml_regions, cfg.PREPROCESS.BASELINE.SQUARE_LINES)

assert (n_regions := len(xml_converter.xml_regions.regions)) == (
n_classes := cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
Expand Down
8 changes: 6 additions & 2 deletions datasets/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@

def get_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
parents=[Preprocess.get_parser(), XMLConverter.get_parser()],
parents=[Preprocess.get_parser(), XMLRegions.get_parser()],
description="Preprocessing an annotated dataset of documents with pageXML",
)

io_args = parser.add_argument_group("IO")
io_args.add_argument("-i", "--input", help="Input folder/file", nargs="+", action="extend", type=str)
io_args.add_argument("-o", "--output", help="Output folder", required=True, type=str)

xml_converter_args = parser.add_argument_group("XML Converter")
xml_converter_args.add_argument("--square-lines", help="Square the lines", action="store_true")

args = parser.parse_args()
return args

Expand Down Expand Up @@ -530,7 +534,7 @@ def main(args) -> None:
merge_regions=args.merge_regions,
region_type=args.region_type,
)
xml_converter = XMLConverter(xml_regions)
xml_converter = XMLConverter(xml_regions, args.square_lines)
process = Preprocess(
input_paths=args.input,
output_dir=args.output,
Expand Down
2 changes: 1 addition & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def main(args) -> None:
merge_regions=cfg.PREPROCESS.REGION.MERGE_REGIONS,
region_type=cfg.PREPROCESS.REGION.REGION_TYPE,
)
xml_converter = XMLConverter(xml_regions)
xml_converter = XMLConverter(xml_regions, cfg.PREPROCESS.BASELINE.SQUARE_LINES)
metadata = metadata_from_classes(xml_regions.regions)

image_paths = get_file_paths(args.input, supported_image_formats, cfg.PREPROCESS.DISABLE_CHECK)
Expand Down
96 changes: 73 additions & 23 deletions page_xml/xml_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import numpy as np
from detectron2 import structures

from utils.vector_utils import point_top_bottom_assignment

sys.path.append(str(Path(__file__).resolve().parent.joinpath("..")))
from page_xml.xml_regions import XMLRegions
from page_xml.xmlPAGE import PageData
from utils.logging_utils import get_logger_name
from utils.vector_utils import (
point_at_start_or_end_assignment,
point_top_bottom_assignment,
)


def get_arguments() -> argparse.Namespace:
Expand All @@ -22,6 +24,9 @@ def get_arguments() -> argparse.Namespace:
io_args.add_argument("-i", "--input", help="Input file", required=True, type=str)
io_args.add_argument("-o", "--output", help="Output file", required=True, type=str)

xml_converter_args = parser.add_argument_group("XML Converter")
xml_converter_args.add_argument("--square-lines", help="Square the lines", action="store_true")

args = parser.parse_args()
return args

Expand Down Expand Up @@ -58,6 +63,7 @@ class XMLConverter:
def __init__(
self,
xml_regions: XMLRegions,
square_lines: bool = True,
) -> None:
"""
Class for turning a pageXML into an image with classes
Expand All @@ -67,6 +73,7 @@ def __init__(
"""
self.logger = logging.getLogger(get_logger_name())
self.xml_regions = xml_regions
self.square_lines = square_lines

@staticmethod
def _scale_coords(coords: np.ndarray, out_size: tuple[int, int], size: tuple[int, int]) -> np.ndarray:
Expand All @@ -83,7 +90,7 @@ def _bounding_box(array: np.ndarray) -> list[float]:

# Taken from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
@staticmethod
def id2rgb(id_map: int | np.ndarray) -> tuple | np.ndarray:
def id2rgb(id_map: int | np.ndarray) -> tuple[int, int, int] | np.ndarray:
if isinstance(id_map, np.ndarray):
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
Expand All @@ -97,6 +104,43 @@ def id2rgb(id_map: int | np.ndarray) -> tuple | np.ndarray:
id_map //= 256
return tuple(color)

def draw_line(
self,
image: np.ndarray,
coords: np.ndarray,
color: int | tuple[int, int, int],
thickness: int = 1,
) -> tuple[np.ndarray, bool]:
"""
Draw lines on an image
Args:
image (np.ndarray): image to draw on
lines (np.ndarray): lines to draw
color (tuple[int, int, int]): color of the lines
thickness (int, optional): thickness of the lines. Defaults to 1.
"""
temp_image = np.zeros_like(image)

rounded_coords = np.round(coords).astype(np.int32)

# Clear the temp image
temp_image.fill(0)

if self.square_lines:
cv2.polylines(temp_image, [rounded_coords.reshape(-1, 1, 2)], False, 1, thickness)
line_pixel_coords = np.column_stack(np.where(temp_image == 1))[:, ::-1]
start_or_end = point_at_start_or_end_assignment(rounded_coords, line_pixel_coords)
colored_start_or_end = np.where(start_or_end, 0, color)
temp_image[line_pixel_coords[:, 1], line_pixel_coords[:, 0]] = colored_start_or_end
else:
cv2.polylines(temp_image, [rounded_coords.reshape(-1, 1, 2)], False, color, thickness)

overlap = np.logical_and(temp_image, image).any().item()
image = np.where(temp_image == 0, image, temp_image)

return image, overlap

## REGIONS

def build_region_instances(self, page: PageData, out_size: tuple[int, int], elements, class_dict) -> list[Instance]:
Expand Down Expand Up @@ -245,10 +289,9 @@ def build_baseline_instances(self, page: PageData, out_size: tuple[int, int], li
instances = []
for baseline_coords in page.iter_baseline_coords():
coords = self._scale_coords(baseline_coords, out_size, size)
rounded_coords = np.round(coords).astype(np.int32)
mask.fill(0)
# HACK Currenty the most simple quickest solution used can probably be optimized
cv2.polylines(mask, [rounded_coords.reshape(-1, 1, 2)], False, 255, line_width)
mask, _ = self.draw_line(mask, coords, 255, thickness=line_width)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) == 0:
raise ValueError(f"{page.filepath} has no contours")
Expand Down Expand Up @@ -284,18 +327,22 @@ def build_baseline_pano(self, page: PageData, out_size: tuple[int, int], line_wi
pano_mask = np.zeros((*out_size, 3), np.uint8)
segments_info = []
_id = 1
total_overlap = False
for baseline_coords in page.iter_baseline_coords():
coords = self._scale_coords(baseline_coords, out_size, size)
rounded_coords = np.round(coords).astype(np.int32)
rgb_color = self.id2rgb(_id)
cv2.polylines(pano_mask, [rounded_coords.reshape(-1, 1, 2)], False, rgb_color, line_width)
pano_mask, overlap = self.draw_line(pano_mask, coords, rgb_color, thickness=line_width)
total_overlap = total_overlap or overlap
segment: SegmentsInfo = {
"id": _id,
"category_id": baseline_class,
"iscrowd": False,
}
segments_info.append(segment)
_id += 1

if total_overlap:
self.logger.warning(f"File {page.filepath} contains overlapping baseline pano")
if not pano_mask.any():
self.logger.warning(f"File {page.filepath} does not contains baseline pano")
return pano_mask, segments_info
Expand All @@ -307,23 +354,17 @@ def build_baseline_sem_seg(self, page: PageData, out_size: tuple[int, int], line
baseline_color = 1
size = page.get_size()
sem_seg = np.zeros(out_size, np.uint8)
binary_mask = np.zeros(out_size, dtype=np.uint8)
overlap = False
total_overlap = False
for baseline_coords in page.iter_baseline_coords():
coords = self._scale_coords(baseline_coords, out_size, size)
rounded_coords = np.round(coords).astype(np.int32)
binary_mask.fill(0)
cv2.polylines(binary_mask, [rounded_coords.reshape(-1, 1, 2)], False, baseline_color, line_width)
sem_seg, overlap = self.draw_line(sem_seg, coords, baseline_color, thickness=line_width)
total_overlap = total_overlap or overlap

overlap = np.logical_or(overlap, np.any(np.logical_and(sem_seg, binary_mask)))
# Add single line to full sem_seg
sem_seg = np.logical_or(sem_seg, binary_mask)

if overlap:
if total_overlap:
self.logger.warning(f"File {page.filepath} contains overlapping baseline sem_seg")
if not sem_seg.any():
self.logger.warning(f"File {page.filepath} does not contains baseline sem_seg")
return sem_seg.astype(np.uint8)
return sem_seg

# TOP BOTTOM

Expand All @@ -337,17 +378,21 @@ def build_top_bottom_sem_seg(self, page: PageData, out_size: tuple[int, int], li
size = page.get_size()
sem_seg = np.zeros(out_size, np.uint8)
binary_mask = np.zeros(out_size, dtype=np.uint8)
total_overlap = False
for baseline_coords in page.iter_baseline_coords():
coords = self._scale_coords(baseline_coords, out_size, size)
rounded_coords = np.round(coords).astype(np.int32)
binary_mask.fill(0)
cv2.polylines(binary_mask, [rounded_coords.reshape(-1, 1, 2)], False, baseline_color, line_width)
binary_mask, overlap = self.draw_line(binary_mask, coords, baseline_color, thickness=line_width)
total_overlap = total_overlap or overlap

# Add single line to full sem_seg
line_pixel_coords = np.column_stack(np.where(binary_mask == 1))[:, ::-1]
rounded_coords = np.round(coords).astype(np.int32)
top_bottom = point_top_bottom_assignment(rounded_coords, line_pixel_coords)
colored_top_bottom = np.where(top_bottom, top_color, bottom_color)
sem_seg[line_pixel_coords[:, 1], line_pixel_coords[:, 0]] = colored_top_bottom

if total_overlap:
self.logger.warning(f"File {page.filepath} contains overlapping top bottom sem_seg")
if not sem_seg.any():
self.logger.warning(f"File {page.filepath} does not contains top bottom sem_seg")
return sem_seg
Expand Down Expand Up @@ -417,15 +462,20 @@ def build_baseline_separator_sem_seg(self, page: PageData, out_size: tuple[int,

size = page.get_size()
sem_seg = np.zeros(out_size, np.uint8)
total_overlap = False
for baseline_coords in page.iter_baseline_coords():
coords = self._scale_coords(baseline_coords, out_size, size)
rounded_coords = np.round(coords).astype(np.int32)
cv2.polylines(sem_seg, [coords.reshape(-1, 1, 2)], False, baseline_color, line_width)
sem_seg, overlap = self.draw_line(sem_seg, rounded_coords, baseline_color, thickness=line_width)
total_overlap = total_overlap or overlap

coords_start = rounded_coords[0]
cv2.circle(sem_seg, coords_start, line_width, separator_color, -1)
coords_end = rounded_coords[-1]
cv2.circle(sem_seg, coords_end, line_width, separator_color, -1)

if total_overlap:
self.logger.warning(f"File {page.filepath} contains overlapping baseline separator sem_seg")
if not sem_seg.any():
self.logger.warning(f"File {page.filepath} does not contains baseline separator sem_seg")
return sem_seg
Expand Down Expand Up @@ -634,7 +684,7 @@ def to_pano(
merge_regions=args.merge_regions,
region_type=args.region_type,
)
XMLConverter(xml_regions)
XMLConverter(xml_regions, args.square_lines)

input_path = Path(args.input)
output_path = Path(args.output)
2 changes: 1 addition & 1 deletion page_xml/xml_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_parser(cls) -> argparse.ArgumentParser:
republic_merge_regions = ["resolution:Resumption,resumption,Insertion,insertion"]
parser = argparse.ArgumentParser(add_help=False)

region_args = parser.add_argument_group("regions")
region_args = parser.add_argument_group("Regions")

region_args.add_argument(
"-m",
Expand Down
2 changes: 2 additions & 0 deletions utils/regions_from_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def main(args):
image_paths = get_file_paths(args.input, supported_image_formats)
xml_paths = [image_path_to_xml_path(image_path) for image_path in image_paths]

# xml_paths = get_file_paths(args.input, [".xml"])

# Single thread
# regions_per_page = []
# for xml_path_i in tqdm(xml_paths):
Expand Down
Loading

0 comments on commit 22a9bfd

Please sign in to comment.