diff --git a/api/app.py b/api/app.py index d352a56..b9604e7 100644 --- a/api/app.py +++ b/api/app.py @@ -1,5 +1,8 @@ # Imports +# > Standard Libraries +import logging + # > External Libraries from flask import Flask @@ -11,21 +14,28 @@ from api.services.model_setup import PredictorGenPageWrapper from main import setup_logging +from utils.logging_utils import get_logger_name def create_app(): - # Read environment variables - max_queue_size, model_base_path, output_base_path = \ - read_environment_variables() - # Capture logging setup_logging() + logger = logging.getLogger(get_logger_name()) + logger.info("Starting Laypa API") - predict_gen_page_wrapper = PredictorGenPageWrapper(model_base_path) - + # Read environment variables + logger.info("Initializing environment") + max_queue_size, model_base_path, output_base_path = \ + read_environment_variables() args, executor, queue_size_gauge, images_processed_counter, \ exception_predict_counter = initialize_environment(max_queue_size, output_base_path) + logger.info("Environment initialized successfully") + + # Initialize model + logger.info("Initializing model wrapper") + predict_gen_page_wrapper = PredictorGenPageWrapper(model_base_path) + logger.info("Model wrapper initialized successfully") app = Flask(__name__) @@ -37,6 +47,9 @@ def create_app(): app.predict_gen_page_wrapper = predict_gen_page_wrapper app.output_base_path = output_base_path app.images_processed_counter = images_processed_counter + app.max_queue_size = max_queue_size + + logger.info("Laypa API started successfully") return app diff --git a/api/routes/utils.py b/api/routes/utils.py new file mode 100644 index 0000000..1fc5da6 --- /dev/null +++ b/api/routes/utils.py @@ -0,0 +1,45 @@ +# Imports + +# > Standard Libraries +from pathlib import Path + +# > External Libraries +from flask import Request + +# > Project Libraries +from api.services.utils import abort_with_info + + +def extract_request_fields(request: Request, + response_info: dict[str, str]) \ + -> tuple[Request, Path]: + try: + identifier = request.form["identifier"] + response_info["identifier"] = identifier + except KeyError: + abort_with_info(400, "Missing identifier in form", response_info) + + try: + model_name = request.form["model"] + response_info["model_name"] = model_name + except KeyError: + abort_with_info(400, "Missing model in form", response_info) + + try: + whitelist = request.form.getlist("whitelist") + response_info["whitelist"] = whitelist + except KeyError: + abort_with_info(400, "Missing whitelist in form", response_info) + + try: + post_file = request.files["image"] + except KeyError: + abort_with_info(400, "Missing image in form", response_info) + + if (image_name := post_file.filename) is not None: + image_name = Path(image_name) + response_info["filename"] = str(image_name) + else: + abort_with_info(400, "Missing filename", response_info) + + return post_file, image_name, identifier, model_name, whitelist diff --git a/api/services/image_processing.py b/api/services/image_processing.py index 9a0f230..77f0a42 100644 --- a/api/services/image_processing.py +++ b/api/services/image_processing.py @@ -7,9 +7,13 @@ # > External Libraries import torch +# > Project Libraries +from utils.logging_utils import get_logger_name + +logger = logging.getLogger(get_logger_name()) + def safe_predict(data, device, predict_gen_page_wrapper) -> Any: - logger = logging.getLogger(__name__) try: return predict_gen_page_wrapper.predictor(data, device) diff --git a/api/services/model_setup.py b/api/services/model_setup.py index e7fbf87..654b3b7 100644 --- a/api/services/model_setup.py +++ b/api/services/model_setup.py @@ -11,6 +11,7 @@ from page_xml.xml_regions import XMLRegions from run import Predictor from api.setup.initialization import DummyArgs +from utils.logging_utils import get_logger_name class PredictorGenPageWrapper: @@ -22,7 +23,7 @@ def __init__(self, model_base_path: Path) -> None: self.model_name: Optional[str] = None self.predictor: Optional[Predictor] = None self.gen_page: Optional[OutputPageXML] = None - self.logger = logging.getLogger(__name__) + self.logger = logging.getLogger(get_logger_name()) self.model_base_path = model_base_path def setup_model(self, model_name: str, args: DummyArgs): @@ -30,8 +31,10 @@ def setup_model(self, model_name: str, args: DummyArgs): Create the model and post-processing code Args: - model_name (str): Model name, used to determine what model to load from models present in base path - args (DummyArgs): Dummy version of command line arguments, to set up config + model_name (str): Model name, used to determine what model to load + from models present in base path + args (DummyArgs): Dummy version of command line arguments, to set + up config """ if ( model_name is not None @@ -41,6 +44,7 @@ def setup_model(self, model_name: str, args: DummyArgs): ): return + self.logger.info(f"Setting up model: {model_name}") self.model_name = model_name model_path = self.model_base_path.joinpath(self.model_name) config_path = model_path.joinpath("config.yaml") @@ -63,3 +67,4 @@ def setup_model(self, model_name: str, args: DummyArgs): ) self.predictor = Predictor(cfg=cfg) + self.logger.info(f"Model {model_name} loaded successfully") diff --git a/api/services/prediction_service.py b/api/services/prediction_service.py index 6dd81b6..3afa93d 100644 --- a/api/services/prediction_service.py +++ b/api/services/prediction_service.py @@ -1,60 +1,54 @@ -from flask import current_app +# Imports + +# > Standard Libraries from pathlib import Path from typing import Any, Optional import time +import logging + +# > External Libraries import numpy as np import torch -import sys +from flask import current_app, Request +# > Project Libraries from api.models.response_info import ResponseInfo from api.services.utils import abort_with_info, check_exception_callback from api.services.image_processing import safe_predict - -sys.path.append(str(Path(__file__).resolve().parent.joinpath("../.."))) # noqa: E402 +from api.routes.utils import extract_request_fields from data.mapper import AugInput from utils.image_utils import load_image_array_from_bytes +from utils.logging_utils import get_logger_name + +logger = logging.getLogger(get_logger_name()) + +def process_prediction(request: Request) -> ResponseInfo: + """ + Start the prediction process for the given image by submitting it to the + executor + + Args: + request (Request): Request object from the API + + Returns: + ResponseInfo: Information about the processed image + """ -def process_prediction(request, max_queue_size: int = 1000) -> ResponseInfo: executor = current_app.executor args = current_app.args output_base_path = current_app.output_base_path predict_gen_page_wrapper = current_app.predict_gen_page_wrapper images_processed_counter = current_app.images_processed_counter + max_queue_size = current_app.max_queue_size response_info = ResponseInfo(status_code=500) current_time = time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(time.time())) response_info["added_time"] = current_time - try: - identifier = request.form["identifier"] - response_info["identifier"] = identifier - except KeyError: - abort_with_info(400, "Missing identifier in form", response_info) - - try: - model_name = request.form["model"] - response_info["model_name"] = model_name - except KeyError: - abort_with_info(400, "Missing model in form", response_info) - - try: - whitelist = request.form.getlist("whitelist") - response_info["whitelist"] = whitelist - except KeyError: - abort_with_info(400, "Missing whitelist in form", response_info) - - try: - post_file = request.files["image"] - except KeyError: - abort_with_info(400, "Missing image in form", response_info) - - if (image_name := post_file.filename) is not None: - image_name = Path(image_name) - response_info["filename"] = str(image_name) - else: - abort_with_info(400, "Missing filename", response_info) + post_file, image_name, identifier, model_name, whitelist \ + = extract_request_fields(request, response_info) queue_size = executor._work_queue.qsize() response_info["added_queue_position"] = queue_size @@ -70,12 +64,14 @@ def process_prediction(request, max_queue_size: int = 1000) -> ResponseInfo: 500, "Image could not be loaded correctly", response_info) future = executor.submit( - predict_image, data["image"], data["dpi"], image_name, identifier, model_name, whitelist, - predict_gen_page_wrapper, output_base_path, images_processed_counter, args + predict_image, data["image"], data["dpi"], image_name, identifier, + model_name, whitelist, predict_gen_page_wrapper, output_base_path, + images_processed_counter, args ) future.add_done_callback(check_exception_callback) - response_info["status_code"] = 202 + + logger.info(f"Image {image_name} added to queue") return response_info @@ -91,6 +87,28 @@ def predict_image( images_processed_counter, args: Any ) -> dict[str, Any]: + """ + Run the prediction for the given image + + Args: + image (np.ndarray): Image array send to model prediction + dpi (Optional[int]): DPI (dots per inch) of the image + image_path (Path): Path to the image file + identifier (str): Unique identifier for the image + model_name (str): Name of the model to use for prediction + whitelist (list[str]): List of characters to whitelist during prediction + predict_gen_page_wrapper: Wrapper for the Predictor and GenPageXML + output_base_path (Path): Base path for the output + images_processed_counter: Counter for the number of images processed + args (Any): Arguments for the model setup + + Raises: + TypeError: If the current GenPageXML is not initialized + TypeError: If the current Predictor is not initialized + + Returns: + dict[str, Any]: Information about the processed image + """ input_args = locals() try: predict_gen_page_wrapper.setup_model(args=args, model_name=model_name) @@ -121,17 +139,23 @@ def predict_image( output_image = outputs[0]["sem_seg"] predict_gen_page_wrapper.gen_page.generate_single_page( - output_image, output_path, old_height=outputs[1], old_width=outputs[2] + output_image, output_path, old_height=outputs[1], + old_width=outputs[2] ) images_processed_counter.inc() + + logger.info(f"Prediction complete for {image_path}") return input_args except Exception as exception: + # Catch CUDA out of memory errors if isinstance(exception, torch.cuda.OutOfMemoryError) or ( - isinstance(exception, RuntimeError) and "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str( - exception) + isinstance(exception, RuntimeError) + and "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str(exception) ): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() + # HACK remove traceback to prevent complete halt of program, not + # sure why this happens exception = exception.with_traceback(None) return input_args | {"exception": exception} diff --git a/api/setup/environment.py b/api/setup/environment.py index cbd8f0a..bb40a51 100644 --- a/api/setup/environment.py +++ b/api/setup/environment.py @@ -1,11 +1,17 @@ # Imports # > Standard Libraries +import logging import os from pathlib import Path +# > External Libraries +from utils.logging_utils import get_logger_name + def read_environment_variables(): + logger = logging.getLogger(get_logger_name()) + try: max_queue_size_string: str = os.environ["LAYPA_MAX_QUEUE_SIZE"] model_base_path_string: str = os.environ["LAYPA_MODEL_BASE_PATH"] @@ -26,4 +32,9 @@ def read_environment_variables(): raise FileNotFoundError( f"LAYPA_OUTPUT_BASE_PATH: {output_base_path} is not found in the current filesystem") + logger.debug("Running with the following environment variables:") + logger.debug(f"LAYPA_MAX_QUEUE_SIZE: {max_queue_size}") + logger.debug(f"LAYPA_MODEL_BASE_PATH: {model_base_path}") + logger.debug(f"LAYPA_OUTPUT_BASE_PATH: {output_base_path}") + return max_queue_size, model_base_path, output_base_path