diff --git a/pagexml/analysis/layout_stats.py b/pagexml/analysis/layout_stats.py index 7046c00..2c12980 100644 --- a/pagexml/analysis/layout_stats.py +++ b/pagexml/analysis/layout_stats.py @@ -95,6 +95,17 @@ def interpolate_baseline_points(points: List[Tuple[int, int]], return interpolated_baseline_points +def compute_points_distances(points1: List[Tuple[int, int]], points2: List[Tuple[int, int]], + step: int = 50): + if points1 is None or points2 is None: + return np.array([]) + b1_points = interpolate_baseline_points(points1, step=step) + b2_points = interpolate_baseline_points(points2, step=step) + distances = np.array([abs(b2_points[curr_x] - b1_points[curr_x]) for curr_x in b1_points + if curr_x in b2_points]) + return distances + + def compute_baseline_distances(line1: Union[pdm.PageXMLTextLine, List[pdm.PageXMLTextLine]], line2: Union[pdm.PageXMLTextLine, List[pdm.PageXMLTextLine]], step: int = 50) -> np.ndarray: @@ -124,12 +135,7 @@ def compute_baseline_distances(line1: Union[pdm.PageXMLTextLine, List[pdm.PageXM points2 = line2.baseline.points if line2.baseline.points is not None else [] else: points2 = [point for line in line2 for point in line.baseline.points if line.baseline.points is not None] - if points1 is None or points2 is None: - return np.array([]) - b1_points = interpolate_baseline_points(points1, step=step) - b2_points = interpolate_baseline_points(points2, step=step) - distances = np.array([abs(b2_points[curr_x] - b1_points[curr_x]) for curr_x in b1_points - if curr_x in b2_points]) + distances = compute_points_distances(points1, points2, step=step) if len(distances) == 0: avg1 = average_baseline_height(line1) avg2 = average_baseline_height(line2) @@ -137,6 +143,21 @@ def compute_baseline_distances(line1: Union[pdm.PageXMLTextLine, List[pdm.PageXM return distances +def get_bottom_points(line: pdm.PageXMLTextLine) -> List[Tuple[int, int]]: + right_most = [p for p in line.coords.points if p[0] == line.coords.right][0] + right_most_index = line.coords.points.index(right_most) + return line.coords.points[right_most_index:] + + +def compute_bounding_box_distances(line1: Union[pdm.PageXMLTextLine, List[pdm.PageXMLTextLine]], + line2: Union[pdm.PageXMLTextLine, List[pdm.PageXMLTextLine]], + step: int = 50): + points1 = get_bottom_points(line1) + points2 = get_bottom_points(line2) + distances = compute_points_distances(points1, points2, step=step) + return distances + + def average_baseline_height(line: Union[pdm.PageXMLTextLine, List[pdm.PageXMLTextLine]]) -> int: """Compute the average (mean) baseline height for comparing lines that are not horizontally aligned. @@ -179,8 +200,9 @@ def average_baseline_height(line: Union[pdm.PageXMLTextLine, List[pdm.PageXMLTex return int(total_avg) -def sort_coords_above_below_baseline(line: pdm.PageXMLTextLine, debug: int = 0) -> Tuple[List[Tuple[int, int]], - List[Tuple[int, int]]]: +def sort_coords_above_below_baseline(line: pdm.PageXMLTextLine, + debug: int = 0) -> Tuple[List[Tuple[int, int]], + List[Tuple[int, int]]]: """Split the list of bounding polygon coordinates of a line in sets of points above and below the baseline. When a line has no baseline or no bounding polygon, empty lists are returned @@ -195,10 +217,17 @@ def sort_coords_above_below_baseline(line: pdm.PageXMLTextLine, debug: int = 0) ci_c = 0 below_baseline = [] above_baseline = [] - if line.baseline is None: + if line.baseline is None or line.coords is None: + return above_baseline, below_baseline + if not line.baseline or not line.coords: + return above_baseline, below_baseline + if line.coords.right < line.baseline.left: + return above_baseline, below_baseline + if line.coords.left > line.baseline.right: return above_baseline, below_baseline interpolated_baseline_points = [i for i in interpolate_baseline_points(line.baseline.points, step=50).items()] if debug > 2: + print('baseline_points:', line.baseline.points) print('interpolated_baseline_points:', interpolated_baseline_points) sorted_coord_points = sorted(line.coords.points, key=lambda p: p[0]) if debug > 0: @@ -206,7 +235,7 @@ def sort_coords_above_below_baseline(line: pdm.PageXMLTextLine, debug: int = 0) print('len(sorted_coord_points):', len(sorted_coord_points)) if debug > 1: print('ci_c:', ci_c) - num_baseline_points = len(line.baseline.points) + num_baseline_points = len(interpolated_baseline_points) num_coord_points = len(sorted_coord_points) for ci_b, curr_b in enumerate(interpolated_baseline_points): curr_bx, curr_by = curr_b @@ -223,24 +252,28 @@ def sort_coords_above_below_baseline(line: pdm.PageXMLTextLine, debug: int = 0) if debug > 0: print(f'sort_above_below - curr_c ({ci_c}): {curr_c}') ci_c += 1 - if curr_cy > curr_by: - if debug > 0: - print(f'sort_above_below - below') - below_baseline.append(curr_c) - elif curr_cy < curr_by: + if curr_cy < curr_by: if debug > 0: print(f'sort_above_below - above') above_baseline.append(curr_c) else: if debug > 0: - print(f'sort_above_below - neither') - pass + print(f'sort_above_below - below') + below_baseline.append(curr_c) return above_baseline, below_baseline -def get_text_heights(line: pdm.PageXMLTextLine, step: int = 50) -> np.array: - above_baseline, below_baseline = sort_coords_above_below_baseline(line) +def get_text_heights(line: pdm.PageXMLTextLine, step: int = 50, + ignore_errors: bool = True, debug: int = 0) -> np.array: + above_baseline, below_baseline = sort_coords_above_below_baseline(line, debug=debug) + if len(above_baseline) == 0: + if ignore_errors is False: + ValueError(f'line {line.id} has no bounding coordinates above baseline') + return None + if len(below_baseline) == 0: + if ignore_errors is False: + ValueError(f'Warning: line {line.id} has no bounding coordinates below baseline') int_base = interpolate_baseline_points(line.baseline.points, step=step) int_above = interpolate_baseline_points(above_baseline, step=step) @@ -249,10 +282,13 @@ def get_text_heights(line: pdm.PageXMLTextLine, step: int = 50) -> np.array: if x in int_above: height[x] = int_base[x] - int_above[x] + if len(height) == 0: + print() + return None return np.array(list(height.values())) -def get_height_stats(line_heights: np.array) -> Dict[str, int]: +def compute_height_stats(line_heights: np.array) -> Dict[str, int]: return { 'max': line_heights.max(), 'min': line_heights.min(), @@ -261,6 +297,24 @@ def get_height_stats(line_heights: np.array) -> Dict[str, int]: } +def get_line_height_stats(line: pdm.PageXMLTextLine, step: int = 50, + ignore_errors: bool = False, debug: int = 0) -> Union[Dict[str, int], None]: + try: + line_heights = get_text_heights(line, step=step, ignore_errors=ignore_errors, debug=debug) + if debug > 0: + print('get_line_height_stats - line_heights:', line_heights) + if line_heights is None: + return None + return compute_height_stats(line_heights) + except IndexError: + print('ERROR INFO:') + print('get_line_height_stats - line.baseline:', line.baseline) + print('get_line_height_stats - line.coords:', line.coords) + raise + except AttributeError: + return None + + def get_line_distances(lines: List[pdm.PageXMLTextLine]) -> List[np.ndarray]: all_distances = [] for li, curr_line in enumerate(lines): @@ -268,7 +322,10 @@ def get_line_distances(lines: List[pdm.PageXMLTextLine]) -> List[np.ndarray]: if li + 1 < len(lines): next_line = lines[li + 1] if next_line: - distances = compute_baseline_distances(curr_line, next_line) + if curr_line.baseline and next_line.baseline: + distances = compute_baseline_distances(curr_line, next_line) + else: + distances = compute_bounding_box_distances(curr_line, next_line) all_distances.append(distances) return all_distances diff --git a/pagexml/analysis/text_stats.py b/pagexml/analysis/text_stats.py index 07b8b44..f3b08f1 100644 --- a/pagexml/analysis/text_stats.py +++ b/pagexml/analysis/text_stats.py @@ -548,7 +548,6 @@ def _set_merged_with(self, lines: Iterable[Union[str, Dict[str, str]]], min_common_freq: int = 1000) -> None: prev_words = [] typical_start_words, typical_end_words = get_typical_start_end_words(self) - li = 0 for li, line in enumerate(lines): if line["text"] is None: continue diff --git a/pagexml/column_parser.py b/pagexml/column_parser.py index cca1dc4..63ea874 100644 --- a/pagexml/column_parser.py +++ b/pagexml/column_parser.py @@ -187,14 +187,16 @@ def make_derived_column(lines: List[pdm.PageXMLTextLine], metadata: dict, page_i def merge_columns(columns: List[pdm.PageXMLColumn], doc_id: str, metadata: dict) -> pdm.PageXMLColumn: - """Merge two columns into one, sorting lines by baseline height.""" - merged_lines = [line for col in columns for line in col.get_lines()] - merged_lines = list(set(merged_lines)) - sorted_lines = sorted(merged_lines, key=lambda x: x.baseline.y) - merged_coords = pdm.parse_derived_coords(sorted_lines) + """Merge a list of columns into one. First, all text regions of all columns are + checked for spatial overlap, whereby overlapping text regions are merged. + Within the merged text regions, lines are sorted by baseline height.""" + trs = [tr for col in columns for tr in col.text_regions] + merged_tr = pagexml_helper.merge_textregions(trs, metadata) + merged_coords = copy.deepcopy(merged_tr.coords) merged_col = pdm.PageXMLColumn(doc_id=doc_id, doc_type='index_column', metadata=metadata, coords=merged_coords, - lines=merged_lines) + text_regions=[merged_tr]) + merged_col.set_as_parent([merged_tr]) return merged_col diff --git a/pagexml/helper/pagexml_helper.py b/pagexml/helper/pagexml_helper.py index ad3967a..1bbeee7 100644 --- a/pagexml/helper/pagexml_helper.py +++ b/pagexml/helper/pagexml_helper.py @@ -24,6 +24,8 @@ def elements_overlap(element1: pdm.PageXMLDoc, element2: pdm.PageXMLDoc, if v_overlap / element2.coords.height > threshold: if h_overlap / element2.coords.width > threshold: return True + else: + return False else: return False @@ -81,6 +83,43 @@ def horizontal_group_lines(lines: List[pdm.PageXMLTextLine]) -> List[List[pdm.Pa return horizontally_grouped_lines +def merge_sets(sets: List[Set[any]], min_overlap: int = 1) -> List[Set[any]]: + merged_sets = [] + + while len(sets) > 0: + current_set = sets.pop(0) + merged_set = set(current_set) + + i = 0 + while i < len(sets): + if len(merged_set.intersection(sets[i])) >= min_overlap: + merged_set.update(sets[i]) + sets.pop(i) + else: + i += 1 + + merged_sets.append(merged_set) + + return merged_sets + + +def merge_textregions(text_regions: List[pdm.PageXMLTextRegion], + metadata: dict = None, doc_id: str = None) -> Union[pdm.PageXMLTextRegion, None]: + """Merge two text_regions into one, sorting lines by baseline height.""" + if len(text_regions) == 0: + return None + merged_lines = [line for tr in text_regions for line in tr.get_lines()] + merged_lines = list(set(merged_lines)) + sorted_lines = sorted(merged_lines, key=lambda x: x.baseline.y) + merged_coords = pdm.parse_derived_coords(sorted_lines) + merged_tr = pdm.PageXMLTextRegion(doc_id=doc_id, doc_type='index_text_region', + metadata=metadata, coords=merged_coords, + lines=sorted_lines) + if doc_id is None: + merged_tr.set_derived_id(text_regions[0].parent.id) + return merged_tr + + def horizontally_merge_lines(lines: List[pdm.PageXMLTextLine]) -> List[pdm.PageXMLTextLine]: """Sort lines vertically and merge horizontally adjacent lines.""" horizontally_grouped_lines = horizontal_group_lines(lines) @@ -368,7 +407,8 @@ def __iter__(self): def make_line_text(line: pdm.PageXMLTextLine, do_merge: bool, - end_word: str, merge_word: str, word_break_chars: Union[str, Set[str]] = '-') -> str: + end_word: str, merge_word: str, + word_break_chars: Union[str, Set[str], List[str]] = '-') -> str: line_text = line.text if len(line_text) >= 2 and line_text[-1] in word_break_chars and line_text[-2] in word_break_chars: # remove the redundant line break char @@ -402,7 +442,7 @@ def make_line_range(text: str, line: pdm.PageXMLTextLine, line_text: str) -> Dic def make_text_region_text(lines: List[pdm.PageXMLTextLine], - word_break_chars: List[str] = '-', + word_break_chars: Union[str, Set[str], List[str]] = '-', wbd: text_stats.WordBreakDetector = None) -> Tuple[Union[str, None], List[Dict[str, any]]]: """Turn the text lines in a region into a single paragraph of text, with a list of line ranges that indicates how the text of each line corresponds to character offsets in the paragraph. @@ -428,6 +468,7 @@ def make_text_region_text(lines: List[pdm.PageXMLTextLine], prev_words = text_helper.get_line_words(prev_line.text, word_break_chars=word_break_chars) \ if prev_line.text else [] if len(lines) > 1: + remove_prefix_word_break = False for curr_line in lines[1:]: if curr_line.text is None or curr_line.text == '': do_merge = False @@ -440,10 +481,17 @@ def make_text_region_text(lines: List[pdm.PageXMLTextLine], if prev_line.text is not None: do_merge, merge_word = text_stats.determine_word_break(curr_words, prev_words, wbd=wbd, - word_break_chars=word_break_chars) + word_break_chars=word_break_chars, + debug=False) # print(do_merge, merge_word) prev_line_text = make_line_text(prev_line, do_merge, prev_words[-1], merge_word, word_break_chars=word_break_chars) + if remove_prefix_word_break and prev_line_text.startswith('„'): + prev_line_text = prev_line_text[1:] + if '„' in word_break_chars and prev_words[-1].endswith('„') and curr_line.text.startswith('„'): + remove_prefix_word_break = True + else: + remove_prefix_word_break = False # print(prev_line_text) else: prev_line_text = '' diff --git a/pagexml/helper/text_helper.py b/pagexml/helper/text_helper.py index 2f6a070..043b806 100644 --- a/pagexml/helper/text_helper.py +++ b/pagexml/helper/text_helper.py @@ -8,11 +8,14 @@ import pagexml.parser as parser -def read_lines_from_line_files(pagexml_line_files: Union[str, List[str]]) -> Generator[str, None, None]: +def read_lines_from_line_files(pagexml_line_files: Union[str, List[str]], + has_headers: bool = True) -> Generator[str, None, None]: if isinstance(pagexml_line_files, str): pagexml_line_files = [pagexml_line_files] - for line_file in pagexml_line_files: + for li, line_file in enumerate(pagexml_line_files): with gzip.open(line_file, 'rt') as fh: + if has_headers is True and li > 0: + _headers = next(fh) for line in fh: yield line @@ -108,11 +111,11 @@ def __init__(self, pagexml_files: Union[str, List[str]] = None, raise TypeError(f"MUST use one of the following optional arguments: " f"'pagexml_files', 'pagexml_docs' or 'pagexml_line_file'.") if pagexml_line_files: - self.pagexml_line_files = make_list(pagexml_line_files) + self.pagexml_line_files = sorted(make_list(pagexml_line_files)) if pagexml_files: - self.pagexml_files = make_list(pagexml_files) + self.pagexml_files = sorted(make_list(pagexml_files)) if pagexml_docs: - self.pagexml_docs = make_list(pagexml_docs) + self.pagexml_docs = sorted(make_list(pagexml_docs)) def __iter__(self) -> Generator[Dict[str, str], None, None]: if self.groupby is None: @@ -149,7 +152,7 @@ def _iter_from_pagexml_docs(self, pagexml_doc_iterator) -> Generator[Dict[str, a yield line def _iter_from_line_file(self) -> Generator[Dict[str, any], None, None]: - line_iterator = read_lines_from_line_files(self.pagexml_line_files) + line_iterator = read_lines_from_line_files(self.pagexml_line_files, has_headers=self.has_headers) if self.has_headers is True: header_line = next(line_iterator) self.line_file_headers = header_line.strip().split('\t') @@ -190,7 +193,12 @@ def read_pagexml_docs_from_line_file(line_files: Union[str, List[str]], has_head # print(line_dict) doc_coords, tr_coords, line_coords = None, None, None if add_bounding_box is True: - doc_coords = transform_box_to_coords(line_dict['doc_box']) + try: + doc_coords = transform_box_to_coords(line_dict['doc_box']) + except ValueError: + print(line_dict['doc_box']) + print(line_dict) + raise tr_coords = transform_box_to_coords(line_dict['textregion_box']) # print('\t', tr_coords, line_dict['textregion_box']) if line_dict['line_box'] is None: diff --git a/pagexml/model/physical_document_model.py b/pagexml/model/physical_document_model.py index 27d8c6f..8211872 100644 --- a/pagexml/model/physical_document_model.py +++ b/pagexml/model/physical_document_model.py @@ -443,7 +443,7 @@ def json(self) -> Dict[str, any]: def set_derived_id(self, parent_id: str): box_string = f"{self.coords.x}-{self.coords.y}-{self.coords.w}-{self.coords.h}" self.id = f"{parent_id}-{self.main_type}-{box_string}" - self.metadata['id'] = self.id + # self.metadata['id'] = self.id class LogicalStructureDoc(StructureDoc): @@ -664,6 +664,7 @@ def add_child(self, child: PageXMLDoc): self.text_regions.append(child) else: raise TypeError(f'unknown child type: {child.__class__.__name__}') + self.coords = parse_derived_coords(self.text_regions + self.lines) @property def json(self) -> Dict[str, any]: @@ -726,7 +727,7 @@ def get_inner_text_regions(self) -> List[PageXMLTextRegion]: def get_lines(self) -> List[PageXMLTextLine]: lines: List[PageXMLTextLine] = [] if self.text_regions: - if self.reading_order: + if self.reading_order and all([tr.id in self.reading_order for tr in self.text_regions]): for tr in sorted(self.text_regions, key=lambda t: self.reading_order_number[t.id]): lines += tr.get_lines() else: @@ -741,7 +742,7 @@ def get_words(self) -> Union[List[str], List[PageXMLWord]]: if self.text is not None: return self.text.split(' ') if self.lines: - for line in self.get_lines(): + for line in self.lines: if line.words: words += line.words elif line.text: @@ -820,30 +821,61 @@ def get_lines(self): for column in sorted(self.columns): lines += column.get_lines() # Second, add lines from text_regions + if self.extra: for tr in self.extra: lines += tr.get_lines() - elif self.text_regions: - if self.reading_order: + if self.text_regions: + # print('get_lines - reading_order_number:', self.reading_order_number) + # print('get_lines - reading_order:', self.reading_order) + if self.reading_order and all([tr.id in self.reading_order for tr in self.text_regions]): for tr in sorted(self.text_regions, key=lambda t: self.reading_order_number[t]): lines += tr.get_lines() else: for tr in sorted(self.text_regions): lines += tr.get_lines() + if self.lines: + raise AttributeError(f'page {self.id} has lines as direct property') return lines def add_child(self, child: PageXMLDoc, as_extra: bool = False): + # print('as_extra:', as_extra) + # print('stats before adding:', self.stats) child.set_parent(self) - if isinstance(child, PageXMLColumn): + if as_extra and (isinstance(child, PageXMLColumn) or isinstance(child, PageXMLTextRegion)): + self.extra.append(child) + elif isinstance(child, PageXMLColumn) or child.__class__.__name__ == 'PageXMLColumn': self.columns.append(child) elif isinstance(child, PageXMLTextLine): self.lines.append(child) elif isinstance(child, PageXMLTextRegion): - if as_extra: - self.extra.append(child) - else: - self.text_regions.append(child) + self.text_regions.append(child) else: raise TypeError(f'unknown child type: {child.__class__.__name__}') + self.coords = parse_derived_coords(self.extra + self.columns + self.text_regions + self.lines) + # print('stats after adding:', self.stats) + + def get_all_text_regions(self): + text_regions = [tr for col in self.columns for tr in col.text_regions] + text_regions.extend([tr for tr in self.extra]) + return text_regions + + def get_text_regions_in_reading_order(self, include_extra: bool = True): + text_regions = [] + if len(self.text_regions) > 0: + text_regions.extend(self.text_regions) + if hasattr(self, 'columns'): + for col in sorted(self.columns): + text_regions.extend(col.get_text_regions_in_reading_order()) + if include_extra and hasattr(self, 'extra'): + text_regions.extend(sorted(self.extra)) + return text_regions + + def get_inner_text_regions(self) -> List[PageXMLTextRegion]: + text_regions = self.get_all_text_regions() + inner_trs = [] + for tr in text_regions: + inner_trs.extend(tr.get_inner_text_regions()) + return inner_trs @property def json(self) -> Dict[str, any]: @@ -871,8 +903,9 @@ def stats(self): } if self.columns: stats['columns'] = len(self.columns) + if self.extra: stats['extra'] = len(self.extra) - elif self.text_regions: + if self.text_regions: stats['text_regions'] = len(self.text_regions) return stats diff --git a/pagexml/parser.py b/pagexml/parser.py index d0c83fa..7433608 100644 --- a/pagexml/parser.py +++ b/pagexml/parser.py @@ -1,5 +1,6 @@ import glob import json +import os import re from xml.parsers import expat from datetime import datetime @@ -381,7 +382,9 @@ def parse_pagexml_files_from_directory(pagexml_directories: List[str], if isinstance(pagexml_directories, str): pagexml_directories = [pagexml_directories] for pagexml_directory in pagexml_directories: - dir_files = glob.glob(pagexml_directory, recursive=True) + # print('dir:', pagexml_directory) + dir_files = glob.glob(os.path.join(pagexml_directory, '**/*.xml'), recursive=True) + # print('num files:', len(dir_files)) pagexml_files = [fname for fname in dir_files if fname.endswith('.xml')] if show_progress is True: for pagexml_file in tqdm(pagexml_files, desc=f'Parsing files from directory {pagexml_directory}'): @@ -392,6 +395,7 @@ def parse_pagexml_files_from_directory(pagexml_directories: List[str], def parse_pagexml_files_from_archive(archive_file: str, ignore_errors: bool = False, + silent_mode: bool = False, encoding: str = 'utf-8') -> Generator[pdm.PageXMLScan, None, None]: """Parse a list of PageXML files from an archive (e.g. zip, tar) and return each PageXML file as a PageXMLScan object. @@ -400,6 +404,8 @@ def parse_pagexml_files_from_archive(archive_file: str, ignore_errors: bool = Fa :type archive_file: str :param ignore_errors: whether to ignore errors when parsing individual PageXML files :type ignore_errors: bool + :param ignore_errors: whether to ignore errors warnings when parsing individual PageXML files + :type ignore_errors: bool :param encoding: the encoding of the file (default utf-8) :type encoding: str :return: a PageXMLScan object @@ -407,14 +413,17 @@ def parse_pagexml_files_from_archive(archive_file: str, ignore_errors: bool = Fa """ for pagefile_info, pagefile_data in read_page_archive_file(archive_file): try: - yield parse_pagexml_file(pagefile_info['archived_filename'], pagexml_data=pagefile_data, - encoding=encoding) + scan = parse_pagexml_file(pagefile_info['archived_filename'], pagexml_data=pagefile_data, + encoding=encoding) + scan.metadata['pagefile_info'] = pagefile_info + yield scan except (KeyError, AttributeError, IndexError, ValueError, TypeError, FileNotFoundError, expat.ExpatError) as err: if ignore_errors is True: - print(f"Skipping file with parser error: {pagefile_info['archived_filename']}") - print(err) + if silent_mode is False: + print(f"Skipping file with parser error: {pagefile_info['archived_filename']}") + print(err) continue else: raise diff --git a/pyproject.toml b/pyproject.toml index d7186d6..bf8b92e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ packages = [{ include = "pagexml" }] [tool.poetry.dependencies] python = "^3.8,<3.12" -fuzzy-search = "^1.5.0" +fuzzy-search = "^2.0.0a" matplotlib = "^3.7.0" numpy = "^1.22.3" pandas = "^1.5.3" diff --git a/tests/analysis-layout_stats_test.py b/tests/analysis-layout_stats_test.py index 2222a68..e4d209b 100644 --- a/tests/analysis-layout_stats_test.py +++ b/tests/analysis-layout_stats_test.py @@ -1,9 +1,49 @@ import unittest +import pagexml.analysis.layout_stats as layout_stats +import pagexml.model.physical_document_model as pdm +import pagexml.parser as parser + class TestLayoutStats(unittest.TestCase): - def test_something(self): - self.assertEqual(True, 1 == 1) + + def setUp(self) -> None: + self.page_file = 'data/example.xml' + self.page_doc = parser.parse_pagexml_file(self.page_file) + self.tr = self.page_doc.text_regions[1] + + def test_sort_above_below(self): + line = self.tr.lines[0] + above, below = layout_stats.sort_coords_above_below_baseline(line) + self.assertEqual(len(line.coords.points), len(above) + len(below)) + + def test_sort_above_below_has_no_above(self): + line1 = self.tr.lines[0] + line2 = self.tr.lines[1] + line1.coords.points = line2.coords.points + above, below = layout_stats.sort_coords_above_below_baseline(line1) + self.assertEqual(0, len(above)) + + def test_sort_above_below_has_no_below(self): + line1 = self.tr.lines[0] + line2 = self.tr.lines[1] + line2.coords.points = line1.coords.points + above, below = layout_stats.sort_coords_above_below_baseline(line2) + self.assertEqual(0, len(below)) + + def test_sort_above_below_returns_empty_when_coords_left_of_baseline(self): + line1 = self.tr.lines[0] + line1.coords = pdm.Coords([(100, 100), (150, 100), (150, 200), (100, 200)]) + line1.baseline = pdm.Coords([(300, 150), (400, 150)]) + above, below = layout_stats.sort_coords_above_below_baseline(line1) + self.assertEqual(0, len(below) + len(above)) + + def test_sort_above_below_returns_empty_when_coords_right_of_baseline(self): + line1 = self.tr.lines[0] + line1.coords = pdm.Coords([(1000, 100), (1500, 100), (1500, 200), (1000, 200)]) + line1.baseline = pdm.Coords([(300, 150), (400, 150)]) + above, below = layout_stats.sort_coords_above_below_baseline(line1) + self.assertEqual(0, len(below) + len(above)) if __name__ == '__main__':