diff --git a/docs/guide/cli.md b/docs/extras/guide/cli.md similarity index 100% rename from docs/guide/cli.md rename to docs/extras/guide/cli.md diff --git a/docs/guide/custom_v1.md b/docs/extras/guide/custom_v1.md similarity index 64% rename from docs/guide/custom_v1.md rename to docs/extras/guide/custom_v1.md index cdf7194a..5d276e97 100644 --- a/docs/guide/custom_v1.md +++ b/docs/extras/guide/custom_v1.md @@ -85,6 +85,53 @@ print(str(result.document.inference.prediction.fields["my-field"])) print(str(result.document.inference.prediction.classifications["my-classification"])) ``` + +# 🧪 Custom Line Items + +> **⚠️ Warning**: Custom Line Items are an **experimental** feature, results may vary. + + +Though not supported directly in the API, sometimes you might need to reconstitute line items by hand. +The library provides a tool for this very purpose: + +## columns_to_line_items() +The **columns_to_line_items()** function can be called from the document and page level prediction objects. + +It takes the following arguments: + +* **anchor_names** (`List[str]`): a list of the names of possible anchor (field) candidate for the horizontal placement a line. If all provided anchors are invalid, the `LineItemV1` won't be built. +* **field_names** (`List[str]`): a list of fields to retrieve the values from +* **height_tolerance** (`float`): Optional, the height tolerance used to build the line. It helps when the height of a line can vary unexpectedly. + +Example use: + +```python +# document-level +response.document.inference.prediction.columns_to_line_items( + anchor_names, + field_names, + 0.011 # optional, defaults to 0.01 +) + +# page-level +response.document.pages[0].prediction.columns_to_line_items( + anchor_names, + field_names, + 0.011 # optional, defaults to 0.01 +) +``` + +It returns a list of [CustomLineV1](#CustomlineV1) objects. + +## CustomlineV1 + +`CustomlineV1` represents a line as it has been read from column fields. It has the following attributes: + +* **row_number** (`int`): Number of a given line. Starts at 1. +* **fields** (`Dict[str, ListFieldValueV1]`[]): List of the fields associated with the line, indexed by their column name. +* **bbox** (`BBox`): Simple bounding box of the current line representing the 4 minimum & maximum coordinates as `float` values. + + # Questions? [Join our Slack](https://join.slack.com/t/mindee-community/shared_invite/zt-1jv6nawjq-FDgFcF2T5CmMmRpl9LLptw) diff --git a/examples/custom_line_items_reconstruction.py b/examples/custom_line_items_reconstruction.py new file mode 100644 index 00000000..c43d702d --- /dev/null +++ b/examples/custom_line_items_reconstruction.py @@ -0,0 +1,55 @@ +import os + +from mindee import Client, product +from mindee.parsing.common.predict_response import PredictResponse + +CUSTOM_ENDPOINT_NAME = os.getenv("CUSTOM_ENDPOINT_NAME", "my-endpoint-name") +CUSTOM_ACCOUNT_NAME = os.getenv("CUSTOM_ACCOUNT_NAME", "my-account-name") +CUSTOM_VERSION = os.getenv("CUSTOM_VERSION", "1") +CUSTOM_DOCUMENT_PATH = os.getenv("CUSTOM_DOCUMENT_PATH", "path/to/your/file.ext") + +# This example assumes you are following the associated tutorial: +# https://developers.mindee.com/docs/extracting-line-items-tutorial#line-reconstruction-code +anchors = ["category"] +columns = ["category", "previous_year_actual", "year_actual", "year_projection"] + + +def get_field_content(line, field) -> str: + if field in line.fields: + return str(line.fields[field].content) + return "" + + +def print_line(line) -> None: + category = get_field_content(line, "category") + previous_year_actual = get_field_content(line, "previous_year_actual") + year_actual = get_field_content(line, "year_actual") + year_projection = get_field_content(line, "year_projection") + # here ljust() fills the rest of the given size with spaces + string_line = ( + category.ljust(20, " ") + + previous_year_actual.ljust(10, " ") + + year_projection.ljust(10, " ") + + year_actual + ) + + print(string_line) + + +client = Client() + +custom_endpoint = client.create_endpoint( + CUSTOM_ENDPOINT_NAME, CUSTOM_ACCOUNT_NAME, CUSTOM_VERSION +) +input_doc = client.source_from_path(CUSTOM_DOCUMENT_PATH) + + +response: PredictResponse[product.CustomV1] = client.parse( + product.CustomV1, input_doc, endpoint=custom_endpoint +) +line_items = response.document.inference.prediction.columns_to_line_items( + anchors, columns +) + +for line in line_items: + print_line(line) diff --git a/mindee/geometry/bbox.py b/mindee/geometry/bbox.py index ac601306..7f3d4153 100644 --- a/mindee/geometry/bbox.py +++ b/mindee/geometry/bbox.py @@ -43,3 +43,31 @@ def get_bbox(points: Points) -> BBox: x_min = min(v[0] for v in points) x_max = max(v[0] for v in points) return BBox(x_min, y_min, x_max, y_max) + + +def merge_bbox(bbox_1: BBox, bbox_2: BBox) -> BBox: + """Merges two BBox.""" + return BBox( + min(bbox_1.x_min, bbox_2.x_min), + min(bbox_1.y_min, bbox_2.y_min), + max(bbox_1.x_max, bbox_2.x_max), + max(bbox_1.y_max, bbox_2.y_max), + ) + + +def extend_bbox(bbox: BBox, points: Points) -> BBox: + """ + Given a BBox and a sequence of points, calculate the surrounding bbox that encompasses all. + + :param bbox: initial BBox to extend. + :param points: Sequence of points to process. Accepts polygons and similar + """ + all_points = [] + for point in points: + all_points.append(point) + + y_min = min(v[1] for v in all_points) + y_max = max(v[1] for v in all_points) + x_min = min(v[0] for v in all_points) + x_max = max(v[0] for v in all_points) + return merge_bbox(bbox, BBox(x_min, y_min, x_max, y_max)) diff --git a/mindee/geometry/quadrilateral.py b/mindee/geometry/quadrilateral.py index 3d30e258..1744b47c 100644 --- a/mindee/geometry/quadrilateral.py +++ b/mindee/geometry/quadrilateral.py @@ -24,22 +24,6 @@ def centroid(self) -> Point: return get_centroid(self) -def get_bounding_box(points: Points) -> Quadrilateral: - """ - Given a sequence of points, calculate a bounding box that encompasses all points. - - :param points: Polygon to process. - :return: A bounding box that encompasses all points. - """ - x_min, y_min, x_max, y_max = get_bbox(points) - return Quadrilateral( - Point(x_min, y_min), - Point(x_max, y_min), - Point(x_max, y_max), - Point(x_min, y_max), - ) - - def quadrilateral_from_prediction(prediction: Sequence[list]) -> Quadrilateral: """ Transform a prediction into a Quadrilateral. @@ -54,3 +38,19 @@ def quadrilateral_from_prediction(prediction: Sequence[list]) -> Quadrilateral: Point(prediction[2][0], prediction[2][1]), Point(prediction[3][0], prediction[3][1]), ) + + +def get_bounding_box(points: Points) -> Quadrilateral: + """ + Given a sequence of points, calculate a bounding box that encompasses all points. + + :param points: Polygon to process. + :return: A bounding box that encompasses all points. + """ + x_min, y_min, x_max, y_max = get_bbox(points) + return Quadrilateral( + Point(x_min, y_min), + Point(x_max, y_min), + Point(x_max, y_max), + Point(x_min, y_max), + ) diff --git a/mindee/parsing/custom/__init__.py b/mindee/parsing/custom/__init__.py index 3753c379..ccb7d8ac 100644 --- a/mindee/parsing/custom/__init__.py +++ b/mindee/parsing/custom/__init__.py @@ -1,3 +1,3 @@ from mindee.parsing.custom.classification import ClassificationFieldV1 -from mindee.parsing.custom.line_items import LineV1, get_line_items +from mindee.parsing.custom.line_items import CustomLineV1, get_line_items from mindee.parsing.custom.list import ListFieldV1, ListFieldValueV1 diff --git a/mindee/parsing/custom/line_items.py b/mindee/parsing/custom/line_items.py index 473078b4..06cedddd 100644 --- a/mindee/parsing/custom/line_items.py +++ b/mindee/parsing/custom/line_items.py @@ -1,27 +1,12 @@ from typing import Dict, List, Sequence -from mindee.geometry import ( - Quadrilateral, - get_bounding_box, - get_min_max_y, - is_point_in_y, - merge_polygons, -) +from mindee.error.mindee_error import MindeeError +from mindee.geometry.bbox import BBox, extend_bbox, get_bbox +from mindee.geometry.minmax import MinMax, get_min_max_y +from mindee.geometry.quadrilateral import get_bounding_box from mindee.parsing.custom.list import ListFieldV1, ListFieldValueV1 -def _array_product(array: Sequence[float]) -> float: - """ - Get the product of a sequence of floats. - - :array: List of floats - """ - product = 1.0 - for k in array: - product = product * k - return product - - def _find_best_anchor(anchors: Sequence[str], fields: Dict[str, ListFieldV1]) -> str: """ Find the anchor with the most rows, in the order specified by `anchors`. @@ -32,28 +17,132 @@ def _find_best_anchor(anchors: Sequence[str], fields: Dict[str, ListFieldV1]) -> anchor_rows = 0 for field in anchors: values = fields[field].values - if len(values) > anchor_rows: + if values and len(values) > anchor_rows: anchor_rows = len(values) anchor = field return anchor -def _get_empty_field() -> ListFieldValueV1: - """Return sample field with empty values.""" - return ListFieldValueV1({"content": "", "polygon": [], "confidence": 0.0}) - - -class LineV1: +class CustomLineV1: """Represent a single line.""" row_number: int + """Index of the row of a given line.""" fields: Dict[str, ListFieldValueV1] - bounding_box: Quadrilateral + """Fields contained in the line.""" + bbox: BBox + """Simplified bounding box of the line.""" + + def __init__(self, row_number: int): + self.row_number = row_number + self.bbox = BBox(1, 1, 0, 0) + self.fields = {} + + def update_field(self, field_name: str, field_value: ListFieldValueV1) -> None: + """ + Updates a field value if it exists. + + :param field_name: name of the field to update. + :param field_value: value of the field to set. + """ + if field_name in self.fields: + existing_field = self.fields[field_name] + existing_content = existing_field.content + merged_content: str = "" + if len(existing_content) > 0: + merged_content += existing_content + " " + merged_content += field_value.content + merged_polygon = get_bounding_box( + [*existing_field.polygon, *field_value.polygon] + ) + merged_confidence = existing_field.confidence * field_value.confidence + else: + merged_content = field_value.content + merged_confidence = field_value.confidence + merged_polygon = get_bounding_box(field_value.polygon) + + self.fields[field_name] = ListFieldValueV1( + { + "content": merged_content, + "confidence": merged_confidence, + "polygon": merged_polygon, + } + ) + + +def is_box_in_line( + line: CustomLineV1, bbox: BBox, height_line_tolerance: float +) -> bool: + """ + Checks if the bbox fits inside the line. + + :param anchor_name: name of the anchor. + :param fields: fields to build lines from. + :param height_line_tolerance: line height tolerance for custom line reconstruction. + """ + if abs(bbox.y_min - line.bbox.y_min) <= height_line_tolerance: + return True + return abs(line.bbox.y_min - bbox.y_min) <= height_line_tolerance + + +def prepare( + anchor_name: str, fields: Dict[str, ListFieldV1], height_line_tolerance: float +) -> List[CustomLineV1]: + """ + Prepares lines before filling them. + + :param anchor_name: name of the anchor. + :param fields: fields to build lines from. + :param height_line_tolerance: line height tolerance for custom line reconstruction. + """ + lines_prepared: List[CustomLineV1] = [] + try: + anchor_field: ListFieldV1 = fields[anchor_name] + except KeyError as exc: + raise MindeeError("No lines have been detected.") from exc + + current_line_number: int = 1 + current_line = CustomLineV1(current_line_number) + if anchor_field and len(anchor_field.values) > 0: + current_value: ListFieldValueV1 = anchor_field.values[0] + current_line.bbox = extend_bbox( + current_line.bbox, + current_value.polygon, + ) + + for i in range(1, len(anchor_field.values)): + current_value = anchor_field.values[i] + current_field_box = get_bbox(current_value.polygon) + if not is_box_in_line( + current_line, current_field_box, height_line_tolerance + ): + lines_prepared.append(current_line) + current_line_number += 1 + current_line = CustomLineV1(current_line_number) + current_line.bbox = extend_bbox( + current_line.bbox, + current_value.polygon, + ) + if ( + len( + [ + line + for line in lines_prepared + if line.row_number == current_line_number + ] + ) + == 0 + ): + lines_prepared.append(current_line) + return lines_prepared def get_line_items( - anchors: Sequence[str], columns: Sequence[str], fields: Dict[str, ListFieldV1] -) -> List[LineV1]: + anchors: Sequence[str], + field_names: Sequence[str], + fields: Dict[str, ListFieldV1], + height_line_tolerance: float = 0.01, +) -> List[CustomLineV1]: """ Reconstruct line items from fields. @@ -61,51 +150,29 @@ def get_line_items( :columns: All fields which are columns :fields: List of field names to reconstruct table with """ - line_items: List[LineV1] = [] - anchor = _find_best_anchor(anchors, fields) + line_items: List[CustomLineV1] = [] + fields_to_transform: Dict[str, ListFieldV1] = {} + for field_name, field_value in fields.items(): + if field_name in field_names: + fields_to_transform[field_name] = field_value + anchor = _find_best_anchor(anchors, fields_to_transform) if not anchor: print(Warning("Could not find an anchor!")) return line_items - - # Loop on anchor items and create an item for each anchor item. - # This will create all rows with just the anchor column value. - for item in fields[anchor].values: - line_item = LineV1() - line_item.fields = {f: _get_empty_field() for f in columns} - line_item.fields[anchor] = item - line_items.append(line_item) - - # Loop on all created rows - for idx, line in enumerate(line_items): - # Compute sliding window between anchor item and the next - min_y, _ = get_min_max_y(line.fields[anchor].polygon) - if idx != len(line_items) - 1: - max_y, _ = get_min_max_y(line_items[idx + 1].fields[anchor].polygon) - else: - max_y = 1.0 # bottom of page - # Get candidates of each field included in sliding window and add it in line item - for field in columns: - field_words = [ - word - for word in fields[field].values - if is_point_in_y(word.polygon.centroid, min_y, max_y) - ] - line.fields[field].content = " ".join([v.content for v in field_words]) - try: - line.fields[field].polygon = merge_polygons( - [v.polygon for v in field_words] - ) - except ValueError: - pass - line.fields[field].confidence = _array_product( - [v.confidence for v in field_words] - ) - all_polygons = [line.fields[anchor].polygon] - for field in columns: - try: - all_polygons.append(line.fields[field].polygon) - except IndexError: - pass - line.bounding_box = get_bounding_box(merge_polygons(all_polygons)) - line.row_number = idx - return line_items + lines_prepared: List[CustomLineV1] = prepare( + anchor, fields_to_transform, height_line_tolerance + ) + + for current_line in lines_prepared: + for field_name, field in fields_to_transform.items(): + for list_field_value in field.values: + min_max_y: MinMax = get_min_max_y(list_field_value.polygon) + if ( + abs(min_max_y.max - current_line.bbox.y_max) + <= height_line_tolerance + and abs(min_max_y.min - current_line.bbox.y_min) + <= height_line_tolerance + ): + current_line.update_field(field_name, list_field_value) + + return lines_prepared diff --git a/mindee/parsing/standard/base.py b/mindee/parsing/standard/base.py index 6846eb71..55754282 100644 --- a/mindee/parsing/standard/base.py +++ b/mindee/parsing/standard/base.py @@ -1,6 +1,8 @@ from typing import Any, List, Optional, Type -from mindee.geometry import Point, Polygon, Quadrilateral, get_bounding_box +from mindee.geometry.point import Point +from mindee.geometry.polygon import Polygon +from mindee.geometry.quadrilateral import Quadrilateral, get_bounding_box from mindee.parsing.common.string_dict import StringDict diff --git a/mindee/parsing/standard/position.py b/mindee/parsing/standard/position.py index 53c60faf..f6f732bf 100644 --- a/mindee/parsing/standard/position.py +++ b/mindee/parsing/standard/position.py @@ -1,12 +1,8 @@ from typing import Optional -from mindee.geometry import ( - GeometryError, - Polygon, - Quadrilateral, - polygon_from_prediction, - quadrilateral_from_prediction, -) +from mindee.geometry.error import GeometryError +from mindee.geometry.polygon import Polygon, polygon_from_prediction +from mindee.geometry.quadrilateral import Quadrilateral, quadrilateral_from_prediction from mindee.parsing.common.string_dict import StringDict from mindee.parsing.standard.base import BaseField diff --git a/mindee/product/custom/__init__.py b/mindee/product/custom/__init__.py index e1335a9b..33330881 100644 --- a/mindee/product/custom/__init__.py +++ b/mindee/product/custom/__init__.py @@ -1 +1 @@ -from mindee.product.custom.custom_v1 import CustomV1 +from mindee.product.custom.custom_v1 import CustomV1, CustomV1Document, CustomV1Page diff --git a/mindee/product/custom/custom_v1_document.py b/mindee/product/custom/custom_v1_document.py index f6e01e09..e1126e97 100644 --- a/mindee/product/custom/custom_v1_document.py +++ b/mindee/product/custom/custom_v1_document.py @@ -1,7 +1,8 @@ -from typing import Dict +from typing import Dict, List from mindee.parsing.common import Prediction, StringDict, clean_out_string from mindee.parsing.custom import ClassificationFieldV1, ListFieldV1 +from mindee.parsing.custom.line_items import CustomLineV1, get_line_items class CustomV1Document(Prediction): @@ -27,6 +28,26 @@ def __init__(self, raw_prediction: StringDict) -> None: elif "values" in field_contents: self.fields[field_name] = ListFieldV1(field_contents) + def columns_to_line_items( + self, + anchor_names: List[str], + field_names: List[str], + height_tolerance: float = 0.01, + ) -> List[CustomLineV1]: + """ + Order column fields into line items. + + :param anchor_names: list of possible anchor fields. + :param field_names: list of all column fields. + :param height_tolerance: height tolerance to apply to lines. + """ + return get_line_items( + anchor_names, + field_names, + self.fields, + height_tolerance, + ) + def __str__(self) -> str: out_str = "" for classification_name, classification_value in self.classifications.items(): diff --git a/mindee/product/custom/custom_v1_page.py b/mindee/product/custom/custom_v1_page.py index d5bdd0be..a26f302d 100644 --- a/mindee/product/custom/custom_v1_page.py +++ b/mindee/product/custom/custom_v1_page.py @@ -1,7 +1,8 @@ -from typing import Dict, Optional +from typing import Dict, List, Optional from mindee.parsing.common import Prediction, StringDict, clean_out_string from mindee.parsing.custom import ListFieldV1 +from mindee.parsing.custom.line_items import CustomLineV1, get_line_items class CustomV1Page(Prediction): @@ -20,6 +21,26 @@ def __init__(self, raw_prediction: StringDict, page_id: Optional[int]) -> None: for field_name, field_contents in raw_prediction.items(): self.fields[field_name] = ListFieldV1(field_contents, page_id=page_id) + def columns_to_line_items( + self, + anchor_names: List[str], + field_names: List[str], + height_tolerance: float = 0.01, + ) -> List[CustomLineV1]: + """ + Order column fields into line items. + + :param anchor_names: list of possible anchor fields. + :param field_names: list of all column fields. + :param height_tolerance: height tolerance to apply to lines. + """ + return get_line_items( + anchor_names, + field_names, + self.fields, + height_tolerance, + ) + def __str__(self) -> str: out_str = "" for field_name, field_value in self.fields.items(): diff --git a/tests/product/custom/test_custom_v1_line_items.py b/tests/product/custom/test_custom_v1_line_items.py index 48b96835..a81251e3 100644 --- a/tests/product/custom/test_custom_v1_line_items.py +++ b/tests/product/custom/test_custom_v1_line_items.py @@ -1,31 +1,39 @@ import json from mindee.parsing.common.document import Document -from mindee.parsing.custom import get_line_items -from mindee.product import CustomV1 +from mindee.parsing.common.page import Page +from mindee.product.custom.custom_v1 import CustomV1 +from mindee.product.custom.custom_v1_page import CustomV1Page + + +def do_tests(line_items): + assert len(line_items) == 3 + assert line_items[0].fields["beneficiary_name"].content == "JAMES BOND 007" + assert line_items[0].fields["beneficiary_birth_date"].content == "1970-11-11" + assert line_items[0].row_number == 1 + assert line_items[1].fields["beneficiary_name"].content == "HARRY POTTER" + assert line_items[1].fields["beneficiary_birth_date"].content == "2010-07-18" + assert line_items[1].row_number == 2 + assert line_items[2].fields["beneficiary_name"].content == "DRAGO MALFOY" + assert line_items[2].fields["beneficiary_birth_date"].content == "2015-07-05" + assert line_items[2].row_number == 3 def test_single_table_01(): json_data_path = ( - f"./tests/data/products/custom/response_v1/line_items/single_table_01.json" + "./tests/data/products/custom/response_v1/line_items/single_table_01.json" ) json_data = json.load(open(json_data_path, "r")) doc = Document(CustomV1, json_data["document"]).inference.prediction - anchors = ["beneficiary_birth_date"] + page = Page(CustomV1Page, json_data["document"]["inference"]["pages"][0]) + anchors = ["beneficiary_name"] columns = [ - "beneficiary_name", "beneficiary_birth_date", - "beneficiary_rank", "beneficiary_number", + "beneficiary_name", + "beneficiary_rank", ] - line_items = get_line_items(anchors, columns, doc.fields) - assert len(line_items) == 3 - assert line_items[0].fields["beneficiary_name"].content == "JAMES BOND 007" - assert line_items[0].fields["beneficiary_birth_date"].content == "1970-11-11" - assert line_items[0].row_number == 0 - assert line_items[1].fields["beneficiary_name"].content == "HARRY POTTER" - assert line_items[1].fields["beneficiary_birth_date"].content == "2010-07-18" - assert line_items[1].row_number == 1 - assert line_items[2].fields["beneficiary_name"].content == "DRAGO MALFOY" - assert line_items[2].fields["beneficiary_birth_date"].content == "2015-07-05" - assert line_items[2].row_number == 2 + line_items = doc.columns_to_line_items(anchors, columns, 0.011) + do_tests(line_items) + line_items_page = page.prediction.columns_to_line_items(anchors, columns, 0.011) + do_tests(line_items_page)