Skip to content

Commit

Permalink
fix feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianMindee committed Oct 18, 2023
1 parent d0d9cee commit 7d74f3a
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 93 deletions.
123 changes: 99 additions & 24 deletions mindee/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from mindee.input.sources import LocalInputSource, UrlInputSource
from mindee.parsing.common.async_predict_response import AsyncPredictResponse
from mindee.parsing.common.document import Document, serialize_for_json
from mindee.parsing.common.feedback_response import FeedbackResponse
from mindee.parsing.common.inference import Inference, TypeInference
from mindee.parsing.common.predict_response import PredictResponse
from mindee.parsing.common.string_dict import StringDict


@dataclass
Expand Down Expand Up @@ -142,21 +144,29 @@ class MindeeParser:
input_doc: Union[LocalInputSource, UrlInputSource]
"""Document to be parsed."""
product_class: Type[Inference]
"""Product to parse"""
"""Product to parse."""
feedback: Optional[StringDict]
"""Dict representation of a feedback."""

def __init__(
self,
parser: Optional[ArgumentParser] = None,
parsed_args: Optional[Namespace] = None,
client: Optional[Client] = None,
input_doc: Optional[Union[LocalInputSource, UrlInputSource]] = None,
document_info: Optional[CommandConfig] = None,
) -> None:
self.parser = parser if parser else ArgumentParser(description="Mindee_API")
self.parsed_args = parsed_args if parsed_args else self._set_args()
self.client = client if client else Client(api_key=self.parsed_args.api_key)
if self.parsed_args.parse_type == "parse":
self.input_doc = input_doc if input_doc else self._get_input_doc()
self.client = (
client
if client
else Client(
api_key=self.parsed_args.api_key
if "api_key" in self.parsed_args
else None
)
)
self._set_input()
self.document_info = (
document_info if document_info else DOCUMENTS[self.parsed_args.product_name]
)
Expand All @@ -166,24 +176,27 @@ def call_endpoint(self) -> None:
if self.parsed_args.parse_type == "parse":
self.call_parse()
else:
self.call_fetch()
self.call_feedback()

def call_fetch(self) -> None:
"""Fetches an API's for a previously enqueued document."""
def call_feedback(self) -> None:
"""Sends feedback to an API."""
custom_endpoint: Optional[Endpoint] = None
if self.parsed_args.product_name == "custom":
custom_endpoint = self.client.create_endpoint(
self.parsed_args.endpoint_name,
self.parsed_args.account_name,
self.parsed_args.api_version,
)
response: PredictResponse = self.client.get_document(
self.document_info.doc_class, self.parsed_args.document_id, custom_endpoint
if self.feedback is None:
raise RuntimeError("Invalid feedback provided.")

response: FeedbackResponse = self.client.send_feedback(
self.document_info.doc_class,
self.parsed_args.document_id,
self.feedback,
custom_endpoint,
)
if self.parsed_args.output_type == "raw":
print(response.raw_http)
else:
print(response.document)
print(response.raw_http)

def call_parse(self) -> None:
"""Calls an endpoint with the appropriate method, and displays the results."""
Expand Down Expand Up @@ -272,10 +285,10 @@ def _set_args(self) -> Namespace:
parse_subparser = parse_product_subparsers.add_parser(name, help=info.help)

call_parser = parse_subparser.add_subparsers(
dest="parse_type",
dest="parse_type", required=True
)
parse_subp = call_parser.add_parser("parse")
fetch_subp = call_parser.add_parser("fetch")
feedback_subp = call_parser.add_parser("feedback")

self._add_main_options(parse_subp)
self._add_sending_options(parse_subp)
Expand All @@ -302,9 +315,22 @@ def _set_args(self) -> Namespace:
default=False,
)

self._add_main_options(fetch_subp)
self._add_display_options(fetch_subp)
self._add_doc_id_option(fetch_subp)
self._add_main_options(feedback_subp)
self._add_doc_id_option(feedback_subp)
self._add_feedback_options(feedback_subp)
feedback_subp.add_argument(
"-i",
"--input-type",
dest="input_type",
choices=["path", "file", "base64", "bytes", "local"],
default="local",
help="Specify how to handle the input.\n"
"- path: open a path (default).\n"
"- file: open as a file handle.\n"
"- base64: open a base64 encoded text file.\n"
"- bytes: open the contents as raw bytes.\n"
"- local: provide the feedback as a dict-like string.",
)

parsed_args = self.parser.parse_args()
return parsed_args
Expand All @@ -313,7 +339,7 @@ def _add_main_options(self, parser: ArgumentParser) -> None:
"""
Adds main options for most parsings.
:param parser: current parser/subparser.
:param parser: current parser.
"""
parser.add_argument(
"-k",
Expand All @@ -328,7 +354,7 @@ def _add_display_options(self, parser: ArgumentParser) -> None:
"""
Adds options related to output/display of a document (parse, parse-queued).
:param parser: current parser/subparser.
:param parser: current parser.
"""
parser.add_argument(
"-o",
Expand All @@ -346,7 +372,7 @@ def _add_sending_options(self, parser: ArgumentParser) -> None:
"""
Adds options for sending requests (parse, enqueue).
:param parser: current parser/subparser.
:param parser: current parser.
"""
parser.add_argument(
"-i",
Expand Down Expand Up @@ -382,18 +408,34 @@ def _add_doc_id_option(self, parser: ArgumentParser):
"""
Adds an option to provide the queue ID for an async document.
:param parser: current parser/subparser.
:param parser: current parser.
"""
parser.add_argument(
dest="document_id",
help="Async queue ID for a document (required)",
type=str,
)

def _add_feedback_options(self, parser: ArgumentParser):
"""
Adds the option to give feedback manually.
:param parser: current parser.
"""
parser.add_argument(
"-f",
"--feedback",
dest="feedback",
required=False,
type=json.loads,
help="Feedback as a string",
)

def _add_custom_options(self, parser: ArgumentParser):
"""
Adds options to custom-type documents.
:param parser: current parser/subparser.
:param parser: current parser.
"""
parser.add_argument(
"-a",
Expand Down Expand Up @@ -436,6 +478,39 @@ def _get_input_doc(self) -> Union[LocalInputSource, UrlInputSource]:
return self.client.source_from_url(self.parsed_args.path)
return self.client.source_from_path(self.parsed_args.path)

def _get_feedback_doc(self) -> StringDict:
"""Loads a feedback."""
json_doc: StringDict = {}
if self.parsed_args.input_type == "file":
with open(self.parsed_args.path, "rb", buffering=30) as f_f:
json_doc = json.loads(f_f.read())
elif self.parsed_args.input_type == "base64":
with open(self.parsed_args.path, "rt", encoding="ascii") as f_b64:
json_doc = json.loads(f_b64.read())
elif self.parsed_args.input_type == "bytes":
with open(self.parsed_args.path, "rb") as f_b:
json_doc = json.loads(f_b.read())
else:
if (
not self.parsed_args.feedback
or not "feedback" in self.parsed_args.feedback
):
raise RuntimeError("Invalid feedback.")
if not json_doc or "feedback" not in json_doc:
raise RuntimeError("Invalid feedback.")
return json_doc

def _set_input(self) -> None:
"""Loads an input document, or a feedback document."""
self.feedback = None
if self.parsed_args.parse_type == "feedback":
if not self.parsed_args.feedback:
self.feedback = self._get_feedback_doc()
else:
self.feedback = self.parsed_args.feedback
else:
self.input_doc = self._get_input_doc()


def main() -> None:
"""Run the Command Line Interface."""
Expand Down
30 changes: 0 additions & 30 deletions mindee/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,36 +309,6 @@ def send_feedback(

return FeedbackResponse(feedback_response.json())

def get_document(
self,
product_class: Type[Inference],
document_id: str,
endpoint: Optional[Endpoint] = None,
) -> PredictResponse:
"""
Fetch prediction results from a document already processed.
:param product_class: The document class to use.
The response object will be instantiated based on this parameter.
:param document_id: The id of the document to send feedback to.
:param endpoint: For custom endpoints, an endpoint has to be given.
"""
if not document_id or len(document_id) == 0:
raise RuntimeError("Invalid document_id.")
if not endpoint:
endpoint = self._initialize_ots_endpoint(product_class)

response = endpoint.document_req_get(document_id)
if not response.ok:
raise handle_error(
str(product_class.endpoint_name),
response.json(),
response.status_code,
)

return PredictResponse(product_class, response.json())

def _make_request(
self,
product_class: Type[Inference],
Expand Down
38 changes: 19 additions & 19 deletions mindee/mindee_http/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,25 +146,6 @@ def document_feedback_req_put(
timeout=self.settings.request_timeout,
)

def document_req_get(self, document_id: str) -> requests.Response:
"""
Make a request to GET annotations for a document.
:param document_id: ID of the document
"""
params = {
"include_annotations": True,
"include_candidates": True,
"global_orientation": True,
}
response = requests.get(
f"{self.settings.url_root}/documents/{document_id}",
headers=self.settings.base_headers,
params=params,
timeout=self.settings.request_timeout,
)
return response


class CustomEndpoint(Endpoint):
"""Endpoint for all custom documents."""
Expand Down Expand Up @@ -243,6 +224,25 @@ def documents_req_get(self, page_id: int = 1) -> requests.Response:
)
return response

def document_req_get(self, document_id: str) -> requests.Response:
"""
Make a request to GET annotations for a document.
:param document_id: ID of the document
"""
params = {
"include_annotations": True,
"include_candidates": True,
"global_orientation": True,
}
response = requests.get(
f"{self.settings.url_root}/documents/{document_id}",
headers=self.settings.base_headers,
params=params,
timeout=self.settings.request_timeout,
)
return response

def annotations_req_post(
self, document_id: str, annotations: dict
) -> requests.Response:
Expand Down
12 changes: 12 additions & 0 deletions tests/api/test_feedback_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import json
from mindee.parsing.common.feedback_response import FeedbackResponse


def test_empty_feedback_response():
response = json.load(
open("./tests/data/products/invoices/feedback_response/empty.json")
)
feedback_response = FeedbackResponse(response)
assert feedback_response is not None
assert feedback_response.feedback["customer_address"] is None

7 changes: 1 addition & 6 deletions tests/api/test_response.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import json

import pytest

from mindee.input.sources import PathInput
from mindee.parsing.common.predict_response import PredictResponse
from mindee.product import ( # FinancialDocumentV1,; InvoiceV3,; PassportV1,; ReceiptV3,
ReceiptV4,
)
from mindee.product import ReceiptV4
from mindee.product.financial_document.financial_document_v1 import FinancialDocumentV1
from mindee.product.financial_document.financial_document_v1_document import (
FinancialDocumentV1Document,
Expand Down
2 changes: 1 addition & 1 deletion tests/data
21 changes: 13 additions & 8 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from argparse import Namespace
from sys import api_version

Expand Down Expand Up @@ -62,8 +63,9 @@ def ots_doc_enqueue_and_parse(monkeypatch):


@pytest.fixture
def ots_doc_fetch(monkeypatch):
def ots_doc_feedback(monkeypatch):
clear_envvars(monkeypatch)
dummy_feedback = '{"feedback": {"dummy_field": {"value": "dummy"}}}'
return Namespace(
api_key="dummy",
output_type="summary",
Expand All @@ -73,7 +75,10 @@ def ots_doc_fetch(monkeypatch):
api_version="dummy",
queue_id="dummy-queue-id",
call_method="parse-queued",
parse_type="fetch",
input_type="path",
path="./tests/data/file_types/pdf/blank.pdf",
parse_type="feedback",
feedback=json.loads(dummy_feedback),
)


Expand Down Expand Up @@ -155,13 +160,13 @@ def test_cli_invoice_splitter_enqueue(ots_doc_enqueue_and_parse):
parser.call_endpoint()


def test_cli_fetch(ots_doc_fetch):
ots_doc_fetch.document_id = "dummy-document-id"
ots_doc_fetch.api_key = ""
def test_cli_feedback(ots_doc_feedback):
ots_doc_feedback.document_id = "dummy-document-id"
ots_doc_feedback.api_key = ""
with pytest.raises(RuntimeError):
parser = MindeeParser(parsed_args=ots_doc_fetch)
parser = MindeeParser(parsed_args=ots_doc_feedback)
parser.call_endpoint()
ots_doc_fetch.api_key = "dummy"
ots_doc_feedback.api_key = "dummy"
with pytest.raises(MindeeHTTPClientException):
parser = MindeeParser(parsed_args=ots_doc_fetch)
parser = MindeeParser(parsed_args=ots_doc_feedback)
parser.call_endpoint()
Loading

0 comments on commit 7d74f3a

Please sign in to comment.