Skip to content

Commit

Permalink
Working pageXML creator now with baselines
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut committed Feb 15, 2024
1 parent 9f336dd commit 0873534
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 275 deletions.
8 changes: 6 additions & 2 deletions page_xml/baseline_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def round_up(array: np.ndarray):
return np.floor(array + 0.5)


def baseline_converter(image: np.ndarray, minimum_height: int = 3, minimum_width: int = 15):
def baseline_converter(
image: np.ndarray,
minimum_width: int = 15,
minimum_height: int = 3,
):
output = cv2.connectedComponentsWithStats(image, connectivity=8)
num_labels = output[0]
labels = output[1]
Expand All @@ -31,10 +35,10 @@ def baseline_converter(image: np.ndarray, minimum_height: int = 3, minimum_width
if len(baseline) < 2:
continue
baseline = cv2.approxPolyDP(np.array(baseline, dtype=np.float32), 1, False).reshape(-1, 2)
baseline = round_up(baseline).astype(int)

if np.max(baseline[:, 0]) - np.min(baseline[:, 0]) < minimum_width:
continue
baselines.append(baseline)

return baselines

Expand Down
202 changes: 91 additions & 111 deletions page_xml/output_pageXML.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
sys.path.append(str(Path(__file__).resolve().parent.joinpath("..")))
from core.setup import get_git_hash
from datasets.dataset import classes_to_colors
from page_xml.pageXML_creator import PageXMLCreator
from page_xml.baseline_extractor import baseline_converter
from page_xml.pageXML_creator import Baseline, PageXMLCreator, Region, TextLine
from page_xml.xml_regions import XMLRegions
from utils.copy_utils import copy_mode
from utils.image_utils import save_image_array_to_path
Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
whitelist: Optional[Iterable[str]] = None,
rectangle_regions: Optional[Iterable[str]] = None,
min_region_size: int = 10,
external_processing: bool = False,
external_processing: bool = True,
grayscale: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -165,18 +166,72 @@ def generate_image_from_sem_seg(self, sem_seg: torch.Tensor, old_height: int, ol

return image

def add_baselines_to_page(self, page: PageXMLCreator, sem_seg: torch.Tensor, image_path: Path, old_height, old_width):
pass
def add_baselines_to_page(
self,
pageXML_creator: PageXMLCreator,
sem_seg: torch.Tensor,
old_height: int,
old_width: int,
upscale: bool = False,
) -> PageXMLCreator:
page = pageXML_creator.pageXML.find("Page")
if page is None:
raise ValueError("Page not found in pageXML")
# image_name = page.attrib["imageFilename"]
page.append(
Region.with_tag(
"TextRegion",
np.asarray([[0, 0], [0, old_height], [old_width, old_height], [old_width, 0]]),
id=f"textregion_{uuid.uuid4()}",
)
)

height, width = sem_seg.shape[-2:]

def add_regions_to_page(self, page: PageXMLCreator, sem_seg: torch.Tensor, old_height, old_width) -> PageXMLCreator:
scaling = np.asarray([old_width, old_height] / np.asarray([width, height]))
if upscale:
sem_seg = torch.nn.functional.interpolate(
sem_seg[None], size=(old_height, old_width), mode="bilinear", align_corners=False
)[0]
scaling = np.asarray([1, 1])

sem_seg_image = torch.argmax(sem_seg, dim=-3).cpu().numpy().astype(np.uint8)

coords_baselines = baseline_converter(sem_seg_image, minimum_width=15 / scaling[0], minimum_height=3 / scaling[1])

text_region = page.find("TextRegion")
if text_region is None:
raise ValueError("TextRegion not found in pageXML")
for coords_baseline in coords_baselines:
coords_baseline = (coords_baseline * scaling).astype(np.float32)
bbox = cv2.boundingRect(coords_baseline)
coords_text_line = np.array(
[
[bbox[0], bbox[1]],
[bbox[0], bbox[1] + bbox[3]],
[bbox[0] + bbox[2], bbox[1] + bbox[3]],
[bbox[0] + bbox[2], bbox[1]],
]
)
text_line = TextLine(coords_text_line, id=f"textline_{uuid.uuid4()}")
baseline = Baseline(coords_baseline)
text_line.append(baseline)
text_region.append(text_line)

return pageXML_creator

def add_regions_to_page(
self,
pageXML_creator: PageXMLCreator,
sem_seg: torch.Tensor,
old_height: int,
old_width: int,
) -> PageXMLCreator:
if self.output_dir is None:
raise TypeError("Output dir is None")
if self.page_dir is None:
raise TypeError("Page dir is None")

if old_height is None or old_width is None:
old_height, old_width = sem_seg.shape[-2:]

height, width = sem_seg.shape[-2:]

scaling = np.asarray([old_width, old_height] / np.asarray([width, height]))
Expand All @@ -185,6 +240,10 @@ def add_regions_to_page(self, page: PageXMLCreator, sem_seg: torch.Tensor, old_h

region_id = 0

page = pageXML_creator.pageXML.find("Page")
if page is None:
raise ValueError("Page not found in pageXML")

for class_id, region in enumerate(self.xml_regions.regions):
# Skip background
if class_id == 0:
Expand Down Expand Up @@ -225,24 +284,16 @@ def add_regions_to_page(self, page: PageXMLCreator, sem_seg: torch.Tensor, old_h
region_coords = region_coords.strip()

_uuid = uuid.uuid4()
text_reg = page.add_region(region_type, f"region_{_uuid}_{region_id}", region, region_coords)

return page

def process_tensor(self, page: PageXMLCreator, sem_seg: torch.Tensor, image_path: Path, old_height, old_width):
if self.output_dir is None:
raise TypeError("Output dir is None")
if self.page_dir is None:
raise TypeError("Page dir is None")

if self.external_processing:
sem_seg_output_path = self.page_dir.joinpath(image_path.stem + ".png")
colored_image = self.generate_image_from_sem_seg(sem_seg, old_height, old_width)
with AtomicFileName(file_path=sem_seg_output_path) as path:
save_image_array_to_path(str(path), colored_image.astype(np.uint8))
page.append(
Region.with_tag(
region_type,
poly,
region,
id=f"region_{_uuid}_{region_id}",
)
)

if self.xml_regions.mode == "region":
pass
return pageXML_creator

def generate_single_page(
self,
Expand All @@ -251,108 +302,37 @@ def generate_single_page(
old_height: Optional[int] = None,
old_width: Optional[int] = None,
):
"""
Convert a single prediction into a page
Args:
sem_seg (torch.Tensor): sem_seg as tensor
image_path (Path): Image path, used for path name
old_height (Optional[int], optional): height of the original image. Defaults to None.
old_width (Optional[int], optional): width of the original image. Defaults to None.
Raises:
TypeError: Output dir has not been set
TypeError: Page dir has not been set
NotImplementedError: mode is not known
"""
if self.output_dir is None:
raise TypeError("Output dir is None")
if self.page_dir is None:
raise TypeError("Page dir is None")

xml_output_path = self.page_dir.joinpath(image_path.stem + ".xml")

if old_height is None or old_width is None:
old_height, old_width = sem_seg.shape[-2:]

height, width = sem_seg.shape[-2:]

scaling = np.asarray([old_width, old_height] / np.asarray([width, height]))

page = PageXMLCreator(xml_output_path)
page.new_page(image_path.name, str(old_height), str(old_width))
pageXML_creator = PageXMLCreator()
pageXML_creator.add_page(image_path.name, old_height, old_width)

if self.cfg is not None:
page.add_processing_step(get_git_hash(), self.cfg.LAYPA_UUID, self.cfg, self.whitelist)

if self.xml_regions.mode == "region":
sem_seg = torch.argmax(sem_seg, dim=-3).cpu().numpy()

region_id = 0

for region in self.xml_regions.regions:
if region == "background":
continue
binary_region_mask = np.zeros_like(sem_seg).astype(np.uint8)
binary_region_mask[sem_seg == self.xml_regions.regions_to_classes[region]] = 1

region_type = self.xml_regions.region_types[region]

contours, hierarchy = cv2.findContours(binary_region_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

for cnt in contours:
# remove small objects
if cnt.shape[0] < 4:
continue
if cv2.contourArea(cnt) < self.min_region_size:
continue

region_id += 1

region_coords = ""
if region in self.rectangle_regions:
# find bounding box
rect = cv2.minAreaRect(cnt)
poly = cv2.boxPoints(rect) * scaling
else:
# soft a bit the region to prevent spikes
epsilon = 0.0005 * cv2.arcLength(cnt, True)
approx_poly = cv2.approxPolyDP(cnt, epsilon, True)
pageXML_creator.add_processing_step(get_git_hash(), self.cfg.LAYPA_UUID, self.cfg, self.whitelist)

approx_poly = np.round((approx_poly * scaling)).astype(np.int32)

poly = approx_poly.reshape(-1, 2)

for coords in poly:
region_coords = region_coords + f" {round(coords[0])},{round(coords[1])}"

region_coords = region_coords.strip()

_uuid = uuid.uuid4()
text_reg = page.add_region(region_type, f"region_{_uuid}_{region_id}", region, region_coords)
elif self.xml_regions.mode in ["baseline", "start", "end", "separator"]:
# Push the calculation to outside of the python code <- mask is used by minion
sem_seg_output_path = self.page_dir.joinpath(image_path.stem + ".png")
sem_seg = torch.nn.functional.interpolate(
sem_seg[None], size=(old_height, old_width), mode="bilinear", align_corners=False
)[0]
sem_seg_image = torch.argmax(sem_seg, dim=-3).cpu().numpy()
with AtomicFileName(file_path=sem_seg_output_path) as path:
save_image_array_to_path(str(path), (sem_seg_image * 255).astype(np.uint8))
elif self.xml_regions.mode in ["baseline_separator", "top_bottom"]:
if self.external_processing:
sem_seg_output_path = self.page_dir.joinpath(image_path.stem + ".png")
sem_seg = torch.nn.functional.interpolate(
sem_seg[None], size=(old_height, old_width), mode="bilinear", align_corners=False
)[0]
sem_seg_image = torch.argmax(sem_seg, dim=-3).cpu().numpy()
colored_image = self.generate_image_from_sem_seg(sem_seg, old_height, old_width)
with AtomicFileName(file_path=sem_seg_output_path) as path:
save_image_array_to_path(str(path), (sem_seg_image * 128).clip(0, 255).astype(np.uint8))
save_image_array_to_path(str(path), colored_image.astype(np.uint8))
pageXML_creator.pageXML.save_xml(self.page_dir.joinpath(image_path.stem + ".xml"))
return

if self.xml_regions.mode == "region":
pageXML_creator = self.add_regions_to_page(pageXML_creator, sem_seg, old_height, old_width)
pageXML_creator.pageXML.save_xml(self.page_dir.joinpath(image_path.stem + ".xml"))
elif self.xml_regions.mode == "baseline":
pageXML_creator = self.add_baselines_to_page(pageXML_creator, sem_seg, old_height, old_width)
pageXML_creator.pageXML.save_xml(self.page_dir.joinpath(image_path.stem + ".xml"))
else:
raise NotImplementedError(f"Mode {self.xml_regions.mode} not implemented")

# TODO Overwrite when multiple image have the same name but different extension
page.save_xml()

def generate_single_page_wrapper(self, info):
"""
Convert a single prediction into a page
Expand Down
Loading

0 comments on commit 0873534

Please sign in to comment.