diff --git a/docs/assets/img/demo/driver_license_ru/template.jpg b/docs/assets/templates/driver_license_ru_01/driver_license_ru.jpg similarity index 100% rename from docs/assets/img/demo/driver_license_ru/template.jpg rename to docs/assets/templates/driver_license_ru_01/driver_license_ru.jpg diff --git a/docs/assets/templates/driver_license_ru_01/driver_license_ru.yml b/docs/assets/templates/driver_license_ru_01/driver_license_ru.yml new file mode 100644 index 0000000..7054a5e --- /dev/null +++ b/docs/assets/templates/driver_license_ru_01/driver_license_ru.yml @@ -0,0 +1,175 @@ +id: "driver_license_ru" +name: "Driver License RU" +source: "driver_license_ru.jpg" +mutators: + source: + target: +keypoints: + title: + x: 453 + y: 55 + w: 792 + h: 70 + matches: + min: 15 + max: 50 + rus_symbol: + x: 138 + y: 33 + w: 187 + h: 109 + matches: + min: 2 + max: 30 + heading_bar: + x: 441 + y: 154 + w: 96 + h: 610 + matches: + min: 5 + max: 30 + heading_4b: + x: 802 + y: 432 + w: 67 + h: 49 + matches: + min: 1 + max: 10 + heading_67: + x: 58 + y: 663 + w: 51 + h: 106 + matches: + min: 3 + max: 10 + b_b1: + x: 531 + y: 681 + w: 141 + h: 78 + matches: + min: 0 + max: 20 +matching: + engine: sift_flann + config: + sift_flann: + sensitivity: 0.7 +supervision: + engine: combinatorial + config: + combinatorial: + min_match_factor: 0.1 + max_transformation_error: 5 + result: best_score +features: + last_name_ru: + x: 525 + y: 160 + w: 600 + h: 45 + class: line_with_russian_text + last_name_en: + x: 525 + y: 200 + w: 600 + h: 35 + class: line_with_english_text + name_ru: + x: 525 + y: 235 + w: 600 + h: 45 + class: line_with_russian_text + name_en: + x: 525 + y: 275 + w: 600 + h: 35 + class: line_with_english_text + birthday: + x: 525 + y: 310 + w: 600 + h: 45 + class: line_with_english_text + place_of_birth_ru: + x: 525 + y: 350 + w: 600 + h: 40 + class: line_with_russian_text + place_of_birth_en: + x: 525 + y: 390 + w: 600 + h: 35 + class: line_with_english_text + issue_date: + x: 525 + y: 430 + w: 250 + h: 45 + class: line_with_english_text + expiry_date: + x: 875 + y: 430 + w: 250 + h: 45 + class: line_with_english_text + issue_authority_ru: + x: 525 + y: 475 + w: 600 + h: 40 + class: line_with_russian_text + issue_authority_en: + x: 525 + y: 510 + w: 600 + h: 40 + class: line_with_english_text + identifier: + x: 525 + y: 550 + w: 600 + h: 50 + class: line_with_english_text + issue_place_ru: + x: 525 + y: 595 + w: 600 + h: 40 + class: line_with_russian_text + issue_place_en: + x: 525 + y: 630 + w: 600 + h: 40 + class: line_with_english_text + face: + x: 87 + y: 192 + w: 313 + h: 460 +feature_classes: + line_with_text: + abstract: yes + mutators: + interpretation: + method: ocr_tesseract + config: + config: --dpi 1000 + line_with_russian_text: + inherits: line_with_text + interpretation: + config: + lang: rus + line_with_english_text: + inherits: line_with_text + interpretation: + config: + lang: eng \ No newline at end of file diff --git a/docs/assets/img/demo/driver_license_ru/examples/01.jpg b/docs/assets/templates/driver_license_ru_01/examples/01.jpg similarity index 100% rename from docs/assets/img/demo/driver_license_ru/examples/01.jpg rename to docs/assets/templates/driver_license_ru_01/examples/01.jpg diff --git a/docs/assets/img/demo/driver_license_ru/template_show_features.jpg b/docs/assets/templates/driver_license_ru_01/show_features.jpg similarity index 100% rename from docs/assets/img/demo/driver_license_ru/template_show_features.jpg rename to docs/assets/templates/driver_license_ru_01/show_features.jpg diff --git a/docs/assets/img/demo/driver_license_ru/template_show_keypoints.jpg b/docs/assets/templates/driver_license_ru_01/show_keypoints.jpg similarity index 100% rename from docs/assets/img/demo/driver_license_ru/template_show_keypoints.jpg rename to docs/assets/templates/driver_license_ru_01/show_keypoints.jpg diff --git a/docs/assets/img/demo/driver_license_ru/test/01/000.png b/docs/assets/templates/driver_license_ru_01/test/01/000.png similarity index 100% rename from docs/assets/img/demo/driver_license_ru/test/01/000.png rename to docs/assets/templates/driver_license_ru_01/test/01/000.png diff --git a/docs/assets/img/demo/driver_license_ru/test/01/000_match_title.png b/docs/assets/templates/driver_license_ru_01/test/01/000_match_title.png similarity index 100% rename from docs/assets/img/demo/driver_license_ru/test/01/000_match_title.png rename to docs/assets/templates/driver_license_ru_01/test/01/000_match_title.png diff --git a/docs/assets/img/demo/driver_license_ru/test/01/001_match_rus_symbol.png b/docs/assets/templates/driver_license_ru_01/test/01/001_match_rus_symbol.png similarity index 100% rename from docs/assets/img/demo/driver_license_ru/test/01/001_match_rus_symbol.png rename to docs/assets/templates/driver_license_ru_01/test/01/001_match_rus_symbol.png diff --git a/docs/assets/img/demo/driver_license_ru/test/01/002_match_heading_bar.png b/docs/assets/templates/driver_license_ru_01/test/01/002_match_heading_bar.png similarity index 100% rename from docs/assets/img/demo/driver_license_ru/test/01/002_match_heading_bar.png rename to docs/assets/templates/driver_license_ru_01/test/01/002_match_heading_bar.png diff --git a/docs/assets/img/demo/driver_license_ru/test/01/003_match_heading_4b.png b/docs/assets/templates/driver_license_ru_01/test/01/003_match_heading_4b.png similarity index 100% rename from docs/assets/img/demo/driver_license_ru/test/01/003_match_heading_4b.png rename to docs/assets/templates/driver_license_ru_01/test/01/003_match_heading_4b.png diff --git a/docs/assets/img/demo/driver_license_ru/test/01/004_match_heading_67.png b/docs/assets/templates/driver_license_ru_01/test/01/004_match_heading_67.png similarity index 100% rename from docs/assets/img/demo/driver_license_ru/test/01/004_match_heading_67.png rename to docs/assets/templates/driver_license_ru_01/test/01/004_match_heading_67.png diff --git a/docs/assets/img/demo/driver_license_ru/test/01/005_match_b_b1.png b/docs/assets/templates/driver_license_ru_01/test/01/005_match_b_b1.png similarity index 100% rename from docs/assets/img/demo/driver_license_ru/test/01/005_match_b_b1.png rename to docs/assets/templates/driver_license_ru_01/test/01/005_match_b_b1.png diff --git a/docs/assets/img/demo/driver_license_ru/test/01/supervision_result.png b/docs/assets/templates/driver_license_ru_01/test/01/supervision_result.png similarity index 100% rename from docs/assets/img/demo/driver_license_ru/test/01/supervision_result.png rename to docs/assets/templates/driver_license_ru_01/test/01/supervision_result.png diff --git a/docs/dev/changelog.md b/docs/dev/changelog.md index ff2f1ac..8989267 100644 --- a/docs/dev/changelog.md +++ b/docs/dev/changelog.md @@ -1,5 +1,28 @@ # Changelog +## Release 1.2.0 (beta) + +### Major changes + +* Implemented the OfficialEye API. Now it is possible to interact with the program programatically, without the need of running the CLI. +* Reimplemented the CLI as a layer on top of the new API. Thus, the API and the internal implementation no longer contain any code that is specific to the CLI user interface. In particular, it is now easy to implement different frontends that rely on OfficialEye as a backend service. +* Implemented a framework for transparent and process-safe interaction with the API backend. +* Switched from thread-based to process-based parallelism for resource-intensive backend operations. +* Substrantially improved the CLI user interface. +* Numerous other related architecture changes aimed at the long-term stability of the software. +* Integrated a new error handling system and related debugging facilitibes. +* Removed the (legacy) orthogonal regression supervision engine. +* Temporarily disabled the ability to generate visualizations. + +### Minor changes + +* Removed the `--worker` argument from the `run` and `test` commands, because it has become redundant and unnececcary in light of the new architecture. +* Implemented a new approach to handling image outputting in the CLI, that is much more flexible compared to the previous one. +* Improved type annotations. +* Removed the `--visualize` argument from the `test` command. + +[View on GitHub](https://github.com/ZeroBone/OfficialEye/releases/tag/1.2.0){ .md-button } + ## Release 1.1.5 (beta) * Added an `--interpret` option to the `run` and `test` commands, allowing one to optionally use a different target image for the interpretation phase. diff --git a/docs/index.md b/docs/index.md index 956bbf0..545071c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -11,9 +11,6 @@ icon: octicons/info-16 OfficialEye is an advanced tool designed to extract information from flat documents, such as passports or application forms, through image analysis. Leveraging state-of-the-art symbolic AI techniques, OfficialEye empowers users to effortlessly transform raw images into standardized canonical forms, facilitating the seamless extraction and processing of crucial information. -![GitHub Release](https://img.shields.io/github/v/release/ZeroBone/OfficialEye?label=latest%20release) -![PyPI - Version](https://img.shields.io/pypi/v/officialeye) - ## Key features diff --git a/docs/usage/examples.md b/docs/usage/examples.md index 27587af..6957071 100644 --- a/docs/usage/examples.md +++ b/docs/usage/examples.md @@ -1,5 +1,6 @@ --- icon: octicons/star-16 +render_macros: true --- # Examples @@ -15,7 +16,7 @@ Start by obtaining a properly positioned and high-quality image of a driver's li For demonstration purposes, let's use the following scan of a driver's license. ???+ tip "Template image" - ![Driver license photo](../assets/img/demo/driver_license_ru/template.jpg){ loading=lazy } + ![Driver license photo](../assets/templates/driver_license_ru_01/driver_license_ru.jpg){ loading=lazy } Next, we initialize a new template configuration file using the following command. @@ -119,7 +120,7 @@ officialeye show driver_license_ru.yml --hide-features ``` ???+ tip "Visualization of the template's keypoints" - ![Visualization of the template's keypoints](../assets/img/demo/driver_license_ru/template_show_keypoints.jpg){ loading=lazy } + ![Visualization of the template's keypoints](../assets/templates/driver_license_ru_01/show_keypoints.jpg){ loading=lazy } #### Features @@ -282,14 +283,14 @@ officialeye show driver_license_ru.yml --hide-keypoints ``` ???+ tip "Visualization of the template's features" - ![Visualization of the template's features](../assets/img/demo/driver_license_ru/template_show_features.jpg){ loading=lazy } + ![Visualization of the template's features](../assets/templates/driver_license_ru_01/show_features.jpg){ loading=lazy } ### Testing document analysis To test OfficialEye's document analysis and processing, we need an example image containing the document type the template is configured for, in this case, a photo of a driver's license. For the sake of the present demonstration, we shall use the following image. ???+ example "example_01.jpg" - ![Driver license photo](../assets/img/demo/driver_license_ru/examples/01.jpg){ loading=lazy } + ![Driver license photo](../assets/templates/driver_license_ru_01/examples/01.jpg){ loading=lazy } We can now tell OfficialEye to run the analysis algorithm and visualize the result by running the @@ -300,7 +301,7 @@ officialeye test example_01.jpg driver_license_ru.yml command, where `example_01.jpg` is the path to the input image (see above). The tool visualizes the result by replacing feature regions in the template image by the transformed version of the corresponding regions in the input image. ???+ example "Result" - ![Driver license photo](../assets/img/demo/driver_license_ru/test/01/supervision_result.png){ loading=lazy } + ![Driver license photo](../assets/templates/driver_license_ru_01/test/01/supervision_result.png){ loading=lazy } As we can see, OficialEye was able to successfully use the template we created to detect the driver's license in the input image, zoom and rotate all features accordingly. @@ -343,181 +344,7 @@ As above, `example_01.jpg` is the path to the input image. As a result, we get t For the sake of completeness and convenience, we provide the full version of the template configuration file of the present example. ```yaml title="driver_license_ru.yml" -id: "driver_license_ru" -name: "Driver License RU" -source: "driver_license_ru.jpg" -mutators: - source: - target: -keypoints: - title: - x: 453 - y: 55 - w: 792 - h: 70 - matches: - min: 15 - max: 50 - rus_symbol: - x: 138 - y: 33 - w: 187 - h: 109 - matches: - min: 2 - max: 30 - heading_bar: - x: 441 - y: 154 - w: 96 - h: 610 - matches: - min: 5 - max: 30 - heading_4b: - x: 802 - y: 432 - w: 67 - h: 49 - matches: - min: 1 - max: 10 - heading_67: - x: 58 - y: 663 - w: 51 - h: 106 - matches: - min: 3 - max: 10 - b_b1: - x: 531 - y: 681 - w: 141 - h: 78 - matches: - min: 0 - max: 20 -matching: - engine: sift_flann - config: - sift_flann: - sensitivity: 0.7 -supervision: - engine: combinatorial - config: - combinatorial: - min_match_factor: 0.1 - max_transformation_error: 5 - result: best_score -features: - last_name_ru: - x: 525 - y: 160 - w: 600 - h: 45 - class: line_with_russian_text - last_name_en: - x: 525 - y: 200 - w: 600 - h: 35 - class: line_with_english_text - name_ru: - x: 525 - y: 235 - w: 600 - h: 45 - class: line_with_russian_text - name_en: - x: 525 - y: 275 - w: 600 - h: 35 - class: line_with_english_text - birthday: - x: 525 - y: 310 - w: 600 - h: 45 - class: line_with_english_text - place_of_birth_ru: - x: 525 - y: 350 - w: 600 - h: 40 - class: line_with_russian_text - place_of_birth_en: - x: 525 - y: 390 - w: 600 - h: 35 - class: line_with_english_text - issue_date: - x: 525 - y: 430 - w: 250 - h: 45 - class: line_with_english_text - expiry_date: - x: 875 - y: 430 - w: 250 - h: 45 - class: line_with_english_text - issue_authority_ru: - x: 525 - y: 475 - w: 600 - h: 40 - class: line_with_russian_text - issue_authority_en: - x: 525 - y: 510 - w: 600 - h: 40 - class: line_with_english_text - identifier: - x: 525 - y: 550 - w: 600 - h: 50 - class: line_with_english_text - issue_place_ru: - x: 525 - y: 595 - w: 600 - h: 40 - class: line_with_russian_text - issue_place_en: - x: 525 - y: 630 - w: 600 - h: 40 - class: line_with_english_text - face: - x: 87 - y: 192 - w: 313 - h: 460 -feature_classes: - line_with_text: - abstract: yes - mutators: - interpretation: - method: ocr_tesseract - config: - config: --dpi 1000 - line_with_russian_text: - inherits: line_with_text - interpretation: - config: - lang: rus - line_with_english_text: - inherits: line_with_text - interpretation: - config: - lang: eng +{% include 'assets/templates/driver_license_ru_01/driver_license_ru.yml' %} ``` [Getting started](getting-started/index.md){ .md-button .md-button--primary} \ No newline at end of file diff --git a/docs_dynamic/deprecated.md b/docs_dynamic/deprecated.md index 7ee8bb4..85040f8 100644 --- a/docs_dynamic/deprecated.md +++ b/docs_dynamic/deprecated.md @@ -111,33 +111,4 @@ feature_classes: # (31)! 36. Optional mutator-specific configuration. 37. An interpretation method defines the way in which the mutated feature location should be processed further. For example, the `ocr_tesseract` method will apply the Tesseract OCR to the image. 38. Name of the interpretation method. -39. Intepretation-method-specific configuration values. - -# Basic usage - -A good introduction to OfficialEye would be to show how to re-create the example of the [home page](index.md). First, we need a high-quality image of an example document. For this tutorial, we shall use the following example of a German identity card: - -![Example of an identity card used in Germany](assets/img/identity_card_de.jpg "Example of an identity card used in Germany") - -Broadly speaking, we now need to explain OfficialEye, which parts of this example document contain information we are interested in, and which parts are present on any document of this kind and can be used to recognize the document on other images. OfficialEye uses a concept called *template* to conveniently capture in a single unit the example document together with this information. In this case, it makes sense to create a template called `German ID Card` and identifier `id_de`: - -```bash -officialeye create demo/templates/id_de.yml --name "German ID Card" --id id_de --force -``` - -This command creates the configuration file `demo/templates/id_de.yml` for the new template, so that we don't have to configure everything from scratch. - -
- ![Driver license photo](../../assets/img/demo/driver_license_ru_01.jpg){ width="600", loading=lazy } -
Driver's license example photo
-
- -
-!!! tip "Template" - * No further requirements - -!!! example "Example document" - * test -
- -hl_lines="5 6 7 8 9 10 11 12" \ No newline at end of file +39. Intepretation-method-specific configuration values. \ No newline at end of file diff --git a/docs_dynamic/gen_api.py b/docs_dynamic/gen_api.py index 11d1634..7db7d76 100644 --- a/docs_dynamic/gen_api.py +++ b/docs_dynamic/gen_api.py @@ -14,7 +14,7 @@ def snake_case_to_title(input_str: str, /) -> str: src = Path(__file__).parent.parent / "src" internal_module = src / "officialeye" / "_internal" -api_module = src / "officialeye" / "api" +api_module = src / "officialeye" / "_api" mod_symbol = '' diff --git a/mkdocs.yml b/mkdocs.yml index 8f61848..a4fd535 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -52,6 +52,8 @@ plugins: - social - search: separator: '[\s\u200b\-_,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])' + - macros: + render_by_default: false - minify: minify_html: true - privacy: diff --git a/pdm.lock b/pdm.lock index a8b6ee5..525fcf7 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "doc", "test"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:4fd01be3b6ebe5bfb0b815c10f5fd102585705db56533e299df2351109dde6d2" +content_hash = "sha256:7dcc5b397d044fd48a992675e75661fff2f88eac1b003f7c9e542527dc22354f" [[package]] name = "babel" @@ -326,6 +326,20 @@ files = [ {file = "Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8"}, ] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +requires_python = ">=3.8" +summary = "Python port of markdown-it. Markdown parsing, done right!" +groups = ["default"] +dependencies = [ + "mdurl~=0.1", +] +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + [[package]] name = "markupsafe" version = "2.1.3" @@ -366,6 +380,17 @@ files = [ {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, ] +[[package]] +name = "mdurl" +version = "0.1.2" +requires_python = ">=3.7" +summary = "Markdown URL utilities" +groups = ["default"] +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "mergedeep" version = "1.3.4" @@ -446,6 +471,24 @@ files = [ {file = "mkdocs_literate_nav-0.6.1.tar.gz", hash = "sha256:78a7ab6d878371728acb0cdc6235c9b0ffc6e83c997b037f4a5c6ff7cef7d759"}, ] +[[package]] +name = "mkdocs-macros-plugin" +version = "1.0.5" +requires_python = ">=3.8" +summary = "Unleash the power of MkDocs with macros and variables" +groups = ["doc"] +dependencies = [ + "jinja2", + "mkdocs>=0.17", + "python-dateutil", + "pyyaml", + "termcolor", +] +files = [ + {file = "mkdocs-macros-plugin-1.0.5.tar.gz", hash = "sha256:fe348d75f01c911f362b6d998c57b3d85b505876dde69db924f2c512c395c328"}, + {file = "mkdocs_macros_plugin-1.0.5-py3-none-any.whl", hash = "sha256:f60e26f711f5a830ddf1e7980865bf5c0f1180db56109803cdd280073c1a050a"}, +] + [[package]] name = "mkdocs-material" version = "9.5.3" @@ -744,7 +787,7 @@ name = "pygments" version = "2.17.2" requires_python = ">=3.7" summary = "Pygments is a syntax highlighting package written in Python." -groups = ["doc"] +groups = ["default", "doc"] files = [ {file = "pygments-2.17.2-py3-none-any.whl", hash = "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c"}, {file = "pygments-2.17.2.tar.gz", hash = "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367"}, @@ -932,6 +975,21 @@ files = [ {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, ] +[[package]] +name = "rich" +version = "13.7.0" +requires_python = ">=3.7.0" +summary = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +groups = ["default"] +dependencies = [ + "markdown-it-py>=2.2.0", + "pygments<3.0.0,>=2.13.0", +] +files = [ + {file = "rich-13.7.0-py3-none-any.whl", hash = "sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235"}, + {file = "rich-13.7.0.tar.gz", hash = "sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa"}, +] + [[package]] name = "six" version = "1.16.0" @@ -957,6 +1015,17 @@ files = [ {file = "strictyaml-1.7.3.tar.gz", hash = "sha256:22f854a5fcab42b5ddba8030a0e4be51ca89af0267961c8d6cfa86395586c407"}, ] +[[package]] +name = "termcolor" +version = "2.4.0" +requires_python = ">=3.8" +summary = "ANSI color formatting for output in terminal" +groups = ["doc"] +files = [ + {file = "termcolor-2.4.0-py3-none-any.whl", hash = "sha256:9297c0df9c99445c2412e832e882a7884038a25617c60cea2ad69488d4040d63"}, + {file = "termcolor-2.4.0.tar.gz", hash = "sha256:aab9e56047c8ac41ed798fa36d892a37aca6b3e9159f3e0c24bc64a9b3ac7b7a"}, +] + [[package]] name = "tinycss2" version = "1.2.1" diff --git a/pyproject.toml b/pyproject.toml index 3c9eb15..9c52c96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "officialeye" -version = "1.1.5" +version = "1.2.0" description = "AI-powered generic document-analysis tool" authors = [ {name = "Alexander Mayorov", email = "zb@zerobone.net"}, @@ -13,6 +13,7 @@ dependencies = [ "pyyaml", "strictyaml==1.7.3", "z3-solver", + "rich" ] requires-python = ">=3.10" readme = "README.md" @@ -27,7 +28,7 @@ classifiers = [ ] [project.scripts] -officialeye = "officialeye._internal.main:cli" +officialeye = "officialeye._cli.main:main" [project.urls] Homepage = "https://github.com/ZeroBone/OfficialEye" @@ -48,7 +49,7 @@ docs-build = "mkdocs build" docs-deploy = "mkdocs gh-deploy --force" # other scripts -count-loc = {shell = "find src/officialeye -name '*.py' | xargs wc -l"} +count-loc = {shell = "find src -name '*.py' | xargs wc -l"} # scripts to be called by the CI (i.e., GitHub actions) ci-pytest = {shell = "pytest src/tests/"} @@ -64,7 +65,8 @@ doc = [ "mkdocstrings", "mkdocstrings[python]", "mkdocs-gen-files", - "mkdocs-literate-nav" + "mkdocs-literate-nav", + "mkdocs-macros-plugin" ] test = [ "pytest" @@ -122,7 +124,8 @@ select = [ ignore = [ "D417", - "SIM108" # Usage of ternary operators instead of if-then-else + "SIM108", # Usage of ternary operators instead of if-then-else + "SIM117" # Use a single `with` statement with multiple contexts instead of nested `with` statements ] [tool.ruff.lint.pydocstyle] diff --git a/src/officialeye/__init__.py b/src/officialeye/__init__.py index a1e4176..138ae97 100644 --- a/src/officialeye/__init__.py +++ b/src/officialeye/__init__.py @@ -1,3 +1,64 @@ """ Root module. -""" \ No newline at end of file +""" + +# disable unused imports ruff check +# ruff: noqa: F401 + +# Config +# noinspection PyProtectedMember +from officialeye._api.config import Config, InterpretationConfig, MatcherConfig, MutatorConfig, SupervisorConfig + +# Context +# noinspection PyProtectedMember +from officialeye._api.context import Context + +# Misc +# noinspection PyProtectedMember +from officialeye._api.future import Future, wait + +# Image-processing +# noinspection PyProtectedMember +from officialeye._api.image import IImage, Image + +# Mutators +# noinspection PyProtectedMember +from officialeye._api.mutator import IMutator, Mutator + +# noinspection PyProtectedMember +from officialeye._api.template.feature import IFeature + +# Interpretation-related imports +# noinspection PyProtectedMember +from officialeye._api.template.interpretation import IInterpretation, Interpretation + +# noinspection PyProtectedMember +from officialeye._api.template.interpretation_result import IInterpretationResult + +# noinspection PyProtectedMember +from officialeye._api.template.keypoint import IKeypoint + +# Matching-related imports +# noinspection PyProtectedMember +from officialeye._api.template.match import IMatch, Match + +# noinspection PyProtectedMember +from officialeye._api.template.matcher import IMatcher, Matcher + +# noinspection PyProtectedMember +from officialeye._api.template.matching_result import IMatchingResult + +# Regions, features and keypoints +# noinspection PyProtectedMember +from officialeye._api.template.region import IRegion, Region + +# noinspection PyProtectedMember +from officialeye._api.template.supervision_result import ISupervisionResult, SupervisionResult + +# Supervision-related imports +# noinspection PyProtectedMember +from officialeye._api.template.supervisor import ISupervisor, Supervisor + +# Template-related +# noinspection PyProtectedMember +from officialeye._api.template.template import ITemplate, Template diff --git a/src/officialeye/__version__.py b/src/officialeye/__version__.py index ead16bc..8496f1c 100644 --- a/src/officialeye/__version__.py +++ b/src/officialeye/__version__.py @@ -1,8 +1,17 @@ __title__ = "officialeye" __description__ = "An AI-powered generic document-analysis tool." __url__ = "https://officialeye.zerobone.net" -__version__ = "1.1.5" +__version__ = "1.2.0" __author__ = "Alexander Mayorov" __author_email__ = "zb@zerobone.net" __license__ = "GPL-3.0" -__copyright__ = "Copyright Alexander Mayorov" +__copyright__ = "Copyright 2024 Alexander Mayorov (zb@zerobone.net, zerobone.net)" +__github_url__ = "github.com/ZeroBone/OfficialEye" +__github_full_url__ = "https://github.com/ZeroBone/OfficialEye" +__ascii_logo__ = """ ____ _________ _ __ ______ + / __ \\/ __/ __(_)____(_)___ _/ / / ____/_ _____ + / / / / /_/ /_/ / ___/ / __ `/ / / __/ / / / / _ \\ +/ /_/ / __/ __/ / /__/ / /_/ / / / /___/ /_/ / __/ +\\____/_/ /_/ /_/\\___/_/\\__,_/_/ /_____/\\__, /\\___/ + /____/ +""" diff --git a/src/officialeye/_api/__init__.py b/src/officialeye/_api/__init__.py new file mode 100644 index 0000000..4a2a615 --- /dev/null +++ b/src/officialeye/_api/__init__.py @@ -0,0 +1,6 @@ +""" +This module contains everything that should be visible to the API user at the top level of imports, +i.e., all symbols that should be importable via `from officialeye import symbol` +""" + +# TODO: write docstrings for all public methods of the api diff --git a/src/officialeye/_api/config.py b/src/officialeye/_api/config.py new file mode 100644 index 0000000..cf4ebd1 --- /dev/null +++ b/src/officialeye/_api/config.py @@ -0,0 +1,102 @@ +""" +Module for abstracting out the ability to inject custom configurations specified using dictionaries. +The goal of this module is to provide a nice API for validated user-specified configurations +and safely retrieving information from there. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Callable + +from officialeye.error.errors.general import ErrInvalidKey + +if TYPE_CHECKING: + from officialeye.types import ConfigDict + + +class Config(ABC): + + def __init__(self, config_dict: ConfigDict, /): + self._config_dict = config_dict + + @abstractmethod + def _get_invalid_key_error(self, key: str, /): + raise NotImplementedError() + + def get(self, key: str, /, *, value_preprocessor: Callable[[str], any] | None = None, default=None): + + if key not in self._config_dict: + + if default is None: + raise self._get_invalid_key_error(key) + + return default + + _value = self._config_dict[key] + + # apply value preprocessor + if value_preprocessor is not None: + _value = value_preprocessor(_value) + + return _value + + +class MutatorConfig(Config): + + def __init__(self, config_dict: ConfigDict, mutator_id: str, /): + + super().__init__(config_dict) + + self._mutator_id = mutator_id + + def _get_invalid_key_error(self, key: str, /): + return ErrInvalidKey( + f"while reading configuration of the '{self._mutator_id}' mutator.", + f"Could not find a value for key '{key}'." + ) + + +class MatcherConfig(Config): + + def __init__(self, config_dict: ConfigDict, matcher_id: str, /): + + super().__init__(config_dict) + + self._matcher_id = matcher_id + + def _get_invalid_key_error(self, key: str, /): + return ErrInvalidKey( + f"while reading configuration of the '{self._matcher_id}' matcher.", + f"Could not find a value for key '{key}'." + ) + + +class SupervisorConfig(Config): + + def __init__(self, config_dict: ConfigDict, matcher_id: str, /): + + super().__init__(config_dict) + + self._matcher_id = matcher_id + + def _get_invalid_key_error(self, key: str, /): + return ErrInvalidKey( + f"while reading configuration of the '{self._matcher_id}' supervisor.", + f"Could not find a value for key '{key}'." + ) + + +class InterpretationConfig(Config): + + def __init__(self, config_dict: ConfigDict, interpretation_id: str, /): + + super().__init__(config_dict) + + self._interpretation_id = interpretation_id + + def _get_invalid_key_error(self, key: str, /): + return ErrInvalidKey( + f"while reading configuration of the '{self._interpretation_id}' interpretation.", + f"Could not find a value for key '{key}'." + ) diff --git a/src/officialeye/_api/context.py b/src/officialeye/_api/context.py new file mode 100644 index 0000000..02a3c6b --- /dev/null +++ b/src/officialeye/_api/context.py @@ -0,0 +1,151 @@ +""" +Module represeting the OfficialEye context. +""" + +from __future__ import annotations + +from concurrent.futures import Future as PythonFuture +from concurrent.futures import ProcessPoolExecutor +from types import TracebackType +from typing import TYPE_CHECKING, Dict + +from officialeye._api.future import Future +from officialeye._api.mutator import IMutator + +# noinspection PyProtectedMember +from officialeye._api_builtins.init import initialize_builtins + +# noinspection PyProtectedMember +from officialeye._internal.feedback.abstract import AbstractFeedbackInterface + +# noinspection PyProtectedMember +from officialeye._internal.feedback.dummy import DummyFeedbackInterface +from officialeye.error.errors.general import ErrInvalidIdentifier +from officialeye.error.errors.internal import ErrInvalidState +from officialeye.error.errors.template import ErrTemplateInvalidMutator + +if TYPE_CHECKING: + from officialeye.types import ConfigDict, InterpretationFactory, MatcherFactory, MutatorFactory, SupervisorFactory + + +class Context: + + def __init__(self, /, *, afi: AbstractFeedbackInterface | None = None): + self._entered: bool = False + self._disposed: bool = False + + if afi is None: + self._afi = DummyFeedbackInterface() + else: + self._afi = afi + + self._executor = ProcessPoolExecutor() + + self._mutator_factories: Dict[str, MutatorFactory] = {} + self._matcher_factories: Dict[str, MatcherFactory] = {} + self._supervisor_factories: Dict[str, SupervisorFactory] = {} + self._interpretation_factories: Dict[str, InterpretationFactory] = {} + + # initialize with built-in mutators + initialize_builtins(self) + + def _get_afi(self) -> AbstractFeedbackInterface: + return self._afi + + def _submit_task(self, task, description: str, *args, **kwargs) -> Future: + + afi_fork = self._afi.fork(description) + + python_future: PythonFuture = self._executor.submit( + task, + *args, + **kwargs, + # Arguments that need to be always passed to the internal implementation when starting tasks. + # It is very important that the argument dictionary is picklable, because it will be passed from the parent + # process to a child process by the ProcessPoolExecutor. + afi=afi_fork, + mutator_factories=self._mutator_factories, + matcher_factories=self._matcher_factories, + supervisor_factories=self._supervisor_factories, + interpretation_factories=self._interpretation_factories + ) + + return Future(self, python_future, afi_fork=afi_fork) + + def register_mutator(self, mutator_id: str, factory: MutatorFactory, /) -> None: + + if mutator_id in self._mutator_factories: + raise ErrInvalidIdentifier( + f"while adding the '{mutator_id}' mutator.", + "A mutator with the same id has already been registered." + ) + + self._mutator_factories[mutator_id] = factory + + def register_matcher(self, matcher_id: str, factory: MatcherFactory, /) -> None: + + if matcher_id in self._matcher_factories: + raise ErrInvalidIdentifier( + f"while adding the '{matcher_id}' matcher.", + "A matcher with the same id has already been registered." + ) + + self._matcher_factories[matcher_id] = factory + + def register_supervisor(self, supervisor_id: str, factory: SupervisorFactory, /) -> None: + + if supervisor_id in self._matcher_factories: + raise ErrInvalidIdentifier( + f"while adding the '{supervisor_id}' matcher.", + "A supervisor with the same id has already been registered." + ) + + self._supervisor_factories[supervisor_id] = factory + + def register_interpretation(self, interpretation_id: str, factory: InterpretationFactory, /) -> None: + + if interpretation_id in self._interpretation_factories: + raise ErrInvalidIdentifier( + f"while adding the '{interpretation_id}' interpretation.", + "An interpretation with the same id has already been registered." + ) + + self._interpretation_factories[interpretation_id] = factory + + def get_mutator(self, mutator_id: str, config: ConfigDict, /) -> IMutator: + + if mutator_id not in self._mutator_factories: + raise ErrTemplateInvalidMutator( + f"while looking for a factory generating mutator '{mutator_id}'.", + "A mutator with this id has not been registered." + ) + + return self._mutator_factories[mutator_id](config) + + def __enter__(self): + + if self._entered: + raise ErrInvalidState( + "while entering the api context.", + "The context has already been entered, which is illegal state." + ) + + self._entered = True + return self + + def __exit__(self, exception_type: any, exception_value: BaseException | None, traceback: TracebackType | None): + assert self._entered + self._entered = False + + if self._disposed: + raise ErrInvalidState( + "while leaving the api context.", + "The resources have already been disposed." + ) + + self.dispose(exception_type, exception_value, traceback) + + def dispose(self, exception_type: any = None, exception_value: BaseException | None = None, traceback: TracebackType | None = None) -> None: + self._afi.dispose(exception_type, exception_value, traceback) + self._executor.shutdown(wait=True) + self._disposed = True diff --git a/src/officialeye/_api/detection.py b/src/officialeye/_api/detection.py new file mode 100644 index 0000000..b4d000a --- /dev/null +++ b/src/officialeye/_api/detection.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from concurrent.futures import ALL_COMPLETED +from typing import TYPE_CHECKING, List + +from officialeye._api.future import Future, wait +from officialeye._api.template.supervision_result import ISupervisionResult + +# noinspection PyProtectedMember +from officialeye._internal.feedback.verbosity import Verbosity + +# noinspection PyProtectedMember +from officialeye._internal.template.external_supervision_result import ExternalSupervisionResult +from officialeye.error.error import OEError +from officialeye.error.errors.internal import ErrInternal +from officialeye.error.errors.supervision import ErrSupervisionCorrespondenceNotFound + +if TYPE_CHECKING: + from officialeye._api.context import Context + from officialeye._api.image import IImage + from officialeye._api.template.template_interface import ITemplate + + +def detect(context: Context, *templates: ITemplate, target: IImage) -> ISupervisionResult: + + futures: List[Future] = [ + template.detect_async(target=target) for template in templates + ] + + done, not_done = wait(futures, return_when=ALL_COMPLETED) + + if len(not_done) > 0: + # noinspection PyProtectedMember + context._get_afi().warn(Verbosity.DEBUG, "Some template analysis futures were not completed.") + + regular_errors: List[OEError] = [] + + best_result: ISupervisionResult | None = None + best_result_score: float = -1.0 + + for completed_future in done: + + assert isinstance(completed_future, Future) + + if completed_future.cancelled(): + # noinspection PyProtectedMember + context._get_afi().warn(Verbosity.DEBUG, "A template analysis future was cancelled.") + continue + + error = completed_future.exception() + + if error is not None: + # there has been an error during the execution of the future + # the cause of the error might have been critical, in which case we should raise an exception immediately, + # or it might be, for example, due to one of the templates not matching the image at all, which is regular behavior + # therefore, we need to distinguish between a critical error and a regular one + + if not isinstance(error, OEError): + err = ErrInternal( + "while analyzing target image against multiple templates.", + "One of the individual analysis workers has crashed due to an external error." + ) + err.add_external_cause(error) + raise err + + assert isinstance(error, OEError) + + if error.is_regular: + # noinspection PyProtectedMember + context._get_afi().warn( + Verbosity.DEBUG, + f"A template analysis worker has returned a regular error {error.code} ({error.code_text})." + ) + regular_errors.append(error) + continue + + # we are dealing with a non-regular OfficialEye error + raise error + + result = completed_future.result() + assert result is not None + assert isinstance(result, ExternalSupervisionResult) + + # noinspection PyProtectedMember + context._get_afi().info(Verbosity.DEBUG, f"Template analysis worker yielded a result with score {result.score}.") + + if result.score > best_result_score: + best_result_score = result.score + best_result = result + + if best_result is None: + error = ErrSupervisionCorrespondenceNotFound( + "while processing the target image analysis results.", + "Could not establish correspondence of the image with any of the templates provided." + ) + + for worker_error in regular_errors: + error.add_cause(worker_error) + + raise error + + return best_result diff --git a/src/officialeye/_api/future.py b/src/officialeye/_api/future.py new file mode 100644 index 0000000..4a35ac8 --- /dev/null +++ b/src/officialeye/_api/future.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from concurrent.futures import ALL_COMPLETED +from concurrent.futures import Future as PythonFuture +from concurrent.futures import wait as python_wait +from typing import TYPE_CHECKING, Any, Dict, Iterable, Set, Tuple + +# noinspection PyProtectedMember +from officialeye._internal.api_implementation import IApiInterfaceImplementation + +# noinspection PyProtectedMember +from officialeye._internal.feedback.abstract import AbstractFeedbackInterface + +if TYPE_CHECKING: + from officialeye._api.context import Context + + +class Future: + + def __init__(self, context: Context, python_future: PythonFuture, /, *, afi_fork: AbstractFeedbackInterface): + self._context = context + self._future = python_future + self._afi_fork = afi_fork + + self._afi_joined = False + + def cancel(self) -> bool: + """ + Attempt to cancel the call. + If the call is currently being executed and cannot be canceled, then the method will return False, + otherwise the call will be canceled, and the method will return True. + """ + return self._future.cancel() + + def cancelled(self) -> bool: + """ Return True if the call was successfully canceled. """ + return self._future.cancelled() + + def running(self) -> bool: + """ Return True if the call is currently being executed and cannot be canceled. """ + return self._future.running() + + def done(self) -> bool: + """ Return True if the call was successfully canceled or finished running. """ + return self._future.done() + + def _afi_join(self): + if not self._afi_joined: + self._afi_joined = True + # noinspection PyProtectedMember + self._context._get_afi().join(self._afi_fork, self._future) + + def result(self, timeout: float | None = None) -> Any: + """ + Return the value returned by the call. If the call hasn’t yet completed, then this method will wait up to timeout seconds. + If the call hasn’t completed in timeout seconds, then a TimeoutError will be raised. Timeout can be an int or float. + If timeout is not specified or None, there is no limit to the wait time. + + If the future is canceled before completing, then CancelledError will be raised. + + If the call raised, this method will raise the same exception. + """ + + result = self._future.result(timeout=timeout) + + assert isinstance(result, IApiInterfaceImplementation), \ + "Every call to an internal API function should return a proper public API interface implementation" + + result.set_api_context(self._context) + + self._afi_join() + + return result + + def exception(self, timeout: float | None = None) -> Any: + """ + Return the exception raised by the call. + If the call hasn’t yet completed, then this method will wait up to timeout seconds. + If the call hasn’t completed in timeout seconds, then a TimeoutError will be raised. + Timeout can be an int or float. + If timeout is not specified or None, there is no limit to the wait time. + + If the future is canceled before completing, then CancelledError will be raised. + + If the call completed without raising, None is returned. + """ + + err = self._future.exception(timeout=timeout) + + self._afi_join() + + return err + + +def wait(futures: Iterable[Future], /, *, timeout: float | None = None, return_when=ALL_COMPLETED) -> Tuple[Set[Future], Set[Future]]: + + # noinspection PyProtectedMember + python_futures = (f._future for f in futures) + + original_futures: Dict[PythonFuture, Future] = {} + + for future in futures: + # noinspection PyProtectedMember + original_futures[future._future] = future + + done, not_done = python_wait(python_futures, timeout=timeout, return_when=return_when) + + corresponding_done = set((original_futures[d] for d in done)) + corresponding_not_done = set((original_futures[d] for d in not_done)) + + return corresponding_done, corresponding_not_done diff --git a/src/officialeye/_api/image.py b/src/officialeye/_api/image.py new file mode 100644 index 0000000..7b01cd6 --- /dev/null +++ b/src/officialeye/_api/image.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List + +import cv2 +import numpy as np + +from officialeye.error.errors.io import ErrIOInvalidPath + +if TYPE_CHECKING: + from officialeye._api.context import Context + from officialeye._api.mutator import IMutator + + +class IImage(ABC): + + def __init__(self): + super().__init__() + + @abstractmethod + def load(self) -> np.ndarray: + raise NotImplementedError() + + @abstractmethod + def apply_mutators(self, *mutators: IMutator): + raise NotImplementedError() + + +class Image(IImage): + + def __init__(self, context: Context, /, *, path: str): + super().__init__() + + self._context = context + self._mutators: List[IMutator] = [] + self._path = path + + def load(self) -> np.ndarray: + + if not os.path.isfile(self._path): + raise ErrIOInvalidPath( + f"while loading image located at '{self._path}'.", + "This path does not refer to a file." + ) + + if not os.access(self._path, os.R_OK): + raise ErrIOInvalidPath( + f"while loading image located at '{self._path}'.", + "The file at this path is not readable." + ) + + img = cv2.imread(self._path, cv2.IMREAD_COLOR) + + for mutator in self._mutators: + img = mutator.mutate(img) + + return img + + def apply_mutators(self, *mutators: IMutator): + self._mutators += mutators diff --git a/src/officialeye/_api/mutator.py b/src/officialeye/_api/mutator.py new file mode 100644 index 0000000..238e27b --- /dev/null +++ b/src/officialeye/_api/mutator.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import numpy as np + +from officialeye._api.config import MutatorConfig + +if TYPE_CHECKING: + from officialeye.types import ConfigDict + + +class IMutator(ABC): + + @property + def config(self) -> MutatorConfig: + raise NotImplementedError() + + @abstractmethod + def mutate(self, img: np.ndarray, /) -> np.ndarray: + raise NotImplementedError() + + +class Mutator(IMutator, ABC): + + def __init__(self, mutator_id: str, config_dict: ConfigDict, /): + super().__init__() + + self.mutator_id = mutator_id + + self._config = MutatorConfig(config_dict, mutator_id) + + @property + def config(self) -> MutatorConfig: + return self._config diff --git a/src/officialeye/_api/template/__init__.py b/src/officialeye/_api/template/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/officialeye/_api/template/feature.py b/src/officialeye/_api/template/feature.py new file mode 100644 index 0000000..020d709 --- /dev/null +++ b/src/officialeye/_api/template/feature.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Iterable + +import numpy as np + +from officialeye._api.template.region import IRegion + +if TYPE_CHECKING: + from officialeye._api.mutator import IMutator + + +class IFeature(IRegion, ABC): + + def __str__(self) -> str: + return f"Feature '{self.identifier}'" + + @abstractmethod + def get_mutators(self) -> Iterable[IMutator]: + """ + Returns: + A list of mutators from the feature class of the feature, in the order in which they are to be applied. + """ + raise NotImplementedError() + + def apply_mutators_to_image(self, img: np.ndarray, /) -> np.ndarray: + """ + Takes an image and applies the mutators defined in the corresponding feature class. + + Arguments: + img: The image that should be transformed. + + Returns: + The resulting image. + """ + + for mutator in self.get_mutators(): + img = mutator.mutate(img) + + return img diff --git a/src/officialeye/_api/template/interpretation.py b/src/officialeye/_api/template/interpretation.py new file mode 100644 index 0000000..9b762e9 --- /dev/null +++ b/src/officialeye/_api/template/interpretation.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import numpy as np + +from officialeye._api.config import InterpretationConfig +from officialeye._api.template.feature import IFeature + +if TYPE_CHECKING: + from officialeye.types import ConfigDict, FeatureInterpretation + + +class IInterpretation(ABC): + + @property + def config(self) -> InterpretationConfig: + raise NotImplementedError() + + @abstractmethod + def interpret(self, feature_img: np.ndarray, feature: IFeature, /) -> FeatureInterpretation: + raise NotImplementedError() + + +class Interpretation(IInterpretation, ABC): + + def __init__(self, interpretation_id: str, config_dict: ConfigDict, /): + super().__init__() + + self.interpretation_id = interpretation_id + + self._config = InterpretationConfig(config_dict, interpretation_id) + + @property + def config(self) -> InterpretationConfig: + return self._config diff --git a/src/officialeye/_api/template/interpretation_result.py b/src/officialeye/_api/template/interpretation_result.py new file mode 100644 index 0000000..84e4bb5 --- /dev/null +++ b/src/officialeye/_api/template/interpretation_result.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from officialeye._api.template.feature import IFeature + +if TYPE_CHECKING: + from officialeye._api.template.template_interface import ITemplate + from officialeye.types import FeatureInterpretation + + +class IInterpretationResult(ABC): + + @property + @abstractmethod + def template(self) -> ITemplate: + raise NotImplementedError() + + @abstractmethod + def get_feature_interpretation(self, feature: IFeature, /) -> FeatureInterpretation: + raise NotImplementedError() diff --git a/src/officialeye/_api/template/keypoint.py b/src/officialeye/_api/template/keypoint.py new file mode 100644 index 0000000..4e0cb80 --- /dev/null +++ b/src/officialeye/_api/template/keypoint.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + +from officialeye._api.template.region import IRegion + + +class IKeypoint(IRegion, ABC): + + @property + @abstractmethod + def matches_min(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def matches_max(self) -> int: + raise NotImplementedError() + + def __str__(self) -> str: + return f"Keypoint '{self.identifier}'" diff --git a/src/officialeye/_api/template/match.py b/src/officialeye/_api/template/match.py new file mode 100644 index 0000000..99d899b --- /dev/null +++ b/src/officialeye/_api/template/match.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import numpy as np + +from officialeye._api.template.keypoint import IKeypoint + +if TYPE_CHECKING: + from officialeye._api.template.template import ITemplate + + +class IMatch(ABC): + + @property + @abstractmethod + def template(self) -> ITemplate: + raise NotImplementedError() + + @property + @abstractmethod + def keypoint(self) -> IKeypoint: + raise NotImplementedError() + + @property + @abstractmethod + def template_point(self) -> np.ndarray: + raise NotImplementedError() + + @property + @abstractmethod + def target_point(self) -> np.ndarray: + raise NotImplementedError() + + @abstractmethod + def get_score(self) -> float: + raise NotImplementedError() + + def get_original_template_point(self) -> np.ndarray: + """Returns the coordinates of the point lying in the keypoint, in the coordinate system of the underlying template.""" + return self.template_point + self.keypoint.top_left + + def __str__(self) -> str: + return (f"Match: Point ({self.target_point[0]}, {self.target_point[1]}) matches ({self.template_point[0]}, {self.template_point[1]}) " + f"in {self.keypoint} of {self.template}.") + + def __eq__(self, o: Any) -> bool: + + if not isinstance(o, IMatch): + return False + + if self.template != o.template: + return False + + if self.keypoint != o.keypoint: + return False + + return (np.array_equal(self.template_point, o.template_point) + and np.array_equal(self.target_point, o.target_point)) + + def __lt__(self, o: Any) -> bool: + assert isinstance(o, Match) + return self.get_score() < o.get_score() + + def __hash__(self): + return hash(( + self.template.identifier, + self.keypoint.identifier, + np.dot(self.template_point, self.template_point), + np.dot(self.target_point, self.target_point) + )) + + +class Match(IMatch): + + def __init__(self, template: ITemplate, keypoint: IKeypoint, /, *, + region_point: np.ndarray, target_point: np.ndarray, score: float = 0.0): + super().__init__() + + self._template = template + self._keypoint = keypoint + + self._region_point = region_point + self._target_point = target_point + + self._score = score + + def get_score(self) -> float: + return self._score + + def set_score(self, new_score: float, /): + self._score = new_score + + @property + def template(self) -> ITemplate: + return self._template + + @property + def keypoint(self) -> IKeypoint: + return self._keypoint + + @property + def template_point(self) -> np.ndarray: + return self._region_point.copy() + + @property + def target_point(self) -> np.ndarray: + return self._target_point.copy() diff --git a/src/officialeye/_api/template/matcher.py b/src/officialeye/_api/template/matcher.py new file mode 100644 index 0000000..f08c92b --- /dev/null +++ b/src/officialeye/_api/template/matcher.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Iterable + +from officialeye._api.config import MatcherConfig +from officialeye._api.template.keypoint import IKeypoint +from officialeye._api.template.match import IMatch + +if TYPE_CHECKING: + from officialeye._api.template.template import ITemplate + from officialeye.types import ConfigDict + + +class IMatcher(ABC): + + @property + @abstractmethod + def config(self) -> MatcherConfig: + raise NotImplementedError() + + @abstractmethod + def setup(self, template: ITemplate, /) -> None: + raise NotImplementedError() + + @abstractmethod + def match(self, keypoint: IKeypoint, /) -> None: + raise NotImplementedError() + + @abstractmethod + def get_matches_for_keypoint(self, keypoint: IKeypoint, /) -> Iterable[IMatch]: + raise NotImplementedError() + + +class Matcher(IMatcher, ABC): + + def __init__(self, matcher_id: str, config_dict: ConfigDict, /): + super().__init__() + + self.matcher_id = matcher_id + + self._config = MatcherConfig(config_dict, matcher_id) + + @property + def config(self) -> MatcherConfig: + return self._config diff --git a/src/officialeye/_api/template/matching_result.py b/src/officialeye/_api/template/matching_result.py new file mode 100644 index 0000000..1570bee --- /dev/null +++ b/src/officialeye/_api/template/matching_result.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Iterable + +from officialeye._api.template.match import IMatch + +if TYPE_CHECKING: + from officialeye._api.template.template_interface import ITemplate + + +class IMatchingResult(ABC): + + @property + @abstractmethod + def template(self) -> ITemplate: + raise NotImplementedError() + + @abstractmethod + def get_all_matches(self) -> Iterable[IMatch]: + raise NotImplementedError() + + @abstractmethod + def get_total_match_count(self) -> int: + raise NotImplementedError() + + @abstractmethod + def get_matches_for_keypoint(self, keypoint_id: str, /) -> Iterable[IMatch]: + raise NotImplementedError() diff --git a/src/officialeye/_api/template/region.py b/src/officialeye/_api/template/region.py new file mode 100644 index 0000000..04adca3 --- /dev/null +++ b/src/officialeye/_api/template/region.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api_builtins.mutator.crop import CropMutator + +if TYPE_CHECKING: + from officialeye._api.image import IImage + from officialeye._api.template.template import ITemplate + + +class IRegion(ABC): + + @property + @abstractmethod + def identifier(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def x(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def y(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def w(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def h(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def template(self) -> ITemplate: + raise NotImplementedError() + + @property + def top_left(self) -> np.ndarray: + return np.array([self.x, self.y]) + + @property + def top_right(self) -> np.ndarray: + return np.array([self.x + self.w, self.y]) + + @property + def bottom_left(self) -> np.ndarray: + return np.array([self.x, self.y + self.h]) + + @property + def bottom_right(self) -> np.ndarray: + return np.array([self.x + self.w, self.y + self.h]) + + def get_image(self) -> IImage: + _mutator = CropMutator(dict(x=self.x, y=self.y, w=self.w, h=self.h)) + _img: IImage = self.template.get_mutated_image() + _img.apply_mutators(_mutator) + return _img + + def insert_into_image(self, target: np.ndarray, transformed_version: np.ndarray = None): + + assert target.shape[0] == self.template.height + assert target.shape[1] == self.template.width + + if transformed_version is None: + transformed_version = self.get_image().load() + + target[self.y: self.y + self.h, self.x: self.x + self.w] = transformed_version + + def __str__(self) -> str: + return f"Region '{self.identifier}'" + + def __eq__(self, o: Any) -> bool: + + if not isinstance(o, Region): + return False + + return self.template.identifier == o.template.identifier and self.identifier == o.identifier + + def __hash__(self): + return hash((self.template, self.identifier)) + + +class Region(IRegion): + + def __init__(self, template: ITemplate, /, *, identifier: str, x: int, y: int, w: int, h: int): + self._template = template + self._identifier = identifier + self._x = x + self._y = y + self._w = w + self._h = h + + @property + def identifier(self) -> str: + return self._identifier + + @property + def x(self) -> int: + return self._x + + @property + def y(self) -> int: + return self._y + + @property + def w(self) -> int: + return self._w + + @property + def h(self) -> int: + return self._h + + @property + def template(self) -> ITemplate: + return self._template diff --git a/src/officialeye/_api/template/supervision_result.py b/src/officialeye/_api/template/supervision_result.py new file mode 100644 index 0000000..58878a0 --- /dev/null +++ b/src/officialeye/_api/template/supervision_result.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import sys +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Dict + +import cv2 +import numpy as np + +from officialeye._api.template.feature import IFeature +from officialeye._api.template.interpretation_result import IInterpretationResult +from officialeye._api.template.match import IMatch +from officialeye.error.errors.general import ErrObjectNotInitialized + +if TYPE_CHECKING: + from officialeye._api.future import Future + from officialeye._api.image import IImage + from officialeye._api.template.matching_result import IMatchingResult + from officialeye._api.template.template_interface import ITemplate + + +class ISupervisionResult(ABC): + + @property + @abstractmethod + def template(self) -> ITemplate: + raise NotImplementedError() + + @property + @abstractmethod + def matching_result(self) -> IMatchingResult: + raise NotImplementedError() + + @property + @abstractmethod + def score(self) -> float: + raise NotImplementedError() + + @property + @abstractmethod + def delta(self) -> np.ndarray: + raise NotImplementedError() + + @property + @abstractmethod + def delta_prime(self) -> np.ndarray: + raise NotImplementedError() + + @property + @abstractmethod + def transformation_matrix(self) -> np.ndarray: + raise NotImplementedError() + + def translate(self, template_point: np.ndarray, /) -> np.ndarray: + """ + Translates the given template point into a target point. That is, given a position in the template's coordinate system, this function + outputs the corresponding position in the target image's coordinate system, according to the affine transformation model. + """ + assert template_point.shape == (2,) + return self.transformation_matrix @ (template_point - self.delta) + self.delta_prime + + @abstractmethod + def get_match_weight(self, match: IMatch, /) -> float: + raise NotImplementedError() + + @abstractmethod + def interpret_async(self, /, *, target: IImage) -> Future: + raise NotImplementedError() + + @abstractmethod + def interpret(self, /, **kwargs) -> IInterpretationResult: + raise NotImplementedError() + + def get_weighted_mse(self, /) -> float: + + error = 0.0 + singificant_match_count = 0 + + for match in self.matching_result.get_all_matches(): + + match_weight = self.get_match_weight(match) + + if match_weight < sys.float_info.epsilon: + continue + + singificant_match_count += 1 + + s = match.get_original_template_point() + + # calculate prediction + p = self.translate(s) + + # calculate destination + d = match.target_point + + current_error = p - d + current_error_value = np.dot(current_error, current_error) + + error += current_error_value * match_weight + + return error / singificant_match_count + + def warp_feature(self, feature: IFeature, target: np.ndarray, /) -> np.ndarray: + + target_tl = self.translate(feature.top_left) + target_tr = self.translate(feature.top_right) + target_bl = self.translate(feature.bottom_left) + target_br = self.translate(feature.bottom_right) + + dest_tl = np.array([0, 0], dtype=np.float64) + dest_tr = np.array([feature.w, 0], dtype=np.float64) + dest_br = np.array([feature.w, feature.h], dtype=np.float64) + dest_bl = np.array([0, feature.h], dtype=np.float64) + + source_points = [target_tl, target_tr, target_br, target_bl] + destination_points = [dest_tl, dest_tr, dest_br, dest_bl] + + homography = cv2.getPerspectiveTransform(np.float32(source_points), np.float32(destination_points)) + + return cv2.warpPerspective( + target, + np.float32(homography), + (feature.w, feature.h), + flags=cv2.INTER_LINEAR + ) + + +class SupervisionResult: + + def __init__(self, /, **kwargs): + # offset in the template's coordinates + self._delta: np.ndarray | None = None + # offset in the target image's coordinates + self._delta_prime: np.ndarray | None = None + + self._transformation_matrix: np.ndarray | None = None + + # keys: matches + # values: weights assigned by the supervision engine to each match (assigning is optional) + # the higher the weight, the more we trust the correctness of the match and the greater its individual impact should be. + # by default, the weight is 1. + self._match_weights: Dict[IMatch, float] = {} + + # an optional value the supervision engine can set, representing how confident the engine is in the result + self._score = 0.0 + + self.set(**kwargs) + + def set(self, /, *, delta: np.ndarray | None = None, delta_prime: np.ndarray | None = None, + transformation_matrix: np.ndarray | None = None, score: float | None = None): + + if delta is not None: + assert delta.shape == (2,) + self._delta = delta + + if delta_prime is not None: + assert delta_prime.shape == (2,) + self._delta_prime = delta_prime + + if transformation_matrix is not None: + assert transformation_matrix.shape == (2, 2) + self._transformation_matrix = transformation_matrix + + if score is not None: + self._score = score + + def set_match_weight(self, match: IMatch, weight: float, /): + assert weight >= 0 + self._match_weights[match] = weight + + def get_score(self) -> float: + assert self._score >= 0.0 + return self._score + + @property + def delta(self) -> np.ndarray: + + if self._delta is None: + raise ErrObjectNotInitialized( + "while trying to access the 'delta' parameter of the supervision result instance.", + "This parameter has not been set." + ) + + return self._delta.copy() + + @property + def delta_prime(self) -> np.ndarray: + + if self._delta_prime is None: + raise ErrObjectNotInitialized( + "while trying to access the 'delta_prime' parameter of the supervision result instance.", + "This parameter has not been set." + ) + + return self._delta_prime.copy() + + @property + def transformation_matrix(self) -> np.ndarray: + + if self._transformation_matrix is None: + raise ErrObjectNotInitialized( + "while trying to access the 'transformation_matrix' parameter of the supervision result instance.", + "This parameter has not been set." + ) + + return self._transformation_matrix.copy() diff --git a/src/officialeye/_api/template/supervisor.py b/src/officialeye/_api/template/supervisor.py new file mode 100644 index 0000000..d4ebe9a --- /dev/null +++ b/src/officialeye/_api/template/supervisor.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Iterable + +from officialeye._api.config import SupervisorConfig +from officialeye._api.template.matching_result import IMatchingResult +from officialeye._api.template.supervision_result import SupervisionResult +from officialeye._api.template.template_interface import ITemplate + +if TYPE_CHECKING: + from officialeye.types import ConfigDict + + +class ISupervisor(ABC): + + @property + @abstractmethod + def config(self) -> SupervisorConfig: + raise NotImplementedError() + + @abstractmethod + def setup(self, template: ITemplate, matching_result: IMatchingResult, /) -> None: + raise NotImplementedError() + + @abstractmethod + def supervise(self, template: ITemplate, matching_result: IMatchingResult, /) -> Iterable[SupervisionResult]: + raise NotImplementedError() + + +class Supervisor(ISupervisor, ABC): + + def __init__(self, supervisor_id: str, config_dict: ConfigDict, /): + super().__init__() + + self._supervisor_id = supervisor_id + self._config = SupervisorConfig(config_dict, supervisor_id) + + @property + def config(self) -> SupervisorConfig: + return self._config diff --git a/src/officialeye/_api/template/template.py b/src/officialeye/_api/template/template.py new file mode 100644 index 0000000..9aecdcf --- /dev/null +++ b/src/officialeye/_api/template/template.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from concurrent.futures import Future +from typing import TYPE_CHECKING, Iterable + +from officialeye._api.image import IImage +from officialeye._api.template.template_interface import ITemplate + +# noinspection PyProtectedMember +from officialeye._internal.api.load import template_load + +# noinspection PyProtectedMember +from officialeye._internal.template.external_template import ExternalTemplate + +if TYPE_CHECKING: + from officialeye._api.context import Context + from officialeye._api.template.feature import IFeature + from officialeye._api.template.keypoint import IKeypoint + from officialeye._api.template.supervision_result import ISupervisionResult + + +class Template(ITemplate): + + def __init__(self, context: Context, /, *, path: str): + super().__init__() + + self._context = context + self._path = path + + # None indicates that the template has not yet been loaded + self._external_template: ExternalTemplate | None = None + + def load(self) -> None: + """ + Loads the template into memory for further processing. + + If you prefer lazy-evaluation, do not call this method. + Instead, run the desired operations with the template, and the necessary resources will be loaded on-the-fly. + Use this method only if you really want to preload the template. + """ + + if self._external_template is not None: + # the template has already been loaded, nothing to do + return + + # noinspection PyProtectedMember + future = self._context._submit_task(template_load, "Loading template...", self._path) + + self._external_template = future.result() + + assert self._external_template is not None + assert isinstance(self._external_template, ExternalTemplate) + + def detect_async(self, /, *, target: IImage) -> Future: + self.load() + return self._external_template.detect_async(target=target) + + def detect(self, /, **kwargs) -> ISupervisionResult: + self.load() + return self._external_template.detect(**kwargs) + + def get_image(self) -> IImage: + self.load() + return self._external_template.get_image() + + def get_mutated_image(self) -> IImage: + self.load() + return self._external_template.get_mutated_image() + + @property + def identifier(self) -> str: + self.load() + return self._external_template.identifier + + @property + def name(self) -> str: + self.load() + return self._external_template.name + + @property + def width(self) -> int: + self.load() + return self._external_template.width + + @property + def height(self) -> int: + self.load() + return self._external_template.height + + @property + def keypoints(self) -> Iterable[IKeypoint]: + self.load() + return self._external_template.keypoints + + @property + def features(self) -> Iterable[IFeature]: + self.load() + return self._external_template.features + + def get_feature(self, feature_id: str, /) -> IFeature | None: + self.load() + return self._external_template.get_feature(feature_id) + + def get_keypoint(self, keypoint_id: str, /) -> IKeypoint | None: + self.load() + return self._external_template.get_keypoint(keypoint_id) diff --git a/src/officialeye/_api/template/template_interface.py b/src/officialeye/_api/template/template_interface.py new file mode 100644 index 0000000..320c0ea --- /dev/null +++ b/src/officialeye/_api/template/template_interface.py @@ -0,0 +1,93 @@ +from abc import ABC, abstractmethod +from typing import Any, Iterable + +from officialeye._api.future import Future +from officialeye._api.image import IImage +from officialeye._api.template.feature import IFeature +from officialeye._api.template.keypoint import IKeypoint +from officialeye._api.template.supervision_result import ISupervisionResult + + +class ITemplate(ABC): + + def __init__(self): + super().__init__() + + @abstractmethod + def load(self) -> None: + """ + Loads the template into memory for further processing. + + If you prefer lazy-evaluation, do not call this method. + Instead, run the desired operations with the template, and the necessary resources will be loaded on-the-fly. + Use this method only if you really want to preload the template. + """ + + raise NotImplementedError() + + @abstractmethod + def detect_async(self, /, *, target: IImage) -> Future: + raise NotImplementedError() + + @abstractmethod + def detect(self, /, **kwargs) -> ISupervisionResult: + raise NotImplementedError() + + @abstractmethod + def get_image(self) -> IImage: + raise NotImplementedError() + + @abstractmethod + def get_mutated_image(self) -> IImage: + raise NotImplementedError() + + @property + @abstractmethod + def identifier(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def name(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def width(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def height(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def keypoints(self) -> Iterable[IKeypoint]: + raise NotImplementedError() + + @property + @abstractmethod + def features(self) -> Iterable[IFeature]: + raise NotImplementedError() + + @abstractmethod + def get_feature(self, feature_id: str, /) -> IFeature | None: + raise NotImplementedError() + + @abstractmethod + def get_keypoint(self, keypoint_id: str, /) -> IKeypoint | None: + raise NotImplementedError() + + def __str__(self) -> str: + return f"Template '{self.identifier}'." + + def __eq__(self, o: Any) -> bool: + + if not isinstance(o, ITemplate): + return False + + return self.identifier == o.identifier + + def __hash__(self): + return hash(self.identifier) diff --git a/src/officialeye/_api_builtins/__init__.py b/src/officialeye/_api_builtins/__init__.py new file mode 100644 index 0000000..232dd0b --- /dev/null +++ b/src/officialeye/_api_builtins/__init__.py @@ -0,0 +1,3 @@ +""" +All algorithms built into the API by default. +""" \ No newline at end of file diff --git a/src/officialeye/_api_builtins/init.py b/src/officialeye/_api_builtins/init.py new file mode 100644 index 0000000..6b6b9f7 --- /dev/null +++ b/src/officialeye/_api_builtins/init.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +# noinspection PyProtectedMember +from officialeye._api.mutator import IMutator + +# noinspection PyProtectedMember +from officialeye._api.template.matcher import IMatcher +from officialeye._api_builtins.interpretation.file import FileInterpretation +from officialeye._api_builtins.interpretation.file_temp import FileTempInterpretation +from officialeye._api_builtins.interpretation.ocr_tesseract import TesseractInterpretation +from officialeye._api_builtins.matcher.sift_flann import SiftFlannMatcher +from officialeye._api_builtins.mutator.clahe import CLAHEMutator +from officialeye._api_builtins.mutator.grayscale import GrayscaleMutator +from officialeye._api_builtins.mutator.non_local_means_denoising import NonLocalMeansDenoisingMutator +from officialeye._api_builtins.mutator.rotate import RotateMutator +from officialeye._api_builtins.supervisor.combinatorial import CombinatorialSupervisor +from officialeye._api_builtins.supervisor.least_squares_regression import LeastSquaresRegressionSupervisor + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.context import Context + + # noinspection PyProtectedMember + from officialeye._api.template.interpretation import IInterpretation + + # noinspection PyProtectedMember + from officialeye._api.template.supervisor import ISupervisor + from officialeye.types import ConfigDict + + +""" +Mutator generators +""" + + +def _gen_mutator_grayscale(config: ConfigDict, /) -> IMutator: + return GrayscaleMutator(config) + + +def _gen_mutator_non_local_means_denoising(config: ConfigDict, /) -> IMutator: + return NonLocalMeansDenoisingMutator(config) + + +def _gen_mutator_clahe(config: ConfigDict, /) -> IMutator: + return CLAHEMutator(config) + + +def _gen_mutator_rotate(config: ConfigDict, /) -> IMutator: + return RotateMutator(config) + + +""" +Matcher generators +""" + + +def _gen_matcher_sift_flann(config: ConfigDict, /) -> IMatcher: + return SiftFlannMatcher(config) + + +""" +Supervisor generators +""" + + +def _gen_supervisor_combinatorial(config: ConfigDict, /) -> ISupervisor: + return CombinatorialSupervisor(config) + + +def _gen_supervisor_least_squares_regression(config: ConfigDict, /) -> ISupervisor: + return LeastSquaresRegressionSupervisor(config) + + +""" +Interpretation generators +""" + + +def _gen_interpretation_file(config: ConfigDict, /) -> IInterpretation: + return FileInterpretation(config) + + +def _gen_interpretation_file_temp(config: ConfigDict, /) -> IInterpretation: + return FileTempInterpretation(config) + + +def _gen_interpretation_ocr_tesseract(config: ConfigDict, /) -> IInterpretation: + return TesseractInterpretation(config) + + +def initialize_builtins(context: Context, /): + + # register mutators + context.register_mutator(GrayscaleMutator.MUTATOR_ID, _gen_mutator_grayscale) + context.register_mutator(NonLocalMeansDenoisingMutator.MUTATOR_ID, _gen_mutator_non_local_means_denoising) + context.register_mutator(CLAHEMutator.MUTATOR_ID, _gen_mutator_clahe) + context.register_mutator(RotateMutator.MUTATOR_ID, _gen_mutator_rotate) + + # register matchers + context.register_matcher(SiftFlannMatcher.MATCHER_ID, _gen_matcher_sift_flann) + + # register supervisors + context.register_supervisor(CombinatorialSupervisor.SUPERVISOR_ID, _gen_supervisor_combinatorial) + context.register_supervisor(LeastSquaresRegressionSupervisor.SUPERVISOR_ID, _gen_supervisor_combinatorial) + + # register interpretations + context.register_interpretation(FileInterpretation.INTERPRETATION_ID, _gen_interpretation_file) + context.register_interpretation(FileTempInterpretation.INTERPRETATION_ID, _gen_interpretation_file_temp) + context.register_interpretation(TesseractInterpretation.INTERPRETATION_ID, _gen_interpretation_ocr_tesseract) diff --git a/src/officialeye/_internal/interpretation/methods/__init__.py b/src/officialeye/_api_builtins/interpretation/__init__.py similarity index 100% rename from src/officialeye/_internal/interpretation/methods/__init__.py rename to src/officialeye/_api_builtins/interpretation/__init__.py diff --git a/src/officialeye/_api_builtins/interpretation/file.py b/src/officialeye/_api_builtins/interpretation/file.py new file mode 100644 index 0000000..98ad2af --- /dev/null +++ b/src/officialeye/_api_builtins/interpretation/file.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import cv2 +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.template.interpretation import Interpretation + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.template.feature import IFeature + from officialeye.types import ConfigDict, FeatureInterpretation + + +class FileInterpretation(Interpretation): + + INTERPRETATION_ID = "file" + + def __init__(self, config_dict: ConfigDict, /): + super().__init__(FileInterpretation.INTERPRETATION_ID, config_dict) + + self._path = self.config.get("path", value_preprocessor=str) + + def interpret(self, feature_img: np.ndarray, feature: IFeature, /) -> FeatureInterpretation: + + os.makedirs(os.path.dirname(self._path), exist_ok=True) + + cv2.imwrite(self._path, feature_img) + + return None diff --git a/src/officialeye/_api_builtins/interpretation/file_temp.py b/src/officialeye/_api_builtins/interpretation/file_temp.py new file mode 100644 index 0000000..9eb858a --- /dev/null +++ b/src/officialeye/_api_builtins/interpretation/file_temp.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import tempfile +from typing import TYPE_CHECKING + +import cv2 +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.template.interpretation import Interpretation + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.template.feature import IFeature + from officialeye.types import ConfigDict, FeatureInterpretation + + +class FileTempInterpretation(Interpretation): + + INTERPRETATION_ID = "file_temp" + + def __init__(self, config_dict: ConfigDict, /): + super().__init__(FileTempInterpretation.INTERPRETATION_ID, config_dict) + + self._format = self.config.get("format", default="png", value_preprocessor=str) + + def interpret(self, feature_img: np.ndarray, feature: IFeature, /) -> FeatureInterpretation: + + with tempfile.NamedTemporaryFile(prefix="officialeye_", suffix=f".{self._format}", delete=False) as fp: + fp.close() + + cv2.imwrite(fp.name, feature_img) + + return fp.name diff --git a/src/officialeye/_api_builtins/interpretation/ocr_tesseract.py b/src/officialeye/_api_builtins/interpretation/ocr_tesseract.py new file mode 100644 index 0000000..10786c7 --- /dev/null +++ b/src/officialeye/_api_builtins/interpretation/ocr_tesseract.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from pytesseract import pytesseract + +# noinspection PyProtectedMember +from officialeye._api.template.interpretation import Interpretation + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.template.feature import IFeature + from officialeye.types import ConfigDict, FeatureInterpretation + + +class TesseractInterpretation(Interpretation): + + INTERPRETATION_ID = "ocr_tesseract" + + def __init__(self, config_dict: ConfigDict, /): + super().__init__(TesseractInterpretation.INTERPRETATION_ID, config_dict) + + self._tesseract_lang = self.config.get("lang", default="eng", value_preprocessor=str) + self._tesseract_config = self.config.get("config", default="", value_preprocessor=str) + + def interpret(self, feature_img: np.ndarray, feature: IFeature, /) -> FeatureInterpretation: + return pytesseract.image_to_string(feature_img, lang=self._tesseract_lang, config=self._tesseract_config).strip() diff --git a/src/officialeye/_internal/matching/matchers/__init__.py b/src/officialeye/_api_builtins/matcher/__init__.py similarity index 100% rename from src/officialeye/_internal/matching/matchers/__init__.py rename to src/officialeye/_api_builtins/matcher/__init__.py diff --git a/src/officialeye/_api_builtins/matcher/sift_flann.py b/src/officialeye/_api_builtins/matcher/sift_flann.py new file mode 100644 index 0000000..386e8a8 --- /dev/null +++ b/src/officialeye/_api_builtins/matcher/sift_flann.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Iterable, List + +import cv2 +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.template.keypoint import IKeypoint + +# noinspection PyProtectedMember +from officialeye._api.template.match import IMatch, Match + +# noinspection PyProtectedMember +from officialeye._api.template.matcher import Matcher +from officialeye.error.errors.matching import ErrMatchingInvalidEngineConfig + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.template.template import ITemplate + from officialeye.types import ConfigDict + + +_FLANN_INDEX_KDTREE = 1 + + +def _preprocess_sensitivity(value: str, /) -> float: + value = float(value) + + if value < 0.0: + raise ErrMatchingInvalidEngineConfig( + f"while loading the '{SiftFlannMatcher.MATCHER_ID}' keypoint matcher", + f"The `sensitivity` value ({value}) cannot be negative." + ) + + if value > 1.0: + raise ErrMatchingInvalidEngineConfig( + f"while loading the '{SiftFlannMatcher.MATCHER_ID}' keypoint matcher", + f"The `sensitivity` value ({value}) cannot exceed 1.0." + ) + + return value + + +class SiftFlannMatcher(Matcher): + + MATCHER_ID = "sift_flann" + + def __init__(self, config_dict: ConfigDict, /): + super().__init__(SiftFlannMatcher.MATCHER_ID, config_dict) + + self._sensitivity = self.config.get("sensitivity", default=0.7, value_preprocessor=_preprocess_sensitivity) + + self._img: np.ndarray | None = None + self._sift = None + + self._keypoints_target = None + self._destination_target = None + self._template: ITemplate | None = None + self._matches: Dict[IKeypoint, List[Match]] | None = {} + + def setup(self, template: ITemplate, /) -> None: + + assert template is not None + + self._img = template.get_mutated_image().load() + + # initialize the SIFT engine in CV2 + # noinspection PyUnresolvedReferences + self._sift = cv2.SIFT_create() + + # pre-compute the sift keypoints in the target image + self._keypoints_target, self._destination_target = self._sift.detectAndCompute(self._img, None) + + self._template = template + + self._matches = {} + + def match(self, keypoint: IKeypoint, /) -> None: + + _original_pattern_image = keypoint.get_image().load() + + pattern = cv2.cvtColor(_original_pattern_image, cv2.COLOR_BGR2GRAY) + + keypoints_pattern, destination_pattern = self._sift.detectAndCompute(pattern, None) + + index_params = { + "algorithm": _FLANN_INDEX_KDTREE, + "trees": 5 + } + + search_params = { + "checks": 50 + } + + flann = cv2.FlannBasedMatcher(index_params, search_params) + matches = flann.knnMatch(destination_pattern, self._destination_target, k=2) + + # we need to draw only good matches, so create a mask + matches_mask = [[0, 0] for _ in range(len(matches))] + + result: List[Match] = [] + + # filter matches + for i, (m, n) in enumerate(matches): + + if m.distance >= self._sensitivity * n.distance: + continue + + matches_mask[i] = [1, 0] + + pattern_point = keypoints_pattern[m.queryIdx].pt + target_point = self._keypoints_target[m.trainIdx].pt + + # maybe one should consider rounding values here, instead of simply stripping the floating-point part + pattern_point = np.array(pattern_point, dtype=int) + target_point = np.array(target_point, dtype=int) + + match = Match(self._template, keypoint, region_point=pattern_point, target_point=target_point) + match.set_score(self._sensitivity * n.distance - m.distance) + + result.append(match) + + assert keypoint not in self._matches + self._matches[keypoint] = result + + def get_matches_for_keypoint(self, keypoint: IKeypoint, /) -> Iterable[IMatch]: + assert keypoint in self._matches + return self._matches[keypoint] diff --git a/src/officialeye/_internal/mutation/mutators/__init__.py b/src/officialeye/_api_builtins/mutator/__init__.py similarity index 100% rename from src/officialeye/_internal/mutation/mutators/__init__.py rename to src/officialeye/_api_builtins/mutator/__init__.py diff --git a/src/officialeye/_internal/mutation/mutators/clahe.py b/src/officialeye/_api_builtins/mutator/clahe.py similarity index 68% rename from src/officialeye/_internal/mutation/mutators/clahe.py rename to src/officialeye/_api_builtins/mutator/clahe.py index 561395c..4d7b621 100644 --- a/src/officialeye/_internal/mutation/mutators/clahe.py +++ b/src/officialeye/_api_builtins/mutator/clahe.py @@ -1,8 +1,15 @@ -from typing import Dict +from __future__ import annotations + +from typing import TYPE_CHECKING import cv2 +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.mutator import Mutator -from officialeye._internal.mutation.mutator import Mutator +if TYPE_CHECKING: + from officialeye.types import ConfigDict class CLAHEMutator(Mutator): @@ -12,10 +19,10 @@ class CLAHEMutator(Mutator): MUTATOR_ID = "clahe" - def __init__(self, config: Dict[str, any], /): + def __init__(self, config: ConfigDict, /): super().__init__(CLAHEMutator.MUTATOR_ID, config) - def mutate(self, img: cv2.Mat, /) -> cv2.Mat: + def mutate(self, img: np.ndarray, /) -> np.ndarray: lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) l_channel, a, b = cv2.split(lab) diff --git a/src/officialeye/_api_builtins/mutator/crop.py b/src/officialeye/_api_builtins/mutator/crop.py new file mode 100644 index 0000000..d244f9e --- /dev/null +++ b/src/officialeye/_api_builtins/mutator/crop.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.mutator import Mutator + +if TYPE_CHECKING: + from officialeye.types import ConfigDict + + +class CropMutator(Mutator): + + MUTATOR_ID = "crop" + + def __init__(self, config_dict: ConfigDict, /): + super().__init__(CropMutator.MUTATOR_ID, config_dict) + + self._x = self.config.get("x", default=0, value_preprocessor=int) + self._y = self.config.get("y", default=0, value_preprocessor=int) + self._w = self.config.get("w", value_preprocessor=int) + self._h = self.config.get("h", value_preprocessor=int) + + def mutate(self, img: np.ndarray, /) -> np.ndarray: + return img[self._y:self._y + self._h, self._x:self._x + self._w] diff --git a/src/officialeye/_api_builtins/mutator/grayscale.py b/src/officialeye/_api_builtins/mutator/grayscale.py new file mode 100644 index 0000000..b125778 --- /dev/null +++ b/src/officialeye/_api_builtins/mutator/grayscale.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import cv2 +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.mutator import Mutator + +if TYPE_CHECKING: + from officialeye.types import ConfigDict + + +class GrayscaleMutator(Mutator): + + MUTATOR_ID = "grayscale" + + def __init__(self, config: ConfigDict, /): + super().__init__(GrayscaleMutator.MUTATOR_ID, config) + + def mutate(self, img: np.ndarray, /) -> np.ndarray: + return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) diff --git a/src/officialeye/_internal/mutation/mutators/non_local_means_denoising.py b/src/officialeye/_api_builtins/mutator/non_local_means_denoising.py similarity index 63% rename from src/officialeye/_internal/mutation/mutators/non_local_means_denoising.py rename to src/officialeye/_api_builtins/mutator/non_local_means_denoising.py index 1949738..fea9ac4 100644 --- a/src/officialeye/_internal/mutation/mutators/non_local_means_denoising.py +++ b/src/officialeye/_api_builtins/mutator/non_local_means_denoising.py @@ -1,34 +1,32 @@ -from typing import Dict +from __future__ import annotations + +from typing import TYPE_CHECKING import cv2 +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.mutator import Mutator +from officialeye.error.errors.template import ErrTemplateInvalidMutator -from officialeye._internal.error.errors.template import ErrTemplateInvalidMutator -from officialeye._internal.mutation.mutator import Mutator +if TYPE_CHECKING: + from officialeye.types import ConfigDict class NonLocalMeansDenoisingMutator(Mutator): MUTATOR_ID = "non_local_means_denoising" - def __init__(self, config: Dict[str, any], /): + def __init__(self, config: ConfigDict, /): super().__init__(NonLocalMeansDenoisingMutator.MUTATOR_ID, config) - # setup configuration loading - - self.get_config().set_value_preprocessor("colored", bool) - self.get_config().set_value_preprocessor("h", int) - self.get_config().set_value_preprocessor("hForColorComponents", int) - self.get_config().set_value_preprocessor("templateWindowSize", int) - self.get_config().set_value_preprocessor("searchWindowSize", int) - # load data from configuration + self._colored_mode = self.config.get("colored", default=True, value_preprocessor=bool) - self._colored_mode = self.get_config().get("colored", default=True) - - self._conf_h = self.get_config().get("h", default=10) - self._conf_hForColorComponents = self.get_config().get("hForColorComponents", default=10) - self._conf_templateWindowSize = self.get_config().get("templateWindowSize", default=7) - self._conf_searchWindowSize = self.get_config().get("searchWindowSize", default=21) + self._conf_h = self.config.get("h", default=10, value_preprocessor=int) + self._conf_hForColorComponents = self.config.get("hForColorComponents", default=10, value_preprocessor=int) + self._conf_templateWindowSize = self.config.get("templateWindowSize", default=7, value_preprocessor=int) + self._conf_searchWindowSize = self.config.get("searchWindowSize", default=21, value_preprocessor=int) # validate templateWindowSize if self._conf_templateWindowSize < 1: @@ -56,7 +54,7 @@ def __init__(self, config: Dict[str, any], /): f"The 'searchWindowSize' parameter must be odd, got '{self._conf_searchWindowSize}'." ) - def mutate(self, img: cv2.Mat, /) -> cv2.Mat: + def mutate(self, img: np.ndarray, /) -> np.ndarray: if self._colored_mode: return cv2.fastNlMeansDenoisingColored( diff --git a/src/officialeye/_internal/mutation/mutators/rotate.py b/src/officialeye/_api_builtins/mutator/rotate.py similarity index 68% rename from src/officialeye/_internal/mutation/mutators/rotate.py rename to src/officialeye/_api_builtins/mutator/rotate.py index 01b4a4f..77ca1f2 100644 --- a/src/officialeye/_internal/mutation/mutators/rotate.py +++ b/src/officialeye/_api_builtins/mutator/rotate.py @@ -1,9 +1,16 @@ -from typing import Dict +from __future__ import annotations + +from typing import TYPE_CHECKING import cv2 +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.mutator import Mutator +from officialeye.error.errors.template import ErrTemplateInvalidMutator -from officialeye._internal.error.errors.template import ErrTemplateInvalidMutator -from officialeye._internal.mutation.mutator import Mutator +if TYPE_CHECKING: + from officialeye.types import ConfigDict class RotateMutator(Mutator): @@ -13,7 +20,7 @@ class RotateMutator(Mutator): MUTATOR_ID = "rotate" - def __init__(self, config: Dict[str, any], /): + def __init__(self, config: ConfigDict, /): super().__init__(RotateMutator.MUTATOR_ID, config) def _angle_preprocessor(angle_text: str) -> int: @@ -27,11 +34,9 @@ def _angle_preprocessor(angle_text: str) -> int: return angle - self.get_config().set_value_preprocessor("angle", _angle_preprocessor) - - self._angle = self.get_config().get("angle") + self._angle = self.config.get("angle", value_preprocessor=_angle_preprocessor) - def mutate(self, img: cv2.Mat, /) -> cv2.Mat: + def mutate(self, img: np.ndarray, /) -> np.ndarray: if self._angle == 0: # we do not need to rotate the image at all diff --git a/src/officialeye/_api_builtins/supervisor/__init__.py b/src/officialeye/_api_builtins/supervisor/__init__.py new file mode 100644 index 0000000..1f3f50e --- /dev/null +++ b/src/officialeye/_api_builtins/supervisor/__init__.py @@ -0,0 +1,3 @@ +""" +A collection of all supervisors built into OfficialEye. +""" \ No newline at end of file diff --git a/src/officialeye/_internal/supervision/supervisors/combinatorial.py b/src/officialeye/_api_builtins/supervisor/combinatorial.py similarity index 58% rename from src/officialeye/_internal/supervision/supervisors/combinatorial.py rename to src/officialeye/_api_builtins/supervisor/combinatorial.py index abbbaa2..dfabe99 100644 --- a/src/officialeye/_internal/supervision/supervisors/combinatorial.py +++ b/src/officialeye/_api_builtins/supervisor/combinatorial.py @@ -1,24 +1,43 @@ +from __future__ import annotations + import random -from typing import Dict, Generator +from typing import TYPE_CHECKING, Dict, Iterable, List import numpy as np import z3 -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.supervision import ErrSupervisionInvalidEngineConfig -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.matching.match import Match -from officialeye._internal.matching.result import MatchingResult -from officialeye._internal.supervision.result import SupervisionResult -from officialeye._internal.supervision.supervisor import Supervisor +# noinspection PyProtectedMember +from officialeye._api.template.match import IMatch + +# noinspection PyProtectedMember +from officialeye._api.template.matching_result import IMatchingResult + +# noinspection PyProtectedMember +from officialeye._api.template.supervision_result import SupervisionResult + +# noinspection PyProtectedMember +from officialeye._api.template.supervisor import Supervisor + +# noinspection PyProtectedMember +from officialeye._api.template.template_interface import ITemplate + +# noinspection PyProtectedMember +from officialeye._internal.context.singleton import get_internal_afi + +# noinspection PyProtectedMember +from officialeye._internal.feedback.verbosity import Verbosity +from officialeye.error.errors.supervision import ErrSupervisionInvalidEngineConfig + +if TYPE_CHECKING: + from officialeye.types import ConfigDict class CombinatorialSupervisor(Supervisor): - ENGINE_ID = "combinatorial" + SUPERVISOR_ID = "combinatorial" - def __init__(self, context: Context, template_id: str, kmr: MatchingResult, /): - super().__init__(context, CombinatorialSupervisor.ENGINE_ID, template_id, kmr) + def __init__(self, config_dict: ConfigDict, /): + super().__init__(CombinatorialSupervisor.SUPERVISOR_ID, config_dict) # setup configuration def _min_match_factor_preprocessor(v: str) -> float: @@ -27,19 +46,19 @@ def _min_match_factor_preprocessor(v: str) -> float: if v > 1.0: raise ErrSupervisionInvalidEngineConfig( - f"while loading the '{CombinatorialSupervisor.ENGINE_ID}' supervisor", + f"while loading the '{CombinatorialSupervisor.SUPERVISOR_ID}' supervisor", f"The `min_match_factor` value ({v}) cannot exceed 1.0." ) if v < 0.0: raise ErrSupervisionInvalidEngineConfig( - f"while loading the '{CombinatorialSupervisor.ENGINE_ID}' supervisor", + f"while loading the '{CombinatorialSupervisor.SUPERVISOR_ID}' supervisor", f"The `min_match_factor` value ({v}) cannot be negative." ) return v - self.get_config().set_value_preprocessor("min_match_factor", _min_match_factor_preprocessor) + self._min_match_factor = self.config.get("min_match_factor", default=0.1, value_preprocessor=_min_match_factor_preprocessor) def _max_transformation_error_preprocessor(v: str) -> int: @@ -47,19 +66,19 @@ def _max_transformation_error_preprocessor(v: str) -> int: if v < 0: raise ErrSupervisionInvalidEngineConfig( - f"while loading the '{CombinatorialSupervisor.ENGINE_ID}' supervisor.", + f"while loading the '{CombinatorialSupervisor.SUPERVISOR_ID}' supervisor.", f"The `max_transformation_error` value ({v}) cannot be negative." ) if v > 5000: raise ErrSupervisionInvalidEngineConfig( - f"while loading the '{CombinatorialSupervisor.ENGINE_ID}' supervisor.", + f"while loading the '{CombinatorialSupervisor.SUPERVISOR_ID}' supervisor.", f"The `max_transformation_error` value ({v}) is too high." ) return v - self.get_config().set_value_preprocessor("max_transformation_error", _max_transformation_error_preprocessor) + self._max_transformation_error = self.config.get("max_transformation_error", value_preprocessor=_max_transformation_error_preprocessor) def _z3_timeout_preprocessor(v: str) -> int: @@ -67,38 +86,46 @@ def _z3_timeout_preprocessor(v: str) -> int: if v < 1: raise ErrSupervisionInvalidEngineConfig( - f"while loading the '{CombinatorialSupervisor.ENGINE_ID}' supervisor.", + f"while loading the '{CombinatorialSupervisor.SUPERVISOR_ID}' supervisor.", f"The `z3_timeout` value ({v}) cannot be negative or zero." ) return v - self.get_config().set_value_preprocessor("z3_timeout", _z3_timeout_preprocessor) + self._z3_timeout = self.config.get("z3_timeout", default=2500, value_preprocessor=_z3_timeout_preprocessor) # initialize all engine-specific values - self._z3_context = z3.Context() + self._z3_context: z3.Context | None = None # create variables for components of the translation matrix - self._transformation_matrix = np.array([ - [z3.Real("a", ctx=self._z3_context), z3.Real("b", ctx=self._z3_context)], - [z3.Real("c", ctx=self._z3_context), z3.Real("d", ctx=self._z3_context)] - ], dtype=z3.AstRef) + self._transformation_matrix: np.ndarray | None = None # keys: matches (instances of Match) # values: z3 integer variables representing the errors for each match, # i.e., how consistent the match is with the affine transformation model - self._match_weight: Dict[Match, z3.ArithRef] = {} + self._match_weight: Dict[IMatch, z3.ArithRef] = {} - for match in self._kmr.get_matches(): - self._match_weight[match] = z3.Real(f"w_{match.get_debug_identifier()}", ctx=self._z3_context) + self._minimum_weight_to_enforce: float | None = None - _config_min_match_factor = self.get_config().get("min_match_factor", default=0.1) - get_logger().debug(f"Min match factor: {_config_min_match_factor}") + def setup(self, template: ITemplate, matching_result: IMatchingResult, /) -> None: - self._minimum_weight_to_enforce = self._kmr.get_total_match_count() * _config_min_match_factor - self._max_transformation_error = self.get_config().get("max_transformation_error") + self._z3_context = z3.Context() - def _get_consistency_check(self, match: Match, delta: np.ndarray, delta_prime: np.ndarray, /) -> z3.AstRef: + self._transformation_matrix = np.array([ + [z3.Real("a", ctx=self._z3_context), z3.Real("b", ctx=self._z3_context)], + [z3.Real("c", ctx=self._z3_context), z3.Real("d", ctx=self._z3_context)] + ], dtype=z3.AstRef) + + for match in matching_result.get_all_matches(): + self._match_weight[match] = z3.Real( + f"w_{match.template_point[0]}_{match.template_point[1]}_{match.target_point[0]}_{match.target_point[1]}", + ctx=self._z3_context + ) + + # calculate the minimum weight that we need to enforce + self._minimum_weight_to_enforce = matching_result.get_total_match_count() * self._min_match_factor + + def _get_consistency_check(self, match: IMatch, delta: np.ndarray, delta_prime: np.ndarray, /) -> z3.AstRef: """ Generates a z3 formula asserting the consistency of the match with the affine linear transformation model. Consistency does not mean ideal matching of coordinates; rather, the template position with the affine @@ -115,7 +142,7 @@ def _get_consistency_check(self, match: Match, delta: np.ndarray, delta_prime: n translated_template_point = self._transformation_matrix @ (template_point - delta) + delta_prime translated_template_point_x, translated_template_point_y = translated_template_point - target_point_x, target_point_y = match.get_target_point() + target_point_x, target_point_y = match.target_point return z3.And( translated_template_point_x - target_point_x <= self._max_transformation_error, @@ -124,15 +151,15 @@ def _get_consistency_check(self, match: Match, delta: np.ndarray, delta_prime: n target_point_y - translated_template_point_y <= self._max_transformation_error, ) - def _run(self) -> Generator[SupervisionResult, None, None]: + def supervise(self, template: ITemplate, matching_result: IMatchingResult, /) -> Iterable[SupervisionResult]: - weights_lower_bounds = z3.And(*(self._match_weight[match] >= 0 for match in self._kmr.get_matches()), self._z3_context) - weights_upper_bounds = z3.And(*(self._match_weight[match] <= 1 for match in self._kmr.get_matches()), self._z3_context) + weights_lower_bounds = z3.And(*(self._match_weight[match] >= 0 for match in matching_result.get_all_matches()), self._z3_context) + weights_upper_bounds = z3.And(*(self._match_weight[match] <= 1 for match in matching_result.get_all_matches()), self._z3_context) - total_weight = z3.Sum(*(self._match_weight[match] for match in self._kmr.get_matches())) + total_weight = z3.Sum(*(self._match_weight[match] for match in matching_result.get_all_matches())) solver = z3.Optimize(ctx=self._z3_context) - solver.set("timeout", self.get_config().get("z3_timeout", default=2500)) + solver.set("timeout", self._z3_timeout) solver.add(weights_lower_bounds) solver.add(weights_upper_bounds) @@ -140,21 +167,21 @@ def _run(self) -> Generator[SupervisionResult, None, None]: solver.maximize(total_weight) - for keypoint_id in self._kmr.get_keypoint_ids(): - keypoint_matches = list(self._kmr.matches_for_keypoint(keypoint_id)) + for keypoint in template.keypoints: + keypoint_matches: List[IMatch] = list(matching_result.get_matches_for_keypoint(keypoint.identifier)) if len(keypoint_matches) == 0: continue # TODO: think whether this is a good algorithm design decision, and improve it if not - anchor_match = keypoint_matches[random.randint(0, len(keypoint_matches) - 1)] + anchor_match: IMatch = random.choice(keypoint_matches) delta = anchor_match.get_original_template_point() - delta_prime = anchor_match.get_target_point() + delta_prime = anchor_match.target_point solver.push() - for match in self._kmr.get_matches(): + for match in matching_result.get_all_matches(): solver.add(z3.Implies( self._match_weight[match] > 0, # consistency check @@ -165,12 +192,12 @@ def _run(self) -> Generator[SupervisionResult, None, None]: result = solver.check() if result == z3.unsat: - get_logger().warn("Could not satisfy the imposed constraints.", fg="red") + get_internal_afi().warn(Verbosity.INFO_VERBOSE, "Could not satisfy the imposed constraints.") solver.pop() continue if result == z3.unknown: - get_logger().warn("Could not decide the satifiability of the imposed constraints.", fg="red") + get_internal_afi().warn(Verbosity.INFO_VERBOSE, "Could not decide the satifiability of the imposed constraints.") solver.pop() continue @@ -185,16 +212,17 @@ def _run(self) -> Generator[SupervisionResult, None, None]: # extract transformation matrix from model transformation_matrix = model_evaluator(self._transformation_matrix) - _result = SupervisionResult(self.template_id, self._kmr, delta, delta_prime, transformation_matrix) - # add fixed constant to make sure that the score value is always non-negative - _result.set_score(model_total_weight) + _result = SupervisionResult( + delta=delta, + delta_prime=delta_prime, + transformation_matrix=transformation_matrix, + score=model_total_weight + ) - for match in self._kmr.get_matches(): + for match in matching_result.get_all_matches(): match_weight = model_evaluator(self._match_weight[match]) _result.set_match_weight(match, match_weight) - get_logger().debug(f"Error: {_result.get_weighted_mse()} Total weight and score: {model_total_weight}") - yield _result solver.pop() diff --git a/src/officialeye/_api_builtins/supervisor/least_squares_regression.py b/src/officialeye/_api_builtins/supervisor/least_squares_regression.py new file mode 100644 index 0000000..798d42d --- /dev/null +++ b/src/officialeye/_api_builtins/supervisor/least_squares_regression.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable + +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.template.matching_result import IMatchingResult + +# noinspection PyProtectedMember +from officialeye._api.template.supervision_result import SupervisionResult + +# noinspection PyProtectedMember +from officialeye._api.template.supervisor import Supervisor + +# noinspection PyProtectedMember +from officialeye._api.template.template_interface import ITemplate + +if TYPE_CHECKING: + from officialeye.types import ConfigDict + +_IND_A = 0 +_IND_B = 1 +_IND_C = 2 +_IND_D = 3 + + +class LeastSquaresRegressionSupervisor(Supervisor): + + SUPERVISOR_ID = "least_squares_regression" + + def __init__(self, config_dict: ConfigDict, /): + super().__init__(LeastSquaresRegressionSupervisor.SUPERVISOR_ID, config_dict) + + def setup(self, template: ITemplate, matching_result: IMatchingResult, /) -> None: + pass + + def supervise(self, template: ITemplate, matching_result: IMatchingResult, /) -> Iterable[SupervisionResult]: + + match_count = matching_result.get_total_match_count() + + for anchor_match in matching_result.get_all_matches(): + + delta = anchor_match.get_original_template_point() + delta_prime = anchor_match.target_point + + matrix = np.zeros((match_count << 1, 4), dtype=np.float64) + rhs = np.zeros(match_count << 1, dtype=np.float64) + + for i, match in enumerate(matching_result.get_all_matches()): + first_constraint_id = i << 1 + second_constraint_id = first_constraint_id + 1 + + s = match.get_original_template_point() + d = match.target_point + + matrix[first_constraint_id][_IND_A] = s[0] - delta[0] + matrix[first_constraint_id][_IND_B] = s[1] - delta[1] + rhs[first_constraint_id] = d[0] - delta_prime[0] + + matrix[second_constraint_id][_IND_C] = s[0] - delta[0] + matrix[second_constraint_id][_IND_D] = s[1] - delta[1] + rhs[second_constraint_id] = d[1] - delta_prime[1] + + regression_matrix = matrix.T @ matrix + regression_matrix = np.linalg.inv(regression_matrix) + rhs_applied = matrix.T @ rhs + x = regression_matrix @ rhs_applied + + transformation_matrix = np.array([ + [x[_IND_A], x[_IND_B]], + [x[_IND_C], x[_IND_D]] + ]) + + _result = SupervisionResult( + delta=delta, + delta_prime=delta_prime, + transformation_matrix=transformation_matrix + ) + + yield _result diff --git a/src/officialeye/_cli/__init__.py b/src/officialeye/_cli/__init__.py new file mode 100644 index 0000000..ec255b4 --- /dev/null +++ b/src/officialeye/_cli/__init__.py @@ -0,0 +1,3 @@ +""" +Module containing everything CLI-specific. +""" \ No newline at end of file diff --git a/src/officialeye/_cli/context.py b/src/officialeye/_cli/context.py new file mode 100644 index 0000000..7ae7682 --- /dev/null +++ b/src/officialeye/_cli/context.py @@ -0,0 +1,194 @@ +import os +import random +from tempfile import NamedTemporaryFile +from types import TracebackType +from typing import List + +import click +import cv2 +import numpy as np +from rich.prompt import Confirm + +from officialeye import Context +from officialeye.__version__ import __ascii_logo__ +from officialeye._cli.ui import TerminalUI, Verbosity + + +class CLIContext: + + def __init__(self, **kwargs): + self._api: Context | None = None + self._ui: TerminalUI | None = None + + self.handle_exceptions = True + self.visualization_generation = False + self.export_directory = None + + self.verbosity = Verbosity.QUIET + self.disable_logo = False + + self._export_counter = 1 + self._not_deleted_temporary_files: List[str] = [] + + self.set_params(**kwargs) + + def set_params(self, /, *, handle_exceptions: bool | None = None, visualization_generation: bool | None = None, + export_directory: str | None = None, verbosity: Verbosity | None = None, disable_logo: bool | None = None): + if handle_exceptions is not None: + self.handle_exceptions = handle_exceptions + + if visualization_generation is not None: + self.visualization_generation = visualization_generation + if export_directory is not None: + self.export_directory = export_directory + + if verbosity is not None: + self.verbosity = verbosity + if disable_logo is not None: + self.disable_logo = disable_logo + + def __enter__(self): + assert self._api is None + assert self._ui is None + + assert self._export_counter == 1 + assert len(self._not_deleted_temporary_files) == 0 + + self._ui = TerminalUI(self.verbosity) + self._api = Context(afi=self._ui) + + return self + + def dispose(self): + + self._assert_entered() + + if len(self._not_deleted_temporary_files) > 0 and (self.verbosity == Verbosity.QUIET or Confirm.ask( + ":question: Should temporary files created above be cleaned up now?", + default=True, console=self._ui.get_console() + ) + ): + files_removed = 0 + + # cleanup temporary files + for temp_file_path in self._not_deleted_temporary_files: + if os.path.isfile(temp_file_path): + os.unlink(temp_file_path) + files_removed += 1 + + self._ui.info(Verbosity.INFO, f"Successfully removed {files_removed} temporary file(s).") + + # reset fields related to file exporting + self._export_counter = 1 + self._not_deleted_temporary_files = [] + + # dispose the api together with all resources it manages, such as the AFI + self._api.dispose() + + self._api = None + self._ui = None + + def __exit__(self, exception_type: any, exception_value: BaseException, traceback: TracebackType): + + self._assert_entered() + + if not self.handle_exceptions: + self.dispose() + # tell python that we do not want to handle the exception + return None + + # handle the possible exception + if exception_value is None: + self.dispose() + return None + + self._ui.handle_uncaught_error(exception_type, exception_value, traceback) + + # free all allocated resources + self.dispose() + + # tell python that we have handled the exception ourselves + return True + + def print_logo(self): + + if self.disable_logo: + return + + logo_color = random.choice([ + "purple", + "yellow", + "red", + "green", + "cyan" + ]) + + self._ui.echo(Verbosity.INFO, __ascii_logo__, style=f"bold {logo_color}") + + def print_intro(self): + + self.print_logo() + + # print preliminary warnings if necessary + if not self.handle_exceptions: + self._ui.warn(Verbosity.INFO, "Raw error mode enabled. Use this mode only if you know precisely what you are doing!", ) + + if self.verbosity >= Verbosity.DEBUG: + self._ui.warn(Verbosity.INFO, "Debug mode enabled. Disable for production use to improve performance.") + + if self.visualization_generation: + self._ui.warn(Verbosity.INFO, "Visualization generation mode enabled. Disable for production use to improve performance.") + + def _assert_entered(self): + assert self._api is not None, "The context must be entered for this method to work correctly." + assert self._ui is not None, "The context must be entered for this method to work correctly." + + def get_api_context(self) -> Context: + self._assert_entered() + return self._api + + def get_terminal_ui(self) -> TerminalUI: + self._assert_entered() + return self._ui + + def _allocate_file_name(self) -> str: + self._assert_entered() + file_name = "%03d.png" % self._export_counter + self._export_counter += 1 + return file_name + + def allocate_file_for_export(self, /, *, file_name: str = "") -> str: + + self._assert_entered() + + file_suffix = ".png" if file_name == "" else file_name + + if self.export_directory is None: + with NamedTemporaryFile(prefix="officialeye_", suffix=f"_{file_suffix}", delete=False) as fp: + fp.close() + self._not_deleted_temporary_files.append(fp.name) + return fp.name + + if file_name == "": + file_name = self._allocate_file_name() + + return os.path.join(self.export_directory, file_name) + + def export_image(self, img: np.ndarray, /, *, file_name: str = "") -> str: + + export_file_path = self.allocate_file_for_export(file_name=file_name) + + cv2.imwrite(export_file_path, img) + + self._ui.info(Verbosity.INFO, f"Exported [b]{export_file_path}[/].") + + return export_file_path + + def export_and_show_image(self, img: np.ndarray, /, *, file_name: str = ""): + path = self.export_image(img, file_name=file_name) + + if self.verbosity != Verbosity.QUIET and Confirm.ask( + ":question: Would you like to open the image in an image viewer (as provided by the OS)?", + default=True, console=self._ui.get_console() + ): + click.launch(path, locate=False) diff --git a/src/officialeye/_internal/template/create.py b/src/officialeye/_cli/create.py similarity index 95% rename from src/officialeye/_internal/template/create.py rename to src/officialeye/_cli/create.py index c5bdee4..8976ee8 100644 --- a/src/officialeye/_internal/template/create.py +++ b/src/officialeye/_cli/create.py @@ -1,10 +1,13 @@ import os -from officialeye._internal.error.errors.io import ErrIOInvalidPath -from officialeye._internal.logger.singleton import get_logger +from officialeye._cli.context import CLIContext +# noinspection PyProtectedMember +from officialeye._internal.feedback.verbosity import Verbosity +from officialeye.error.errors.io import ErrIOInvalidPath -def create_example_template_config_file(template_path: str, template_image: str, template_id: str, template_name: str, force_mode: bool, /): + +def do_create(context: CLIContext, /, *, template_path: str, template_image: str, template_id: str, template_name: str, force_mode: bool): # validate the path first if os.path.isdir(template_path): @@ -172,4 +175,4 @@ def create_example_template_config_file(template_path: str, template_image: str, with open(template_path, "w") as fh: fh.write(template_yml) - get_logger().info(f"Initialized template configuration file at '{template_path}'.") + context.get_terminal_ui().echo(Verbosity.INFO, f":party_popper: Initialized template configuration file at '{template_path}'!") diff --git a/src/officialeye/_cli/main.py b/src/officialeye/_cli/main.py new file mode 100644 index 0000000..88ac7ae --- /dev/null +++ b/src/officialeye/_cli/main.py @@ -0,0 +1,164 @@ +""" +OfficialEye CLI frontend main entry point. +""" + +from typing import List + +import click + +from officialeye.__version__ import __github_full_url__, __github_url__, __version__ +from officialeye._cli.context import CLIContext +from officialeye._cli.create import do_create +from officialeye._cli.run import do_run +from officialeye._cli.show import do_show +from officialeye._cli.test import do_test +from officialeye._cli.ui import Verbosity + +_context = CLIContext() + + +@click.group() +@click.option("-d", "--debug", is_flag=True, show_default=True, default=False, help="Enable debug mode.") +@click.option("--edir", type=click.Path(exists=True, file_okay=True, readable=True), help="Specify export directory.") +@click.option("-q", "--quiet", is_flag=True, show_default=True, default=False, help="Disable standard output messages.") +@click.option("-v", "--verbose", is_flag=True, show_default=True, default=False, help="Enable verbose logging.") +@click.option("-dl", "--disable-logo", is_flag=True, show_default=True, default=False, help="Disable the officialeye logo.") +@click.option("-re", "--raw-errors", is_flag=True, show_default=False, default=False, help="Do not handle errors.") +def main(debug: bool, edir: str, quiet: bool, verbose: bool, disable_logo: bool, raw_errors: bool): + global _context + + # configure context + if quiet: + verbosity = Verbosity.QUIET + elif debug: + if verbose: + verbosity = Verbosity.DEBUG_VERBOSE + else: + verbosity = Verbosity.DEBUG + else: + # info verbosity + if verbose: + verbosity = Verbosity.INFO_VERBOSE + else: + verbosity = Verbosity.INFO + + _context.set_params( + export_directory=edir, + handle_exceptions=not raw_errors, + verbosity=verbosity, + disable_logo=disable_logo + ) + + +# noinspection PyShadowingBuiltins +@click.command() +@click.argument("template_path", type=click.Path(exists=False, file_okay=True, readable=True, writable=True)) +@click.argument("template_image", type=click.Path(exists=True, file_okay=True, readable=True, writable=False)) +@click.option("--id", type=str, show_default=False, default="example", help="Specify the template identifier.") +@click.option("--name", type=str, show_default=False, default="Example", help="Specify the template name.") +@click.option("--force", is_flag=True, show_default=True, default=False, help="Create missing directories and overwrite file.") +def create(template_path: str, template_image: str, id: str, name: str, force: bool): + """Creates a new template configuration file at the specified path.""" + + global _context + + with _context as context: + context.print_logo() + + do_create( + context, + template_path=template_path, + template_image=template_image, + template_id=id, + template_name=name, + force_mode=force + ) + + +@click.command() +@click.argument("template_path", type=click.Path(exists=True, file_okay=True, readable=True)) +@click.option("--hide-features", is_flag=True, show_default=False, default=False, help="Do not visualize the locations of features.") +@click.option("--hide-keypoints", is_flag=True, show_default=False, default=False, help="Do not visualize the locations of keypoints.") +def show(template_path: str, hide_features: bool, hide_keypoints: bool): + """Exports template as an image with features visualized.""" + + global _context + + with _context as context: + do_show(context, template_path=template_path, hide_features=hide_features, hide_keypoints=hide_keypoints) + + +@click.command() +@click.argument("target_path", type=click.Path(exists=True, file_okay=True, readable=True)) +@click.argument("template_paths", type=click.Path(exists=True, file_okay=True, readable=True), nargs=-1) +@click.option("--show-features", is_flag=True, show_default=False, default=False, help="Visualize the locations of features.") +def test(target_path: str, template_paths: List[str], show_features: bool): + """Visualizes the analysis of an image using one or more templates.""" + + global _context + + with _context as context: + do_test( + context, + target_path=target_path, + template_paths=template_paths, + show_features=show_features + ) + + +@click.command() +@click.argument("target_path", type=click.Path(exists=True, file_okay=True, readable=True)) +@click.argument("template_paths", type=click.Path(exists=True, file_okay=True, readable=True), nargs=-1) +@click.option("--interpret", type=click.Path(exists=True, file_okay=True, readable=True), + default=None, help="Use the image at the specified path to run the interpretation phase.") +@click.option("--visualize", is_flag=True, show_default=False, default=False, help="Generate visualizations of intermediate steps.") +def run(target_path: str, template_paths: List[str], interpret: str | None, visualize: bool): + """Applies one or more templates to an image.""" + + global _context + + # TODO: think whether this is a good design choice + _context.set_params(visualization_generation=visualize) + + with _context as context: + do_run( + context, + target_path=target_path, + template_paths=template_paths, + interpret_path=interpret, + visualize=visualize + ) + + +@click.command() +def homepage(): + """Go to the officialeye's official GitHub homepage.""" + + global _context + + with _context as context: + context.get_terminal_ui().info(Verbosity.INFO, f"GitHub: [link={__github_full_url__}]{__github_url__}[/link]") + + click.launch(__github_full_url__) + + +@click.command() +def version(): + """Print the version of OfficialEye.""" + + global _context + + with _context as context: + context.print_logo() + context.get_terminal_ui().info(Verbosity.INFO, f"Version: {__version__}") + + +main.add_command(create) +main.add_command(show) +main.add_command(test) +main.add_command(run) +main.add_command(homepage) +main.add_command(version) + +if __name__ == "__main__": + main() diff --git a/src/officialeye/_cli/run.py b/src/officialeye/_cli/run.py new file mode 100644 index 0000000..2fdab3c --- /dev/null +++ b/src/officialeye/_cli/run.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List + +from rich.json import JSON +from rich.table import Table + +# noinspection PyProtectedMember +from officialeye._api.detection import detect + +# noinspection PyProtectedMember +from officialeye._api.image import Image + +# noinspection PyProtectedMember +from officialeye._api.template.template import Template +from officialeye._cli.context import CLIContext + +# noinspection PyProtectedMember +from officialeye._internal.feedback.verbosity import Verbosity + +if TYPE_CHECKING: + from officialeye.types import FeatureInterpretation + + +def do_run(context: CLIContext, /, *, target_path: str, template_paths: List[str], interpret_path: str | None, visualize: bool): + # print OfficialEye logo and other introductory information (if necessary) + context.print_intro() + + # TODO: implement visualization generation + # TODO: update the example in the documentation + + api_context = context.get_api_context() + + target_image = Image(api_context, path=target_path) + + interpretation_target_image = target_image if interpret_path is None else Image(api_context, path=interpret_path) + + templates = [Template(api_context, path=template_path) for template_path in template_paths] + + result = detect(api_context, *templates, target=target_image) + + interpretation_result = result.interpret(target=interpretation_target_image) + + table = Table() + + table.add_column("Feature", justify="right") + table.add_column("Interpretation", justify="left") + + for feature in interpretation_result.template.features: + interpretation: FeatureInterpretation = interpretation_result.get_feature_interpretation(feature) + + interpretation_visualization = JSON.from_data(interpretation, indent=4) + + table.add_row(feature.identifier, interpretation_visualization) + + context.get_terminal_ui().echo(Verbosity.INFO, table) diff --git a/src/officialeye/_cli/show.py b/src/officialeye/_cli/show.py new file mode 100644 index 0000000..68633f1 --- /dev/null +++ b/src/officialeye/_cli/show.py @@ -0,0 +1,54 @@ +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.template.template import Template +from officialeye._cli.context import CLIContext +from officialeye._cli.utils import visualize_feature, visualize_keypoint + +# noinspection PyProtectedMember +from officialeye._internal.feedback.verbosity import Verbosity + + +def _visualize_regions(template: Template, background_img: np.ndarray, /, *, + hide_features: bool = False, hide_keypoints: bool = False) -> np.ndarray: + + img = background_img + + if not hide_features: + for feature in template.features: + img = visualize_feature(feature, img) + + if not hide_keypoints: + for keypoint in template.keypoints: + img = visualize_keypoint(keypoint, img) + + return img + + +def do_show(context: CLIContext, /, *, template_path: str, **kwargs): + + # print OfficialEye logo and other introductory information (if necessary) + context.print_intro() + + template = Template(context.get_api_context(), path=template_path) + + template_img = template.get_image().load() + + template_img_mutated = template.get_mutated_image().load() + + if template_img_mutated.shape == template_img.shape: + background_img = template_img_mutated + else: + + context.get_terminal_ui().warn( + Verbosity.INFO_VERBOSE, + f"One of the source mutators of the '{template.identifier}' template has changed the shape of the image. " + f"To ensure that the regions of the template are visualized correctly, " + f"the original template image had to be used as the background." + ) + + background_img = template_img + + visualization = _visualize_regions(template, background_img, **kwargs) + + context.export_and_show_image(visualization, file_name=f"{template.identifier}.png") diff --git a/src/officialeye/_cli/test.py b/src/officialeye/_cli/test.py new file mode 100644 index 0000000..330bae1 --- /dev/null +++ b/src/officialeye/_cli/test.py @@ -0,0 +1,83 @@ +from typing import List + +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.detection import detect + +# noinspection PyProtectedMember +from officialeye._api.image import Image + +# noinspection PyProtectedMember +from officialeye._api.template.template import Template + +# noinspection PyProtectedMember +from officialeye._api.template.template_interface import ITemplate +from officialeye._cli.context import CLIContext +from officialeye._cli.utils import visualize_feature + +# noinspection PyProtectedMember +from officialeye._internal.feedback.verbosity import Verbosity + + +def _get_background(context: CLIContext, template: ITemplate, /) -> np.ndarray: + + raw_image = template.get_image().load() + mutated_image = template.get_mutated_image().load() + + if raw_image.shape == mutated_image.shape: + return mutated_image + + context.get_terminal_ui().warn( + Verbosity.INFO, + f"Could not use the mutated version of the '{template.identifier}' template " + f"because one of the source mutators did not preserve the shape of the image. " + f"Falling back to the non-mutated version of the source image." + ) + + return raw_image + + +def do_test(context: CLIContext, /, *, + target_path: str, template_paths: List[str], show_features: bool): + # print OfficialEye logo and other introductory information (if necessary) + context.print_intro() + + api_context = context.get_api_context() + + target_image = Image(api_context, path=target_path) + + templates = [Template(api_context, path=template_path) for template_path in template_paths] + + result = detect(api_context, *templates, target=target_image) + + visualization = _get_background(context, result.template) + target_image_mat = target_image.load() + + for feature in result.template.features: + feature_image_mat = result.warp_feature(feature, target_image_mat) + + feature_image_mutated_mat = feature.apply_mutators_to_image(feature_image_mat) + + if feature_image_mat.shape == feature_image_mutated_mat.shape: + # mutators didn't change the shape of the image + feature.insert_into_image(visualization, feature_image_mutated_mat) + else: + # some mutator has altered the shape of the feature image. + # this means that we can no longer safely insert the mutated feature into the visualization. + # therefore, we have to fall back to inserting the feature image unmutated + context.get_terminal_ui().warn( + Verbosity.INFO, + f"Could not visualize the '{feature.identifier}' feature of the '{feature.template.identifier}' template, " + f"because one of the mutators (corresponding to this feature) did not preserve the shape of the image. " + f"Falling back to the non-mutated version of the feature image." + ) + + feature.insert_into_image(visualization, feature_image_mat) + + if show_features: + # visualize features on the image + for feature in result.template.features: + visualization = visualize_feature(feature, visualization) + + context.export_and_show_image(visualization, file_name=f"{result.template.identifier}.png") diff --git a/src/officialeye/_cli/ui.py b/src/officialeye/_cli/ui.py new file mode 100644 index 0000000..7991184 --- /dev/null +++ b/src/officialeye/_cli/ui.py @@ -0,0 +1,461 @@ +from __future__ import annotations + +from concurrent.futures import Future +from contextlib import contextmanager +from multiprocessing import Pipe + +# noinspection PyProtectedMember +from multiprocessing.connection import Connection, wait +from threading import Lock, Thread +from types import TracebackType +from typing import Any, Dict, List, Set, Tuple + +from rich.console import Console, ConsoleRenderable +from rich.panel import Panel +from rich.progress import Progress, SpinnerColumn, TaskID, TextColumn +from rich.theme import Theme +from rich.traceback import Traceback + +# noinspection PyProtectedMember +from officialeye._internal.context.feedback import InternalFeedbackInterface, IPCMessageType + +# noinspection PyProtectedMember +from officialeye._internal.feedback.abstract import AbstractFeedbackInterface + +# noinspection PyProtectedMember +from officialeye._internal.feedback.verbosity import Verbosity +from officialeye.error.error import OEError +from officialeye.error.errors.general import ErrOperationNotSupported +from officialeye.error.errors.internal import ErrInternal + +_THEME_TAG_INFO = "info" +_THEME_TAG_INFO_VERBOSE = "infov" +_THEME_TAG_DEBUG = "debug" +_THEME_TAG_DEBUG_VERBOSE = "debugv" + +_THEME_TAG_WARN = "warn" +_THEME_TAG_ERR = "err" + +_THEME: Dict[str, str] = { + _THEME_TAG_INFO: "bold green", + _THEME_TAG_INFO_VERBOSE: "bold cyan", + _THEME_TAG_DEBUG: "bold purple", + _THEME_TAG_DEBUG_VERBOSE: "bold magenta", + _THEME_TAG_WARN: "bold yellow", + _THEME_TAG_ERR: "bold red" +} + +_THEME_MAP: Dict[Verbosity, str] = { + Verbosity.INFO: _THEME_TAG_INFO, + Verbosity.INFO_VERBOSE: _THEME_TAG_INFO_VERBOSE, + Verbosity.DEBUG: _THEME_TAG_DEBUG, + Verbosity.DEBUG_VERBOSE: _THEME_TAG_DEBUG_VERBOSE +} + + +def _wrap_exception(exception_value: BaseException, /) -> OEError: + + if isinstance(exception_value, OEError): + return exception_value + + if isinstance(exception_value, BaseException): + oe_error = ErrInternal( + "while leaving the CLI context.", + "An internal error occurred.", + ) + oe_error.add_external_cause(exception_value) + return oe_error + + return ErrInternal( + "while determining the raised exception type.", + "Could not decide how to handle the raised error." + ) + + +def _child_listener(listener: _ChildrenListener, /): + + # a known child is a child that the listener thread is either already activaly listening to, or has already stopped listening to due to + # the receival of a message over IPC, indicating that the child is done with the work + known_children: Set[int] = set() + + while True: + + # first, we wait for one of the children to send some message + with listener.children_lock: + if len(listener.children) == 0: + break + + # determine whether there are children the children listener thread does not know about + # in this case, those children should be found and added to the known list, + # and the corresponding lock should be acquired, indicating that the connection with the current + # child has not yet been shut down gracefully + for child_id in listener.children: + + if child_id in known_children: + continue + + child = listener.children[child_id] + + # indicate that we are listening to that child, and that once it is done with its work, + # our handling of messages that might still be pending on the IPC, is to be respected + child.is_being_listened_to.acquire() + + # now we know the child + known_children.add(child_id) + + connections = [ + listener.children[child_id].connection for child_id in listener.children + ] + + assert len(connections) >= 1 + + # TODO: consider introducing a mechanism allowing one to let the child listener thread know about the change in the children dictionary + # TODO: this will improve performance, because there will be no need to wait for the timeout to expire + # TODO: this idea can be implemented, for example, by introducing a new dummy connection designed only to communicate 'refresh' messages + # TODO: with the child listener thread + wait(connections, timeout=1.0) + + messages_to_handle: List[Tuple[Any, _Child]] = [] + + with listener.children_lock: + if len(listener.children) == 0: + break + + for child_id in listener.children: + child = listener.children[child_id] + + if not child.connection.poll(): + continue + + message = child.connection.recv() + + messages_to_handle.append((message, child)) + + for message, child in messages_to_handle: + listener.handle_message(message, child) + + +class _Child: + + def __init__(self, child_id: int, task_id: TaskID, connection: Connection, /): + self.child_id = child_id + self.task_id = task_id + self.connection = connection + self.is_being_listened_to = Lock() + + +# noinspection PyProtectedMember +class _ChildrenListener: + + def __init__(self, terminal_ui, /): + + self._terminal_ui = terminal_ui + + self._progress = Progress( + SpinnerColumn(), + TextColumn("{task.description}"), + TextColumn("[cyan]{task.fields[status]}"), + console=self._terminal_ui._console, + auto_refresh=True, + disable=self._terminal_ui._verbosity == Verbosity.QUIET, + transient=True + ) + + self.children: Dict[int, _Child] = {} + self.children_lock = Lock() + + self._children_listener: Thread | None = None + + def handle_message(self, message: tuple, child: _Child, /): + + message_type, args, kwargs = message + + is_task_done: bool = False + + with self._terminal_ui.as_author(child.child_id): + if message_type == IPCMessageType.ECHO: + self._terminal_ui.echo(*args, **kwargs) + elif message_type == IPCMessageType.INFO: + self._terminal_ui.info(*args, **kwargs) + elif message_type == IPCMessageType.WARN: + self._terminal_ui.warn(*args, **kwargs) + elif message_type == IPCMessageType.ERROR: + self._terminal_ui.error(*args, **kwargs) + elif message_type == IPCMessageType.UPDATE_STATUS: + new_status_text: str = args[0] + assert isinstance(new_status_text, str) + self._progress.update(child.task_id, status=new_status_text) + elif message_type == IPCMessageType.TASK_DONE: + is_task_done = True + else: + raise AssertionError("Unknown IPC message type received by parent process.") + + if is_task_done: + + task_done_successfully: bool = args[0] + assert isinstance(task_done_successfully, bool) + + self._terminal_ui.info( + Verbosity.DEBUG, + f"Child has indicated that the task is done (success={task_done_successfully}), " + "releasing the corresponding lock to enable graceful shutdown of the child listener." + ) + + if task_done_successfully: + self._progress.update(child.task_id, completed=100, status="[green]:heavy_check_mark: Success![/]") + else: + self._progress.update(child.task_id, completed=100, status="[red]:heavy_multiplication_x: Error![/]") + + child.is_being_listened_to.release() + + def listen_to(self, child_id: int, connection: Connection, description: str, /): + + # create a new task associated with the child + task_id = self._progress.add_task(description, status="") + + with self.children_lock: + assert child_id not in self.children, "Child ID is not unique." + self.children[child_id] = _Child(child_id, task_id, connection) + + if self._children_listener is None: + # we have added the first child. therefore, the progress bar needs to be started. + self._progress.start() + + # we need to also start a thread listening for messages from children + self._children_listener = Thread(target=_child_listener, name="Child Process Listener", args=(self,)) + self._children_listener.start() + + def stop_listening_to(self, child_id: int, /): + + # we first attempt to wait until the child listener thread has done processing all messages sent by the child + # in other words we want to try and gracefully stop listening to the child + + with self.children_lock: + + if child_id not in self.children: + self._terminal_ui.warn( + Verbosity.DEBUG, + f"Could not stop listening for child {child_id} because it could not be found among children the main process is listening to." + ) + return + + _child = self.children[child_id] + + self._terminal_ui.info( + Verbosity.DEBUG, + f"Waiting for all messages from child {child_id} to be processed by the child listener thread." + ) + + child_lock = _child.is_being_listened_to + + if child_lock.acquire(blocking=True, timeout=4): + child_lock.release() + + self._terminal_ui.info( + Verbosity.DEBUG, + f"The child listener thread indicated that it has processed all messages from child {child_id}." + ) + else: + self._terminal_ui.warn( + Verbosity.DEBUG, + f"The child listener thread did not indicate that it has processed all messages from child {child_id}!" + ) + + # we now proceed with removing the child completely + + last_child_removed = False + + with self.children_lock: + + if child_id not in self.children: + self._terminal_ui.warn( + Verbosity.DEBUG, + f"Could not stop listening for child {child_id} because it could not be found among children the main process is listening to." + ) + return + + child = self.children[child_id] + + child.connection.close() + + del self.children[child_id] + + if len(self.children) == 0: + # we have removed the last child + last_child_removed = True + + if last_child_removed: + self._terminal_ui.info(Verbosity.DEBUG_VERBOSE, "Last child removed, stopping the child listener and the progress bar.") + + # stop the thread listening for messages from children + if self._children_listener is not None: + self._terminal_ui.info(Verbosity.DEBUG_VERBOSE, "Joining the children listener thread.") + + self._children_listener.join() + self._children_listener = None + + self._terminal_ui.info(Verbosity.DEBUG_VERBOSE, "Children listener thread successfully joined.") + + # stop the progress bar + self._terminal_ui.info(Verbosity.DEBUG_VERBOSE, "Stopping the progress bar due to removal of last child.") + self._progress.stop() + self._terminal_ui.info(Verbosity.DEBUG_VERBOSE, "Stopped the progress bar due to removal of last child.") + + def remove_all_children(self): + + while True: + + with self.children_lock: + children_to_be_removed = list(self.children.keys()) + + if len(children_to_be_removed) == 0: + break + + for child_id in children_to_be_removed: + self.stop_listening_to(child_id) + + def dispose(self): + self._terminal_ui.info(Verbosity.DEBUG_VERBOSE, "Dispoing child listener...") + self.remove_all_children() + + +class TerminalUI(AbstractFeedbackInterface): + + def __init__(self, verbosity: Verbosity, /): + super().__init__(verbosity) + + self._console = Console(highlight=False, theme=Theme(_THEME)) + self._err_console = Console(stderr=True, theme=Theme(_THEME)) + + self._children_listener: _ChildrenListener = _ChildrenListener(self) + self._fork_counter: int = 0 + + self._last_printed_message_author: int | None = None + + def get_console(self) -> Console: + return self._console + + def _print_message_authors(self) -> bool: + return self._verbosity >= Verbosity.DEBUG + + @contextmanager + def as_author(self, author: int): + + print_authors = self._print_message_authors() + + if print_authors and self._last_printed_message_author != author: + # the same author prints the message + self._console.rule(f"Messages by worker #{author}") + + self._last_printed_message_author = None + + yield self + + if print_authors: + self._last_printed_message_author = author + + def echo(self, verbosity: Verbosity, message: str | ConsoleRenderable = "", /, *, err: bool = False, **kwargs: Any) -> None: + + assert verbosity != Verbosity.QUIET + + if self._last_printed_message_author is not None: + self._console.rule("Messages by the main process") + self._last_printed_message_author = None + + if self._verbosity < verbosity: + return + + console = self._err_console if err else self._console + + if not console.is_interactive: + kwargs.setdefault("crop", False) + kwargs.setdefault("overflow", "ignore") + + console.print(message, **kwargs) + + def info(self, verbosity: Verbosity, message: str, /) -> None: + assert verbosity != Verbosity.QUIET + _tag = _THEME_MAP[verbosity] + self.echo(verbosity, f"[{_tag}]INFO [/] {message}", highlight=True) + + def warn(self, verbosity: Verbosity, message: str, /) -> None: + assert verbosity != Verbosity.QUIET + self.echo(verbosity, f"[{_THEME_TAG_WARN}]WARN [/] {message}") + + def error(self, verbosity: Verbosity, message: str, /) -> None: + assert verbosity != Verbosity.QUIET + self.echo(verbosity, f"[{_THEME_TAG_ERR}]ERROR[/] {message}", err=True) + + def update_status(self, new_status_text: str, /) -> None: + raise ErrOperationNotSupported( + "while updating status of a task.", + "The terminal UI does not support this operation." + ) + + def _print_oe_error(self, error: OEError, /, *, verbosity: Verbosity = Verbosity.INFO): + + self.error(verbosity, f"Error {error.code} ('{error.code_text}') occurred in module '{error.module}' {error.while_text}") + self.error(verbosity, f"Problem: {error.problem_text}") + + error_details = error.get_details() + + if error_details is not None and verbosity != Verbosity.QUIET: + detail_panel = Panel(error_details, title="Error details", expand=False, border_style="red", highlight=True) + self._err_console.print(detail_panel) + + causes = error.get_causes() + external_causes = error.get_external_causes() + + if len(causes) + len(external_causes) > 0: + self._err_console.rule("The above error has been caused by the errors listed below") + + for cause in causes: + self._print_oe_error(cause, verbosity=verbosity) + + for external_cause in external_causes: + rich_traceback = Traceback.from_exception(external_cause.__class__, external_cause, None) + self._err_console.print(rich_traceback) + + def handle_uncaught_error(self, exception_type: any, exception_value: BaseException, traceback: TracebackType, /): + + # remove all children to make sure that the progress bar dissapears and + # does not interfere with the printing of the error to the terminal + self._children_listener.remove_all_children() + + self._print_oe_error(_wrap_exception(exception_value)) + + if self._verbosity >= Verbosity.DEBUG_VERBOSE: + # self._err_console.rule("Error details") + rich_traceback = Traceback.from_exception(exception_type, exception_value, traceback) + self._err_console.print(rich_traceback) + + def dispose(self, exception_type: any = None, exception_value: BaseException | None = None, traceback: TracebackType | None = None) -> None: + self._children_listener.dispose() + + def fork(self, description: str, /) -> AbstractFeedbackInterface: + + self.info(Verbosity.DEBUG_VERBOSE, "AbstractFeedbackInterface: fork()") + + rx, tx = Pipe(duplex=False) + + assert isinstance(rx, Connection) + assert isinstance(tx, Connection) + + self._fork_counter += 1 + child_id = self._fork_counter + + child = InternalFeedbackInterface(self._verbosity, child_id, tx) + + self._children_listener.listen_to(child_id, rx, description) + + return child + + def join(self, child: AbstractFeedbackInterface, future: Future, /) -> None: + + assert isinstance(child, InternalFeedbackInterface), "Invalid child type" + + child_id = child.get_child_id() + + self.info(Verbosity.DEBUG_VERBOSE, f"AbstractFeedbackInterface: join() of child #{child_id}") + + self._children_listener.stop_listening_to(child_id) diff --git a/src/officialeye/_cli/utils.py b/src/officialeye/_cli/utils.py new file mode 100644 index 0000000..e7ad467 --- /dev/null +++ b/src/officialeye/_cli/utils.py @@ -0,0 +1,51 @@ +from typing import Tuple + +import cv2 +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.template.feature import IFeature + +# noinspection PyProtectedMember +from officialeye._api.template.keypoint import IKeypoint + +# noinspection PyProtectedMember +from officialeye._api.template.region import IRegion + +_LABEL_COLOR_DEFAULT = (0, 0, 0xff) +_VISUALIZATION_SCALE_COEFF = 1.0 / 1400.0 + +_FEATURE_RECT_COLOR = (0, 0xff, 0) +_KEYPOINT_RECT_COLOR = (0, 0, 0xff) + + +def visualize_region(region: IRegion, img: np.ndarray, /, *, rect_color: Tuple[int, int, int], label_color=_LABEL_COLOR_DEFAULT) -> np.ndarray: + + img = cv2.rectangle(img, (region.x, region.y), (region.x + region.w, region.y + region.h), rect_color, 4) + + label_origin = ( + region.x + int(10 * img.shape[0] * _VISUALIZATION_SCALE_COEFF), + region.y + int(30 * img.shape[0] * _VISUALIZATION_SCALE_COEFF) + ) + + font_scale = img.shape[0] * _VISUALIZATION_SCALE_COEFF + img = cv2.putText( + img, + region.identifier, + label_origin, + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + label_color, + int(2 * img.shape[0] * _VISUALIZATION_SCALE_COEFF), + cv2.LINE_AA + ) + + return img + + +def visualize_feature(feature: IFeature, img: np.ndarray, /) -> np.ndarray: + return visualize_region(feature, img, rect_color=_FEATURE_RECT_COLOR) + + +def visualize_keypoint(keypoint: IKeypoint, img: np.ndarray, /) -> np.ndarray: + return visualize_region(keypoint, img, rect_color=_KEYPOINT_RECT_COLOR) diff --git a/src/officialeye/_internal/__init__.py b/src/officialeye/_internal/__init__.py index c6096d4..35400b1 100644 --- a/src/officialeye/_internal/__init__.py +++ b/src/officialeye/_internal/__init__.py @@ -4,4 +4,8 @@ WARNING! Do not import it unless you know precisely what you are doing. Instead, use the public API to interact with OfficialEye programatically. -""" \ No newline at end of file +""" + +import z3 + +z3.set_param("parallel.enable", True) diff --git a/src/officialeye/_internal/_types.py b/src/officialeye/_internal/_types.py new file mode 100644 index 0000000..13668af --- /dev/null +++ b/src/officialeye/_internal/_types.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar + +if TYPE_CHECKING: + from typing import Protocol + + SpinnerT = TypeVar("SpinnerT", bound="Spinner") + + class Spinner(Protocol): + def update(self, text: str) -> None: + ... + + def __enter__(self: SpinnerT) -> SpinnerT: + ... + + def __exit__(self, *args: Any) -> None: + ... + + class RichProtocol(Protocol): + def __rich__(self) -> str: + ... diff --git a/src/officialeye/_internal/api/__init__.py b/src/officialeye/_internal/api/__init__.py new file mode 100644 index 0000000..c83b8e2 --- /dev/null +++ b/src/officialeye/_internal/api/__init__.py @@ -0,0 +1,4 @@ +""" +This module provides a set of functions connecting the API interface with the internal implementation interface. +In other words, the functions of this module form a low-level API that should be called internally in the actual public API. +""" \ No newline at end of file diff --git a/src/officialeye/_internal/api/detect.py b/src/officialeye/_internal/api/detect.py new file mode 100644 index 0000000..a03dd87 --- /dev/null +++ b/src/officialeye/_internal/api/detect.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import cv2 +import numpy as np + +from officialeye._internal.context.singleton import get_internal_context +from officialeye._internal.template.schema.loader import load_template + +if TYPE_CHECKING: + from officialeye._internal.template.external_supervision_result import ExternalSupervisionResult + from officialeye._internal.template.internal_supervision_result import InternalSupervisionResult + + +def template_detect(template_path: str, /, *, target_path: str, **kwargs) -> ExternalSupervisionResult: + + from officialeye._internal.template.external_supervision_result import ExternalSupervisionResult + + with get_internal_context().setup(**kwargs): + template = load_template(template_path) + + target: np.ndarray = cv2.imread(target_path, cv2.IMREAD_COLOR) + + # TODO: move the following to a separate internal api method + """ + + """ + + internal_supervision_result: InternalSupervisionResult = template.do_detect(target) + + return ExternalSupervisionResult(internal_supervision_result) diff --git a/src/officialeye/_internal/api/interpret.py b/src/officialeye/_internal/api/interpret.py new file mode 100644 index 0000000..d95ca3b --- /dev/null +++ b/src/officialeye/_internal/api/interpret.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import cv2 + +from officialeye._internal.context.singleton import get_internal_context + +# noinspection PyProtectedMember +from officialeye._internal.template.external_interpretation_result import ExternalInterpretationResult +from officialeye._internal.template.schema.loader import load_template + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.template.supervision_result import ISupervisionResult + + +def template_interpret(template_path: str, supervision_result: ISupervisionResult, /, *, + interpretation_target_path: str, **kwargs) -> ExternalInterpretationResult: + + with get_internal_context().setup(**kwargs): + + template = load_template(template_path) + + interpretation_target = cv2.imread(interpretation_target_path, cv2.IMREAD_COLOR) + + # TODO: make sure that the target image and the interpretation target images have the same shape, similar to the following snippet + """ + if target.shape != interpretation_target.shape: + raise ErrInvalidImage( + "while making sure that the target image and the interpretation target images have the same shape.", + f"The shapes mismatch. " + f"The target image has shape {target.shape}, while the interpretation target image has shape {interpretation_target.shape}." + ) + """ + + feature_interpretation_dict = {} + + for feature in template.features: + + feature_class = feature.get_feature_class() + + if feature_class is None: + continue + + feature_img = supervision_result.warp_feature(feature, interpretation_target) + feature_img_mutated = feature.apply_mutators_to_image(feature_img) + interpretation = feature.interpret_image(feature_img_mutated) + + feature_interpretation_dict[feature.identifier] = interpretation + + return ExternalInterpretationResult(template, feature_interpretation_dict) diff --git a/src/officialeye/_internal/api/load.py b/src/officialeye/_internal/api/load.py new file mode 100644 index 0000000..53594b9 --- /dev/null +++ b/src/officialeye/_internal/api/load.py @@ -0,0 +1,10 @@ +from officialeye._internal.context.singleton import get_internal_context +from officialeye._internal.template.external_template import ExternalTemplate +from officialeye._internal.template.schema.loader import load_template + + +def template_load(template_path: str, /, **kwargs) -> ExternalTemplate: + + with get_internal_context().setup(**kwargs): + template = load_template(template_path) + return ExternalTemplate(template) diff --git a/src/officialeye/_internal/api_implementation.py b/src/officialeye/_internal/api_implementation.py new file mode 100644 index 0000000..1f9f6f6 --- /dev/null +++ b/src/officialeye/_internal/api_implementation.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.context import Context + + +class IApiInterfaceImplementation(ABC): + + @abstractmethod + def set_api_context(self, context: Context, /) -> None: + """ + This method should be used to propagate the public API's context to the objects returned by the internal implementation of the API. + Those objects are called 'external' and should be picklable if the API context has not yet been set via this method. + If it was, then all methods guaranteed by the corresponding object's public API interface can be implemented properly. + """ + raise NotImplementedError() + + @abstractmethod + def clear_api_context(self) -> None: + """ + This method should clear the public API's context stored in the current object, and in all internal objects implementing this interface. + It is essential that after running this method, the object is picklable. + """ + raise NotImplementedError() diff --git a/src/officialeye/_internal/config/__init__.py b/src/officialeye/_internal/config/__init__.py deleted file mode 100644 index 8543c2c..0000000 --- a/src/officialeye/_internal/config/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -Module for abstracting out the ability to inject custom configurations specified using dictionaries. -The goal of this module is to provide a nice API for validated user-specified configurations -and safely retrieving information from there. -""" diff --git a/src/officialeye/_internal/config/config.py b/src/officialeye/_internal/config/config.py deleted file mode 100644 index 15a8edd..0000000 --- a/src/officialeye/_internal/config/config.py +++ /dev/null @@ -1,34 +0,0 @@ -import abc -from typing import Callable, Dict - - -class Config(abc.ABC): - - def __init__(self, config_dict: Dict[str, any], /): - self._config_dict = config_dict - - self._value_preprocessors: Dict[str, Callable[[str], any]] = {} - - def set_value_preprocessor(self, key: str, preprocessor: Callable[[str], any], /): - self._value_preprocessors[key] = preprocessor - - @abc.abstractmethod - def _get_invalid_key_error(self, key: str, /): - raise NotImplementedError() - - def get(self, key: str, /, *, default=None): - - if key not in self._config_dict: - - if default is None: - raise self._get_invalid_key_error(key) - - return default - - _value = self._config_dict[key] - - # apply value preprocessor - if key in self._value_preprocessors: - _value = self._value_preprocessors[key](_value) - - return _value diff --git a/src/officialeye/_internal/context/context.py b/src/officialeye/_internal/context/context.py index eec6399..d6fb665 100644 --- a/src/officialeye/_internal/context/context.py +++ b/src/officialeye/_internal/context/context.py @@ -1,124 +1,158 @@ # needed to not break type annotations if we are not in type checking mode from __future__ import annotations -import os -import tempfile -from typing import TYPE_CHECKING, Dict, List, Union +from types import TracebackType +from typing import TYPE_CHECKING, Dict -import click -import cv2 -import z3 - -from officialeye._internal.error.error import OEError -from officialeye._internal.error.errors.internal import ErrInternal -from officialeye._internal.error.errors.template import ErrTemplateIdNotUnique -from officialeye._internal.logger.singleton import get_logger +from officialeye._internal.feedback.abstract import AbstractFeedbackInterface +from officialeye._internal.feedback.dummy import DummyFeedbackInterface +from officialeye.error.error import OEError +from officialeye.error.errors.general import ErrInvalidKey +from officialeye.error.errors.template import ErrTemplateIdNotUnique if TYPE_CHECKING: - from officialeye._internal.io.driver import IODriver - from officialeye._internal.template.template import Template - + # noinspection PyProtectedMember + # noinspection PyProtectedMember + from officialeye._api.mutator import IMutator + from officialeye._api.template.interpretation import IInterpretation -# TODO: move part of the Context class methods to IO driver + # noinspection PyProtectedMember + from officialeye._api.template.matcher import IMatcher -class Context: + # noinspection PyProtectedMember + from officialeye._api.template.supervisor import ISupervisor + from officialeye._internal.template.internal_template import InternalTemplate + from officialeye.types import ConfigDict, InterpretationFactory, MatcherFactory, MutatorFactory, SupervisorFactory - def __init__(self, manager, /, *, visualization_generation: bool = False): - self._manager = manager - self._io_driver: Union[IODriver, None] = None +class InternalContext: - self._visualization_generation = visualization_generation + def __init__(self): + self._afi = DummyFeedbackInterface() - self._export_counter = 1 - self._not_deleted_temporary_files: List[str] = [] + self._mutator_factories: Dict[str, MutatorFactory] = {} + self._matcher_factories: Dict[str, MatcherFactory] = {} + self._supervisor_factories: Dict[str, SupervisorFactory] = {} + self._interpretation_factories: Dict[str, InterpretationFactory] = {} # keys: template ids # values: template - self._loaded_templates: Dict[str, Template] = {} + self._loaded_templates: Dict[str, InternalTemplate] = {} + + # keys: paths to templates + # values: corresponding template ids + self._template_ids: Dict[str, str] = {} + + def setup(self, /, *, afi: AbstractFeedbackInterface, mutator_factories: Dict[str, MutatorFactory], + matcher_factories: Dict[str, MatcherFactory], supervisor_factories: Dict[str, SupervisorFactory], + interpretation_factories: Dict[str, InterpretationFactory]) -> InternalContext: + assert afi is not None + + assert mutator_factories is not None + assert matcher_factories is not None + assert supervisor_factories is not None + + self._afi = afi + self._mutator_factories = mutator_factories + self._matcher_factories = matcher_factories + self._supervisor_factories = supervisor_factories + self._interpretation_factories = interpretation_factories + + return self + + def __enter__(self): + return None - z3.set_param("parallel.enable", True) + def __exit__(self, exception_type: any, exception_value: BaseException | None, traceback: TracebackType | None): + # inform the parent process that the current task is done + self._afi.dispose(exception_type, exception_value, traceback) + self._afi = DummyFeedbackInterface() - def visualization_generation_enabled(self) -> bool: - return self._visualization_generation + def get_afi(self) -> AbstractFeedbackInterface: + return self._afi - def get_io_driver(self) -> IODriver: + def get_mutator(self, mutator_id: str, mutator_config: ConfigDict, /) -> IMutator: - if self._io_driver is None: - raise ErrInternal( - "while trying to access officialeye's IO driver.", - "The present officialeye context does not have an IO Driver set." + # TODO: (low priority) consider caching mutators that have the same id and configuration + + if mutator_id not in self._mutator_factories: + raise ErrInvalidKey( + f"while loading mutator '{mutator_id}'.", + "Unknown mutator. Has this mutator been properly loaded?" ) - return self._io_driver + return self._mutator_factories[mutator_id](mutator_config) + + def get_matcher(self, matcher_id: str, matcher_config: ConfigDict, /) -> IMatcher: - def set_io_driver(self, io_driver: IODriver, /): - assert io_driver is not None - self._io_driver = io_driver + # TODO: (low priority) consider caching matchers that have the same id and configuration + + if matcher_id not in self._matcher_factories: + raise ErrInvalidKey( + f"while loading matcher '{matcher_id}'.", + "Unknown matcher. Has this matcher been properly loaded?" + ) - def add_template(self, template: Template, /): + return self._matcher_factories[matcher_id](matcher_config) - if template.template_id in self._loaded_templates: + def get_supervisor(self, supervisor_id: str, supervisor_config: ConfigDict, /) -> ISupervisor: + + # TODO: (low priority) consider caching supervisors that have the same id and configuration + + if supervisor_id not in self._supervisor_factories: + raise ErrInvalidKey( + f"while loading supervisor '{supervisor_id}'.", + "Unknown supervisor. Has this supervisor been properly loaded?" + ) + + return self._supervisor_factories[supervisor_id](supervisor_config) + + def get_interpretation(self, interpretation_id: str, interpretation_config: ConfigDict, /) -> IInterpretation: + + # TODO: (low priority) consider caching interpretations that have the same id and configuration + + if interpretation_id not in self._interpretation_factories: + raise ErrInvalidKey( + f"while loading interpretation '{interpretation_id}'.", + "Unknown interpretation. Has this interpretation method been properly loaded?" + ) + + return self._interpretation_factories[interpretation_id](interpretation_config) + + def add_template(self, template: InternalTemplate, /): + + template_path = template.get_path() + + assert template_path not in self._template_ids, "A template from the same path has already been loaded" + + if template.identifier in self._loaded_templates: raise ErrTemplateIdNotUnique( - f"while loading template '{template.template_id}'", + f"while loading template '{template.identifier}'", "A template with the same id has already been loaded." ) - self._loaded_templates[template.template_id] = template + self._loaded_templates[template.identifier] = template + self._template_ids[template_path] = template.identifier try: template.validate() except OEError as err: # rollback the loaded template - del self._loaded_templates[template.template_id] + del self._loaded_templates[template.identifier] + del self._template_ids[template_path] + # reraise the cause raise err - def get_template(self, template_id: str, /) -> Template: + def get_template(self, template_id: str, /) -> InternalTemplate: assert template_id in self._loaded_templates, "Unknown template id" return self._loaded_templates[template_id] - def _allocate_file_name(self) -> str: - file_name = "%03d.png" % self._export_counter - self._export_counter += 1 - return file_name - - def allocate_file_for_export(self, /, *, file_name: str = "") -> str: - - if self._manager.export_directory is None: - with tempfile.NamedTemporaryFile(prefix="officialeye_", suffix=".png", delete=False) as fp: - fp.close() - self._not_deleted_temporary_files.append(fp.name) - return fp.name - - if file_name == "": - file_name = self._allocate_file_name() - - return os.path.join(self._manager.export_directory, file_name) - - def export_image(self, img: cv2.Mat, /, *, file_name: str = "") -> str: - export_file_path = self.allocate_file_for_export(file_name=file_name) - cv2.imwrite(export_file_path, img) - get_logger().info(f"Exported '{export_file_path}'.") - return export_file_path - - def _export_and_show_image(self, img: cv2.Mat, /, *, file_name: str = ""): - path = self.export_image(img, file_name=file_name) - click.launch(path, locate=False) - click.pause() - - def export_primary_image(self, img: cv2.Mat, /, *, file_name: str = ""): - if get_logger().quiet_mode: - # just save the image, do not export - self.export_image(img, file_name=file_name) - else: - self._export_and_show_image(img, file_name=file_name) - - def _cleanup_temporary_files(self): - while len(self._not_deleted_temporary_files) > 0: - temp_file = self._not_deleted_temporary_files.pop() - if os.path.exists(temp_file): - os.unlink(temp_file) - - def dispose(self): - self._cleanup_temporary_files() + def get_template_by_path(self, template_path: str, /) -> InternalTemplate | None: + + if template_path not in self._template_ids: + return None + + template_id = self._template_ids[template_path] + + return self.get_template(template_id) diff --git a/src/officialeye/_internal/context/feedback.py b/src/officialeye/_internal/context/feedback.py new file mode 100644 index 0000000..e2d24c4 --- /dev/null +++ b/src/officialeye/_internal/context/feedback.py @@ -0,0 +1,73 @@ +import enum +from concurrent.futures import Future + +# noinspection PyProtectedMember +from multiprocessing.connection import Connection +from types import TracebackType +from typing import Any + +# noinspection PyProtectedMember +from officialeye._internal.feedback.abstract import AbstractFeedbackInterface + +# noinspection PyProtectedMember +from officialeye._internal.feedback.verbosity import Verbosity + + +class IPCMessageType(enum.IntEnum): + ECHO = 0 + INFO = 1 + WARN = 2 + ERROR = 3 + + UPDATE_STATUS = 4 + # means that the task is completed in any way, including throwing an exception + TASK_DONE = 5 + + +class InternalFeedbackInterface(AbstractFeedbackInterface): + + def __init__(self, verbosity: Verbosity, child_id: int, tx: Connection, /): + super().__init__(verbosity) + + assert tx is not None + + self._child_id = child_id + self._tx = tx + + def get_child_id(self) -> int: + return self._child_id + + def _send_ipc_message(self, message_type: IPCMessageType, *args, **kwargs): + ipc_message = (message_type, args, kwargs) + self._tx.send(ipc_message) + + def echo(self, *args: Any, **kwargs: Any) -> None: + self._send_ipc_message(IPCMessageType.ECHO, *args, **kwargs) + + def info(self, *args: Any, **kwargs: Any) -> None: + self._send_ipc_message(IPCMessageType.INFO, *args, **kwargs) + + def warn(self, *args: Any, **kwargs: Any) -> None: + self._send_ipc_message(IPCMessageType.WARN, *args, **kwargs) + + def error(self, *args: Any, **kwargs: Any) -> None: + self._send_ipc_message(IPCMessageType.ERROR, *args, **kwargs) + + def update_status(self, new_status_text: str, /) -> None: + self._send_ipc_message(IPCMessageType.UPDATE_STATUS, new_status_text) + + def dispose(self, exception_type: any = None, exception_value: BaseException | None = None, traceback: TracebackType | None = None) -> None: + + task_done_successfully: bool = exception_value is None + + self._send_ipc_message(IPCMessageType.TASK_DONE, task_done_successfully) + + self._tx.close() + + def fork(self, description: str, /) -> AbstractFeedbackInterface: + # the internal feedback interface isn't meant to be forked + raise NotImplementedError() + + def join(self, child: AbstractFeedbackInterface, future: Future, /) -> None: + # the internal feedback interface isn't meant to be forked, so it cannot be joined either + raise NotImplementedError() diff --git a/src/officialeye/_internal/context/manager.py b/src/officialeye/_internal/context/manager.py deleted file mode 100644 index 8eddb75..0000000 --- a/src/officialeye/_internal/context/manager.py +++ /dev/null @@ -1,66 +0,0 @@ -from types import TracebackType -from typing import Union - -from officialeye._internal.context.context import Context -from officialeye._internal.error.error import OEError -from officialeye._internal.error.errors.internal import ErrInternal - - -class ContextManager: - - def __init__(self, /, *, handle_exceptions: bool = True, visualization_generation: bool = False, - export_directory: Union[str, None] = None): - - self._context: Union[Context, None] = None - - self.handle_exceptions = handle_exceptions - - self.visualization_generation = visualization_generation - - self.export_directory = export_directory - - def __enter__(self) -> Context: - - if self._context is not None: - raise ErrInternal( - "while entering an officialeye context.", - "The present context manager has already got an associated context. Are you trying to reuse the context manager?" - ) - - self._context = Context(self, visualization_generation=self.visualization_generation) - - return self._context - - def __exit__(self, exception_type: any, exception_value: BaseException, traceback: TracebackType): - - if self._context is None: - raise ErrInternal( - "while leaving an officialeye context.", - "The present context manager has no context associated with it." - ) - - self._context.dispose() - - if not self.handle_exceptions: - return - - # handle the possible exception - if exception_value is None: - # there is no exception, nothing to handle - return - - if isinstance(exception_value, OEError): - oe_error = exception_value - elif isinstance(exception_value, BaseException): - oe_error = ErrInternal( - "while leaving an officialeye context.", - "An internal error occurred.", - ) - oe_error.add_external_cause(exception_value) - else: - raise AssertionError() - - self._context.get_io_driver().handle_error(oe_error) - - # tell python that we have handled the exception ourselves - return True diff --git a/src/officialeye/_internal/context/singleton.py b/src/officialeye/_internal/context/singleton.py new file mode 100644 index 0000000..f27628d --- /dev/null +++ b/src/officialeye/_internal/context/singleton.py @@ -0,0 +1,13 @@ +from officialeye._internal.context.context import InternalContext +from officialeye._internal.feedback.abstract import AbstractFeedbackInterface + +_internal_context: InternalContext = InternalContext() + + +def get_internal_context() -> InternalContext: + global _internal_context + return _internal_context + + +def get_internal_afi() -> AbstractFeedbackInterface: + return get_internal_context().get_afi() diff --git a/src/officialeye/_internal/diffobject/difference_expansion.py b/src/officialeye/_internal/diffobject/difference_expansion.py index 82862ee..c96a24d 100644 --- a/src/officialeye/_internal/diffobject/difference_expansion.py +++ b/src/officialeye/_internal/diffobject/difference_expansion.py @@ -1,10 +1,11 @@ from typing import Dict +from officialeye._internal.context.singleton import get_internal_afi from officialeye._internal.diffobject.difference_modes import DIFF_MODE_ADD, DIFF_MODE_OVERRIDE, DIFF_MODE_REMOVE from officialeye._internal.diffobject.exception import DiffObjectException from officialeye._internal.diffobject.specification import DiffObjectSpecification from officialeye._internal.diffobject.specification_entry import DiffObjectSpecificationEntry -from officialeye._internal.logger.singleton import get_logger +from officialeye._internal.feedback.verbosity import Verbosity class DiffObjectExpansion: @@ -57,8 +58,11 @@ def _do_add(specification_dict: Dict[str, any], full_key = f"{previous_keys}{key}" - get_logger().debug_verbose(f"Key: '{full_key}' Specification value: {specification_entry} " - f"Object value: {object_value} Current value: {current_value}") + get_internal_afi().info( + Verbosity.DEBUG_VERBOSE, + f"Key: '{full_key}' Specification value: {specification_entry} " + f"Object value: {object_value} Current value: {current_value}" + ) if isinstance(specification_entry, dict): # the specification says that there is a nested dictionary at the present key. diff --git a/src/officialeye/_internal/error/errors/internal.py b/src/officialeye/_internal/error/errors/internal.py deleted file mode 100644 index 753601f..0000000 --- a/src/officialeye/_internal/error/errors/internal.py +++ /dev/null @@ -1,9 +0,0 @@ -from officialeye._internal.error.codes import ERR_INTERNAL -from officialeye._internal.error.error import OEError -from officialeye._internal.error.modules import ERR_MODULE_INTERNAL - - -class ErrInternal(OEError): - - def __init__(self, while_text: str, problem_text: str, /): - super().__init__(ERR_MODULE_INTERNAL, ERR_INTERNAL[0], ERR_INTERNAL[1], while_text, problem_text, is_regular=False) diff --git a/src/officialeye/_internal/feedback/__init__.py b/src/officialeye/_internal/feedback/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/officialeye/_internal/feedback/abstract.py b/src/officialeye/_internal/feedback/abstract.py new file mode 100644 index 0000000..88e770b --- /dev/null +++ b/src/officialeye/_internal/feedback/abstract.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from concurrent.futures import Future +from types import TracebackType +from typing import TYPE_CHECKING, Any + +from officialeye._internal.feedback.verbosity import Verbosity + +if TYPE_CHECKING: + from officialeye._internal._types import RichProtocol + + +class AbstractFeedbackInterface(ABC): + + def __init__(self, verbosity: Verbosity, /): + self._verbosity = verbosity + + @abstractmethod + def echo( + self, + verbosity: Verbosity, + message: str | RichProtocol = "", /, *, + err: bool = False, + **kwargs: Any + ) -> None: + raise NotImplementedError() + + @abstractmethod + def info(self, verbosity: Verbosity, message: str, /) -> None: + raise NotImplementedError() + + @abstractmethod + def warn(self, verbosity: Verbosity, message: str, /) -> None: + raise NotImplementedError() + + @abstractmethod + def error(self, verbosity: Verbosity, message: str, /) -> None: + raise NotImplementedError() + + @abstractmethod + def update_status(self, new_status_text: str, /) -> None: + raise NotImplementedError() + + @abstractmethod + def dispose(self, exception_type: any = None, exception_value: BaseException | None = None, traceback: TracebackType | None = None) -> None: + raise NotImplementedError() + + @abstractmethod + def fork(self, description: str, /) -> AbstractFeedbackInterface: + raise NotImplementedError() + + @abstractmethod + def join(self, child: AbstractFeedbackInterface, future: Future, /) -> None: + raise NotImplementedError() diff --git a/src/officialeye/_internal/feedback/dummy.py b/src/officialeye/_internal/feedback/dummy.py new file mode 100644 index 0000000..3b5d567 --- /dev/null +++ b/src/officialeye/_internal/feedback/dummy.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from concurrent.futures import Future +from types import TracebackType +from typing import TYPE_CHECKING, Any + +from officialeye._internal.feedback.abstract import AbstractFeedbackInterface +from officialeye._internal.feedback.verbosity import Verbosity + +if TYPE_CHECKING: + from officialeye._internal._types import RichProtocol + + +class DummyFeedbackInterface(AbstractFeedbackInterface): + + def __init__(self, /): + super().__init__(Verbosity.QUIET) + + def echo(self, verbosity: Verbosity, message: str | RichProtocol = "", /, *, err: bool = False, **kwargs: Any) -> None: + pass + + def info(self, verbosity: Verbosity, message: str, /) -> None: + pass + + def warn(self, verbosity: Verbosity, message: str, /) -> None: + pass + + def error(self, verbosity: Verbosity, message: str, /) -> None: + pass + + def update_status(self, new_status_text: str, /) -> None: + pass + + def dispose(self, exception_type: any = None, exception_value: BaseException | None = None, traceback: TracebackType | None = None) -> None: + pass + + def fork(self, description: str, /) -> AbstractFeedbackInterface: + return DummyFeedbackInterface() + + def join(self, child: AbstractFeedbackInterface, future: Future, /) -> None: + pass + + diff --git a/src/officialeye/_internal/feedback/verbosity.py b/src/officialeye/_internal/feedback/verbosity.py new file mode 100644 index 0000000..2a7b04f --- /dev/null +++ b/src/officialeye/_internal/feedback/verbosity.py @@ -0,0 +1,9 @@ +import enum + + +class Verbosity(enum.IntEnum): + QUIET = -1 + INFO = 0 + INFO_VERBOSE = 1 + DEBUG = 2 + DEBUG_VERBOSE = 3 diff --git a/src/officialeye/_internal/interpretation/__init__.py b/src/officialeye/_internal/interpretation/__init__.py deleted file mode 100644 index 4143b9b..0000000 --- a/src/officialeye/_internal/interpretation/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Module containing everything related to interpretation methods. -""" diff --git a/src/officialeye/_internal/interpretation/config.py b/src/officialeye/_internal/interpretation/config.py deleted file mode 100644 index 39564ca..0000000 --- a/src/officialeye/_internal/interpretation/config.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Dict - -from officialeye._internal.config.config import Config -from officialeye._internal.error.errors.template import ErrTemplateInvalidInterpretation - - -class InterpretationMethodConfig(Config): - - def __init__(self, config_dict: Dict[str, any], interpretation_method: str, /): - - super().__init__(config_dict) - - self._interpretation_method = interpretation_method - - def _get_invalid_key_error(self, key: str, /): - - return ErrTemplateInvalidInterpretation( - f"while reading configuration of the '{self._interpretation_method}' interpretation method.", - f"Could not find a value for key '{key}'." - ) diff --git a/src/officialeye/_internal/interpretation/loader.py b/src/officialeye/_internal/interpretation/loader.py deleted file mode 100644 index 9250e35..0000000 --- a/src/officialeye/_internal/interpretation/loader.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Dict - -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.template import ErrTemplateInvalidInterpretation -from officialeye._internal.interpretation.method import InterpretationMethod -from officialeye._internal.interpretation.methods.file import FileMethod -from officialeye._internal.interpretation.methods.file_temp import FileTempMethod -from officialeye._internal.interpretation.methods.ocr_tesseract import TesseractMethod - - -def load_interpretation_method(context: Context, method_id: str, config_dict: Dict[str, any], /) -> InterpretationMethod: - - if method_id == TesseractMethod.METHOD_ID: - return TesseractMethod(context, config_dict) - - if method_id == FileTempMethod.METHOD_ID: - return FileTempMethod(context, config_dict) - - if method_id == FileMethod.METHOD_ID: - return FileMethod(context, config_dict) - - raise ErrTemplateInvalidInterpretation( - f"while loading interpretation method '{method_id}'.", - "Unknown interpretation method id." - ) diff --git a/src/officialeye/_internal/interpretation/method.py b/src/officialeye/_internal/interpretation/method.py deleted file mode 100644 index f8dbdbd..0000000 --- a/src/officialeye/_internal/interpretation/method.py +++ /dev/null @@ -1,26 +0,0 @@ -import abc -from typing import Dict - -import cv2 - -from officialeye._internal.context.context import Context -from officialeye._internal.interpretation.config import InterpretationMethodConfig -from officialeye._internal.interpretation.serializable import Serializable - - -class InterpretationMethod(abc.ABC): - - def __init__(self, context: Context, method_id: str, config_dict: Dict[str, any], /): - super().__init__() - - self._context = context - self.method_id = method_id - - self._config = InterpretationMethodConfig(config_dict, method_id) - - def get_config(self) -> InterpretationMethodConfig: - return self._config - - @abc.abstractmethod - def interpret(self, feature_img: cv2.Mat, template_id: str, feature_id: str, /) -> Serializable: - raise NotImplementedError() diff --git a/src/officialeye/_internal/interpretation/methods/file.py b/src/officialeye/_internal/interpretation/methods/file.py deleted file mode 100644 index e6b784d..0000000 --- a/src/officialeye/_internal/interpretation/methods/file.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -from typing import Dict - -import cv2 - -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.template import ErrTemplateInvalidInterpretation -from officialeye._internal.interpretation.method import InterpretationMethod -from officialeye._internal.interpretation.serializable import Serializable - - -class FileMethod(InterpretationMethod): - - METHOD_ID = "file" - - def __init__(self, context: Context, config_dict: Dict[str, any]): - super().__init__(context, FileMethod.METHOD_ID, config_dict) - - self._path = self.get_config().get("path") - - def interpret(self, feature_img: cv2.Mat, template_id: str, feature_id: str, /) -> Serializable: - - feature = self._context.get_template(template_id).get_feature(feature_id) - - feature_class_generator = feature.get_feature_class().get_features() - - # check if the generator generates at least two elements - if sum(1 for _ in zip(range(2), feature_class_generator, strict=False)) == 2: - raise ErrTemplateInvalidInterpretation( - "while applying the '{FileMethod.METHOD_ID}' interpretation method.", - "This method cannot be applied if there are at least two features inheriting the feature class defining this method." - ) - - os.makedirs(os.path.dirname(self._path), exist_ok=True) - - cv2.imwrite(self._path, feature_img) - - return None diff --git a/src/officialeye/_internal/interpretation/methods/file_temp.py b/src/officialeye/_internal/interpretation/methods/file_temp.py deleted file mode 100644 index 1915436..0000000 --- a/src/officialeye/_internal/interpretation/methods/file_temp.py +++ /dev/null @@ -1,27 +0,0 @@ -import tempfile -from typing import Dict - -import cv2 - -from officialeye._internal.context.context import Context -from officialeye._internal.interpretation.method import InterpretationMethod -from officialeye._internal.interpretation.serializable import Serializable - - -class FileTempMethod(InterpretationMethod): - - METHOD_ID = "file_temp" - - def __init__(self, context: Context, config_dict: Dict[str, any]): - super().__init__(context, FileTempMethod.METHOD_ID, config_dict) - - self._format = self.get_config().get("format", default="png") - - def interpret(self, feature_img: cv2.Mat, template_id: str, feature_id: str, /) -> Serializable: - - with tempfile.NamedTemporaryFile(prefix="officialeye_", suffix=f".{self._format}", delete=False) as fp: - fp.close() - - cv2.imwrite(fp.name, feature_img) - - return fp.name diff --git a/src/officialeye/_internal/interpretation/methods/ocr_tesseract.py b/src/officialeye/_internal/interpretation/methods/ocr_tesseract.py deleted file mode 100644 index 5151e17..0000000 --- a/src/officialeye/_internal/interpretation/methods/ocr_tesseract.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Dict - -import cv2 -from pytesseract import pytesseract - -from officialeye._internal.context.context import Context -from officialeye._internal.interpretation.method import InterpretationMethod -from officialeye._internal.interpretation.serializable import Serializable - - -class TesseractMethod(InterpretationMethod): - - METHOD_ID = "ocr_tesseract" - - def __init__(self, context: Context, config_dict: Dict[str, any]): - super().__init__(context, TesseractMethod.METHOD_ID, config_dict) - - self._tesseract_lang = self.get_config().get("lang", default="eng") - self._tesseract_config = self.get_config().get("config", default="") - - def interpret(self, feature_img: cv2.Mat, template_id: str, feature_id: str, /) -> Serializable: - return pytesseract.image_to_string(feature_img, lang=self._tesseract_lang, config=self._tesseract_config).strip() diff --git a/src/officialeye/_internal/interpretation/serializable.py b/src/officialeye/_internal/interpretation/serializable.py deleted file mode 100644 index 4e87a6b..0000000 --- a/src/officialeye/_internal/interpretation/serializable.py +++ /dev/null @@ -1,3 +0,0 @@ -from typing import TypeAlias - -Serializable: TypeAlias = dict[str, "Serializable"] | list["Serializable"] | str | int | float | bool | None diff --git a/src/officialeye/_internal/io/__init__.py b/src/officialeye/_internal/io/__init__.py deleted file mode 100644 index b8cee4b..0000000 --- a/src/officialeye/_internal/io/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -Module responsible for handling IO (input and output). -In this context, IO means everything related to the way we output complex data. -Simply logging does not count as -""" \ No newline at end of file diff --git a/src/officialeye/_internal/io/driver.py b/src/officialeye/_internal/io/driver.py deleted file mode 100644 index 2c3c19f..0000000 --- a/src/officialeye/_internal/io/driver.py +++ /dev/null @@ -1,27 +0,0 @@ -import abc -from abc import ABC - -import cv2 - -from officialeye._internal.context.context import Context -from officialeye._internal.error.error import OEError -from officialeye._internal.supervision.result import SupervisionResult -from officialeye._internal.template.template import Template - - -class IODriver(ABC): - - def __init__(self, context: Context): - self._context = context - - @abc.abstractmethod - def handle_supervision_result(self, target: cv2.Mat, result: SupervisionResult, /): - raise NotImplementedError() - - @abc.abstractmethod - def handle_show_result(self, template: Template, img: cv2.Mat, /): - raise NotImplementedError() - - @abc.abstractmethod - def handle_error(self, error: OEError, /): - raise NotImplementedError() diff --git a/src/officialeye/_internal/io/drivers/__init__.py b/src/officialeye/_internal/io/drivers/__init__.py deleted file mode 100644 index 5c41285..0000000 --- a/src/officialeye/_internal/io/drivers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -This module contains all IO Driver's built into OfficialEye. -""" \ No newline at end of file diff --git a/src/officialeye/_internal/io/drivers/run.py b/src/officialeye/_internal/io/drivers/run.py deleted file mode 100644 index 61b7f25..0000000 --- a/src/officialeye/_internal/io/drivers/run.py +++ /dev/null @@ -1,64 +0,0 @@ -import json -import sys - -import cv2 - -from officialeye._internal.context.context import Context -from officialeye._internal.error.error import OEError -from officialeye._internal.error.errors.io import ErrIOOperationNotSupportedByDriver -from officialeye._internal.io.driver import IODriver -from officialeye._internal.supervision.result import SupervisionResult -from officialeye._internal.template.template import Template - - -def _output_dict(d: dict): - json.dump(d, sys.stdout, indent=4, ensure_ascii=False) - sys.stdout.write("\n") - sys.stdout.flush() - - -class RunIODriver(IODriver): - - def __init__(self, context: Context, /): - super().__init__(context) - - def handle_show_result(self, template: Template, img: cv2.Mat, /): - raise ErrIOOperationNotSupportedByDriver( - f"while trying to output the result of showing the template '{template.template_id}'", - "Driver 'run' does not support this operation." - ) - - def handle_error(self, error: OEError, /): - _output_dict({ - "ok": False, - "err": error.serialize() - }) - - def handle_supervision_result(self, target: cv2.Mat, result: SupervisionResult, /): - - assert result is not None - - template = self._context.get_template(result.template_id) - - feature_interpretation_dict = {} - - # extract the features from the target image - for feature in template.features(): - - feature_class = feature.get_feature_class() - - if feature_class is None: - continue - - feature_img = result.get_feature_warped_region(target, feature) - feature_img_mutated = feature.apply_mutators_to_image(feature_img) - interpretation = feature.interpret_image(feature_img_mutated) - - feature_interpretation_dict[feature.region_id] = interpretation - - _output_dict({ - "ok": True, - "template": result.template_id, - "score": result.get_score(), - "features": feature_interpretation_dict - }) diff --git a/src/officialeye/_internal/io/drivers/test.py b/src/officialeye/_internal/io/drivers/test.py deleted file mode 100644 index 08d88c4..0000000 --- a/src/officialeye/_internal/io/drivers/test.py +++ /dev/null @@ -1,55 +0,0 @@ -import cv2 - -from officialeye._internal.context.context import Context -from officialeye._internal.error.error import OEError -from officialeye._internal.io.driver import IODriver -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.supervision.result import SupervisionResult -from officialeye._internal.template.template import Template - - -class TestIODriver(IODriver): - - def __init__(self, context: Context, /): - super().__init__(context) - - self.visualize_features: bool = True - - def handle_supervision_result(self, target: cv2.Mat, result: SupervisionResult, /): - - assert result is not None - - template = self._context.get_template(result.template_id) - - application_image = template.load_source_image() - - # extract the features from the target image - for feature in template.features(): - feature_img = result.get_feature_warped_region(target, feature) - - feature_img_mutated = feature.apply_mutators_to_image(feature_img) - - if feature_img.shape == feature_img_mutated.shape: - # mutators didn't change the shape of the image - feature.insert_into_image(application_image, feature_img_mutated) - else: - # some mutator has altered the shape of the feature image. - # this means that we can no longer safely insert the mutated feature into the visualization. - # therefore, we have to fall back to inserting the feature image unmutated - get_logger().warn(f"Could not visualize the '{feature.region_id}' feature of the '{feature.get_template().template_id}' template, " - f"because one of the mutators (corresponding to this feature) did not preserve the shape of the image. " - f"Falling back to the non-mutated version of the feature image.") - feature.insert_into_image(application_image, feature_img) - - if self.visualize_features: - # visualize features on the image - for feature in template.features(): - application_image = feature.visualize(application_image) - - self._context.export_primary_image(application_image, file_name="supervision_result.png") - - def handle_show_result(self, template: Template, img: cv2.Mat, /): - self._context.export_primary_image(img, file_name=f"{template.template_id}.png") - - def handle_error(self, error: OEError, /): - get_logger().error_oe_error(error) diff --git a/src/officialeye/_internal/logger/__init__.py b/src/officialeye/_internal/logger/__init__.py deleted file mode 100644 index c01651d..0000000 --- a/src/officialeye/_internal/logger/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Module responsible for all CLI logging. -""" \ No newline at end of file diff --git a/src/officialeye/_internal/logger/logger.py b/src/officialeye/_internal/logger/logger.py deleted file mode 100644 index b80f50b..0000000 --- a/src/officialeye/_internal/logger/logger.py +++ /dev/null @@ -1,114 +0,0 @@ -from typing import Callable - -import click - -from officialeye._internal.error.error import OEError -from officialeye.meta import OFFICIALEYE_CLI_LOGO - - -def _do_print_oe_error(output_func: Callable[[str], None], error: OEError, /): - output_func(f"Error {error.code} in module {error.module}: {error.code_text}") - output_func(f"Error occurred {error.while_text}") - output_func(f"Problem: {error.problem_text}") - - causes = error.get_causes() - external_causes = error.get_external_causes() - - if len(causes) + len(external_causes) > 0: - output_func("The above error has been caused by the following error(s).") - - for cause in causes: - _do_print_oe_error(output_func, cause) - - for external_cause in external_causes: - output_func(str(external_cause)) - - -class Logger: - - def __init__(self, /, *, debug_mode: bool = False, quiet_mode: bool = False, verbose_mode: bool = False, disable_logo: bool = False): - - self.debug_mode = debug_mode - self.quiet_mode = quiet_mode - self.verbose_mode = verbose_mode - self.disable_logo = disable_logo - - def debug(self, msg, *args, prefix: bool = True, **kwargs): - - if not self.debug_mode: - return - - if self.quiet_mode: - return - - if prefix: - click.secho("DEBUG", bold=True, bg="yellow", nl=False) - click.echo(" ", nl=False) - - click.secho(msg, *args, **kwargs) - - def info(self, msg: any, *args, prefix: bool = True, **kwargs): - - if self.quiet_mode: - return - - if prefix: - click.secho("INFO", bold=True, bg="blue", nl=False) - click.echo(" ", nl=False) - - click.secho(msg, *args, **kwargs) - - def warn(self, msg: any, *args, prefix: bool = True, **kwargs): - - if self.quiet_mode: - return - - if prefix: - click.secho("WARN", bold=True, bg="yellow", nl=False) - click.echo(" ", nl=False) - - click.secho(msg, *args, **kwargs) - - def debug_verbose(self, msg: any, *args, prefix: bool = True, **kwargs): - - if not self.debug_mode: - return - - if self.quiet_mode: - return - - if not self.verbose_mode: - return - - if prefix: - click.secho("DEBUG", bold=True, bg="yellow", nl=False) - click.echo(" ", nl=False) - - click.secho(msg, *args, **kwargs) - - def error(self, msg: any, *args, prefix: bool = True, **kwargs): - - if self.quiet_mode: - return - - for line in msg.splitlines(): - if prefix: - click.secho("ERROR", bold=True, bg="red", nl=False, err=True) - click.echo(" ", nl=False, err=True) - click.secho(line, *args, **kwargs, err=True) - - def logo(self): - - if self.quiet_mode: - return - - if self.disable_logo: - return - - click.secho(OFFICIALEYE_CLI_LOGO, fg="red") - - def error_oe_error(self, error: OEError, /): - _do_print_oe_error(self.error, error) - - def debug_oe_error(self, error: OEError, /): - _do_print_oe_error(self.debug, error) diff --git a/src/officialeye/_internal/logger/singleton.py b/src/officialeye/_internal/logger/singleton.py deleted file mode 100644 index e0faa25..0000000 --- a/src/officialeye/_internal/logger/singleton.py +++ /dev/null @@ -1,7 +0,0 @@ -from officialeye._internal.logger.logger import Logger - -_logger = Logger() - - -def get_logger() -> Logger: - return _logger diff --git a/src/officialeye/_internal/main.py b/src/officialeye/_internal/main.py deleted file mode 100644 index 02fc024..0000000 --- a/src/officialeye/_internal/main.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -OfficialEye main entry point. -""" - -from typing import List, Union - -import click -import cv2 - -from officialeye._internal.context.manager import ContextManager -from officialeye._internal.io.drivers.run import RunIODriver -from officialeye._internal.io.drivers.test import TestIODriver -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.template.analyze import do_analyze -from officialeye._internal.template.create import create_example_template_config_file -from officialeye._internal.template.schema.loader import load_template -from officialeye.meta import OFFICIALEYE_GITHUB, OFFICIALEYE_VERSION - -_context_manager: ContextManager = ContextManager() - - -@click.group() -@click.option("-d", "--debug", is_flag=True, show_default=True, default=False, help="Enable debug mode.") -@click.option("--edir", type=click.Path(exists=True, file_okay=True, readable=True), help="Specify export directory.") -@click.option("-q", "--quiet", is_flag=True, show_default=True, default=False, help="Disable standard output messages.") -@click.option("-v", "--verbose", is_flag=True, show_default=True, default=False, help="Enable verbose logging.") -@click.option("-dl", "--disable-logo", is_flag=True, show_default=True, default=False, help="Disable the officialeye logo.") -@click.option("-re", "--raw-errors", is_flag=True, show_default=False, default=False, help="Do not handle errors.") -def cli(debug: bool, edir: str, quiet: bool, verbose: bool, disable_logo: bool, raw_errors: bool): - global _context_manager - - # configure logger - get_logger().debug_mode = debug - get_logger().quiet_mode = quiet - get_logger().verbose_mode = verbose - get_logger().disable_logo = disable_logo - - # print OfficialEye logo if necessary - get_logger().logo() - - # configure context manager - if edir is not None: - _context_manager.export_directory = edir - - if raw_errors: - get_logger().warn("Raw error mode enabled. Use this mode only if you know precisely what you are doing!") - _context_manager.handle_exceptions = False - - # print preliminary warning if necessary - if get_logger().debug_mode: - get_logger().warn("Debug mode enabled. Disable for production use to improve performance.") - - -# noinspection PyShadowingBuiltins -@click.command() -@click.argument("template_path", type=click.Path(exists=False, file_okay=True, readable=True, writable=True)) -@click.argument("template_image", type=click.Path(exists=True, file_okay=True, readable=True, writable=False)) -@click.option("--id", type=str, show_default=False, default="example", help="Specify the template identifier.") -@click.option("--name", type=str, show_default=False, default="Example", help="Specify the template name.") -@click.option("--force", is_flag=True, show_default=True, default=False, help="Create missing directories and overwrite file.") -def create(template_path: str, template_image: str, id: str, name: str, force: bool): - """Creates a new template configuration file at the specified path.""" - create_example_template_config_file(template_path, template_image, id, name, force) - - -@click.command() -@click.argument("template_path", type=click.Path(exists=True, file_okay=True, readable=True)) -@click.option("--hide-features", is_flag=True, show_default=False, default=False, help="Do not visualize the locations of features.") -@click.option("--hide-keypoints", is_flag=True, show_default=False, default=False, help="Do not visualize the locations of keypoints.") -def show(template_path: str, hide_features: bool, hide_keypoints: bool): - """Exports template as an image with features visualized.""" - - global _context_manager - - with _context_manager as oe_context: - # setup IO driver - oe_context.set_io_driver(TestIODriver(oe_context)) - - # load template - template = load_template(oe_context, template_path) - - # render resulting image - img = template.show(hide_features=hide_features, hide_keypoints=hide_keypoints) - - # show rendered image - oe_context.get_io_driver().handle_show_result(template, img) - - -@click.command() -@click.argument("target_path", type=click.Path(exists=True, file_okay=True, readable=True)) -@click.argument("template_paths", type=click.Path(exists=True, file_okay=True, readable=True), nargs=-1) -@click.option("--workers", type=int, default=4, show_default=True, help="Specify number of threads to use for the pool of workers.") -@click.option("--interpret", type=click.Path(exists=True, file_okay=True, readable=True), - default=None, help="Use the image at the specified path to run the interpretation phase.") -@click.option("--show-features", is_flag=True, show_default=False, default=False, help="Visualize the locations of features.") -@click.option("--visualize", is_flag=True, show_default=False, default=False, help="Generate visualizations of intermediate steps.") -def test(target_path: str, template_paths: List[str], workers: int, interpret: Union[str, None], show_features: bool, visualize: bool): - """Visualizes the analysis of an image using one or more templates.""" - - global _context_manager - - _context_manager.visualization_generation = visualize - - if _context_manager.visualization_generation: - get_logger().warn("Visualization generation mode enabled. Disable for production use to improve performance.") - - with (_context_manager as oe_context): - - # setup IO driver - io_driver = TestIODriver(oe_context) - io_driver.visualize_features = show_features - oe_context.set_io_driver(io_driver) - - # load target image - target = cv2.imread(target_path, cv2.IMREAD_COLOR) - - # load interpretation target image if necessary - interpretation_target: Union[cv2.Mat, None] = \ - None if interpret is None else cv2.imread(interpret, cv2.IMREAD_COLOR) - - # load templates - templates = [load_template(oe_context, template_path) for template_path in template_paths] - - # perform analysis - do_analyze(oe_context, target, templates, num_workers=workers, interpretation_target=interpretation_target) - - -@click.command() -@click.argument("target_path", type=click.Path(exists=True, file_okay=True, readable=True)) -@click.argument("template_paths", type=click.Path(exists=True, file_okay=True, readable=True), nargs=-1) -@click.option("--workers", type=int, default=4, show_default=True, help="Specify number of threads to use for the pool of workers.") -@click.option("--interpret", type=click.Path(exists=True, file_okay=True, readable=True), - default=None, help="Use the image at the specified path to run the interpretation phase.") -@click.option("--visualize", is_flag=True, show_default=False, default=False, help="Generate visualizations of intermediate steps.") -def run(target_path: str, template_paths: List[str], workers: int, interpret: Union[str, None], visualize: bool): - """Applies one or more templates to an image.""" - - global _context_manager - - _context_manager.visualization_generation = visualize - - if _context_manager.visualization_generation: - get_logger().warn("Visualization generation mode enabled. Disable for production use to improve performance.") - - with _context_manager as oe_context: - # setup IO driver - oe_context.set_io_driver(RunIODriver(oe_context)) - - # load target image - target = cv2.imread(target_path, cv2.IMREAD_COLOR) - - # load interpretation target image if necessary - interpretation_target: Union[cv2.Mat, None] = \ - None if interpret is None else cv2.imread(interpret, cv2.IMREAD_COLOR) - - # load templates - templates = [load_template(oe_context, template_path) for template_path in template_paths] - - # perform analysis - do_analyze(oe_context, target, templates, num_workers=workers, interpretation_target=interpretation_target) - - -@click.command() -def homepage(): - """Go to the officialeye's official GitHub homepage.""" - get_logger().info(f"Opening {OFFICIALEYE_GITHUB}") - click.launch(OFFICIALEYE_GITHUB) - - -@click.command() -def version(): - """Print the version of OfficialEye.""" - get_logger().info(f"Version: {OFFICIALEYE_VERSION}") - - -cli.add_command(create) -cli.add_command(show) -cli.add_command(test) -cli.add_command(run) -cli.add_command(homepage) -cli.add_command(version) - -if __name__ == "__main__": - cli() diff --git a/src/officialeye/_internal/matching/__init__.py b/src/officialeye/_internal/matching/__init__.py deleted file mode 100644 index abfa32c..0000000 --- a/src/officialeye/_internal/matching/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Module handling everything related to matching, i.e., to the process of finding correspondences -between the template and a given image. -""" \ No newline at end of file diff --git a/src/officialeye/_internal/matching/match.py b/src/officialeye/_internal/matching/match.py deleted file mode 100644 index de40feb..0000000 --- a/src/officialeye/_internal/matching/match.py +++ /dev/null @@ -1,75 +0,0 @@ -import numpy as np - -from officialeye._internal.context.context import Context -from officialeye._internal.template.region.keypoint import TemplateKeypoint - - -class Match: - - def __init__(self, context: Context, template_id: str, keypoint_region_id: str, - region_point: np.ndarray, target_point: np.ndarray, /, *, score: float = 0.0): - - self._context = context - - self.template_id = template_id - self.keypoint_id = keypoint_region_id - - assert region_point.shape[0] == 2 - assert target_point.shape[0] == 2 - - self._region_point = region_point - self._target_point = target_point - - self._score = score - - def set_score(self, new_score: float, /): - self._score = new_score - - def get_score(self) -> float: - return self._score - - def get_template_point(self) -> np.ndarray: - return self._region_point.copy() - - def get_template(self): - return self._context.get_template(self.template_id) - - def get_keypoint(self) -> TemplateKeypoint: - return self.get_template().get_keypoint(self.keypoint_id) - - def get_original_template_point(self) -> np.ndarray: - """Returns the coordinates of the point lying in the keypoint, in the coordinate system of the underlying template.""" - return self._region_point + self.get_keypoint().get_top_left_vec() - - def get_target_point(self) -> np.ndarray: - return self._target_point.copy() - - def __lt__(self, other) -> bool: - assert isinstance(other, Match) - return self.get_score() < other.get_score() - - def __eq__(self, o): - if not isinstance(o, Match): - return False - if self.template_id != o.template_id: - return False - if self.keypoint_id != o.keypoint_id: - return False - return (np.array_equal(self._region_point, o._region_point) - and np.array_equal(self._target_point, o._target_point)) - - def __ne__(self, __value): - return not self == __value - - def __hash__(self): - return hash((self.template_id, self.keypoint_id, np.dot(self._region_point, self._target_point))) - - def __str__(self) -> str: - return "%s_%s: (%4d, %4d) <-> (%4d, %4d)" % (self.template_id, self.keypoint_id, - int(self._region_point[0]), int(self._region_point[1]), - int(self._target_point[0]), int(self._target_point[1])) - - def get_debug_identifier(self) -> str: - return "%s_%s_%04d_%04d_%04d_%04d" % (self.template_id, self.keypoint_id, - int(self._region_point[0]), int(self._region_point[1]), - int(self._target_point[0]), int(self._target_point[1])) diff --git a/src/officialeye/_internal/matching/matcher.py b/src/officialeye/_internal/matching/matcher.py deleted file mode 100644 index fd26917..0000000 --- a/src/officialeye/_internal/matching/matcher.py +++ /dev/null @@ -1,54 +0,0 @@ -import abc -from abc import ABC - -import cv2 - -from officialeye._internal.context.context import Context -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.matching.matcher_config import KeypointMatcherConfig -from officialeye._internal.matching.result import MatchingResult - - -class Matcher(ABC): - # TODO: migrate matcher to a separate module - - def __init__(self, context: Context, engine_id: str, template_id: str, img: cv2.Mat, /): - super().__init__() - - self._context = context - self._engine_id = engine_id - - self.template_id = template_id - - # retreive configurations for all keypoint matching engines - matching_config = self.get_template().get_matching_config() - - assert isinstance(matching_config, dict) - - # get the configuration for the particular engine of interest - if self._engine_id in matching_config: - config_dict = matching_config[self._engine_id] - else: - get_logger().warn( - self._context, - f"Could not find any configuration entries for the '{self._engine_id}' matching engine that is being used." - ) - config_dict = {} - - self._config = KeypointMatcherConfig(config_dict, engine_id) - - self._img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - - @abc.abstractmethod - def match_keypoint(self, pattern: cv2.Mat, keypoint_id: str, /): - raise NotImplementedError() - - @abc.abstractmethod - def match_finish(self) -> MatchingResult: - raise NotImplementedError() - - def get_template(self): - return self._context.get_template(self.template_id) - - def get_config(self) -> KeypointMatcherConfig: - return self._config diff --git a/src/officialeye/_internal/matching/matcher_config.py b/src/officialeye/_internal/matching/matcher_config.py deleted file mode 100644 index 003f216..0000000 --- a/src/officialeye/_internal/matching/matcher_config.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Dict - -from officialeye._internal.config.config import Config -from officialeye._internal.error.errors.matching import ErrMatchingInvalidEngineConfig - - -class KeypointMatcherConfig(Config): - - def __init__(self, config_dict: Dict[str, any], matching_engine_id: str, /): - super().__init__(config_dict) - - self._matching_engine_id = matching_engine_id - - def _get_invalid_key_error(self, key: str, /): - return ErrMatchingInvalidEngineConfig( - f"while reading configuration of the '{self._matching_engine_id}' matching engine.", - f"Could not find a value for key '{key}'." - ) diff --git a/src/officialeye/_internal/matching/matchers/sift_flann.py b/src/officialeye/_internal/matching/matchers/sift_flann.py deleted file mode 100644 index 51d1d43..0000000 --- a/src/officialeye/_internal/matching/matchers/sift_flann.py +++ /dev/null @@ -1,112 +0,0 @@ -import cv2 -import numpy as np - -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.matching import ErrMatchingInvalidEngineConfig -from officialeye._internal.matching.match import Match -from officialeye._internal.matching.matcher import Matcher -from officialeye._internal.matching.result import MatchingResult - -_FLANN_INDEX_KDTREE = 1 - - -class SiftFlannMatcher(Matcher): - - ENGINE_ID = "sift_flann" - - def __init__(self, context: Context, template_id: str, img: cv2.Mat, /): - super().__init__(context, SiftFlannMatcher.ENGINE_ID, template_id, img) - - def _preprocess_sensitivity(value: any) -> float: - - value = float(value) - - if value < 0.0: - raise ErrMatchingInvalidEngineConfig( - f"while loading the '{SiftFlannMatcher.ENGINE_ID}' keypoint matcher", - f"The `sensitivity` value ({self._sensitivity}) cannot be negative." - ) - - if value > 1.0: - raise ErrMatchingInvalidEngineConfig( - f"while loading the '{SiftFlannMatcher.ENGINE_ID}' keypoint matcher", - f"The `sensitivity` value ({self._sensitivity}) cannot exceed 1.0." - ) - - return value - - self.get_config().set_value_preprocessor("sensitivity", _preprocess_sensitivity) - - self._sensitivity = self.get_config().get("sensitivity", default=0.7) - - self._debug_images = [] - self._result = MatchingResult(self._context, template_id) - - # initialize the SIFT engine in CV2 - # noinspection PyUnresolvedReferences - self._sift = cv2.SIFT_create() - - # pre-compute the sift keypoints in the target image - self._keypoints_target, self._destination_target = self._sift.detectAndCompute(self._img, None) - - def match_keypoint(self, pattern: cv2.Mat, keypoint_id: str, /): - - pattern = cv2.cvtColor(pattern, cv2.COLOR_BGR2GRAY) - - keypoints_pattern, destination_pattern = self._sift.detectAndCompute(pattern, None) - - index_params = { - "algorithm": _FLANN_INDEX_KDTREE, - "trees": 5 - } - - search_params = { - "checks": 50 - } - - flann = cv2.FlannBasedMatcher(index_params, search_params) - matches = flann.knnMatch(destination_pattern, self._destination_target, k=2) - - # we need to draw only good matches, so create a mask - matches_mask = [[0, 0] for _ in range(len(matches))] - - # filter matches - for i, (m, n) in enumerate(matches): - - if m.distance >= self._sensitivity * n.distance: - continue - - matches_mask[i] = [1, 0] - - pattern_point = keypoints_pattern[m.queryIdx].pt - target_point = self._keypoints_target[m.trainIdx].pt - - # maybe one should consider rounding values here, instead of simply stripping the floating-point part - pattern_point = np.array(pattern_point, dtype=int) - target_point = np.array(target_point, dtype=int) - - match = Match(self._context, self.template_id, keypoint_id, pattern_point, target_point) - match.set_score(self._sensitivity * n.distance - m.distance) - - self._result.add_match(match) - - if self._context.visualization_generation_enabled(): - - # noinspection PyTypeChecker - debug_image = cv2.drawMatchesKnn( - pattern, - keypoints_pattern, - self._img, - self._keypoints_target, - matches, - None, - matchColor=(0, 0xff, 0), - singlePointColor=(0xff, 0, 0), - matchesMask=matches_mask, - flags=cv2.DrawMatchesFlags_DEFAULT - ) - - self._context.export_image(debug_image, file_name=f"match_{keypoint_id}.png") - - def match_finish(self): - return self._result diff --git a/src/officialeye/_internal/matching/result.py b/src/officialeye/_internal/matching/result.py deleted file mode 100644 index a87b501..0000000 --- a/src/officialeye/_internal/matching/result.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Dict, List - -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.matching import ErrMatchingMatchCountOutOfBounds -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.matching.match import Match - - -class MatchingResult: - - def __init__(self, context: Context, template_id: str, /): - self._context = context - self._template_id = template_id - - # keys: keypoint ids - # values: matches with this keypoint - self._matches_dict: Dict[str, List[Match]] = {} - - for keypoint in self.get_template().keypoints(): - self._matches_dict[keypoint.region_id] = [] - - def remove_all_matches(self): - self._matches_dict = {} - - def add_match(self, match: Match, /): - assert match.keypoint_id in self._matches_dict - self._matches_dict[match.keypoint_id].append(match) - - def get_matches(self): - for keypoint_id in self._matches_dict: - for match in self._matches_dict[keypoint_id]: - yield match - - def get_total_match_count(self) -> int: - match_count = 0 - for keypoint_id in self._matches_dict: - match_count += len(self._matches_dict[keypoint_id]) - return match_count - - def get_keypoint_ids(self): - for keypoint_id in self._matches_dict: - yield keypoint_id - - def matches_for_keypoint(self, keypoint_id: str, /): - for match in self._matches_dict[keypoint_id]: - yield match - - def get_template(self): - return self._context.get_template(self._template_id) - - def validate(self): - - get_logger().debug("Validating the keypoint matching result.") - - assert len(self._matches_dict) > 0 - - total_match_count = 0 - - # verify that for every keypoint, it has been matched a number of times that is in the desired bounds - for keypoint_id in self._matches_dict: - keypoint = self.get_template().get_keypoint(keypoint_id) - - keypoint_matches_min = keypoint.get_matches_min() - keypoint_matches_max = keypoint.get_matches_max() - - keypoint_matches_count = len(self._matches_dict[keypoint_id]) - - if keypoint_matches_count < keypoint_matches_min: - raise ErrMatchingMatchCountOutOfBounds( - f"while checking that keypoint '{keypoint_id}' of template '{self._template_id}' " - f"has been matched a sufficient number of times", - f"Expected at least {keypoint_matches_min} matches, got {keypoint_matches_count}" - ) - - if keypoint_matches_count > keypoint_matches_max: - get_logger().debug( - f"Keypoint '{keypoint_id}' of template '{self._template_id}' has too many matches " - f"(matches: {keypoint_matches_count} max: {keypoint_matches_max}). Cherry-picking the best matches.") - # cherry-pick the best matches - self._matches_dict[keypoint_id] = sorted(self._matches_dict[keypoint_id])[:keypoint_matches_max] - keypoint_matches_count = keypoint_matches_max - - get_logger().debug(f"Keypoint '{keypoint_id}' of template '{self._template_id}' has been matched {keypoint_matches_count} times " - f"(min: {keypoint_matches_min} max: {keypoint_matches_max}).") - - total_match_count += keypoint_matches_count - - assert total_match_count >= 0 - if total_match_count == 0: - raise ErrMatchingMatchCountOutOfBounds( - f"while checking that there has been at least one match for template '{self._template_id}'.", - "There have been no matches." - ) - - def debug_print(self): - get_logger().debug(f"Found {self.get_total_match_count()} matched points!") - - get_logger().debug_verbose("Listing matched points:") - for match in self.get_matches(): - get_logger().debug_verbose(f"> {match}") diff --git a/src/officialeye/_internal/mutation/__init__.py b/src/officialeye/_internal/mutation/__init__.py deleted file mode 100644 index 3e2a45c..0000000 --- a/src/officialeye/_internal/mutation/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Module capturing everything related to mutation/mutators. -""" \ No newline at end of file diff --git a/src/officialeye/_internal/mutation/config.py b/src/officialeye/_internal/mutation/config.py deleted file mode 100644 index 12c0881..0000000 --- a/src/officialeye/_internal/mutation/config.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Dict - -from officialeye._internal.config.config import Config -from officialeye._internal.error.errors.template import ErrTemplateInvalidMutator - - -class MutatorConfig(Config): - - def __init__(self, config_dict: Dict[str, any], mutator_id: str, /): - - super().__init__(config_dict) - - self._mutator_id = mutator_id - - def _get_invalid_key_error(self, key: str, /): - return ErrTemplateInvalidMutator( - f"while reading configuration of the '{self._mutator_id}' mutator.", - f"Could not find a value for key '{key}'." - ) diff --git a/src/officialeye/_internal/mutation/loader.py b/src/officialeye/_internal/mutation/loader.py deleted file mode 100644 index 22b531c..0000000 --- a/src/officialeye/_internal/mutation/loader.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Dict - -from officialeye._internal.error.errors.template import ErrTemplateInvalidMutator -from officialeye._internal.mutation.mutator import Mutator -from officialeye._internal.mutation.mutators.clahe import CLAHEMutator -from officialeye._internal.mutation.mutators.grayscale import GrayscaleMutator -from officialeye._internal.mutation.mutators.non_local_means_denoising import NonLocalMeansDenoisingMutator -from officialeye._internal.mutation.mutators.rotate import RotateMutator - - -def load_mutator(mutator_id: str, config: Dict[str, any], /) -> Mutator: - - # TODO: make a container allowing one to dynamically load mutators (add such a container to OfficialEye's context). - - if mutator_id == GrayscaleMutator.MUTATOR_ID: - return GrayscaleMutator(config) - - if mutator_id == NonLocalMeansDenoisingMutator.MUTATOR_ID: - return NonLocalMeansDenoisingMutator(config) - - if mutator_id == CLAHEMutator.MUTATOR_ID: - return CLAHEMutator(config) - - if mutator_id == RotateMutator.MUTATOR_ID: - return RotateMutator(config) - - raise ErrTemplateInvalidMutator( - f"while loading mutator '{mutator_id}'.", - "Unknown mutator id." - ) - - -def load_mutator_from_dict(mutator_dict: Dict[str, any], /) -> Mutator: - - assert "id" in mutator_dict - - mutator_id = mutator_dict["id"] - - mutator_config = mutator_dict["config"] if "config" in mutator_dict else {} - - return load_mutator(mutator_id, mutator_config) diff --git a/src/officialeye/_internal/mutation/mutator.py b/src/officialeye/_internal/mutation/mutator.py deleted file mode 100644 index 4453967..0000000 --- a/src/officialeye/_internal/mutation/mutator.py +++ /dev/null @@ -1,23 +0,0 @@ -import abc -from typing import Dict - -import cv2 - -from officialeye._internal.mutation.config import MutatorConfig - - -class Mutator(abc.ABC): - - def __init__(self, mutator_id: str, config_dict: Dict[str, any], /): - super().__init__() - - self.mutator_id = mutator_id - - self._config = MutatorConfig(config_dict, mutator_id) - - def get_config(self) -> MutatorConfig: - return self._config - - @abc.abstractmethod - def mutate(self, img: cv2.Mat, /) -> cv2.Mat: - raise NotImplementedError() diff --git a/src/officialeye/_internal/mutation/mutators/grayscale.py b/src/officialeye/_internal/mutation/mutators/grayscale.py deleted file mode 100644 index 7c3394a..0000000 --- a/src/officialeye/_internal/mutation/mutators/grayscale.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Dict - -import cv2 - -from officialeye._internal.mutation.mutator import Mutator - - -class GrayscaleMutator(Mutator): - - MUTATOR_ID = "grayscale" - - def __init__(self, config: Dict[str, any], /): - super().__init__(GrayscaleMutator.MUTATOR_ID, config) - - def mutate(self, img: cv2.Mat, /) -> cv2.Mat: - return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) diff --git a/src/officialeye/_internal/supervision/__init__.py b/src/officialeye/_internal/supervision/__init__.py index 604aebf..6d6e68f 100644 --- a/src/officialeye/_internal/supervision/__init__.py +++ b/src/officialeye/_internal/supervision/__init__.py @@ -1,3 +1,5 @@ """ Module containing everything related to supervision. -""" \ No newline at end of file +""" + +# TODO: get rid of this module and move the visualization logic into a module named in a better way diff --git a/src/officialeye/_internal/supervision/result.py b/src/officialeye/_internal/supervision/result.py deleted file mode 100644 index 689ee9b..0000000 --- a/src/officialeye/_internal/supervision/result.py +++ /dev/null @@ -1,123 +0,0 @@ -import sys -from typing import Dict, Set - -import cv2 -import numpy as np - -from officialeye._internal.matching.match import Match -from officialeye._internal.matching.result import MatchingResult - - -class SupervisionResult: - - def __init__(self, template_id: str, kmr: MatchingResult, - delta: np.ndarray, delta_prime: np.ndarray, transformation_matrix: np.ndarray, /): - - self.template_id = template_id - self._kmr = kmr - - assert delta.shape == (2,) - assert delta_prime.shape == (2,) - assert transformation_matrix.shape == (2, 2) - - # offset in the template's coordinates - self._delta = delta - # offset in the target image's coordinates - self._delta_prime = delta_prime - # self.dpo = delta_prime.copy() - - self._transformation_matrix = transformation_matrix - - # keys: matches - # values: weights assigned by the supervision engine to each match (assigning is optional) - # the higher the weight, the more we trust the correctness of the match and the greater its individual impact should be. - # by default, the weight is 1. - self._match_weights: Dict[Match, float] = {} - - # an optional value the supervision engine can set, representing how confident the engine is that the result is of high quality - self._score = 0.0 - - def get_score(self) -> float: - assert self._score >= 0.0 - return self._score - - def set_score(self, new_score: float, /): - assert new_score >= 0 - assert isinstance(new_score, float) - self._score = new_score - - def get_match_weight(self, match: Match, /) -> float: - if match in self._match_weights: - return self._match_weights[match] - return 1.0 - - def set_match_weight(self, match: Match, weight: float, /): - self._match_weights[match] = weight - - def template_point_to_target_point(self, template_point: np.ndarray, /) -> np.ndarray: - assert template_point.shape == (2,) - assert self._delta.shape == (2,) - assert self._delta_prime.shape == (2,) - return self._transformation_matrix @ (template_point - self._delta) + self._delta_prime - - def get_feature_warped_region(self, target: cv2.Mat, feature) -> cv2.Mat: - - feature_tl = feature.get_top_left_vec() - feature_tr = feature.get_top_right_vec() - feature_bl = feature.get_bottom_left_vec() - feature_br = feature.get_bottom_right_vec() - - target_tl = self.template_point_to_target_point(feature_tl) - target_tr = self.template_point_to_target_point(feature_tr) - target_bl = self.template_point_to_target_point(feature_bl) - target_br = self.template_point_to_target_point(feature_br) - - dest_tl = np.array([0, 0], dtype=np.float64) - dest_tr = np.array([feature.w, 0], dtype=np.float64) - dest_br = np.array([feature.w, feature.h], dtype=np.float64) - dest_bl = np.array([0, feature.h], dtype=np.float64) - - source_points = [target_tl, target_tr, target_br, target_bl] - destination_points = [dest_tl, dest_tr, dest_br, dest_bl] - - homography = cv2.getPerspectiveTransform(np.float32(source_points), np.float32(destination_points)) - - return cv2.warpPerspective( - target, - np.float32(homography), - (feature.w, feature.h), - flags=cv2.INTER_LINEAR - ) - - def get_relevant_keypoint_ids(self) -> Set[str]: - rk = set() - for match in self._kmr.get_matches(): - rk.add(match.get_keypoint().region_id) - assert len(rk) > 0 - return rk - - def get_keypoint_matching_result(self) -> MatchingResult: - return self._kmr - - def get_weighted_mse(self, /) -> float: - error = 0.0 - singificant_match_count = 0 - for match in self._kmr.get_matches(): - - match_weight = self.get_match_weight(match) - - if match_weight < sys.float_info.epsilon: - continue - - singificant_match_count += 1 - - s = match.get_original_template_point() - # calculate prediction - p = self.template_point_to_target_point(s) - # calculate destination - d = match.get_target_point() - current_error = p - d - current_error_value = np.dot(current_error, current_error) - error += current_error_value * match_weight - - return error / singificant_match_count diff --git a/src/officialeye/_internal/supervision/supervisor.py b/src/officialeye/_internal/supervision/supervisor.py deleted file mode 100644 index d45b0c8..0000000 --- a/src/officialeye/_internal/supervision/supervisor.py +++ /dev/null @@ -1,133 +0,0 @@ -import abc -import random -from abc import ABC -from typing import Generator, Union - -from officialeye._internal.context.context import Context -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.matching.result import MatchingResult -from officialeye._internal.supervision.result import SupervisionResult -from officialeye._internal.supervision.supervisor_config import SupervisorConfig - -_SUPERVISION_RESULT_FIRST = "first" -_SUPERVISION_RESULT_RANDOM = "random" -_SUPERVISION_RESULT_BEST_MSE = "best_mse" -_SUPERVISION_RESULT_BEST_SCORE = "best_score" - - -class Supervisor(ABC): - - def __init__(self, context: Context, engine_id: str, template_id: str, kmr: MatchingResult, /): - super().__init__() - - self._context = context - self._engine_id = engine_id - - self.template_id = template_id - self._kmr = kmr - - get_logger().debug(f"Total match count: {self._kmr.get_total_match_count()}") - - # initialize configuration manager - supervision_config = self.get_template().get_supervision_config() - - assert isinstance(supervision_config, dict) - - if self._engine_id in supervision_config: - config_dict = supervision_config[self._engine_id] - else: - get_logger().warn(f"Could not find any configuration entries for the '{self._engine_id}' supervision engine.") - config_dict = {} - - self._config = SupervisorConfig(config_dict, self._engine_id) - - def get_template(self): - return self._context.get_template(self.template_id) - - @abc.abstractmethod - def _run(self) -> Generator[SupervisionResult, None, None]: - raise NotImplementedError() - - def _run_first(self) -> Union[SupervisionResult, None]: - results_generator = self._run() - return next(results_generator, None) - - def _run_random(self) -> Union[SupervisionResult, None]: - results = list(self._run()) - return None if len(results) == 0 else results[random.randint(0, len(results) - 1)] - - def _run_best_mse(self) -> Union[SupervisionResult, None]: - - results = list(self._run()) - - if len(results) == 0: - return None - - best_result = results[0] - best_result_mse = best_result.get_weighted_mse() - - for result_id, result in enumerate(results): - result_mse = result.get_weighted_mse() - - get_logger().debug_verbose(f"Result #{result_id + 1} has MSE {result_mse}") - - if result_mse < best_result_mse: - best_result_mse = result_mse - best_result = result - - get_logger().debug(f"Best result has MSE {best_result_mse}") - - return best_result - - def _run_best_score(self) -> Union[SupervisionResult, None]: - - results = list(self._run()) - - if len(results) == 0: - return None - - best_result = results[0] - best_result_score = best_result.get_score() - best_result_mse = best_result.get_weighted_mse() - - for result_id, result in enumerate(results): - result_score = result.get_score() - - get_logger().debug_verbose(f"Result #{result_id + 1} has score {result_score}") - - if result_score > best_result_score: - best_result_score = result_score - best_result_mse = result.get_weighted_mse() - best_result = result - elif result_score == best_result_score: - current_result_mse = result.get_weighted_mse() - if current_result_mse < best_result_mse: - best_result_mse = current_result_mse - best_result = result - - get_logger().debug(f"Best result has score {best_result_score} and MSE {best_result_mse}") - - return best_result - - def get_config(self) -> SupervisorConfig: - return self._config - - def run(self) -> Union[SupervisionResult, None]: - - supervision_result_choice_engine = self.get_template().get_supervision_result() - - get_logger().debug(f"Applying '{supervision_result_choice_engine}' supervision result choice engine.") - - if supervision_result_choice_engine == _SUPERVISION_RESULT_FIRST: - return self._run_first() - - if supervision_result_choice_engine == _SUPERVISION_RESULT_RANDOM: - return self._run_random() - - if supervision_result_choice_engine == _SUPERVISION_RESULT_BEST_MSE: - return self._run_best_mse() - - if supervision_result_choice_engine == _SUPERVISION_RESULT_BEST_SCORE: - return self._run_best_score() - - raise AssertionError("Invalid supervision result") diff --git a/src/officialeye/_internal/supervision/supervisor_config.py b/src/officialeye/_internal/supervision/supervisor_config.py deleted file mode 100644 index 7003613..0000000 --- a/src/officialeye/_internal/supervision/supervisor_config.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Dict - -from officialeye._internal.config.config import Config -from officialeye._internal.error.errors.supervision import ErrSupervisionInvalidEngineConfig - - -class SupervisorConfig(Config): - - def __init__(self, config_dict: Dict[str, any], supervision_engine_id: str, /): - super().__init__(config_dict) - - self._supervision_engine_id = supervision_engine_id - - def _get_invalid_key_error(self, key: str, /): - return ErrSupervisionInvalidEngineConfig( - f"while reading configuration of the '{self._supervision_engine_id}' supervision engine.", - f"Could not find a value for key '{key}'." - ) diff --git a/src/officialeye/_internal/supervision/supervisors/__init__.py b/src/officialeye/_internal/supervision/supervisors/__init__.py deleted file mode 100644 index 7559d46..0000000 --- a/src/officialeye/_internal/supervision/supervisors/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -This module contains all supervisors built into OfficialEye. -""" \ No newline at end of file diff --git a/src/officialeye/_internal/supervision/supervisors/least_squares_regression.py b/src/officialeye/_internal/supervision/supervisors/least_squares_regression.py deleted file mode 100644 index 592d5ae..0000000 --- a/src/officialeye/_internal/supervision/supervisors/least_squares_regression.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Generator - -import numpy as np - -from officialeye._internal.context.context import Context -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.matching.result import MatchingResult -from officialeye._internal.supervision.result import SupervisionResult -from officialeye._internal.supervision.supervisor import Supervisor - -_IND_A = 0 -_IND_B = 1 -_IND_C = 2 -_IND_D = 3 - - -class LeastSquaresRegressionSupervisor(Supervisor): - - ENGINE_ID = "least_squares_regression" - - def __init__(self, context: Context, template_id: str, kmr: MatchingResult, /): - super().__init__(context, LeastSquaresRegressionSupervisor.ENGINE_ID, template_id, kmr) - - def _run(self) -> Generator[SupervisionResult, None, None]: - - match_count = self._kmr.get_total_match_count() - - for anchor_match in self._kmr.get_matches(): - delta = anchor_match.get_original_template_point() - delta_prime = anchor_match.get_target_point() - - matrix = np.zeros((match_count << 1, 4), dtype=np.float64) - rhs = np.zeros(match_count << 1, dtype=np.float64) - - for i, match in enumerate(self._kmr.get_matches()): - first_constraint_id = i << 1 - second_constraint_id = first_constraint_id + 1 - - s = match.get_original_template_point() - d = match.get_target_point() - - matrix[first_constraint_id][_IND_A] = s[0] - delta[0] - matrix[first_constraint_id][_IND_B] = s[1] - delta[1] - rhs[first_constraint_id] = d[0] - delta_prime[0] - - matrix[second_constraint_id][_IND_C] = s[0] - delta[0] - matrix[second_constraint_id][_IND_D] = s[1] - delta[1] - rhs[second_constraint_id] = d[1] - delta_prime[1] - - regression_matrix = matrix.T @ matrix - regression_matrix = np.linalg.inv(regression_matrix) - rhs_applied = matrix.T @ rhs - x = regression_matrix @ rhs_applied - - transformation_matrix = np.array([ - [x[_IND_A], x[_IND_B]], - [x[_IND_C], x[_IND_D]] - ]) - - _result = SupervisionResult(self.template_id, self._kmr, delta, delta_prime, transformation_matrix) - - get_logger().debug(f"Current MSE: {_result.get_weighted_mse()}") - - yield _result diff --git a/src/officialeye/_internal/supervision/supervisors/orthogonal_regression.py b/src/officialeye/_internal/supervision/supervisors/orthogonal_regression.py deleted file mode 100644 index a14935b..0000000 --- a/src/officialeye/_internal/supervision/supervisors/orthogonal_regression.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Dict, Generator - -import numpy as np -import z3 - -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.supervision import ErrSupervisionInvalidEngineConfig -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.matching.match import Match -from officialeye._internal.matching.result import MatchingResult -from officialeye._internal.supervision.result import SupervisionResult -from officialeye._internal.supervision.supervisor import Supervisor - - -class OrthogonalRegressionSupervisor(Supervisor): - - ENGINE_ID = "orthogonal_regression" - - def __init__(self, context: Context, template_id: str, kmr: MatchingResult, /): - super().__init__(context, OrthogonalRegressionSupervisor.ENGINE_ID, template_id, kmr) - - def _z3_timeout_preprocessor(v: any) -> int: - - v = int(v) - - if v < 1: - raise ErrSupervisionInvalidEngineConfig( - f"while loading the '{OrthogonalRegressionSupervisor.ENGINE_ID}' supervisor.", - f"The `z3_timeout` value ({v}) cannot be negative or zero." - ) - - return v - - self.get_config().set_value_preprocessor("z3_timeout", _z3_timeout_preprocessor) - - self._z3_context = z3.Context() - - # create variables for components of the translation matrix - self._transformation_matrix = np.array([ - [z3.Real("a", ctx=self._z3_context), z3.Real("b", ctx=self._z3_context)], - [z3.Real("c", ctx=self._z3_context), z3.Real("d", ctx=self._z3_context)] - ], dtype=z3.AstRef) - - # keys: matches (instances of Match) - # values: z3 integer variables representing the errors for each match, - # i.e., how consistent the match is with the affine transformation model - self._match_error: Dict[Match, z3.ArithRef] = {} - - for match in self._kmr.get_matches(): - self._match_error[match] = z3.Real(f"e_{match.get_debug_identifier()}", ctx=self._z3_context) - - def _get_consistency_check(self, match: Match, delta: np.ndarray, delta_prime: np.ndarray, /) -> z3.AstRef: - """ - Generates a z3 formula asserting the consistency of the match with the affine linear transformation model. - Consistency does not mean ideal matching of coordinates; rather, the template position with the affine - transformation applied to it, must roughly be equal the target position for consistency to hold - In other words, targetpoint = M * (templatepoint - offset), where offset is a vector and M is a 2x2 matrix - """ - - template_point = match.get_original_template_point() - - assert delta.shape == (2,) - assert delta_prime.shape == (2,) - assert template_point.shape == (2,) - - translated_template_point = self._transformation_matrix @ (template_point - delta) + delta_prime - translated_template_point_x, translated_template_point_y = translated_template_point - - target_point_x, target_point_y = match.get_target_point() - - return z3.And( - translated_template_point_x - target_point_x <= self._match_error[match], - translated_template_point_x - target_point_x >= -self._match_error[match], - translated_template_point_y - target_point_y <= self._match_error[match], - translated_template_point_y - target_point_y >= -self._match_error[match], - ) - - def _run(self) -> Generator[SupervisionResult, None, None]: - - for anchor_match in self._kmr.get_matches(): - delta = anchor_match.get_original_template_point() - delta_prime = anchor_match.get_target_point() - - error_lower_bounds = z3.And(*(self._match_error[match] >= 0 for match in self._kmr.get_matches()), self._z3_context) - total_error = z3.Sum(*(self._match_error[match] for match in self._kmr.get_matches()), self._z3_context) - - solver = z3.Optimize(ctx=self._z3_context) - solver.set("timeout", self.get_config().get("z3_timeout", default=2500)) - - solver.add(error_lower_bounds) - - for match in self._kmr.get_matches(): - solver.add(self._get_consistency_check(match, delta, delta_prime)) - - solver.minimize(total_error) - - _result = solver.check() - - if _result == z3.unsat: - get_logger().debug("Z3 returned unsat.") - return - - if _result == z3.unknown: - get_logger().debug("Z3 returned unknown.") - return - - assert _result == z3.sat - - model = solver.model() - - evaluator = np.vectorize(lambda var: float(model.eval(var, model_completion=True).as_fraction())) # noqa: B023 - - # extract transformation matrix from model - transformation_matrix = evaluator(self._transformation_matrix) - - _result = SupervisionResult(self.template_id, self._kmr, delta, delta_prime, transformation_matrix) - - get_logger().debug(f"Error: {_result.get_weighted_mse()}") - - yield _result diff --git a/src/officialeye/_internal/supervision/visualizer.py b/src/officialeye/_internal/supervision/visualizer.py index dbd40cb..ce4c911 100644 --- a/src/officialeye/_internal/supervision/visualizer.py +++ b/src/officialeye/_internal/supervision/visualizer.py @@ -3,15 +3,14 @@ import cv2 import numpy as np -from officialeye._internal.context.context import Context +from officialeye._internal.context.singleton import get_internal_context from officialeye._internal.supervision.result import SupervisionResult -from officialeye._internal.template.region.keypoint import TemplateKeypoint +from officialeye._internal.template.keypoint import InternalKeypoint class SupervisionResultVisualizer: - def __init__(self, context: Context, result: SupervisionResult, target: cv2.Mat): - self._context = context + def __init__(self, result: SupervisionResult, target: np.ndarray): self._result = result self._target = target @@ -23,9 +22,9 @@ def __init__(self, context: Context, result: SupervisionResult, target: cv2.Mat) for keypoint_id in self._relevant_keypoint_ids) def get_template(self): - return self._context.get_template(self._result.template_id) + return get_internal_context().get_template(self._result.template_id) - def get_padded_keypoint_image(self, keypoint: TemplateKeypoint) -> np.ndarray: + def get_padded_keypoint_image(self, keypoint: InternalKeypoint) -> np.ndarray: assert keypoint.w <= self._palette_width @@ -53,7 +52,7 @@ def get_padded_keypoint_image(self, keypoint: TemplateKeypoint) -> np.ndarray: return keypoint_padded - def render(self) -> cv2.Mat: + def render(self) -> np.ndarray: keypoints_palette = cv2.vconcat([ self.get_padded_keypoint_image(self.get_template().get_keypoint(keypoint_id)) diff --git a/src/officialeye/_internal/template/analyze.py b/src/officialeye/_internal/template/analyze.py deleted file mode 100644 index 5a7fedb..0000000 --- a/src/officialeye/_internal/template/analyze.py +++ /dev/null @@ -1,144 +0,0 @@ -from queue import Queue -from threading import Thread -from typing import List, Tuple, Union - -import cv2 - -from officialeye._internal.context.context import Context -from officialeye._internal.error.error import OEError -from officialeye._internal.error.errors.io import ErrIOInvalidImage -from officialeye._internal.error.errors.supervision import ErrSupervisionCorrespondenceNotFound -from officialeye._internal.error.errors.template import ErrTemplateInvalidConcurrencyConfig -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.supervision.result import SupervisionResult -from officialeye._internal.template.template import Template - - -class AnalysisWorker(Thread): - - def __init__(self, worker_id: int, queue: Queue, target: cv2.Mat, /): - Thread.__init__(self) - - self.worker_id = worker_id - self.queue = queue - - self._target = target - self._results: List[Tuple[Union[SupervisionResult, None], Union[OEError, None]]] = [] - - def run(self): - - while True: - template: Template = self.queue.get() - - try: - _current_result = template.run_analysis(self._target) - self._results.append((_current_result, None)) - except OEError as err: - self._results.append((None, err)) - finally: - self.queue.task_done() - - def get_successful_results(self): - for result, error in self._results: - if result is None or error is not None: - continue - yield result - - def get_errors(self): - for _, error in self._results: - if error is not None: - yield error - - -def do_analyze(context: Context, target: cv2.Mat, templates: List[Template], /, *, - num_workers: int, interpretation_target: Union[cv2.Mat, None] = None): - - if len(templates) == 0: - # the program should be a noop if there are no templates provided - return - - if interpretation_target is None: - # if not specified, interpret the given image - interpretation_target = target - else: - # there is a custom interpretation target specified. - # it is essential that it has the same shape as the target image. - # for this reason, we should verify this here - if interpretation_target.shape != target.shape: - raise ErrIOInvalidImage( - "while making sure that the target image and the interpretation target images have the same shape.", - f"The shapes mismatch. " - f"The target image has shape {target.shape}, while the interpretation target image has shape {interpretation_target.shape}." - ) - - assert interpretation_target is not None - assert num_workers is not None - - if num_workers < 1: - raise ErrTemplateInvalidConcurrencyConfig( - "while setting up workers for analyzing the target image.", - f"The provided number of workers ({num_workers}) cannot be less than one." - ) - - if num_workers > 0xff: - raise ErrTemplateInvalidConcurrencyConfig( - "while setting up workers for analyzing the target image.", - f"The provided number of workers ({num_workers}) is too high." - ) - - queue = Queue(maxsize=len(templates)) - - workers = [AnalysisWorker(worker_id, queue, target) for worker_id in range(num_workers)] - - for worker in workers: - worker.daemon = True - worker.start() - - for template in templates: - queue.put(template) - - queue.join() - - best_result = None - best_result_score = -1.0 - - # a list containing regular errors that occurred in workers - regular_errors = [] - - for worker in workers: - for result in worker.get_successful_results(): - assert result is not None - - result_score = result.get_score() - if result_score > best_result_score: - best_result_score = result_score - best_result = result - - for error in worker.get_errors(): - assert error is not None - - # we ignore regular errors here, because they may well be simply caused by trying to match - # a given document against an invalid template, which is normal behavior - if not error.is_regular: - raise error - else: - get_logger().debug(f"Worker {worker.worker_id} returned the following non-regular error {error.code_text}:") - get_logger().debug_oe_error(error) - regular_errors.append(error) - - # note: best_result may be None here - - if best_result is None: - - error = ErrSupervisionCorrespondenceNotFound( - "while running supervisor", - "could not establish correspondence of the image with any of the templates provided" - ) - - for worker_error in regular_errors: - error.add_cause(worker_error) - - raise error - - io_driver = context.get_io_driver() - io_driver.handle_supervision_result(interpretation_target, best_result) diff --git a/src/officialeye/_internal/template/external_feature.py b/src/officialeye/_internal/template/external_feature.py new file mode 100644 index 0000000..33fa5ae --- /dev/null +++ b/src/officialeye/_internal/template/external_feature.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, List + +# noinspection PyProtectedMember +from officialeye._api.template.feature import IFeature +from officialeye._internal.api_implementation import IApiInterfaceImplementation +from officialeye._internal.template.region import ExternalRegion + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.context import Context + + # noinspection PyProtectedMember + from officialeye._api.mutator import IMutator + from officialeye._internal.template.external_template import ExternalTemplate + from officialeye._internal.template.internal_feature import InternalFeature + + +class ExternalFeature(ExternalRegion, IFeature, IApiInterfaceImplementation): + + def __init__(self, internal_feature: InternalFeature, external_template: ExternalTemplate, /): + super().__init__(internal_feature, external_template) + + self._mutators: List[IMutator] = list(internal_feature.get_mutators()) + + def get_mutators(self) -> Iterable[IMutator]: + return self._mutators + + def set_api_context(self, context: Context, /) -> None: + # no methods of this class require any contextual information to work, nothing to do + pass + + def clear_api_context(self) -> None: + # no methods of this class require any contextual information to work, nothing to do + pass diff --git a/src/officialeye/_internal/template/external_interpretation_result.py b/src/officialeye/_internal/template/external_interpretation_result.py new file mode 100644 index 0000000..38daca0 --- /dev/null +++ b/src/officialeye/_internal/template/external_interpretation_result.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict + +from officialeye import Context + +# noinspection PyProtectedMember +from officialeye._api.template.feature import IFeature + +# noinspection PyProtectedMember +from officialeye._api.template.interpretation_result import IInterpretationResult + +# noinspection PyProtectedMember +from officialeye._api.template.template_interface import ITemplate +from officialeye._internal.api_implementation import IApiInterfaceImplementation +from officialeye._internal.template.external_template import ExternalTemplate + +if TYPE_CHECKING: + from officialeye._internal.template.internal_template import InternalTemplate + from officialeye.types import FeatureInterpretation + + +class ExternalInterpretationResult(IInterpretationResult, IApiInterfaceImplementation): + + def __init__(self, template: InternalTemplate, feature_interpretations: Dict[str, FeatureInterpretation], /): + self._template = ExternalTemplate(template) + self._feature_interpretation = feature_interpretations + + @property + def template(self) -> ITemplate: + return self._template + + def get_feature_interpretation(self, feature: IFeature, /) -> FeatureInterpretation: + + if feature.identifier in self._feature_interpretation: + return self._feature_interpretation[feature.identifier] + + return None + + def set_api_context(self, context: Context, /) -> None: + self._template.set_api_context(context) + + def clear_api_context(self) -> None: + self._template.clear_api_context() diff --git a/src/officialeye/_internal/template/external_matching_result.py b/src/officialeye/_internal/template/external_matching_result.py new file mode 100644 index 0000000..cd96bd1 --- /dev/null +++ b/src/officialeye/_internal/template/external_matching_result.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from officialeye._internal.api_implementation import IApiInterfaceImplementation +from officialeye._internal.template.external_template import ExternalTemplate +from officialeye._internal.template.shared_matching_result import SharedMatchingResult + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.context import Context + from officialeye._internal.template.internal_matching_result import InternalMatchingResult + + +class ExternalMatchingResult(SharedMatchingResult, IApiInterfaceImplementation): + """ + Representation of the matching result, designed to be used by the main process. + For this reason, it is essential that this class is picklable. + """ + + def __init__(self, internal_matching_result: InternalMatchingResult, external_template: ExternalTemplate, /): + super().__init__(internal_matching_result.template) + + self._template = external_template + + @property + def template(self) -> ExternalTemplate: + return self._template + + def set_api_context(self, context: Context, /) -> None: + self._template.set_api_context(context) + + def clear_api_context(self) -> None: + self._template.clear_api_context() diff --git a/src/officialeye/_internal/template/external_supervision_result.py b/src/officialeye/_internal/template/external_supervision_result.py new file mode 100644 index 0000000..4a71571 --- /dev/null +++ b/src/officialeye/_internal/template/external_supervision_result.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict + +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.context import Context + +# noinspection PyProtectedMember +from officialeye._api.future import Future + +# noinspection PyProtectedMember +from officialeye._api.image import Image + +# noinspection PyProtectedMember +from officialeye._api.template.match import IMatch + +# noinspection PyProtectedMember +from officialeye._api.template.supervision_result import ISupervisionResult +from officialeye._internal.api.interpret import template_interpret +from officialeye._internal.api_implementation import IApiInterfaceImplementation + +# noinspection PyProtectedMember +from officialeye._internal.template.external_interpretation_result import ExternalInterpretationResult +from officialeye._internal.template.external_matching_result import ExternalMatchingResult +from officialeye._internal.template.external_template import ExternalTemplate + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.image import IImage + from officialeye._internal.template.internal_supervision_result import InternalSupervisionResult + + +class ExternalSupervisionResult(ISupervisionResult, IApiInterfaceImplementation): + + def __init__(self, internal_supervision_result: InternalSupervisionResult, /): + super().__init__() + + self._context: Context | None = None + + self._template_path = internal_supervision_result.template.get_path() + + self._template = ExternalTemplate(internal_supervision_result.template) + self._matching_result = ExternalMatchingResult(internal_supervision_result.matching_result, self._template) + + self._score = internal_supervision_result.score + self._delta = internal_supervision_result.delta + self._delta_prime = internal_supervision_result.delta_prime + self._transformation_matrix = internal_supervision_result.transformation_matrix + + # noinspection PyProtectedMember + self._match_weights: Dict[IMatch, float] = internal_supervision_result.get_match_weights() + + def set_api_context(self, context: Context, /) -> None: + self._context = context + + # propagate the context further down the hierarchy of objects + self._template.set_api_context(context) + self._matching_result.set_api_context(context) + + def clear_api_context(self) -> None: + self._context = None + + self._template.clear_api_context() + self._matching_result.clear_api_context() + + @property + def template(self) -> ExternalTemplate: + return self._template + + @property + def matching_result(self) -> ExternalMatchingResult: + return self._matching_result + + @property + def score(self) -> float: + return self._score + + @property + def delta(self) -> np.ndarray: + return self._delta + + @property + def delta_prime(self) -> np.ndarray: + return self._delta_prime + + @property + def transformation_matrix(self) -> np.ndarray: + return self._transformation_matrix + + def interpret_async(self, /, *, target: IImage) -> Future: + + # TODO: this is hacky, maybe use a more clean approach here? + assert isinstance(target, Image) + + assert self._context is not None, \ + ("The external superivision result has no context information, probably because it has been given to the API user " + "before the context has been initialized in this object via the 'set_api_context' method, which is incorrect behavior.") + + _api_context = self._context + + self.clear_api_context() + + # noinspection PyProtectedMember + return _api_context._submit_task( + template_interpret, + f"Interpreting [b]{self.template.name}[/]...", + self._template_path, + self, + interpretation_target_path=target._path + ) + + def interpret(self, /, **kwargs) -> ExternalInterpretationResult: + future = self.interpret_async(**kwargs) + return future.result() + + def get_match_weight(self, match: IMatch, /) -> float: + + if match in self._match_weights: + return self._match_weights[match] + + return 1.0 diff --git a/src/officialeye/_internal/template/external_template.py b/src/officialeye/_internal/template/external_template.py new file mode 100644 index 0000000..b9d6d46 --- /dev/null +++ b/src/officialeye/_internal/template/external_template.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Iterable, List + +# noinspection PyProtectedMember +from officialeye._api.future import Future + +# noinspection PyProtectedMember +from officialeye._api.image import IImage, Image + +# noinspection PyProtectedMember +from officialeye._api.mutator import IMutator + +# noinspection PyProtectedMember +from officialeye._api.template.template_interface import ITemplate +from officialeye._internal.api.detect import template_detect +from officialeye._internal.api_implementation import IApiInterfaceImplementation + +# noinspection PyProtectedMember +from officialeye._internal.template.external_feature import ExternalFeature +from officialeye._internal.template.keypoint import ExternalKeypoint +from officialeye.error.errors.general import ErrOperationNotSupported + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.context import Context + + # noinspection PyProtectedMember + from officialeye._api.template.supervision_result import ISupervisionResult + from officialeye._internal.template.internal_template import InternalTemplate + + +class ExternalTemplate(ITemplate, IApiInterfaceImplementation): + """ + Representation of a template instance designed to be shared between processes. + It is very important that this class is picklable! + """ + + def __init__(self, template: InternalTemplate, /): + super().__init__() + + self._context: Context | None = None + + self._identifier: str = template.identifier + self._name: str = template.name + self._path: str = template.get_path() + self._source_image_path: str = template.get_source_image_path() + + self._width = template.width + self._height = template.height + + self._keypoints: Dict[str, ExternalKeypoint] = {} + self._features: Dict[str, ExternalFeature] = {} + + for keypoint in template.keypoints: + self._keypoints[keypoint.identifier] = ExternalKeypoint(keypoint, self) + + for feature in template.features: + self._features[feature.identifier] = ExternalFeature(feature, self) + + self._source_mutators: List[IMutator] = [ + mutator for mutator in template.get_source_mutators() + ] + + self._target_mutators: List[IMutator] = [ + mutator for mutator in template.get_target_mutators() + ] + + def set_api_context(self, context: Context, /) -> None: + self._context = context + + for external_keypoint in self.keypoints: + external_keypoint.set_api_context(context) + + for external_feature in self.features: + external_feature.set_api_context(context) + + def clear_api_context(self) -> None: + self._context = None + + for external_keypoint in self.keypoints: + external_keypoint.clear_api_context() + + for external_feature in self.features: + external_feature.clear_api_context() + + def load(self) -> None: + raise ErrOperationNotSupported( + "while accessing an external template instance.", + "The way in which it was accessed is not supported." + ) + + def detect_async(self, /, *, target: IImage) -> Future: + + # TODO: this is hacky, maybe use a more clean approach here? + assert isinstance(target, Image) + + # noinspection PyProtectedMember + return self._context._submit_task( + template_detect, + f"Detecting [b]{self._name}[/]...", + self._path, + target_path=target._path, + ) + + def detect(self, /, **kwargs) -> ISupervisionResult: + future = self.detect_async(**kwargs) + return future.result() + + def get_image(self) -> IImage: + return Image(self._context, path=self._source_image_path) + + def get_mutated_image(self) -> IImage: + img = self.get_image() + img.apply_mutators(*self._source_mutators) + return img + + @property + def identifier(self) -> str: + return self._identifier + + @property + def name(self) -> str: + return self._name + + @property + def width(self) -> int: + return self._width + + @property + def height(self) -> int: + return self._height + + @property + def keypoints(self) -> Iterable[ExternalKeypoint]: + return self._keypoints.values() + + @property + def features(self) -> Iterable[ExternalFeature]: + return self._features.values() + + def get_feature(self, feature_id: str, /) -> ExternalFeature | None: + if feature_id not in self._features: + return None + return self._features[feature_id] + + def get_keypoint(self, keypoint_id: str, /) -> ExternalKeypoint | None: + if keypoint_id not in self._keypoints: + return None + return self._keypoints[keypoint_id] diff --git a/src/officialeye/_internal/template/feature_class/loader.py b/src/officialeye/_internal/template/feature_class/loader.py index 02db50c..410583b 100644 --- a/src/officialeye/_internal/template/feature_class/loader.py +++ b/src/officialeye/_internal/template/feature_class/loader.py @@ -1,13 +1,12 @@ -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.template import ErrTemplateInvalidFeatureClass from officialeye._internal.template.feature_class.manager import FeatureClassManager +from officialeye.error.errors.template import ErrTemplateInvalidFeatureClass -def load_template_feature_classes(context: Context, feature_classes_dict: dict, template_id: str, /) -> FeatureClassManager: +def load_template_feature_classes(feature_classes_dict: dict, template_id: str, /) -> FeatureClassManager: assert isinstance(feature_classes_dict, dict) - _manager = FeatureClassManager(context, template_id) + _manager = FeatureClassManager(template_id) for class_id in feature_classes_dict: diff --git a/src/officialeye/_internal/template/feature_class/manager.py b/src/officialeye/_internal/template/feature_class/manager.py index 1e9f67f..88a31ba 100644 --- a/src/officialeye/_internal/template/feature_class/manager.py +++ b/src/officialeye/_internal/template/feature_class/manager.py @@ -1,16 +1,15 @@ from typing import Dict -from officialeye._internal.context.context import Context +from officialeye._internal.context.singleton import get_internal_context from officialeye._internal.diffobject.exception import DiffObjectException -from officialeye._internal.error.errors.template import ErrTemplateInvalidFeatureClass from officialeye._internal.template.feature_class.const import IMPLICIT_FEATURE_CLASS_BASE_INSTANCE_ID from officialeye._internal.template.feature_class.feature_class import FeatureClass +from officialeye.error.errors.template import ErrTemplateInvalidFeatureClass class FeatureClassManager: - def __init__(self, context: Context, template_id: str, /): - self._context = context + def __init__(self, template_id: str, /): self._template_id = template_id self._classes: Dict[str, FeatureClass] = { IMPLICIT_FEATURE_CLASS_BASE_INSTANCE_ID: FeatureClass(self, IMPLICIT_FEATURE_CLASS_BASE_INSTANCE_ID, { @@ -26,7 +25,7 @@ def get_class(self, class_id: str, /): return self._classes[class_id] def get_template(self): - return self._context.get_template(self._template_id) + return get_internal_context().get_template(self._template_id) def contains_class(self, class_id: str, /) -> bool: return class_id in self._classes diff --git a/src/officialeye/_internal/template/image.py b/src/officialeye/_internal/template/image.py new file mode 100644 index 0000000..18fbea2 --- /dev/null +++ b/src/officialeye/_internal/template/image.py @@ -0,0 +1,48 @@ +import os +from typing import List + +import cv2 +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.image import IImage + +# noinspection PyProtectedMember +from officialeye._api.mutator import IMutator +from officialeye._internal.context.singleton import get_internal_afi +from officialeye._internal.feedback.verbosity import Verbosity +from officialeye.error.errors.io import ErrIOInvalidPath + + +class InternalImage(IImage): + + def __init__(self, /, *, path: str): + super().__init__() + + self._mutators: List[IMutator] = [] + self._path = path + + def load(self) -> np.ndarray: + + if not os.path.isfile(self._path): + raise ErrIOInvalidPath( + f"while loading image located at '{self._path}'.", + "This path does not refer to a file." + ) + + if not os.access(self._path, os.R_OK): + raise ErrIOInvalidPath( + f"while loading image located at '{self._path}'.", + "The file at this path is not readable." + ) + + img = cv2.imread(self._path, cv2.IMREAD_COLOR) + + for mutator in self._mutators: + get_internal_afi().info(Verbosity.DEBUG, f"InternalImage::load() applies mutator '{mutator}'") + img = mutator.mutate(img) + + return img + + def apply_mutators(self, *mutators: IMutator): + self._mutators += mutators diff --git a/src/officialeye/_internal/template/region/feature.py b/src/officialeye/_internal/template/internal_feature.py similarity index 50% rename from src/officialeye/_internal/template/region/feature.py rename to src/officialeye/_internal/template/internal_feature.py index 474f615..6858762 100644 --- a/src/officialeye/_internal/template/region/feature.py +++ b/src/officialeye/_internal/template/internal_feature.py @@ -1,23 +1,28 @@ -from typing import Dict, Union +from __future__ import annotations -import cv2 +from typing import TYPE_CHECKING, Dict, Iterable, Union -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.template import ErrTemplateInvalidFeature -from officialeye._internal.interpretation.loader import load_interpretation_method -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.mutation.loader import load_mutator_from_dict +import numpy as np + +from officialeye import IMutator + +# noinspection PyProtectedMember +from officialeye._api.template.feature import IFeature +from officialeye._internal.context.singleton import get_internal_context from officialeye._internal.template.feature_class.feature_class import FeatureClass -from officialeye._internal.template.feature_class.manager import FeatureClassManager -from officialeye._internal.template.region.region import TemplateRegion +from officialeye._internal.template.region import InternalRegion +from officialeye._internal.template.utils import load_mutator_from_dict +from officialeye.error.errors.template import ErrTemplateInvalidFeature -_FEATURE_RECT_COLOR = (0, 0xff, 0) +if TYPE_CHECKING: + from officialeye._internal.template.feature_class.manager import FeatureClassManager + from officialeye.types import FeatureInterpretation -class TemplateFeature(TemplateRegion): +class InternalFeature(InternalRegion, IFeature): - def __init__(self, context: Context, template_id: str, feature_dict: Dict[str, any], /): - super().__init__(context, template_id, feature_dict) + def __init__(self, template_id: str, feature_dict: Dict[str, any], /): + super().__init__(template_id, feature_dict) if "class" in feature_dict: self._class_id = feature_dict["class"] @@ -25,19 +30,16 @@ def __init__(self, context: Context, template_id: str, feature_dict: Dict[str, a else: self._class_id = None - def visualize(self, img: cv2.Mat, /): - return super()._visualize(img, rect_color=_FEATURE_RECT_COLOR) - def validate_feature_class(self): if self._class_id is None: return - feature_classes: FeatureClassManager = self.get_template().get_feature_classes() + feature_classes: FeatureClassManager = self.template.get_feature_classes() if not feature_classes.contains_class(self._class_id): raise ErrTemplateInvalidFeature( - f"while loading class for feature '{self.region_id}' in template '{self.get_template().template_id}'.", + f"while loading class for feature '{self.identifier}' in template '{self.template.identifier}'.", f"Specified feature class '{self._class_id}' is not defined." ) @@ -45,7 +47,7 @@ def validate_feature_class(self): if feature_class.is_abstract(): raise ErrTemplateInvalidFeature( - f"while loading class for feature '{self.region_id}' in template '{self.get_template().template_id}'.", + f"while loading class for feature '{self.identifier}' in template '{self.template.identifier}'.", f"Cannot instantiate an abstract feature class '{self._class_id}'." ) @@ -55,7 +57,7 @@ def get_feature_class(self) -> Union[FeatureClass, None]: if self._class_id is None: return None - feature_classes: FeatureClassManager = self.get_template().get_feature_classes() + feature_classes: FeatureClassManager = self.template.get_feature_classes() assert feature_classes.contains_class(self._class_id) @@ -65,34 +67,22 @@ def get_feature_class(self) -> Union[FeatureClass, None]: return feature_class - def apply_mutators_to_image(self, img: cv2.Mat, /) -> cv2.Mat: - """ - Takes an image and applies the mutators defined in the corresponding feature class. - - Arguments: - img: The image that should be transformed. - - Returns: - The resulting image. - """ + def get_mutators(self) -> Iterable[IMutator]: feature_class = self.get_feature_class() if feature_class is None: - return img + return [] mutators = feature_class.get_data()["mutators"] assert isinstance(mutators, list) - for mutator_dict in mutators: - mutator = load_mutator_from_dict(mutator_dict) - get_logger().debug(f"Applying mutator '{mutator.mutator_id}'.") - img = mutator.mutate(img) - - return img + return [ + load_mutator_from_dict(mutator_dict) for mutator_dict in mutators + ] - def interpret_image(self, img: cv2.Mat, /) -> any: + def interpret_image(self, img: np.ndarray, /) -> FeatureInterpretation: """ Takes an image and runs the interpretation method defined in the corresponding feature class. Assumes that the feature class is present. @@ -114,6 +104,6 @@ def interpret_image(self, img: cv2.Mat, /) -> any: assert isinstance(interpretation_method_id, str) assert isinstance(interpretation_method_config, dict) - interpretation_method = load_interpretation_method(self._context, interpretation_method_id, interpretation_method_config) + interpretation_method = get_internal_context().get_interpretation(interpretation_method_id, interpretation_method_config) - return interpretation_method.interpret(img, self._template_id, self.region_id) + return interpretation_method.interpret(img, self) diff --git a/src/officialeye/_internal/template/internal_matching_result.py b/src/officialeye/_internal/template/internal_matching_result.py new file mode 100644 index 0000000..a436f3a --- /dev/null +++ b/src/officialeye/_internal/template/internal_matching_result.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, List + +# noinspection PyProtectedMember +from officialeye._api.template.match import IMatch +from officialeye._internal.context.singleton import get_internal_context +from officialeye._internal.template.shared_matching_result import SharedMatchingResult + +if TYPE_CHECKING: + from officialeye._internal.template.internal_template import InternalTemplate + + +class InternalMatchingResult(SharedMatchingResult): + """ + Representation of the matching result, designed to be used by the child process only. + """ + + def __init__(self, template: InternalTemplate, /): + super().__init__(template) + + self._template_id = template.identifier + + # keys: keypoint ids + # values: matches with this keypoint + self._matches_dict: Dict[str, List[IMatch]] = {} + + for keypoint in self.template.keypoints: + self._matches_dict[keypoint.identifier] = [] + + @property + def template(self) -> InternalTemplate: + return get_internal_context().get_template(self._template_id) diff --git a/src/officialeye/_internal/template/internal_supervision_result.py b/src/officialeye/_internal/template/internal_supervision_result.py new file mode 100644 index 0000000..d23291d --- /dev/null +++ b/src/officialeye/_internal/template/internal_supervision_result.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict + +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.future import Future + +# noinspection PyProtectedMember +from officialeye._api.template.interpretation_result import IInterpretationResult + +# noinspection PyProtectedMember +from officialeye._api.template.supervision_result import ISupervisionResult +from officialeye.error.errors.general import ErrOperationNotSupported + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.image import IImage + + # noinspection PyProtectedMember + from officialeye._api.template.match import IMatch + + # noinspection PyProtectedMember + from officialeye._api.template.supervision_result import SupervisionResult + from officialeye._internal.template.internal_matching_result import InternalMatchingResult + from officialeye._internal.template.internal_template import InternalTemplate + + +class InternalSupervisionResult(ISupervisionResult): + + def __init__(self, supervision_result: SupervisionResult, internal_template: InternalTemplate, + internal_matching_result: InternalMatchingResult, /): + self._supervision_result = supervision_result + self._internal_template = internal_template + self._internal_matching_result = internal_matching_result + + @property + def template(self) -> InternalTemplate: + return self._internal_template + + @property + def matching_result(self) -> InternalMatchingResult: + return self._internal_matching_result + + @property + def score(self) -> float: + return self._supervision_result.get_score() + + @property + def delta(self) -> np.ndarray: + return self._supervision_result.delta + + @property + def delta_prime(self) -> np.ndarray: + return self._supervision_result.delta_prime + + @property + def transformation_matrix(self) -> np.ndarray: + return self._supervision_result.transformation_matrix + + def interpret_async(self, /, *, target: IImage) -> Future: + raise ErrOperationNotSupported( + "while accessing an internal supervision result instance.", + "The way in which it was accessed is not supported." + ) + + def interpret(self, /, **kwargs) -> IInterpretationResult: + raise ErrOperationNotSupported( + "while accessing an internal supervision result instance.", + "The way in which it was accessed is not supported." + ) + + def get_match_weights(self) -> Dict[IMatch, float]: + # noinspection PyProtectedMember + return self._supervision_result._match_weights + + def get_match_weight(self, match: IMatch, /) -> float: + + match_weights = self.get_match_weights() + + if match in match_weights: + return match_weights[match] + + return 1.0 diff --git a/src/officialeye/_internal/template/internal_template.py b/src/officialeye/_internal/template/internal_template.py new file mode 100644 index 0000000..e0c92e0 --- /dev/null +++ b/src/officialeye/_internal/template/internal_template.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +import os +import random +from typing import TYPE_CHECKING, Dict, Iterable, List + +import numpy as np + +# noinspection PyProtectedMember +from officialeye._api.future import Future + +# noinspection PyProtectedMember +from officialeye._api.image import IImage + +# noinspection PyProtectedMember +from officialeye._api.template.match import IMatch + +# noinspection PyProtectedMember +from officialeye._api.template.supervisor import ISupervisor + +# noinspection PyProtectedMember +from officialeye._api.template.template import ITemplate +from officialeye._internal.context.singleton import get_internal_afi, get_internal_context + +# noinspection PyProtectedMember +from officialeye._internal.feedback.verbosity import Verbosity +from officialeye._internal.template.feature_class.loader import load_template_feature_classes +from officialeye._internal.template.feature_class.manager import FeatureClassManager +from officialeye._internal.template.image import InternalImage +from officialeye._internal.template.internal_feature import InternalFeature +from officialeye._internal.template.internal_matching_result import InternalMatchingResult +from officialeye._internal.template.internal_supervision_result import InternalSupervisionResult +from officialeye._internal.template.keypoint import InternalKeypoint +from officialeye._internal.template.utils import load_mutator_from_dict +from officialeye._internal.timer import Timer +from officialeye.error.errors.general import ErrInvalidIdentifier, ErrOperationNotSupported +from officialeye.error.errors.supervision import ErrSupervisionCorrespondenceNotFound +from officialeye.error.errors.template import ErrTemplateInvalidFeature, ErrTemplateInvalidKeypoint + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.mutator import IMutator + + # noinspection PyProtectedMember + from officialeye._api.template.matcher import IMatcher + + # noinspection PyProtectedMember + from officialeye._api.template.supervision_result import ISupervisionResult + from officialeye.types import ConfigDict + + +# TODO: refactor this into an enum +_SUPERVISION_RESULT_FIRST = "first" +_SUPERVISION_RESULT_RANDOM = "random" +_SUPERVISION_RESULT_BEST_MSE = "best_mse" +_SUPERVISION_RESULT_BEST_SCORE = "best_score" + + +class InternalTemplate(ITemplate): + + def __init__(self, yaml_dict: Dict[str, any], path_to_template: str, /): + super().__init__() + + self._path_to_template = path_to_template + + self._template_id = yaml_dict["id"] + self._name = yaml_dict["name"] + self._source = yaml_dict["source"] + + self._height, self._width, _ = self.get_image().load().shape + + self._source_mutators: List[IMutator] = [ + load_mutator_from_dict(mutator_dict) for mutator_dict in yaml_dict["mutators"]["source"] + ] + + self._target_mutators: List[IMutator] = [ + load_mutator_from_dict(mutator_dict) for mutator_dict in yaml_dict["mutators"]["target"] + ] + + self._keypoints: Dict[str, InternalKeypoint] = {} + self._features: Dict[str, InternalFeature] = {} + + for keypoint_id in yaml_dict["keypoints"]: + keypoint_dict = yaml_dict["keypoints"][keypoint_id] + keypoint_dict["id"] = keypoint_id + keypoint = InternalKeypoint(self.identifier, keypoint_dict) + + if keypoint.identifier in self._keypoints: + raise ErrTemplateInvalidKeypoint( + f"while initializing keypoint '{keypoint_id}' of template '{self.identifier}'", + f"There is already a keypoint with the same identifier '{keypoint.identifier}'." + ) + + if keypoint.identifier in self._features: + raise ErrTemplateInvalidKeypoint( + f"while initializing keypoint '{keypoint_id}' of template '{self.identifier}'", + f"There is already a feature with the same identifier '{keypoint.identifier}'." + ) + + self._keypoints[keypoint.identifier] = keypoint + + self._matching = yaml_dict["matching"] + self._supervision = yaml_dict["supervision"] + + # load feature classes + self._feature_class_manager = load_template_feature_classes(yaml_dict["feature_classes"], self.identifier) + + # load features + for feature_id in yaml_dict["features"]: + feature_dict = yaml_dict["features"][feature_id] + feature_dict["id"] = feature_id + feature = InternalFeature(self.identifier, feature_dict) + + if feature.identifier in self._keypoints: + raise ErrTemplateInvalidFeature( + f"while initializing feature '{feature_id}' of template '{self.identifier}'", + f"There is already a keypoint with the same identifier '{feature.identifier}'." + ) + + if feature.identifier in self._features: + raise ErrTemplateInvalidFeature( + f"while initializing feature '{feature_id}' of template '{self.identifier}'", + f"There is already a feature with the same identifier '{feature.identifier}'." + ) + + self._features[feature.identifier] = feature + + get_internal_context().add_template(self) + + def get_source_mutators(self) -> Iterable[IMutator]: + return self._source_mutators + + def get_target_mutators(self) -> Iterable[IMutator]: + return self._target_mutators + + def load(self) -> None: + raise ErrOperationNotSupported( + "while accessing an internal template instance.", + "The way in which it was accessed is not supported." + ) + + def detect_async(self, /, *, target: IImage) -> Future: + raise ErrOperationNotSupported( + "while accessing an internal template instance.", + "The way in which it was accessed is not supported." + ) + + def detect(self, /, **kwargs) -> ISupervisionResult: + raise ErrOperationNotSupported( + "while accessing an internal template instance.", + "The way in which it was accessed is not supported." + ) + + def get_source_image_path(self) -> str: + if os.path.isabs(self._source): + return self._source + path_to_template_dir = os.path.dirname(self._path_to_template) + path = os.path.join(path_to_template_dir, self._source) + return os.path.normpath(path) + + def get_image(self) -> IImage: + return InternalImage(path=self.get_source_image_path()) + + def get_mutated_image(self) -> IImage: + img = self.get_image() + img.apply_mutators(*self._source_mutators) + return img + + @property + def identifier(self) -> str: + return self._template_id + + @property + def name(self) -> str: + return self._name + + @property + def width(self) -> int: + return self._width + + @property + def height(self) -> int: + return self._height + + @property + def keypoints(self) -> Iterable[InternalKeypoint]: + for keypoint_id in self._keypoints: + yield self._keypoints[keypoint_id] + + @property + def features(self) -> Iterable[InternalFeature]: + for feature_id in self._features: + yield self._features[feature_id] + + def validate(self): + for feature in self.features: + feature.validate_feature_class() + + def get_matcher(self, /) -> IMatcher: + matcher_id = self._matching["engine"] + matcher_config = self._matching["config"] + return get_internal_context().get_matcher(matcher_id, matcher_config) + + def get_supervisor(self, /) -> ISupervisor: + supervisor_id = self._supervision["engine"] + supervisor_config_generic = self._supervision["config"] + + if supervisor_id in supervisor_config_generic: + supervisor_config: ConfigDict = supervisor_config_generic[supervisor_id] + else: + get_internal_afi().warn(Verbosity.INFO, f"Could not find any configuration entries for the '{supervisor_id}' supervisor.") + supervisor_config: ConfigDict = {} + + return get_internal_context().get_supervisor(supervisor_id, supervisor_config) + + def get_supervision_config(self) -> dict: + return self._supervision["config"] + + def get_feature_classes(self) -> FeatureClassManager: + return self._feature_class_manager + + def get_feature(self, feature_id: str, /) -> InternalFeature | None: + + if feature_id not in self._features: + return None + + return self._features[feature_id] + + def get_keypoint(self, keypoint_id: str, /) -> InternalKeypoint | None: + + if keypoint_id not in self._keypoints: + return None + + return self._keypoints[keypoint_id] + + def get_path(self) -> str: + return self._path_to_template + + def _run_supervisor(self, keypoint_matching_result: InternalMatchingResult, /) -> InternalSupervisionResult | None: + + supervisor = self.get_supervisor() + supervisor.setup(self, keypoint_matching_result) + + supervision_result_choice_engine = self._supervision["result"] + results: List[InternalSupervisionResult] = [ + InternalSupervisionResult(supervision_result, self, keypoint_matching_result) + for supervision_result in supervisor.supervise(self, keypoint_matching_result) + ] + + if len(results) == 0: + return None + + for result in results: + get_internal_afi().info( + Verbosity.INFO_VERBOSE, + f"Got result with score {result.score} and error {result.get_weighted_mse()} from supervisor '{supervisor}'." + ) + + if supervision_result_choice_engine == _SUPERVISION_RESULT_FIRST: + return results[0] + + if supervision_result_choice_engine == _SUPERVISION_RESULT_RANDOM: + return random.choice(results) + + if supervision_result_choice_engine == _SUPERVISION_RESULT_BEST_MSE: + + best_result = results[0] + best_result_mse = best_result.get_weighted_mse() + + for result_id, result in enumerate(results): + result_mse = result.get_weighted_mse() + + get_internal_afi().info(Verbosity.INFO_VERBOSE, f"Result #{result_id + 1} has MSE {result_mse}.") + + if result_mse < best_result_mse: + best_result_mse = result_mse + best_result = result + + get_internal_afi().info(Verbosity.INFO_VERBOSE, f"Best result has MSE {best_result_mse}.") + + return best_result + + if supervision_result_choice_engine == _SUPERVISION_RESULT_BEST_SCORE: + + best_result = results[0] + best_result_mse = best_result.get_weighted_mse() + best_result_score = best_result.score + + for result_id, result in enumerate(results): + result_score = result.score + result_mse = result.get_weighted_mse() + + get_internal_afi().info(Verbosity.INFO_VERBOSE, f"Result #{result_id + 1} has score {result_score}.") + + if result_score > best_result_score or result_score == best_result_score and result_mse < best_result_mse: + best_result_mse = result_mse + best_result_score = result_score + best_result = result + + get_internal_afi().info(Verbosity.INFO_VERBOSE, f"Best result has score {best_result_score} and MSE {best_result_mse}.") + + return best_result + + raise ErrInvalidIdentifier( + "while running supervisor.", + f"Invalid supervision result choice engine '{supervision_result_choice_engine}'." + ) + + def do_detect(self, target: np.ndarray, /) -> InternalSupervisionResult: + # find all patterns in the target image + + # apply mutators to the target image + for mutator in self._target_mutators: + target = mutator.mutate(target) + + get_internal_afi().update_status("Running matching phase...") + + _timer = Timer() + + with _timer: + # start matching + matcher: IMatcher = self.get_matcher() + matcher.setup(self) + + for keypoint in self.keypoints: + get_internal_afi().info(Verbosity.DEBUG, f"Running matcher '{matcher}' for keypoint '{keypoint.identifier}'.") + assert isinstance(keypoint, InternalKeypoint) + matcher.match(keypoint) + + keypoint_matching_result = InternalMatchingResult(self) + + for keypoint in self.keypoints: + for match in matcher.get_matches_for_keypoint(keypoint): + assert isinstance(match, IMatch) + keypoint_matching_result.add_match(match) + + keypoint_matching_result.validate() + assert keypoint_matching_result.get_total_match_count() > 0 + + get_internal_afi().info( + Verbosity.INFO, + f"Matching succeeded in {_timer.get_real_time():.2f} seconds of real time " + f"and {_timer.get_cpu_time():.2f} seconds of CPU time." + ) + + get_internal_afi().update_status("Running supervision phase...") + + with _timer: + # run supervision to obtain correspondence between template and target regions + supervision_result = self._run_supervisor(keypoint_matching_result) + + get_internal_afi().info( + Verbosity.INFO, + f"Supervision succeeded in {_timer.get_real_time():.2f} seconds of real time " + f"and {_timer.get_cpu_time():.2f} seconds of CPU time." + ) + + if supervision_result is None: + raise ErrSupervisionCorrespondenceNotFound( + "while processing a supervision result.", + f"Could not establish correspondence of the image with the '{self.identifier}' template." + ) + + # TODO: visualizations + """ + if get_internal_context().visualization_generation_enabled(): + supervision_result_visualizer = SupervisionResultVisualizer(supervision_result, target) + visualization = supervision_result_visualizer.render() + get_internal_context().export_image(visualization, file_name="matches.png") + """ + + return supervision_result + + def __str__(self): + return f"{self.name} ({self._source}, {len(self._keypoints)} keypoints, {len(self._features)} features)" diff --git a/src/officialeye/_internal/template/keypoint.py b/src/officialeye/_internal/template/keypoint.py new file mode 100644 index 0000000..2942366 --- /dev/null +++ b/src/officialeye/_internal/template/keypoint.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +# noinspection PyProtectedMember +from officialeye._api.template.keypoint import IKeypoint +from officialeye._internal.api_implementation import IApiInterfaceImplementation +from officialeye._internal.template.region import ExternalRegion, InternalRegion +from officialeye.error.errors.template import ErrTemplateInvalidKeypoint + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.context import Context + from officialeye._internal.template.external_template import ExternalTemplate + + +class InternalKeypoint(InternalRegion, IKeypoint): + + def __init__(self, template_id: str, keypoint_dict: dict, /): + super().__init__(template_id, keypoint_dict) + + self._matches_min = keypoint_dict["matches"]["min"] + self._matches_max = keypoint_dict["matches"]["max"] + + if self._matches_max < self._matches_min: + raise ErrTemplateInvalidKeypoint( + f"while loading template keypoint '{self.identifier}'", + f"the lower bound on the match count ({self._matches_min}) exceeds the upper bound ({self._matches_max})" + ) + + if self._matches_min < 0: + raise ErrTemplateInvalidKeypoint( + f"while loading template keypoint '{self.identifier}'", + f"the lower bound on the match count ({self._matches_min}) cannot be negative" + ) + + assert 0 <= self._matches_min <= self._matches_max + + @property + def matches_min(self) -> int: + return self._matches_min + + @property + def matches_max(self) -> int: + return self._matches_max + + +class ExternalKeypoint(ExternalRegion, IKeypoint, IApiInterfaceImplementation): + + def __init__(self, internal_keypoint: InternalKeypoint, external_template: ExternalTemplate, /): + super().__init__(internal_keypoint, external_template) + + self._matches_min = internal_keypoint.matches_min + self._matches_max = internal_keypoint.matches_max + + @property + def matches_min(self) -> int: + return self._matches_min + + @property + def matches_max(self) -> int: + return self._matches_max + + def set_api_context(self, context: Context, /) -> None: + # no methods of this class require any contextual information to work, nothing to do + pass + + def clear_api_context(self) -> None: + pass diff --git a/src/officialeye/_internal/template/region.py b/src/officialeye/_internal/template/region.py new file mode 100644 index 0000000..6c4b4f0 --- /dev/null +++ b/src/officialeye/_internal/template/region.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Dict + +# noinspection PyProtectedMember +from officialeye._api.template.region import IRegion +from officialeye._internal.context.singleton import get_internal_context + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.template.template_interface import ITemplate + from officialeye._internal.template.external_template import ExternalTemplate + from officialeye._internal.template.internal_template import InternalTemplate + + +class SharedRegion(IRegion, ABC): + + def __init__(self, /, *, identifier: str, x: int, y: int, w: int, h: int): + self._identifier = identifier + self._x = x + self._y = y + self._w = w + self._h = h + + @property + def identifier(self) -> str: + return self._identifier + + @property + def x(self) -> int: + return self._x + + @property + def y(self) -> int: + return self._y + + @property + def w(self) -> int: + return self._w + + @property + def h(self) -> int: + return self._h + + +class InternalRegion(SharedRegion, ABC): + + def __init__(self, template_id: str, region_dict: Dict[str, any], /): + super().__init__( + identifier=str(region_dict["id"]), + x=int(region_dict["x"]), + y=int(region_dict["y"]), + w=int(region_dict["w"]), + h=int(region_dict["h"]) + ) + + self._template_id = template_id + + @property + def template(self) -> InternalTemplate: + return get_internal_context().get_template(self._template_id) + + +class ExternalRegion(SharedRegion, ABC): + + def __init__(self, internal_region: InternalRegion, external_template: ExternalTemplate, /): + super().__init__( + identifier=internal_region.identifier, + x=internal_region.x, + y=internal_region.y, + w=internal_region.w, + h=internal_region.h + ) + + self._external_template = external_template + + @property + def template(self) -> ITemplate: + return self._external_template diff --git a/src/officialeye/_internal/template/region/__init__.py b/src/officialeye/_internal/template/region/__init__.py deleted file mode 100644 index a47c1f4..0000000 --- a/src/officialeye/_internal/template/region/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -This module contains the TemplateRegion class and its descendants. -""" \ No newline at end of file diff --git a/src/officialeye/_internal/template/region/keypoint.py b/src/officialeye/_internal/template/region/keypoint.py deleted file mode 100644 index 7645eff..0000000 --- a/src/officialeye/_internal/template/region/keypoint.py +++ /dev/null @@ -1,38 +0,0 @@ -import cv2 - -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.template import ErrTemplateInvalidKeypoint -from officialeye._internal.template.region.region import TemplateRegion - -_KEYPOINT_RECT_COLOR = (0, 0, 0xff) - - -class TemplateKeypoint(TemplateRegion): - def __init__(self, context: Context, template_id: str, keypoint_dict: dict, /): - super().__init__(context, template_id, keypoint_dict) - - self._matches_min = keypoint_dict["matches"]["min"] - self._matches_max = keypoint_dict["matches"]["max"] - - if self._matches_max < self._matches_min: - raise ErrTemplateInvalidKeypoint( - f"while loading template keypoint '{self.region_id}'", - f"the lower bound on the match count ({self._matches_min}) exceeds the upper bound ({self._matches_max})" - ) - - if self._matches_min < 0: - raise ErrTemplateInvalidKeypoint( - f"while loading template keypoint '{self.region_id}'", - f"the lower bound on the match count ({self._matches_min}) cannot be negative" - ) - - assert 0 <= self._matches_min <= self._matches_max - - def get_matches_min(self) -> int: - return self._matches_min - - def get_matches_max(self) -> int: - return self._matches_max - - def visualize(self, img: cv2.Mat, /): - return super()._visualize(img, rect_color=_KEYPOINT_RECT_COLOR) diff --git a/src/officialeye/_internal/template/region/region.py b/src/officialeye/_internal/template/region/region.py deleted file mode 100644 index a4c7b7a..0000000 --- a/src/officialeye/_internal/template/region/region.py +++ /dev/null @@ -1,76 +0,0 @@ -import abc -from typing import Dict, Tuple - -import cv2 -import numpy as np - -from officialeye._internal.context.context import Context - -_LABEL_COLOR_DEFAULT = (0, 0, 0xff) -_VISUALIZATION_SCALE_COEFF = 1.0 / 1400.0 - - -class TemplateRegion(abc.ABC): - - def __init__(self, context: Context, template_id: str, region_dict: Dict[str, any], /): - self._context = context - self._template_id = template_id - - self.region_id = str(region_dict["id"]) - self.x = int(region_dict["x"]) - self.y = int(region_dict["y"]) - self.w = int(region_dict["w"]) - self.h = int(region_dict["h"]) - - def get_template(self): - return self._context.get_template(self._template_id) - - @abc.abstractmethod - def visualize(self, img: cv2.Mat) -> cv2.Mat: - raise NotImplementedError() - - def get_top_left_vec(self) -> np.ndarray: - return np.array([self.x, self.y]) - - def get_top_right_vec(self) -> np.ndarray: - return np.array([self.x + self.w, self.y]) - - def get_bottom_left_vec(self) -> np.ndarray: - return np.array([self.x, self.y + self.h]) - - def get_bottom_right_vec(self) -> np.ndarray: - return np.array([self.x + self.w, self.y + self.h]) - - def _visualize(self, img: cv2.Mat, /, *, - rect_color: Tuple[int, int, int], label_color=_LABEL_COLOR_DEFAULT) -> cv2.Mat: - img = cv2.rectangle(img, (self.x, self.y), (self.x + self.w, self.y + self.h), rect_color, 4) - label_origin = ( - self.x + int(10 * img.shape[0] * _VISUALIZATION_SCALE_COEFF), - self.y + int(30 * img.shape[0] * _VISUALIZATION_SCALE_COEFF) - ) - font_scale = img.shape[0] * _VISUALIZATION_SCALE_COEFF - img = cv2.putText( - img, - self.region_id, - label_origin, - cv2.FONT_HERSHEY_SIMPLEX, - font_scale, - label_color, - int(2 * img.shape[0] * _VISUALIZATION_SCALE_COEFF), - cv2.LINE_AA - ) - return img - - def to_image(self): - img = self.get_template().load_source_image() - return img[self.y:self.y + self.h, self.x:self.x + self.w] - - def insert_into_image(self, target: np.ndarray, transformed_version: np.ndarray = None): - - assert target.shape[0] == self.get_template().height - assert target.shape[1] == self.get_template().width - - if transformed_version is None: - transformed_version = self.to_image() - - target[self.y: self.y + self.h, self.x: self.x + self.w] = transformed_version diff --git a/src/officialeye/_internal/template/schema/loader.py b/src/officialeye/_internal/template/schema/loader.py index ccc2470..30b1610 100644 --- a/src/officialeye/_internal/template/schema/loader.py +++ b/src/officialeye/_internal/template/schema/loader.py @@ -1,46 +1,50 @@ import strictyaml as yml -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.template import ErrTemplateInvalidSyntax -from officialeye._internal.logger.singleton import get_logger +from officialeye._internal.context.singleton import get_internal_afi, get_internal_context + +# noinspection PyProtectedMember +from officialeye._internal.feedback.verbosity import Verbosity +from officialeye._internal.template.internal_template import InternalTemplate from officialeye._internal.template.schema.schema import generate_template_schema -from officialeye._internal.template.template import Template +from officialeye.error.errors.template import ErrTemplateInvalidSyntax _oe_template_schema = generate_template_schema() -def _print_error_message(err: yml.StrictYAMLError, template_path: str): +def _strict_yaml_error_to_syntax_error(error: yml.YAMLError, /, *, path: str) -> ErrTemplateInvalidSyntax: + + return ErrTemplateInvalidSyntax( + f"while loading template configuration file at '{path}'.", + "Could not parse the configuration file due to invalid syntax or encoding.", + str(error).replace("", path) + ) + + +def _do_load_template(path: str, /) -> InternalTemplate: + global _oe_template_schema + + with open(path, "r") as fh: + raw_data = fh.read() - get_logger().error("Error ", bold=True, nl=False) + try: + yaml_document = yml.load(raw_data, schema=_oe_template_schema) + except yml.YAMLError as err: + raise _strict_yaml_error_to_syntax_error(err, path=path) from err - if err.context is not None: - get_logger().error(err.context, prefix=False) - else: - get_logger().error("while parsing", prefix=False) + data = yaml_document.data - if err.context_mark is not None and ( - err.problem is None - or err.problem_mark is None - or err.context_mark.name != err.problem_mark.name - or err.context_mark.line != err.problem_mark.line - or err.context_mark.column != err.problem_mark.column - ): - get_logger().error(str(err.context_mark).replace("", template_path)) + template = InternalTemplate(data, path) - if err.problem is not None: - get_logger().error("Problem", bold=True, nl=False) - get_logger().error(f": {err.problem}", prefix=False) + get_internal_afi().info(Verbosity.DEBUG, f"Loaded template: [b]{template}[/]") - if err.problem_mark is not None: - get_logger().error(str(err.problem_mark).replace("", template_path)) + return template -def load_template(context: Context, path: str) -> Template: +def load_template(path: str, /) -> InternalTemplate: """ Loads a template from a file located at the specified path. Arguments: - context: The global officialeye context. path: The path to the YAML template configuration file. Returns: @@ -50,27 +54,12 @@ def load_template(context: Context, path: str) -> Template: OEError: In case there has been an error validating the correctness of the template. """ - global _oe_template_schema - - with open(path, "r") as fh: - raw_data = fh.read() - - try: - yaml_document = yml.load(raw_data, schema=_oe_template_schema) - except yml.StrictYAMLError as err: - _print_error_message(err, path) - exit(4) - except yml.YAMLError as err: - raise ErrTemplateInvalidSyntax( - f"while loading template configuration file at '{path}'.", - "General parsing error. Check the syntax and the encoding of the file." - ) from err + template = get_internal_context().get_template_by_path(path) - data = yaml_document.data - - template = Template(context, data, path) + if template is not None: + get_internal_afi().info(Verbosity.DEBUG, f"Template at path '{path}' has already been loaded and cached, reusing it!") + return template - get_logger().info("Loaded template: ", nl=False) - get_logger().info(str(template), prefix=False, bold=True) + get_internal_afi().info(Verbosity.DEBUG, f"Template at path '{path}' has not yet been loaded, loading it.") - return template + return _do_load_template(path) diff --git a/src/officialeye/_internal/template/shared_matching_result.py b/src/officialeye/_internal/template/shared_matching_result.py new file mode 100644 index 0000000..28ddc47 --- /dev/null +++ b/src/officialeye/_internal/template/shared_matching_result.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Dict, Iterable, List + +# noinspection PyProtectedMember +from officialeye._api.template.match import IMatch + +# noinspection PyProtectedMember +from officialeye._api.template.matching_result import IMatchingResult +from officialeye._internal.context.singleton import get_internal_afi +from officialeye._internal.feedback.verbosity import Verbosity +from officialeye.error.errors.matching import ErrMatchingMatchCountOutOfBounds + +if TYPE_CHECKING: + # noinspection PyProtectedMember + from officialeye._api.template.template_interface import ITemplate + + +class SharedMatchingResult(IMatchingResult, ABC): + """ + This class contains all the logic of the matching result instance, irrespective of whether we use the internal or the external representation. + The parent process uses the internal representation, whereas the external representation is only for the child process. + This class represents the aspects that both representations have in common. + Therefore, it is important that this class operates only on the interface level and is picklable. + """ + + def __init__(self, template: ITemplate, /): + # keys: keypoint ids + # values: matches with this keypoint + self._matches_dict: Dict[str, List[IMatch]] = {} + + for keypoint in template.keypoints: + self._matches_dict[keypoint.identifier] = [] + + def remove_all_matches(self): + self._matches_dict = {} + + def add_match(self, match: IMatch, /): + assert match.keypoint.identifier in self._matches_dict + self._matches_dict[match.keypoint.identifier].append(match) + + def get_all_matches(self) -> Iterable[IMatch]: + for keypoint_id in self._matches_dict: + for match in self._matches_dict[keypoint_id]: + yield match + + def get_total_match_count(self) -> int: + match_count = 0 + for keypoint_id in self._matches_dict: + match_count += len(self._matches_dict[keypoint_id]) + return match_count + + def get_keypoint_ids(self) -> Iterable[str]: + for keypoint_id in self._matches_dict: + yield keypoint_id + + def get_matches_for_keypoint(self, keypoint_id: str, /) -> Iterable[IMatch]: + for match in self._matches_dict[keypoint_id]: + yield match + + def validate(self): + + get_internal_afi().info(Verbosity.DEBUG, "Validating the keypoint matching result.") + + assert len(self._matches_dict) > 0 + + total_match_count = 0 + + # verify that for every keypoint, it has been matched a number of times that is in the desired bounds + for keypoint_id in self._matches_dict: + keypoint = self.template.get_keypoint(keypoint_id) + + keypoint_matches_min = keypoint.matches_min + keypoint_matches_max = keypoint.matches_max + + keypoint_matches_count = len(self._matches_dict[keypoint_id]) + + if keypoint_matches_count < keypoint_matches_min: + raise ErrMatchingMatchCountOutOfBounds( + f"while checking that keypoint '{keypoint_id}' of template '{self.template.identifier}' " + f"has been matched a sufficient number of times", + f"Expected at least {keypoint_matches_min} matches, got {keypoint_matches_count}" + ) + + if keypoint_matches_count > keypoint_matches_max: + + get_internal_afi().info( + Verbosity.INFO_VERBOSE, + f"Keypoint '{keypoint_id}' of template '{self.template.identifier}' has too many matches " + f"(matches: {keypoint_matches_count} max: {keypoint_matches_max}). Cherry-picking the best matches.") + # cherry-pick the best matches + self._matches_dict[keypoint_id] = sorted(self._matches_dict[keypoint_id])[:keypoint_matches_max] + keypoint_matches_count = keypoint_matches_max + + get_internal_afi().info( + Verbosity.INFO_VERBOSE, + f"Keypoint '{keypoint_id}' of template '{self.template.identifier}' has been matched {keypoint_matches_count} times " + f"(min: {keypoint_matches_min} max: {keypoint_matches_max})." + ) + + total_match_count += keypoint_matches_count + + assert total_match_count >= 0 + if total_match_count == 0: + raise ErrMatchingMatchCountOutOfBounds( + f"while checking that there has been at least one match for template '{self.template.identifier}'.", + "There have been no matches." + ) diff --git a/src/officialeye/_internal/template/template.py b/src/officialeye/_internal/template/template.py deleted file mode 100644 index 57403c8..0000000 --- a/src/officialeye/_internal/template/template.py +++ /dev/null @@ -1,272 +0,0 @@ -import os -import time -from typing import Dict, Generator, List, Union - -import cv2 - -from officialeye._internal.context.context import Context -from officialeye._internal.error.errors.io import ErrIOInvalidPath -from officialeye._internal.error.errors.template import ( - ErrTemplateInvalidFeature, - ErrTemplateInvalidKeypoint, - ErrTemplateInvalidMatchingEngine, - ErrTemplateInvalidSupervisionEngine, -) -from officialeye._internal.logger.singleton import get_logger -from officialeye._internal.matching.matcher import Matcher -from officialeye._internal.matching.matchers.sift_flann import SiftFlannMatcher -from officialeye._internal.matching.result import MatchingResult -from officialeye._internal.mutation.loader import load_mutator_from_dict -from officialeye._internal.mutation.mutator import Mutator -from officialeye._internal.supervision.result import SupervisionResult -from officialeye._internal.supervision.supervisors.combinatorial import CombinatorialSupervisor -from officialeye._internal.supervision.supervisors.least_squares_regression import LeastSquaresRegressionSupervisor -from officialeye._internal.supervision.supervisors.orthogonal_regression import OrthogonalRegressionSupervisor -from officialeye._internal.supervision.visualizer import SupervisionResultVisualizer -from officialeye._internal.template.feature_class.loader import load_template_feature_classes -from officialeye._internal.template.feature_class.manager import FeatureClassManager -from officialeye._internal.template.region.feature import TemplateFeature -from officialeye._internal.template.region.keypoint import TemplateKeypoint - - -class Template: - - def __init__(self, context: Context, yaml_dict: Dict[str, any], path_to_template: str, /): - - self._context = context - - self._path_to_template = path_to_template - - self.template_id = yaml_dict["id"] - self.name = yaml_dict["name"] - self._source = yaml_dict["source"] - - self.height, self.width, _ = self.load_source_image().shape - - self._source_mutators: List[Mutator] = [ - load_mutator_from_dict(mutator_dict) for mutator_dict in yaml_dict["mutators"]["source"] - ] - - self._target_mutators: List[Mutator] = [ - load_mutator_from_dict(mutator_dict) for mutator_dict in yaml_dict["mutators"]["target"] - ] - - self._keypoints: Dict[str, TemplateKeypoint] = {} - self._features: Dict[str, TemplateFeature] = {} - - for keypoint_id in yaml_dict["keypoints"]: - keypoint_dict = yaml_dict["keypoints"][keypoint_id] - keypoint_dict["id"] = keypoint_id - keypoint = TemplateKeypoint(self._context, self.template_id, keypoint_dict) - - if keypoint.region_id in self._keypoints: - raise ErrTemplateInvalidKeypoint( - f"while initializing keypoint '{keypoint_id}' of template '{self.template_id}'", - f"There is already a keypoint with the same identifier '{keypoint.region_id}'." - ) - - if keypoint.region_id in self._features: - raise ErrTemplateInvalidKeypoint( - f"while initializing keypoint '{keypoint_id}' of template '{self.template_id}'", - f"There is already a feature with the same identifier '{keypoint.region_id}'." - ) - - self._keypoints[keypoint.region_id] = keypoint - - self._matching = yaml_dict["matching"] - self._supervision = yaml_dict["supervision"] - - # load feature classes - self._feature_class_manager = load_template_feature_classes(self._context, yaml_dict["feature_classes"], self.template_id) - - # load features - for feature_id in yaml_dict["features"]: - feature_dict = yaml_dict["features"][feature_id] - feature_dict["id"] = feature_id - feature = TemplateFeature(self._context, self.template_id, feature_dict) - - if feature.region_id in self._keypoints: - raise ErrTemplateInvalidFeature( - f"while initializing feature '{feature_id}' of template '{self.template_id}'", - f"There is already a keypoint with the same identifier '{feature.region_id}'." - ) - - if feature.region_id in self._features: - raise ErrTemplateInvalidFeature( - f"while initializing feature '{feature_id}' of template '{self.template_id}'", - f"There is already a feature with the same identifier '{feature.region_id}'." - ) - - self._features[feature.region_id] = feature - - self._context.add_template(self) - - def validate(self): - for feature in self.features(): - feature.validate_feature_class() - - def get_matching_engine(self) -> str: - return self._matching["engine"] - - def get_supervision_engine(self) -> str: - return self._supervision["engine"] - - def get_supervision_result(self) -> str: - return self._supervision["result"] - - def get_supervision_config(self) -> dict: - return self._supervision["config"] - - def get_matching_config(self) -> dict: - matching_config = self._matching["config"] - assert isinstance(matching_config, dict) - return matching_config - - def get_feature_classes(self) -> FeatureClassManager: - return self._feature_class_manager - - def _load_keypoint_matcher(self, target_img: cv2.Mat, /) -> Matcher: - - matching_engine = self.get_matching_engine() - - if matching_engine == SiftFlannMatcher.ENGINE_ID: - return SiftFlannMatcher(self._context, self.template_id, target_img) - - raise ErrTemplateInvalidMatchingEngine( - "while loading keypoint matcher", - f"unknown matching engine '{matching_engine}'" - ) - - def features(self) -> Generator[TemplateFeature, None, None]: - for feature_id in self._features: - yield self._features[feature_id] - - def get_feature(self, feature_id: str, /) -> TemplateFeature: - assert feature_id in self._features, "Invalid feature id" - return self._features[feature_id] - - def keypoints(self) -> Generator[TemplateKeypoint, None, None]: - for keypoint_id in self._keypoints: - yield self._keypoints[keypoint_id] - - def get_keypoint(self, keypoint_id: str, /) -> TemplateKeypoint: - assert keypoint_id in self._keypoints, "Invalid keypoint id" - return self._keypoints[keypoint_id] - - def _get_source_image_path(self) -> str: - if os.path.isabs(self._source): - return self._source - path_to_template_dir = os.path.dirname(self._path_to_template) - path = os.path.join(path_to_template_dir, self._source) - return os.path.normpath(path) - - def load_source_image(self) -> cv2.Mat: - - _image_path = self._get_source_image_path() - - if not os.path.isfile(_image_path): - raise ErrIOInvalidPath( - f"while loading template source image of template '{self.template_id}'.", - f"Inferred path '{_image_path}' does not refer to a file." - ) - - if not os.access(_image_path, os.R_OK): - raise ErrIOInvalidPath( - f"while loading template source image of template '{self.template_id}'.", - f"The file at path '{_image_path}' could not be read." - ) - - return cv2.imread(self._get_source_image_path(), cv2.IMREAD_COLOR) - - def _show(self, img: cv2.Mat, /, *, hide_features: bool, hide_keypoints: bool) -> cv2.Mat: - - if not hide_features: - for feature in self.features(): - img = feature.visualize(img) - - if not hide_keypoints: - for keypoint in self.keypoints(): - img = keypoint.visualize(img) - - return img - - def show(self, /, **kwargs) -> cv2.Mat: - - img = self.load_source_image() - - # apply template mutators to the target image - for mutator in self._source_mutators: - get_logger().debug(f"Applying mutator '{mutator.mutator_id}' to the source image of template '{self.template_id}'.") - img = mutator.mutate(img) - - return self._show(img, **kwargs) - - def _load_supervisor(self, kmr: MatchingResult): - superivision_engine = self.get_supervision_engine() - - if superivision_engine == LeastSquaresRegressionSupervisor.ENGINE_ID: - return LeastSquaresRegressionSupervisor(self._context, self.template_id, kmr) - - if superivision_engine == OrthogonalRegressionSupervisor.ENGINE_ID: - return OrthogonalRegressionSupervisor(self._context, self.template_id, kmr) - - if superivision_engine == CombinatorialSupervisor.ENGINE_ID: - return CombinatorialSupervisor(self._context, self.template_id, kmr) - - raise ErrTemplateInvalidSupervisionEngine( - "while loading supervisor", - f"unknown supervision engine '{superivision_engine}'" - ) - - def run_analysis(self, target: cv2.Mat, /) -> Union[SupervisionResult, None]: - # find all patterns in the target image - - _analysis_start_time = time.perf_counter(), time.process_time() - - # apply mutators to the target image - for mutator in self._target_mutators: - get_logger().debug(f"Applying mutator '{mutator.mutator_id}' to input image.") - target = mutator.mutate(target) - - # start matching - matcher = self._load_keypoint_matcher(target) - - for keypoint in self.keypoints(): - keypoint_pattern = keypoint.to_image() - get_logger().debug(f"Running matcher for keypoint '{keypoint.region_id}'.") - matcher.match_keypoint(keypoint_pattern, keypoint.region_id) - - keypoint_matching_result = matcher.match_finish() - - if self._context.visualization_generation_enabled(): - keypoint_matching_result.debug_print() - - keypoint_matching_result.validate() - assert keypoint_matching_result.get_total_match_count() > 0 - - _matching_ended_time = time.perf_counter(), time.process_time() - - get_logger().info(f"Matching succeeded in {_matching_ended_time[0] - _analysis_start_time[0]:.2f} seconds of real time " - f"and {_matching_ended_time[1] - _analysis_start_time[1]:.2f} seconds of CPU time.") - - # run supervision to obtain correspondence between template and target regions - supervisor = self._load_supervisor(keypoint_matching_result) - supervision_result = supervisor.run() - - if supervision_result is None: - return None - - if self._context.visualization_generation_enabled(): - supervision_result_visualizer = SupervisionResultVisualizer(self._context, supervision_result, target) - visualization = supervision_result_visualizer.render() - self._context.export_image(visualization, file_name="matches.png") - - _supervision_ended_time = time.perf_counter(), time.process_time() - - get_logger().info(f"Supervision succeeded in {_supervision_ended_time[0] - _matching_ended_time[0]:.2f} seconds of real time " - f"and {_supervision_ended_time[1] - _matching_ended_time[1]:.2f} seconds of CPU time.") - - return supervision_result - - def __str__(self): - return f"{self.name} ({self._source}, {len(self._keypoints)} keypoints, {len(self._features)} features)" diff --git a/src/officialeye/_internal/template/utils.py b/src/officialeye/_internal/template/utils.py new file mode 100644 index 0000000..1626bce --- /dev/null +++ b/src/officialeye/_internal/template/utils.py @@ -0,0 +1,16 @@ +from typing import Dict + +# noinspection PyProtectedMember +from officialeye._api.mutator import IMutator +from officialeye._internal.context.singleton import get_internal_context + + +def load_mutator_from_dict(mutator_dict: Dict[str, any], /) -> IMutator: + + assert "id" in mutator_dict + + mutator_id = mutator_dict["id"] + + mutator_config = mutator_dict["config"] if "config" in mutator_dict else {} + + return get_internal_context().get_mutator(mutator_id, mutator_config) diff --git a/src/officialeye/_internal/timer.py b/src/officialeye/_internal/timer.py new file mode 100644 index 0000000..c46ca70 --- /dev/null +++ b/src/officialeye/_internal/timer.py @@ -0,0 +1,25 @@ +import time +from typing import Tuple + + +class Timer: + + def __init__(self): + self._start_time: Tuple[float, float] | None = None + self._end_time: Tuple[float, float] | None = None + + def __enter__(self): + self._start_time = time.perf_counter(), time.process_time() + + def __exit__(self, exc_type, exc_val, exc_tb): + self._end_time = time.perf_counter(), time.process_time() + + def get_real_time(self) -> float: + assert self._start_time is not None + assert self._end_time is not None + return self._end_time[0] - self._start_time[0] + + def get_cpu_time(self) -> float: + assert self._start_time is not None + assert self._end_time is not None + return self._end_time[1] - self._start_time[1] diff --git a/src/officialeye/api/__init__.py b/src/officialeye/api/__init__.py deleted file mode 100644 index c49a432..0000000 --- a/src/officialeye/api/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -This module provides a compatibility layer connecting stable OfficialEye's public API with the implementation-specific details that tend to change. -""" \ No newline at end of file diff --git a/src/officialeye/api/context.py b/src/officialeye/api/context.py deleted file mode 100644 index 56b9acc..0000000 --- a/src/officialeye/api/context.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Module represeting the OfficialEye context. -""" \ No newline at end of file diff --git a/src/officialeye/detection.py b/src/officialeye/detection.py new file mode 100644 index 0000000..2158236 --- /dev/null +++ b/src/officialeye/detection.py @@ -0,0 +1,5 @@ +""" +Module providing an API for all OfficialEye's document detection tools. +""" + +# noinspection PyProtectedMember diff --git a/src/officialeye/_internal/error/__init__.py b/src/officialeye/error/__init__.py similarity index 100% rename from src/officialeye/_internal/error/__init__.py rename to src/officialeye/error/__init__.py diff --git a/src/officialeye/_internal/error/codes.py b/src/officialeye/error/codes.py similarity index 95% rename from src/officialeye/_internal/error/codes.py rename to src/officialeye/error/codes.py index ee0bbcf..1b953fe 100644 --- a/src/officialeye/_internal/error/codes.py +++ b/src/officialeye/error/codes.py @@ -1,8 +1,9 @@ ERR_INTERNAL = (1, "INTERNAL") +ERR_GENERAL = (2, "GENERAL") # IO errors ERR_IO_INVALID_SUPERVISION_ENGINE = (101, "INVALID_IO_ENGINE") -ERR_IO_OPERATION_NOT_SUPPORTED_BY_DRIVER = (102, "OPERATION_NOT_SUPPORTED_BY_DRIVER") +ERR_IO_OPERATION_NOT_SUPPORTED_BY_DRIVER = (102, "OPERATION_NOT_SUPPORTED_BY_DRIVER") # TODO: deprecated ERR_IO_INVALID_PATH = (103, "INVALID_PATH") ERR_IO_INVALID_IMAGE = (104, "INVALID_IMAGE") diff --git a/src/officialeye/_internal/error/error.py b/src/officialeye/error/error.py similarity index 94% rename from src/officialeye/_internal/error/error.py rename to src/officialeye/error/error.py index f049763..322c412 100644 --- a/src/officialeye/_internal/error/error.py +++ b/src/officialeye/error/error.py @@ -1,7 +1,8 @@ +from abc import ABC from typing import List -class OEError(Exception): +class OEError(Exception, ABC): """ Base class for all officialeye-related errors. """ @@ -40,6 +41,9 @@ def add_external_cause(self, cause: BaseException, /): def get_external_causes(self) -> List[BaseException]: return self._external_causes + def get_details(self) -> str | None: + return None + def serialize(self) -> dict: causes = [ diff --git a/src/officialeye/_internal/error/errors/__init__.py b/src/officialeye/error/errors/__init__.py similarity index 100% rename from src/officialeye/_internal/error/errors/__init__.py rename to src/officialeye/error/errors/__init__.py diff --git a/src/officialeye/error/errors/general.py b/src/officialeye/error/errors/general.py new file mode 100644 index 0000000..8090a11 --- /dev/null +++ b/src/officialeye/error/errors/general.py @@ -0,0 +1,69 @@ +from officialeye.error.codes import ERR_GENERAL +from officialeye.error.error import OEError +from officialeye.error.modules import ERR_MODULE_GENERAL + + +class ErrGeneral(OEError): + + def __init__(self, while_text: str, problem_text: str, /): + super().__init__(ERR_MODULE_GENERAL, ERR_GENERAL[0], ERR_GENERAL[1], while_text, problem_text, is_regular=False) + + self._init_args = while_text, problem_text + + def __reduce__(self): + return self.__class__, self._init_args + + +class ErrOperationNotSupported(ErrGeneral): + + def __init__(self, while_text: str, problem_text: str, /): + super().__init__(while_text, problem_text) + + self._init_args = while_text, problem_text + + def __reduce__(self): + return self.__class__, self._init_args + + +class ErrInvalidKey(ErrGeneral): + + def __init__(self, while_text: str, problem_text: str, /): + super().__init__(while_text, problem_text) + + self._init_args = while_text, problem_text + + def __reduce__(self): + return self.__class__, self._init_args + + +class ErrInvalidIdentifier(ErrGeneral): + + def __init__(self, while_text: str, problem_text: str, /): + super().__init__(while_text, problem_text) + + self._init_args = while_text, problem_text + + def __reduce__(self): + return self.__class__, self._init_args + + +class ErrObjectNotInitialized(ErrGeneral): + + def __init__(self, while_text: str, problem_text: str, /): + super().__init__(while_text, problem_text) + + self._init_args = while_text, problem_text + + def __reduce__(self): + return self.__class__, self._init_args + + +class ErrInvalidImage(ErrGeneral): + + def __init__(self, while_text: str, problem_text: str, /): + super().__init__(while_text, problem_text) + + self._init_args = while_text, problem_text + + def __reduce__(self): + return self.__class__, self._init_args diff --git a/src/officialeye/error/errors/internal.py b/src/officialeye/error/errors/internal.py new file mode 100644 index 0000000..a63fe34 --- /dev/null +++ b/src/officialeye/error/errors/internal.py @@ -0,0 +1,25 @@ +from officialeye.error.codes import ERR_INTERNAL +from officialeye.error.error import OEError +from officialeye.error.modules import ERR_MODULE_INTERNAL + + +class ErrInternal(OEError): + + def __init__(self, while_text: str, problem_text: str, /): + super().__init__(ERR_MODULE_INTERNAL, ERR_INTERNAL[0], ERR_INTERNAL[1], while_text, problem_text, is_regular=False) + + self._init_args = while_text, problem_text + + def __reduce__(self): + return self.__class__, self._init_args + + +class ErrInvalidState(ErrInternal): + + def __init__(self, while_text: str, problem_text: str, /): + super().__init__(while_text, problem_text) + + self._init_args = while_text, problem_text + + def __reduce__(self): + return self.__class__, self._init_args diff --git a/src/officialeye/_internal/error/errors/io.py b/src/officialeye/error/errors/io.py similarity index 64% rename from src/officialeye/_internal/error/errors/io.py rename to src/officialeye/error/errors/io.py index 3fd511e..ee21c4b 100644 --- a/src/officialeye/_internal/error/errors/io.py +++ b/src/officialeye/error/errors/io.py @@ -1,38 +1,65 @@ -from officialeye._internal.error.codes import ( +from abc import ABC + +from officialeye.error.codes import ( ERR_IO_INVALID_IMAGE, ERR_IO_INVALID_PATH, ERR_IO_INVALID_SUPERVISION_ENGINE, ERR_IO_OPERATION_NOT_SUPPORTED_BY_DRIVER, ) -from officialeye._internal.error.error import OEError -from officialeye._internal.error.modules import ERR_MODULE_IO +from officialeye.error.error import OEError +from officialeye.error.modules import ERR_MODULE_IO -class ErrIO(OEError): +class ErrIO(OEError, ABC): def __init__(self, code: int, code_text: str, while_text: str, problem_text: str, /, *, is_regular: bool = False, **kwargs): super().__init__(ERR_MODULE_IO, code, code_text, while_text, problem_text, is_regular=is_regular) class ErrIOInvalidSupervisionEngine(ErrIO): + # TODO: consider removing this error + def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_IO_INVALID_SUPERVISION_ENGINE[0], ERR_IO_INVALID_SUPERVISION_ENGINE[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrIOOperationNotSupportedByDriver(ErrIO): + def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_IO_OPERATION_NOT_SUPPORTED_BY_DRIVER[0], ERR_IO_OPERATION_NOT_SUPPORTED_BY_DRIVER[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrIOInvalidPath(ErrIO): + def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_IO_INVALID_PATH[0], ERR_IO_INVALID_PATH[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrIOInvalidImage(ErrIO): + def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_IO_INVALID_IMAGE[0], ERR_IO_INVALID_IMAGE[1], while_text, problem_text, **kwargs) + + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args diff --git a/src/officialeye/_internal/error/errors/matching.py b/src/officialeye/error/errors/matching.py similarity index 60% rename from src/officialeye/_internal/error/errors/matching.py rename to src/officialeye/error/errors/matching.py index 6316e67..9cd13a7 100644 --- a/src/officialeye/_internal/error/errors/matching.py +++ b/src/officialeye/error/errors/matching.py @@ -1,9 +1,11 @@ -from officialeye._internal.error.codes import ERR_MATCHING_INVALID_ENGINE_CONFIG, ERR_MATCHING_MATCH_COUNT_OUT_OF_BOUNDS -from officialeye._internal.error.error import OEError -from officialeye._internal.error.modules import ERR_MODULE_MATCHING +from abc import ABC +from officialeye.error.codes import ERR_MATCHING_INVALID_ENGINE_CONFIG, ERR_MATCHING_MATCH_COUNT_OUT_OF_BOUNDS +from officialeye.error.error import OEError +from officialeye.error.modules import ERR_MODULE_MATCHING -class ErrMatching(OEError): + +class ErrMatching(OEError, ABC): def __init__(self, code: int, code_text: str, while_text: str, problem_text: str, /, *, is_regular: bool, **kwargs): super().__init__(ERR_MODULE_MATCHING, code, code_text, while_text, problem_text, is_regular=is_regular) @@ -14,8 +16,18 @@ def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_MATCHING_MATCH_COUNT_OUT_OF_BOUNDS[0], ERR_MATCHING_MATCH_COUNT_OUT_OF_BOUNDS[1], while_text, problem_text, is_regular=True, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrMatchingInvalidEngineConfig(ErrMatching): def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_MATCHING_INVALID_ENGINE_CONFIG[0], ERR_MATCHING_INVALID_ENGINE_CONFIG[1], while_text, problem_text, is_regular=False, **kwargs) + + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args diff --git a/src/officialeye/_internal/error/errors/supervision.py b/src/officialeye/error/errors/supervision.py similarity index 62% rename from src/officialeye/_internal/error/errors/supervision.py rename to src/officialeye/error/errors/supervision.py index 236f167..9cbc838 100644 --- a/src/officialeye/_internal/error/errors/supervision.py +++ b/src/officialeye/error/errors/supervision.py @@ -1,9 +1,11 @@ -from officialeye._internal.error.codes import ERR_SUPERVISION_CORRESPONDENCE_NOT_FOUND, ERR_SUPERVISION_INVALID_ENGINE_CONFIG -from officialeye._internal.error.error import OEError -from officialeye._internal.error.modules import ERR_MODULE_SUPERVISION +from abc import ABC +from officialeye.error.codes import ERR_SUPERVISION_CORRESPONDENCE_NOT_FOUND, ERR_SUPERVISION_INVALID_ENGINE_CONFIG +from officialeye.error.error import OEError +from officialeye.error.modules import ERR_MODULE_SUPERVISION -class ErrSupervision(OEError): + +class ErrSupervision(OEError, ABC): def __init__(self, code: int, code_text: str, while_text: str, problem_text: str, /, *, is_regular: bool, **kwargs): super().__init__(ERR_MODULE_SUPERVISION, code, code_text, while_text, problem_text, is_regular=is_regular) @@ -16,6 +18,11 @@ def __init__(self, while_text: str, problem_text: str, /, **kwargs): ERR_SUPERVISION_CORRESPONDENCE_NOT_FOUND[1], while_text, problem_text, is_regular=True, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrSupervisionInvalidEngineConfig(ErrSupervision): def __init__(self, while_text: str, problem_text: str, /, **kwargs): @@ -23,3 +30,8 @@ def __init__(self, while_text: str, problem_text: str, /, **kwargs): ERR_SUPERVISION_INVALID_ENGINE_CONFIG[0], ERR_SUPERVISION_INVALID_ENGINE_CONFIG[1], while_text, problem_text, is_regular=False, **kwargs) + + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args diff --git a/src/officialeye/_internal/error/errors/template.py b/src/officialeye/error/errors/template.py similarity index 65% rename from src/officialeye/_internal/error/errors/template.py rename to src/officialeye/error/errors/template.py index 101ff66..e26f824 100644 --- a/src/officialeye/_internal/error/errors/template.py +++ b/src/officialeye/error/errors/template.py @@ -1,4 +1,6 @@ -from officialeye._internal.error.codes import ( +from abc import ABC + +from officialeye.error.codes import ( ERR_TEMPLATE_ID_NOT_UNIQUE, ERR_TEMPLATE_INVALID_CONCURRENCY_CONFIG, ERR_TEMPLATE_INVALID_FEATURE, @@ -10,11 +12,11 @@ ERR_TEMPLATE_INVALID_SUPERVISION_ENGINE, ERR_TEMPLATE_INVALID_SYNTAX, ) -from officialeye._internal.error.error import OEError -from officialeye._internal.error.modules import ERR_MODULE_TEMPLATE +from officialeye.error.error import OEError +from officialeye.error.modules import ERR_MODULE_TEMPLATE -class ErrTemplate(OEError): +class ErrTemplate(OEError, ABC): def __init__(self, code: int, code_text: str, while_text: str, problem_text: str, /, *, is_regular: bool = False, **kwargs): super().__init__(ERR_MODULE_TEMPLATE, code, code_text, while_text, problem_text, is_regular=is_regular) @@ -25,56 +27,111 @@ def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_TEMPLATE_INVALID_SUPERVISION_ENGINE[0], ERR_TEMPLATE_INVALID_SUPERVISION_ENGINE[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrTemplateInvalidMatchingEngine(ErrTemplate): def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_TEMPLATE_INVALID_MATCHING_ENGINE[0], ERR_TEMPLATE_INVALID_MATCHING_ENGINE[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrTemplateIdNotUnique(ErrTemplate): def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_TEMPLATE_ID_NOT_UNIQUE[0], ERR_TEMPLATE_ID_NOT_UNIQUE[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrTemplateInvalidKeypoint(ErrTemplate): def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_TEMPLATE_INVALID_KEYPOINT[0], ERR_TEMPLATE_INVALID_KEYPOINT[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrTemplateInvalidFeature(ErrTemplate): def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_TEMPLATE_INVALID_FEATURE[0], ERR_TEMPLATE_INVALID_FEATURE[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrTemplateInvalidConcurrencyConfig(ErrTemplate): def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_TEMPLATE_INVALID_CONCURRENCY_CONFIG[0], ERR_TEMPLATE_INVALID_CONCURRENCY_CONFIG[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrTemplateInvalidSyntax(ErrTemplate): - def __init__(self, while_text: str, problem_text: str, /, **kwargs): + def __init__(self, while_text: str, problem_text: str, yml_error: str | None = None, /, **kwargs): super().__init__( ERR_TEMPLATE_INVALID_SYNTAX[0], ERR_TEMPLATE_INVALID_SYNTAX[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, yml_error, *kwargs + + self.yml_error = yml_error + + def get_details(self) -> str | None: + return self.yml_error + + def __reduce__(self): + return self.__class__, self._init_args + class ErrTemplateInvalidFeatureClass(ErrTemplate): def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_TEMPLATE_INVALID_FEATURE_CLASS[0], ERR_TEMPLATE_INVALID_FEATURE_CLASS[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrTemplateInvalidMutator(ErrTemplate): def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_TEMPLATE_INVALID_MUTATOR[0], ERR_TEMPLATE_INVALID_MUTATOR[1], while_text, problem_text, **kwargs) + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args + class ErrTemplateInvalidInterpretation(ErrTemplate): def __init__(self, while_text: str, problem_text: str, /, **kwargs): super().__init__( ERR_TEMPLATE_INVALID_INTERPRETATION[0], ERR_TEMPLATE_INVALID_INTERPRETATION[1], while_text, problem_text, **kwargs) + + self._init_args = while_text, problem_text, *kwargs + + def __reduce__(self): + return self.__class__, self._init_args diff --git a/src/officialeye/_internal/error/modules.py b/src/officialeye/error/modules.py similarity index 70% rename from src/officialeye/_internal/error/modules.py rename to src/officialeye/error/modules.py index 5f6e596..bf551fe 100644 --- a/src/officialeye/_internal/error/modules.py +++ b/src/officialeye/error/modules.py @@ -1,5 +1,8 @@ +ERR_MODULE_GENERAL = "general" ERR_MODULE_INTERNAL = "internal" ERR_MODULE_IO = "io" ERR_MODULE_MATCHING = "matching" ERR_MODULE_SUPERVISION = "supervision" ERR_MODULE_TEMPLATE = "template" + +# TODO: remove the module system diff --git a/src/officialeye/meta.py b/src/officialeye/meta.py deleted file mode 100644 index a93daf0..0000000 --- a/src/officialeye/meta.py +++ /dev/null @@ -1,12 +0,0 @@ -from officialeye.__version__ import __version__ - -OFFICIALEYE_NAME = "OfficialEye" -OFFICIALEYE_GITHUB = "https://github.com/ZeroBone/OfficialEye" -OFFICIALEYE_VERSION = __version__ -OFFICIALEYE_CLI_LOGO = """ ____ _________ _ __ ______ - / __ \\/ __/ __(_)____(_)___ _/ / / ____/_ _____ - / / / / /_/ /_/ / ___/ / __ `/ / / __/ / / / / _ \\ -/ /_/ / __/ __/ / /__/ / /_/ / / / /___/ /_/ / __/ -\\____/_/ /_/ /_/\\___/_/\\__,_/_/ /_____/\\__, /\\___/ - /____/ -""" diff --git a/src/officialeye/types.py b/src/officialeye/types.py new file mode 100644 index 0000000..62c6b16 --- /dev/null +++ b/src/officialeye/types.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Dict, TypeAlias, Union + +# noinspection PyProtectedMember +from officialeye._api.mutator import IMutator + +# noinspection PyProtectedMember +from officialeye._api.template.interpretation import IInterpretation + +# noinspection PyProtectedMember +from officialeye._api.template.matcher import IMatcher + +# noinspection PyProtectedMember +from officialeye._api.template.supervisor import ISupervisor + +if TYPE_CHECKING: + ConfigValue = Union[str, int, float, Dict[str, "ConfigValue"]] + ConfigDict = Dict[str, ConfigValue] + + MutatorFactory = Callable[[ConfigDict], IMutator] + MatcherFactory = Callable[[ConfigDict], IMatcher] + SupervisorFactory = Callable[[ConfigDict], ISupervisor] + InterpretationFactory = Callable[[ConfigDict], IInterpretation] + + FeatureInterpretation: TypeAlias = dict[str, "FeatureInterpretation"] | list["FeatureInterpretation"] | str | int | float | bool | None diff --git a/src/tests/unit/__init__.py b/src/tests/unit/__init__.py index 9718612..ea41014 100644 --- a/src/tests/unit/__init__.py +++ b/src/tests/unit/__init__.py @@ -1,3 +1,3 @@ """ Unit tests. -""" +""" \ No newline at end of file diff --git a/src/tests/unit/api/__init__.py b/src/tests/unit/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/unit/api/test_basic.py b/src/tests/unit/api/test_basic.py new file mode 100644 index 0000000..3a9c76e --- /dev/null +++ b/src/tests/unit/api/test_basic.py @@ -0,0 +1,57 @@ +import pytest +from officialeye import Context, Template +from officialeye.error.errors.internal import ErrInvalidState + + +def test_context_reenter(): + + with Context() as context: + with pytest.raises(ErrInvalidState): + with context as _: + pass + + with Context() as context: + with pytest.raises(ErrInvalidState): + with context as _: + pass + + +def test_illegal_dispose(): + + with pytest.raises(ErrInvalidState): + with Context() as context: + context.dispose() + + +def test_template_load(): + + with Context() as context: + template = Template(context, path="docs/assets/templates/driver_license_ru_01/driver_license_ru.yml") + assert template.identifier == "driver_license_ru" + assert template.name == "Driver License RU" + assert len([k for k in template.keypoints]) == 6 + assert len([f for f in template.features]) == 15 + + with Context() as context: + template = Template(context, path="docs/assets/templates/driver_license_ru_01/driver_license_ru.yml") + assert len([f for f in template.features]) == 15 + assert template.name == "Driver License RU" + + +def test_image_dimensions(): + + with Context() as context: + template = Template(context, path="docs/assets/templates/driver_license_ru_01/driver_license_ru.yml") + img = template.get_image().load() + h, w, _ = img.shape + assert template.width == w + assert template.height == h + + +def test_mutated_image_dimensions(): + with Context() as context: + template = Template(context, path="docs/assets/templates/driver_license_ru_01/driver_license_ru.yml") + img = template.get_mutated_image().load() + h, w, _ = img.shape + assert template.width == w + assert template.height == h diff --git a/src/tests/unit/api/test_detect.py b/src/tests/unit/api/test_detect.py new file mode 100644 index 0000000..661769d --- /dev/null +++ b/src/tests/unit/api/test_detect.py @@ -0,0 +1,15 @@ +from officialeye import Context, IImage, Image, ISupervisionResult, Template +from officialeye.detection import detect + + +def test_detect(): + + with Context() as context: + template = Template(context, path="docs/assets/templates/driver_license_ru_01/driver_license_ru.yml") + + image: IImage = Image(context, path="docs/assets/templates/driver_license_ru_01/examples/01.jpg") + assert isinstance(image, IImage) + + result = detect(context, template, target=image) + + assert isinstance(result, ISupervisionResult)