Skip to content

Commit

Permalink
Code simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcombessie committed Aug 10, 2020
1 parent e2fffa4 commit 1b2a3c8
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 105 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# -*- coding: utf-8 -*-
import json
from typing import List, Union, Dict, AnyStr
from ratelimit import limits, RateLimitException
from retry import retry

from google.protobuf.json_format import MessageToDict
from google.api_core.exceptions import GoogleAPIError

from plugin_config_loader import load_plugin_config
from google_vision_api_client import GoogleCloudVisionAPIWrapper
from dku_io_utils import generate_path_df, set_column_description
Expand Down Expand Up @@ -36,27 +32,17 @@
def call_api_content_detection(
max_results: int, row: Dict = None, batch: List[Dict] = None
) -> Union[List[Dict], AnyStr]:
features = [{"type": c, "max_results": max_results} for c in config["content_categories"]]
if config["input_folder_is_gcs"]:
image_requests = [
api_wrapper.batch_api_gcs_image_request(
folder_bucket=config["input_folder_bucket"],
folder_root_path=config["input_folder_root_path"],
path=row.get(PATH_COLUMN),
features=features,
)
for row in batch
]
responses = api_wrapper.client.batch_annotate_images(image_requests)
return responses
else:
image_path = row.get(PATH_COLUMN)
with config["input_folder"].get_download_stream(image_path) as stream:
image_request = {"image": {"content": stream.read()}, "features": features}
response_dict = MessageToDict(api_wrapper.client.annotate_image(image_request))
if "error" in response_dict.keys(): # Required as annotate_image does not raise exceptions
raise GoogleAPIError(response_dict.get("error", {}).get("message", ""))
return json.dumps(response_dict)
results = api_wrapper.call_api_annotate_image(
row=row,
batch=batch,
path_column=PATH_COLUMN,
folder=config.get("input_folder"),
folder_is_gcs=config.get("input_folder_is_gcs"),
folder_bucket=config.get("input_folder_bucket"),
folder_root_path=config.get("input_folder_root_path"),
features=[{"type": c, "max_results": max_results} for c in config.get("content_categories", [])],
)
return results


df = api_parallelizer(
Expand Down
38 changes: 12 additions & 26 deletions custom-recipes/google-cloud-vision-cropping/recipe.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
# -*- coding: utf-8 -*-
import json
from typing import List, Union, Dict, AnyStr
from ratelimit import limits, RateLimitException
from retry import retry

from google.cloud import vision
from google.protobuf.json_format import MessageToDict
from google.api_core.exceptions import GoogleAPIError

from plugin_config_loader import load_plugin_config
from google_vision_api_client import GoogleCloudVisionAPIWrapper
Expand Down Expand Up @@ -35,29 +32,18 @@
@retry((RateLimitException, OSError), delay=config["api_quota_period"], tries=5)
@limits(calls=config["api_quota_rate_limit"], period=config["api_quota_period"])
def call_api_crop_hints(aspect_ratio: float, row: Dict = None, batch: List[Dict] = None) -> Union[List[Dict], AnyStr]:
features = [{"type": vision.enums.Feature.Type.CROP_HINTS}]
image_context = {"crop_hints_params": {"aspect_ratios": [aspect_ratio]}}
if config["input_folder_is_gcs"]:
image_requests = [
api_wrapper.batch_api_gcs_image_request(
folder_bucket=config["input_folder_bucket"],
folder_root_path=config["input_folder_root_path"],
path=row.get(PATH_COLUMN),
features=features,
image_context=image_context,
)
for row in batch
]
responses = api_wrapper.client.batch_annotate_images(image_requests)
return responses
else:
image_path = row.get(PATH_COLUMN)
with config["input_folder"].get_download_stream(image_path) as stream:
image_request = {"image": {"content": stream.read()}, "features": features, "image_context": image_context}
response_dict = MessageToDict(api_wrapper.client.annotate_image(image_request))
if "error" in response_dict.keys(): # Required as annotate_image does not raise exceptions
raise GoogleAPIError(response_dict.get("error", {}).get("message", ""))
return json.dumps(response_dict)
results = api_wrapper.call_api_annotate_image(
row=row,
batch=batch,
path_column=PATH_COLUMN,
folder=config.get("input_folder"),
folder_is_gcs=config.get("input_folder_is_gcs"),
folder_bucket=config.get("input_folder_bucket"),
folder_root_path=config.get("input_folder_root_path"),
features=[{"type": vision.enums.Feature.Type.CROP_HINTS}],
image_context={"crop_hints_params": {"aspect_ratios": [aspect_ratio]}},
)
return results


df = api_parallelizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# SETUP
# ==============================================================================

config = load_plugin_config(divide_quota_with_batch_size=False) # edge case
config = load_plugin_config(mandatory_output="folder", divide_quota_with_batch_size=False) # edge case
column_prefix = "text_api"

api_wrapper = GoogleCloudVisionAPIWrapper(gcp_service_account_key=config["gcp_service_account_key"])
Expand Down
39 changes: 12 additions & 27 deletions custom-recipes/google-cloud-vision-image-text-detection/recipe.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# -*- coding: utf-8 -*-
import json
from typing import List, Union, Dict, AnyStr
from ratelimit import limits, RateLimitException
from retry import retry

from google.protobuf.json_format import MessageToDict
from google.api_core.exceptions import GoogleAPIError

from plugin_config_loader import load_plugin_config
from google_vision_api_client import GoogleCloudVisionAPIWrapper
from dku_io_utils import generate_path_df, set_column_description
Expand Down Expand Up @@ -36,29 +32,18 @@
def call_api_text_detection(
language_hints: List[AnyStr], row: Dict = None, batch: List[Dict] = None
) -> Union[List[Dict], AnyStr]:
features = [{"type": config["ocr_model"]}]
image_context = {"language_hints": language_hints}
if config["input_folder_is_gcs"]:
image_requests = [
api_wrapper.batch_api_gcs_image_request(
folder_bucket=config["input_folder_bucket"],
folder_root_path=config["input_folder_root_path"],
path=row.get(PATH_COLUMN),
features=features,
image_context=image_context,
)
for row in batch
]
responses = api_wrapper.client.batch_annotate_images(image_requests)
return responses
else:
image_path = row.get(PATH_COLUMN)
with config["input_folder"].get_download_stream(image_path) as stream:
image_request = {"image": {"content": stream.read()}, "features": features, "image_context": image_context}
response_dict = MessageToDict(api_wrapper.client.annotate_image(image_request))
if "error" in response_dict.keys(): # Required as annotate_image does not raise exceptions
raise GoogleAPIError(response_dict.get("error", {}).get("message", ""))
return json.dumps(response_dict)
results = api_wrapper.call_api_annotate_image(
row=row,
batch=batch,
path_column=PATH_COLUMN,
folder=config.get("input_folder"),
folder_is_gcs=config.get("input_folder_is_gcs"),
folder_bucket=config.get("input_folder_bucket"),
folder_root_path=config.get("input_folder_root_path"),
features=[{"type": config.get("ocr_model")}],
image_context={"language_hints": language_hints},
)
return results


df = api_parallelizer(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
# -*- coding: utf-8 -*-
import json
from typing import List, Union, Dict, AnyStr
from ratelimit import limits, RateLimitException
from retry import retry

from google.cloud import vision
from google.protobuf.json_format import MessageToDict
from google.api_core.exceptions import GoogleAPIError

from plugin_config_loader import load_plugin_config
from google_vision_api_client import GoogleCloudVisionAPIWrapper
Expand Down Expand Up @@ -35,27 +32,17 @@
@retry((RateLimitException, OSError), delay=config["api_quota_period"], tries=5)
@limits(calls=config["api_quota_rate_limit"], period=config["api_quota_period"])
def call_api_moderation(row: Dict = None, batch: List[Dict] = None) -> Union[List[Dict], AnyStr]:
features = [{"type": vision.enums.Feature.Type.SAFE_SEARCH_DETECTION}]
if config["input_folder_is_gcs"]:
image_requests = [
api_wrapper.batch_api_gcs_image_request(
folder_bucket=config["input_folder_bucket"],
folder_root_path=config["input_folder_root_path"],
path=row.get(PATH_COLUMN),
features=features,
)
for row in batch
]
responses = api_wrapper.client.batch_annotate_images(image_requests)
return responses
else:
image_path = row.get(PATH_COLUMN)
with config["input_folder"].get_download_stream(image_path) as stream:
image_request = {"image": {"content": stream.read()}, "features": features}
response_dict = MessageToDict(api_wrapper.client.annotate_image(image_request))
if "error" in response_dict.keys(): # Required as annotate_image does not raise exceptions
raise GoogleAPIError(response_dict.get("error", {}).get("message", ""))
return json.dumps(response_dict)
results = api_wrapper.call_api_annotate_image(
row=row,
batch=batch,
path_column=PATH_COLUMN,
folder=config.get("input_folder"),
folder_is_gcs=config.get("input_folder_is_gcs"),
folder_bucket=config.get("input_folder_bucket"),
folder_root_path=config.get("input_folder_root_path"),
features=[{"type": vision.enums.Feature.Type.SAFE_SEARCH_DETECTION}],
)
return results


df = api_parallelizer(
Expand Down
4 changes: 2 additions & 2 deletions python-lib/dku_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
Input/Output plugin utility functions which *REQUIRE* the Dataiku API
"""

import dataiku
from typing import Dict, AnyStr, List, Callable

import pandas as pd

import dataiku

from plugin_io_utils import PATH_COLUMN


Expand Down
40 changes: 40 additions & 0 deletions python-lib/google_vision_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from google.oauth2 import service_account
from google.protobuf.json_format import MessageToDict

import dataiku


class GoogleCloudVisionAPIWrapper:
"""
Expand Down Expand Up @@ -102,3 +104,41 @@ def batch_api_response_parser(
batch[i][api_column_names.error_type] = error_raw.get("code", "")
batch[i][api_column_names.error_raw] = error_raw
return batch

def call_api_annotate_image(
self,
folder: dataiku.Folder,
features: Dict,
image_context: Dict = None,
row: Dict = None,
batch: List[Dict] = None,
path_column: AnyStr = "",
folder_is_gcs: bool = False,
folder_bucket: AnyStr = "",
folder_root_path: AnyStr = "",
) -> Union[List[Dict], AnyStr]:
if folder_is_gcs:
image_requests = [
self.batch_api_gcs_image_request(
folder_bucket=folder_bucket,
folder_root_path=folder_root_path,
path=row.get(path_column),
features=features,
image_context=image_context,
)
for row in batch
]
responses = self.client.batch_annotate_images(image_requests)
return responses
else:
image_path = row.get(path_column)
with folder.get_download_stream(image_path) as stream:
image_request = {
"image": {"content": stream.read()},
"features": features,
"image_context": image_context,
}
response_dict = MessageToDict(self.client.annotate_image(image_request))
if "error" in response_dict.keys(): # Required as annotate_image does not raise exceptions
raise GoogleAPIError(response_dict.get("error", {}).get("message", ""))
return json.dumps(response_dict)
3 changes: 3 additions & 0 deletions python-lib/plugin_config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def load_plugin_config(mandatory_output: AnyStr = "dataset", divide_quota_with_b
config = {}
# Input folder configuration
input_folder_names = get_input_names_for_role("input_folder")
assert len(input_folder_names) != 0, "Please specify input folder"
config["input_folder"] = dataiku.Folder(input_folder_names[0])
config["api_support_batch"] = False
config["input_folder_is_gcs"] = config["input_folder"].get_info().get("type", "") == "GCS"
Expand All @@ -42,11 +43,13 @@ def load_plugin_config(mandatory_output: AnyStr = "dataset", divide_quota_with_b
output_dataset_names = get_output_names_for_role("output_dataset")
config["output_dataset"] = None
if mandatory_output == "dataset" or len(output_dataset_names) != 0:
assert len(output_dataset_names) != 0, "Please specify output dataset"
config["output_dataset"] = dataiku.Dataset(output_dataset_names[0])
# Output folder configuration
output_folder_names = get_output_names_for_role("output_folder") # optional output
config["output_folder"] = None
if mandatory_output == "folder" or len(output_folder_names) != 0:
assert len(output_folder_names) != 0, "Please specify output folder"
config["output_folder"] = dataiku.Folder(output_folder_names[0])
config["output_folder_is_gcs"] = config["output_folder"].get_info().get("type", "") == "GCS"
if config["output_folder_is_gcs"]:
Expand Down

0 comments on commit 1b2a3c8

Please sign in to comment.