diff --git a/custom-recipes/google-cloud-vision-content-detection-labeling/recipe.py b/custom-recipes/google-cloud-vision-content-detection-labeling/recipe.py index 829de14..3bb4832 100644 --- a/custom-recipes/google-cloud-vision-content-detection-labeling/recipe.py +++ b/custom-recipes/google-cloud-vision-content-detection-labeling/recipe.py @@ -1,12 +1,8 @@ # -*- coding: utf-8 -*- -import json from typing import List, Union, Dict, AnyStr from ratelimit import limits, RateLimitException from retry import retry -from google.protobuf.json_format import MessageToDict -from google.api_core.exceptions import GoogleAPIError - from plugin_config_loader import load_plugin_config from google_vision_api_client import GoogleCloudVisionAPIWrapper from dku_io_utils import generate_path_df, set_column_description @@ -36,27 +32,17 @@ def call_api_content_detection( max_results: int, row: Dict = None, batch: List[Dict] = None ) -> Union[List[Dict], AnyStr]: - features = [{"type": c, "max_results": max_results} for c in config["content_categories"]] - if config["input_folder_is_gcs"]: - image_requests = [ - api_wrapper.batch_api_gcs_image_request( - folder_bucket=config["input_folder_bucket"], - folder_root_path=config["input_folder_root_path"], - path=row.get(PATH_COLUMN), - features=features, - ) - for row in batch - ] - responses = api_wrapper.client.batch_annotate_images(image_requests) - return responses - else: - image_path = row.get(PATH_COLUMN) - with config["input_folder"].get_download_stream(image_path) as stream: - image_request = {"image": {"content": stream.read()}, "features": features} - response_dict = MessageToDict(api_wrapper.client.annotate_image(image_request)) - if "error" in response_dict.keys(): # Required as annotate_image does not raise exceptions - raise GoogleAPIError(response_dict.get("error", {}).get("message", "")) - return json.dumps(response_dict) + results = api_wrapper.call_api_annotate_image( + row=row, + batch=batch, + path_column=PATH_COLUMN, + folder=config.get("input_folder"), + folder_is_gcs=config.get("input_folder_is_gcs"), + folder_bucket=config.get("input_folder_bucket"), + folder_root_path=config.get("input_folder_root_path"), + features=[{"type": c, "max_results": max_results} for c in config.get("content_categories", [])], + ) + return results df = api_parallelizer( diff --git a/custom-recipes/google-cloud-vision-cropping/recipe.py b/custom-recipes/google-cloud-vision-cropping/recipe.py index 952fa83..b403b99 100644 --- a/custom-recipes/google-cloud-vision-cropping/recipe.py +++ b/custom-recipes/google-cloud-vision-cropping/recipe.py @@ -1,12 +1,9 @@ # -*- coding: utf-8 -*- -import json from typing import List, Union, Dict, AnyStr from ratelimit import limits, RateLimitException from retry import retry from google.cloud import vision -from google.protobuf.json_format import MessageToDict -from google.api_core.exceptions import GoogleAPIError from plugin_config_loader import load_plugin_config from google_vision_api_client import GoogleCloudVisionAPIWrapper @@ -35,29 +32,18 @@ @retry((RateLimitException, OSError), delay=config["api_quota_period"], tries=5) @limits(calls=config["api_quota_rate_limit"], period=config["api_quota_period"]) def call_api_crop_hints(aspect_ratio: float, row: Dict = None, batch: List[Dict] = None) -> Union[List[Dict], AnyStr]: - features = [{"type": vision.enums.Feature.Type.CROP_HINTS}] - image_context = {"crop_hints_params": {"aspect_ratios": [aspect_ratio]}} - if config["input_folder_is_gcs"]: - image_requests = [ - api_wrapper.batch_api_gcs_image_request( - folder_bucket=config["input_folder_bucket"], - folder_root_path=config["input_folder_root_path"], - path=row.get(PATH_COLUMN), - features=features, - image_context=image_context, - ) - for row in batch - ] - responses = api_wrapper.client.batch_annotate_images(image_requests) - return responses - else: - image_path = row.get(PATH_COLUMN) - with config["input_folder"].get_download_stream(image_path) as stream: - image_request = {"image": {"content": stream.read()}, "features": features, "image_context": image_context} - response_dict = MessageToDict(api_wrapper.client.annotate_image(image_request)) - if "error" in response_dict.keys(): # Required as annotate_image does not raise exceptions - raise GoogleAPIError(response_dict.get("error", {}).get("message", "")) - return json.dumps(response_dict) + results = api_wrapper.call_api_annotate_image( + row=row, + batch=batch, + path_column=PATH_COLUMN, + folder=config.get("input_folder"), + folder_is_gcs=config.get("input_folder_is_gcs"), + folder_bucket=config.get("input_folder_bucket"), + folder_root_path=config.get("input_folder_root_path"), + features=[{"type": vision.enums.Feature.Type.CROP_HINTS}], + image_context={"crop_hints_params": {"aspect_ratios": [aspect_ratio]}}, + ) + return results df = api_parallelizer( diff --git a/custom-recipes/google-cloud-vision-document-text-detection/recipe.py b/custom-recipes/google-cloud-vision-document-text-detection/recipe.py index ba6fffe..d9bd77e 100644 --- a/custom-recipes/google-cloud-vision-document-text-detection/recipe.py +++ b/custom-recipes/google-cloud-vision-document-text-detection/recipe.py @@ -18,7 +18,7 @@ # SETUP # ============================================================================== -config = load_plugin_config(divide_quota_with_batch_size=False) # edge case +config = load_plugin_config(mandatory_output="folder", divide_quota_with_batch_size=False) # edge case column_prefix = "text_api" api_wrapper = GoogleCloudVisionAPIWrapper(gcp_service_account_key=config["gcp_service_account_key"]) diff --git a/custom-recipes/google-cloud-vision-image-text-detection/recipe.py b/custom-recipes/google-cloud-vision-image-text-detection/recipe.py index 40ad2e1..f848d5c 100644 --- a/custom-recipes/google-cloud-vision-image-text-detection/recipe.py +++ b/custom-recipes/google-cloud-vision-image-text-detection/recipe.py @@ -1,12 +1,8 @@ # -*- coding: utf-8 -*- -import json from typing import List, Union, Dict, AnyStr from ratelimit import limits, RateLimitException from retry import retry -from google.protobuf.json_format import MessageToDict -from google.api_core.exceptions import GoogleAPIError - from plugin_config_loader import load_plugin_config from google_vision_api_client import GoogleCloudVisionAPIWrapper from dku_io_utils import generate_path_df, set_column_description @@ -36,29 +32,18 @@ def call_api_text_detection( language_hints: List[AnyStr], row: Dict = None, batch: List[Dict] = None ) -> Union[List[Dict], AnyStr]: - features = [{"type": config["ocr_model"]}] - image_context = {"language_hints": language_hints} - if config["input_folder_is_gcs"]: - image_requests = [ - api_wrapper.batch_api_gcs_image_request( - folder_bucket=config["input_folder_bucket"], - folder_root_path=config["input_folder_root_path"], - path=row.get(PATH_COLUMN), - features=features, - image_context=image_context, - ) - for row in batch - ] - responses = api_wrapper.client.batch_annotate_images(image_requests) - return responses - else: - image_path = row.get(PATH_COLUMN) - with config["input_folder"].get_download_stream(image_path) as stream: - image_request = {"image": {"content": stream.read()}, "features": features, "image_context": image_context} - response_dict = MessageToDict(api_wrapper.client.annotate_image(image_request)) - if "error" in response_dict.keys(): # Required as annotate_image does not raise exceptions - raise GoogleAPIError(response_dict.get("error", {}).get("message", "")) - return json.dumps(response_dict) + results = api_wrapper.call_api_annotate_image( + row=row, + batch=batch, + path_column=PATH_COLUMN, + folder=config.get("input_folder"), + folder_is_gcs=config.get("input_folder_is_gcs"), + folder_bucket=config.get("input_folder_bucket"), + folder_root_path=config.get("input_folder_root_path"), + features=[{"type": config.get("ocr_model")}], + image_context={"language_hints": language_hints}, + ) + return results df = api_parallelizer( diff --git a/custom-recipes/google-cloud-vision-unsafe-content-moderation/recipe.py b/custom-recipes/google-cloud-vision-unsafe-content-moderation/recipe.py index 570a89b..280df6d 100644 --- a/custom-recipes/google-cloud-vision-unsafe-content-moderation/recipe.py +++ b/custom-recipes/google-cloud-vision-unsafe-content-moderation/recipe.py @@ -1,12 +1,9 @@ # -*- coding: utf-8 -*- -import json from typing import List, Union, Dict, AnyStr from ratelimit import limits, RateLimitException from retry import retry from google.cloud import vision -from google.protobuf.json_format import MessageToDict -from google.api_core.exceptions import GoogleAPIError from plugin_config_loader import load_plugin_config from google_vision_api_client import GoogleCloudVisionAPIWrapper @@ -35,27 +32,17 @@ @retry((RateLimitException, OSError), delay=config["api_quota_period"], tries=5) @limits(calls=config["api_quota_rate_limit"], period=config["api_quota_period"]) def call_api_moderation(row: Dict = None, batch: List[Dict] = None) -> Union[List[Dict], AnyStr]: - features = [{"type": vision.enums.Feature.Type.SAFE_SEARCH_DETECTION}] - if config["input_folder_is_gcs"]: - image_requests = [ - api_wrapper.batch_api_gcs_image_request( - folder_bucket=config["input_folder_bucket"], - folder_root_path=config["input_folder_root_path"], - path=row.get(PATH_COLUMN), - features=features, - ) - for row in batch - ] - responses = api_wrapper.client.batch_annotate_images(image_requests) - return responses - else: - image_path = row.get(PATH_COLUMN) - with config["input_folder"].get_download_stream(image_path) as stream: - image_request = {"image": {"content": stream.read()}, "features": features} - response_dict = MessageToDict(api_wrapper.client.annotate_image(image_request)) - if "error" in response_dict.keys(): # Required as annotate_image does not raise exceptions - raise GoogleAPIError(response_dict.get("error", {}).get("message", "")) - return json.dumps(response_dict) + results = api_wrapper.call_api_annotate_image( + row=row, + batch=batch, + path_column=PATH_COLUMN, + folder=config.get("input_folder"), + folder_is_gcs=config.get("input_folder_is_gcs"), + folder_bucket=config.get("input_folder_bucket"), + folder_root_path=config.get("input_folder_root_path"), + features=[{"type": vision.enums.Feature.Type.SAFE_SEARCH_DETECTION}], + ) + return results df = api_parallelizer( diff --git a/python-lib/dku_io_utils.py b/python-lib/dku_io_utils.py index a844e7a..66e62d4 100644 --- a/python-lib/dku_io_utils.py +++ b/python-lib/dku_io_utils.py @@ -4,11 +4,11 @@ Input/Output plugin utility functions which *REQUIRE* the Dataiku API """ -import dataiku from typing import Dict, AnyStr, List, Callable - import pandas as pd +import dataiku + from plugin_io_utils import PATH_COLUMN diff --git a/python-lib/google_vision_api_client.py b/python-lib/google_vision_api_client.py index df4adfd..b24790a 100644 --- a/python-lib/google_vision_api_client.py +++ b/python-lib/google_vision_api_client.py @@ -15,6 +15,8 @@ from google.oauth2 import service_account from google.protobuf.json_format import MessageToDict +import dataiku + class GoogleCloudVisionAPIWrapper: """ @@ -102,3 +104,41 @@ def batch_api_response_parser( batch[i][api_column_names.error_type] = error_raw.get("code", "") batch[i][api_column_names.error_raw] = error_raw return batch + + def call_api_annotate_image( + self, + folder: dataiku.Folder, + features: Dict, + image_context: Dict = None, + row: Dict = None, + batch: List[Dict] = None, + path_column: AnyStr = "", + folder_is_gcs: bool = False, + folder_bucket: AnyStr = "", + folder_root_path: AnyStr = "", + ) -> Union[List[Dict], AnyStr]: + if folder_is_gcs: + image_requests = [ + self.batch_api_gcs_image_request( + folder_bucket=folder_bucket, + folder_root_path=folder_root_path, + path=row.get(path_column), + features=features, + image_context=image_context, + ) + for row in batch + ] + responses = self.client.batch_annotate_images(image_requests) + return responses + else: + image_path = row.get(path_column) + with folder.get_download_stream(image_path) as stream: + image_request = { + "image": {"content": stream.read()}, + "features": features, + "image_context": image_context, + } + response_dict = MessageToDict(self.client.annotate_image(image_request)) + if "error" in response_dict.keys(): # Required as annotate_image does not raise exceptions + raise GoogleAPIError(response_dict.get("error", {}).get("message", "")) + return json.dumps(response_dict) diff --git a/python-lib/plugin_config_loader.py b/python-lib/plugin_config_loader.py index c67c481..80e4cca 100644 --- a/python-lib/plugin_config_loader.py +++ b/python-lib/plugin_config_loader.py @@ -29,6 +29,7 @@ def load_plugin_config(mandatory_output: AnyStr = "dataset", divide_quota_with_b config = {} # Input folder configuration input_folder_names = get_input_names_for_role("input_folder") + assert len(input_folder_names) != 0, "Please specify input folder" config["input_folder"] = dataiku.Folder(input_folder_names[0]) config["api_support_batch"] = False config["input_folder_is_gcs"] = config["input_folder"].get_info().get("type", "") == "GCS" @@ -42,11 +43,13 @@ def load_plugin_config(mandatory_output: AnyStr = "dataset", divide_quota_with_b output_dataset_names = get_output_names_for_role("output_dataset") config["output_dataset"] = None if mandatory_output == "dataset" or len(output_dataset_names) != 0: + assert len(output_dataset_names) != 0, "Please specify output dataset" config["output_dataset"] = dataiku.Dataset(output_dataset_names[0]) # Output folder configuration output_folder_names = get_output_names_for_role("output_folder") # optional output config["output_folder"] = None if mandatory_output == "folder" or len(output_folder_names) != 0: + assert len(output_folder_names) != 0, "Please specify output folder" config["output_folder"] = dataiku.Folder(output_folder_names[0]) config["output_folder_is_gcs"] = config["output_folder"].get_info().get("type", "") == "GCS" if config["output_folder_is_gcs"]: