Skip to content

Commit

Permalink
:sparkes: add support for full text ocr extra (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianMindee authored Sep 4, 2024
1 parent 17c4207 commit 7d85b5f
Show file tree
Hide file tree
Showing 35 changed files with 183 additions and 60 deletions.
25 changes: 22 additions & 3 deletions mindee/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def parse(
page_options: Optional[PageOptions] = None,
cropper: bool = False,
endpoint: Optional[Endpoint] = None,
full_text: bool = False,
) -> PredictResponse:
"""
Call prediction API on the document and parse the results.
Expand All @@ -89,6 +90,7 @@ def parse(
:param include_words: Whether to include the full text for each page.
This performs a full OCR operation on the server and will increase response time.
Only available on financial document APIs.
:param close_file: Whether to ``close()`` the file after parsing it.
Set to ``False`` if you need to access the file after this operation.
Expand All @@ -101,6 +103,7 @@ def parse(
This performs a cropping operation on the server and will increase response time.
:param endpoint: For custom endpoints, an endpoint has to be given.
:param full_text: Whether to include the full OCR text response in compatible APIs.
"""
if input_source is None:
raise MindeeClientError("No input document provided.")
Expand All @@ -118,7 +121,13 @@ def parse(
page_options.page_indexes,
)
return self._make_request(
product_class, input_source, endpoint, include_words, close_file, cropper
product_class,
input_source,
endpoint,
include_words,
close_file,
cropper,
full_text,
)

def enqueue(
Expand All @@ -130,6 +139,7 @@ def enqueue(
page_options: Optional[PageOptions] = None,
cropper: bool = False,
endpoint: Optional[Endpoint] = None,
full_text: bool = False,
) -> AsyncPredictResponse:
"""
Enqueues a document to an asynchronous endpoint.
Expand All @@ -154,6 +164,8 @@ def enqueue(
This performs a cropping operation on the server and will increase response time.
:param endpoint: For custom endpoints, an endpoint has to be given.
:param full_text: Whether to include the full OCR text response in compatible APIs.
"""
if input_source is None:
raise MindeeClientError("No input document provided.")
Expand All @@ -177,6 +189,7 @@ def enqueue(
include_words,
close_file,
cropper,
full_text,
)

def load_prediction(
Expand Down Expand Up @@ -246,6 +259,7 @@ def enqueue_and_parse(
initial_delay_sec: float = 4,
delay_sec: float = 2,
max_retries: int = 30,
full_text: bool = False,
) -> AsyncPredictResponse:
"""
Enqueues to an asynchronous endpoint and automatically polls for a response.
Expand Down Expand Up @@ -274,6 +288,8 @@ def enqueue_and_parse(
:param delay_sec: Delay between each polling attempts This should not be shorter than 2 seconds.
:param max_retries: Total amount of polling attempts.
:param full_text: Whether to include the full OCR text response in compatible APIs.
"""
self._validate_async_params(initial_delay_sec, delay_sec, max_retries)
if not endpoint:
Expand All @@ -286,6 +302,7 @@ def enqueue_and_parse(
page_options,
cropper,
endpoint,
full_text,
)
logger.debug(
"Successfully enqueued document with job id: %s", queue_result.job.id
Expand Down Expand Up @@ -352,9 +369,10 @@ def _make_request(
include_words: bool,
close_file: bool,
cropper: bool,
full_text: bool,
) -> PredictResponse:
response = endpoint.predict_req_post(
input_source, include_words, close_file, cropper
input_source, include_words, close_file, cropper, full_text
)

dict_response = response.json()
Expand All @@ -376,14 +394,15 @@ def _predict_async(
include_words: bool = False,
close_file: bool = True,
cropper: bool = False,
full_text: bool = False,
) -> AsyncPredictResponse:
"""Sends a document to the queue, and sends back an asynchronous predict response."""
if input_source is None:
raise MindeeClientError("No input document provided")
if not endpoint:
endpoint = self._initialize_ots_endpoint(product_class)
response = endpoint.predict_async_req_post(
input_source, include_words, close_file, cropper
input_source, include_words, close_file, cropper, full_text
)

dict_response = response.json()
Expand Down
17 changes: 10 additions & 7 deletions mindee/mindee_http/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def predict_req_post(
include_words: bool = False,
close_file: bool = True,
cropper: bool = False,
full_text: bool = False,
) -> requests.Response:
"""
Make a request to POST a document for prediction.
Expand All @@ -42,10 +43,11 @@ def predict_req_post(
:param include_words: Include raw OCR words in the response
:param close_file: Whether to `close()` the file after parsing it.
:param cropper: Including Mindee cropping results.
:param full_text: Whether to include the full OCR text response in compatible APIs.
:return: requests response
"""
return self._custom_request(
"predict", input_source, include_words, close_file, cropper
"predict", input_source, include_words, close_file, cropper, full_text
)

def predict_async_req_post(
Expand All @@ -54,6 +56,7 @@ def predict_async_req_post(
include_words: bool = False,
close_file: bool = True,
cropper: bool = False,
full_text: bool = False,
) -> requests.Response:
"""
Make an asynchronous request to POST a document for prediction.
Expand All @@ -62,10 +65,11 @@ def predict_async_req_post(
:param include_words: Include raw OCR words in the response
:param close_file: Whether to `close()` the file after parsing it.
:param cropper: Including Mindee cropping results.
:param full_text: Whether to include the full OCR text response in compatible APIs.
:return: requests response
"""
return self._custom_request(
"predict_async", input_source, include_words, close_file, cropper
"predict_async", input_source, include_words, close_file, cropper, full_text
)

def _custom_request(
Expand All @@ -75,11 +79,15 @@ def _custom_request(
include_words: bool = False,
close_file: bool = True,
cropper: bool = False,
full_text: bool = False,
):
data = {}
if include_words:
data["include_mvision"] = "true"

if full_text:
data["full_text_ocr"] = "true"

params = {}
if cropper:
params["cropper"] = "true"
Expand Down Expand Up @@ -111,11 +119,6 @@ def document_queue_req_get(self, queue_id: str) -> requests.Response:
Sends a request matching a given queue_id. Returns either a Job or a Document.
:param queue_id: queue_id received from the API
:param include_words: Whether to include the full text for each page.
This performs a full OCR operation on the server and will increase response time.
:param cropper: Whether to include cropper results for each page.
This performs a cropping operation on the server and will increase response time.
"""
return requests.get(
f"{self.settings.url_root}/documents/queue/{queue_id}",
Expand Down
26 changes: 25 additions & 1 deletion mindee/parsing/common/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Document(Generic[TypePrediction, TypePage]):
"""Result of the base inference"""
id: str
"""Id of the document as sent back by the server"""
extras: Optional[Extras]
extras: Extras
"""Potential Extras fields sent back along the prediction"""
ocr: Optional[Ocr]
"""Potential raw text results read by the OCR (limited feature)"""
Expand All @@ -47,6 +47,7 @@ def __init__(
self.ocr = Ocr(raw_response["ocr"])
if "extras" in raw_response and raw_response["extras"]:
self.extras = Extras(raw_response["extras"])
self._inject_full_text_ocr(raw_response)
self.inference = inference_type(raw_response["inference"])
self.n_pages = raw_response["n_pages"]

Expand All @@ -57,3 +58,26 @@ def __str__(self) -> str:
f":Filename: {self.filename}\n\n"
f"{self.inference}"
)

def _inject_full_text_ocr(self, raw_prediction: StringDict) -> None:
pages = raw_prediction.get("inference", {}).get("pages", [])

if (
not pages
or "extras" not in pages[0]
or "full_text_ocr" not in pages[0]["extras"]
):
return

full_text_content = "\n".join(
page["extras"]["full_text_ocr"]["content"]
for page in pages
if "extras" in page and "full_text_ocr" in page["extras"]
)

artificial_text_obj = {"content": full_text_content}

if not hasattr(self, "extras"):
self.extras = Extras({"full_text_ocr": artificial_text_obj})
else:
self.extras.add_artificial_extra({"full_text_ocr": artificial_text_obj})
15 changes: 14 additions & 1 deletion mindee/parsing/common/extras/extras.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from mindee.parsing.common.extras.cropper_extra import CropperExtra
from mindee.parsing.common.extras.full_text_ocr_extra import FullTextOcrExtra
from mindee.parsing.common.string_dict import StringDict


Expand All @@ -12,12 +13,15 @@ class Extras:
"""

cropper: Optional[CropperExtra]
full_text_ocr: Optional[FullTextOcrExtra]

def __init__(self, raw_prediction: StringDict) -> None:
if "cropper" in raw_prediction and raw_prediction["cropper"]:
self.cropper = CropperExtra(raw_prediction["cropper"])
if "full_text_ocr" in raw_prediction and raw_prediction["full_text_ocr"]:
self.full_text_ocr = FullTextOcrExtra(raw_prediction["full_text_ocr"])
for key, extra in raw_prediction.items():
if key != "cropper":
if key not in ["cropper", "full_text_ocr"]:
setattr(self, key, extra)

def __str__(self) -> str:
Expand All @@ -26,3 +30,12 @@ def __str__(self) -> str:
if not attr.startswith("__"):
out_str += f":{attr}: {getattr(self, attr)}\n"
return out_str

def add_artificial_extra(self, raw_prediction: StringDict):
"""
Adds artificial extra data for reconstructed extras. Currently only used for full_text_ocr.
:param raw_prediction: Raw prediction used by the document.
"""
if "full_text_ocr" in raw_prediction and raw_prediction["full_text_ocr"]:
self.full_text_ocr = FullTextOcrExtra(raw_prediction["full_text_ocr"])
20 changes: 20 additions & 0 deletions mindee/parsing/common/extras/full_text_ocr_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Optional

from mindee.parsing.common.string_dict import StringDict


class FullTextOcrExtra:
"""Full Text OCR result."""

content: Optional[str]
language: Optional[str]

def __init__(self, raw_prediction: StringDict) -> None:
if raw_prediction and "content" in raw_prediction:
self.content = raw_prediction["content"]

if raw_prediction and "language" in raw_prediction:
self.language = raw_prediction["language"]

def __str__(self) -> str:
return self.content if self.content else ""
4 changes: 2 additions & 2 deletions mindee/product/barcode_reader/barcode_reader_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, raw_prediction: StringDict):
self.pages = []
for page in raw_prediction["pages"]:
try:
page_production = page["prediction"]
page_prediction = page["prediction"]
except KeyError:
continue
if page_production:
if page_prediction:
self.pages.append(Page(BarcodeReaderV1Document, page))
4 changes: 2 additions & 2 deletions mindee/product/cropper/cropper_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, raw_prediction: StringDict):
self.pages = []
for page in raw_prediction["pages"]:
try:
page_production = page["prediction"]
page_prediction = page["prediction"]
except KeyError:
continue
if page_production:
if page_prediction:
self.pages.append(Page(CropperV1Page, page))
4 changes: 2 additions & 2 deletions mindee/product/eu/driver_license/driver_license_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, raw_prediction: StringDict):
self.pages = []
for page in raw_prediction["pages"]:
try:
page_production = page["prediction"]
page_prediction = page["prediction"]
except KeyError:
continue
if page_production:
if page_prediction:
self.pages.append(Page(DriverLicenseV1Page, page))
4 changes: 2 additions & 2 deletions mindee/product/eu/license_plate/license_plate_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, raw_prediction: StringDict):
self.pages = []
for page in raw_prediction["pages"]:
try:
page_production = page["prediction"]
page_prediction = page["prediction"]
except KeyError:
continue
if page_production:
if page_prediction:
self.pages.append(Page(LicensePlateV1Document, page))
4 changes: 2 additions & 2 deletions mindee/product/financial_document/financial_document_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, raw_prediction: StringDict):
self.pages = []
for page in raw_prediction["pages"]:
try:
page_production = page["prediction"]
page_prediction = page["prediction"]
except KeyError:
continue
if page_production:
if page_prediction:
self.pages.append(Page(FinancialDocumentV1Document, page))
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, raw_prediction: StringDict):
self.pages = []
for page in raw_prediction["pages"]:
try:
page_production = page["prediction"]
page_prediction = page["prediction"]
except KeyError:
continue
if page_production:
if page_prediction:
self.pages.append(Page(BankAccountDetailsV1Document, page))
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, raw_prediction: StringDict):
self.pages = []
for page in raw_prediction["pages"]:
try:
page_production = page["prediction"]
page_prediction = page["prediction"]
except KeyError:
continue
if page_production:
if page_prediction:
self.pages.append(Page(BankAccountDetailsV2Document, page))
4 changes: 2 additions & 2 deletions mindee/product/fr/carte_grise/carte_grise_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, raw_prediction: StringDict):
self.pages = []
for page in raw_prediction["pages"]:
try:
page_production = page["prediction"]
page_prediction = page["prediction"]
except KeyError:
continue
if page_production:
if page_prediction:
self.pages.append(Page(CarteGriseV1Document, page))
4 changes: 2 additions & 2 deletions mindee/product/fr/carte_vitale/carte_vitale_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, raw_prediction: StringDict):
self.pages = []
for page in raw_prediction["pages"]:
try:
page_production = page["prediction"]
page_prediction = page["prediction"]
except KeyError:
continue
if page_production:
if page_prediction:
self.pages.append(Page(CarteVitaleV1Document, page))
Loading

0 comments on commit 7d85b5f

Please sign in to comment.