Skip to content

Commit

Permalink
Improve logging
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 21, 2024
1 parent 136cc68 commit ce2c6b3
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 49 deletions.
25 changes: 19 additions & 6 deletions api/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Imports

# > Standard Libraries
import logging

# > External Libraries
from flask import Flask

Expand All @@ -11,21 +14,28 @@
from api.services.model_setup import PredictorGenPageWrapper

from main import setup_logging
from utils.logging_utils import get_logger_name


def create_app():
# Read environment variables
max_queue_size, model_base_path, output_base_path = \
read_environment_variables()

# Capture logging
setup_logging()
logger = logging.getLogger(get_logger_name())
logger.info("Starting Laypa API")

predict_gen_page_wrapper = PredictorGenPageWrapper(model_base_path)

# Read environment variables
logger.info("Initializing environment")
max_queue_size, model_base_path, output_base_path = \
read_environment_variables()
args, executor, queue_size_gauge, images_processed_counter, \
exception_predict_counter = initialize_environment(max_queue_size,
output_base_path)
logger.info("Environment initialized successfully")

# Initialize model
logger.info("Initializing model wrapper")
predict_gen_page_wrapper = PredictorGenPageWrapper(model_base_path)
logger.info("Model wrapper initialized successfully")

app = Flask(__name__)

Expand All @@ -37,6 +47,9 @@ def create_app():
app.predict_gen_page_wrapper = predict_gen_page_wrapper
app.output_base_path = output_base_path
app.images_processed_counter = images_processed_counter
app.max_queue_size = max_queue_size

logger.info("Laypa API started successfully")

return app

Expand Down
45 changes: 45 additions & 0 deletions api/routes/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Imports

# > Standard Libraries
from pathlib import Path

# > External Libraries
from flask import Request

# > Project Libraries
from api.services.utils import abort_with_info


def extract_request_fields(request: Request,
response_info: dict[str, str]) \
-> tuple[Request, Path]:
try:
identifier = request.form["identifier"]
response_info["identifier"] = identifier
except KeyError:
abort_with_info(400, "Missing identifier in form", response_info)

try:
model_name = request.form["model"]
response_info["model_name"] = model_name
except KeyError:
abort_with_info(400, "Missing model in form", response_info)

try:
whitelist = request.form.getlist("whitelist")
response_info["whitelist"] = whitelist
except KeyError:
abort_with_info(400, "Missing whitelist in form", response_info)

try:
post_file = request.files["image"]
except KeyError:
abort_with_info(400, "Missing image in form", response_info)

if (image_name := post_file.filename) is not None:
image_name = Path(image_name)
response_info["filename"] = str(image_name)
else:
abort_with_info(400, "Missing filename", response_info)

return post_file, image_name, identifier, model_name, whitelist
6 changes: 5 additions & 1 deletion api/services/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
# > External Libraries
import torch

# > Project Libraries
from utils.logging_utils import get_logger_name

logger = logging.getLogger(get_logger_name())


def safe_predict(data, device, predict_gen_page_wrapper) -> Any:
logger = logging.getLogger(__name__)

try:
return predict_gen_page_wrapper.predictor(data, device)
Expand Down
11 changes: 8 additions & 3 deletions api/services/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from page_xml.xml_regions import XMLRegions
from run import Predictor
from api.setup.initialization import DummyArgs
from utils.logging_utils import get_logger_name


class PredictorGenPageWrapper:
Expand All @@ -22,16 +23,18 @@ def __init__(self, model_base_path: Path) -> None:
self.model_name: Optional[str] = None
self.predictor: Optional[Predictor] = None
self.gen_page: Optional[OutputPageXML] = None
self.logger = logging.getLogger(__name__)
self.logger = logging.getLogger(get_logger_name())
self.model_base_path = model_base_path

def setup_model(self, model_name: str, args: DummyArgs):
"""
Create the model and post-processing code
Args:
model_name (str): Model name, used to determine what model to load from models present in base path
args (DummyArgs): Dummy version of command line arguments, to set up config
model_name (str): Model name, used to determine what model to load
from models present in base path
args (DummyArgs): Dummy version of command line arguments, to set
up config
"""
if (
model_name is not None
Expand All @@ -41,6 +44,7 @@ def setup_model(self, model_name: str, args: DummyArgs):
):
return

self.logger.info(f"Setting up model: {model_name}")
self.model_name = model_name
model_path = self.model_base_path.joinpath(self.model_name)
config_path = model_path.joinpath("config.yaml")
Expand All @@ -63,3 +67,4 @@ def setup_model(self, model_name: str, args: DummyArgs):
)

self.predictor = Predictor(cfg=cfg)
self.logger.info(f"Model {model_name} loaded successfully")
102 changes: 63 additions & 39 deletions api/services/prediction_service.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,54 @@
from flask import current_app
# Imports

# > Standard Libraries
from pathlib import Path
from typing import Any, Optional
import time
import logging

# > External Libraries
import numpy as np
import torch
import sys
from flask import current_app, Request

# > Project Libraries
from api.models.response_info import ResponseInfo
from api.services.utils import abort_with_info, check_exception_callback
from api.services.image_processing import safe_predict

sys.path.append(str(Path(__file__).resolve().parent.joinpath("../.."))) # noqa: E402
from api.routes.utils import extract_request_fields
from data.mapper import AugInput
from utils.image_utils import load_image_array_from_bytes
from utils.logging_utils import get_logger_name

logger = logging.getLogger(get_logger_name())


def process_prediction(request: Request) -> ResponseInfo:
"""
Start the prediction process for the given image by submitting it to the
executor
Args:
request (Request): Request object from the API
Returns:
ResponseInfo: Information about the processed image
"""

def process_prediction(request, max_queue_size: int = 1000) -> ResponseInfo:
executor = current_app.executor
args = current_app.args
output_base_path = current_app.output_base_path
predict_gen_page_wrapper = current_app.predict_gen_page_wrapper
images_processed_counter = current_app.images_processed_counter
max_queue_size = current_app.max_queue_size

response_info = ResponseInfo(status_code=500)
current_time = time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
response_info["added_time"] = current_time

try:
identifier = request.form["identifier"]
response_info["identifier"] = identifier
except KeyError:
abort_with_info(400, "Missing identifier in form", response_info)

try:
model_name = request.form["model"]
response_info["model_name"] = model_name
except KeyError:
abort_with_info(400, "Missing model in form", response_info)

try:
whitelist = request.form.getlist("whitelist")
response_info["whitelist"] = whitelist
except KeyError:
abort_with_info(400, "Missing whitelist in form", response_info)

try:
post_file = request.files["image"]
except KeyError:
abort_with_info(400, "Missing image in form", response_info)

if (image_name := post_file.filename) is not None:
image_name = Path(image_name)
response_info["filename"] = str(image_name)
else:
abort_with_info(400, "Missing filename", response_info)
post_file, image_name, identifier, model_name, whitelist \
= extract_request_fields(request, response_info)

queue_size = executor._work_queue.qsize()
response_info["added_queue_position"] = queue_size
Expand All @@ -70,12 +64,14 @@ def process_prediction(request, max_queue_size: int = 1000) -> ResponseInfo:
500, "Image could not be loaded correctly", response_info)

future = executor.submit(
predict_image, data["image"], data["dpi"], image_name, identifier, model_name, whitelist,
predict_gen_page_wrapper, output_base_path, images_processed_counter, args
predict_image, data["image"], data["dpi"], image_name, identifier,
model_name, whitelist, predict_gen_page_wrapper, output_base_path,
images_processed_counter, args
)
future.add_done_callback(check_exception_callback)

response_info["status_code"] = 202

logger.info(f"Image {image_name} added to queue")
return response_info


Expand All @@ -91,6 +87,28 @@ def predict_image(
images_processed_counter,
args: Any
) -> dict[str, Any]:
"""
Run the prediction for the given image
Args:
image (np.ndarray): Image array send to model prediction
dpi (Optional[int]): DPI (dots per inch) of the image
image_path (Path): Path to the image file
identifier (str): Unique identifier for the image
model_name (str): Name of the model to use for prediction
whitelist (list[str]): List of characters to whitelist during prediction
predict_gen_page_wrapper: Wrapper for the Predictor and GenPageXML
output_base_path (Path): Base path for the output
images_processed_counter: Counter for the number of images processed
args (Any): Arguments for the model setup
Raises:
TypeError: If the current GenPageXML is not initialized
TypeError: If the current Predictor is not initialized
Returns:
dict[str, Any]: Information about the processed image
"""
input_args = locals()
try:
predict_gen_page_wrapper.setup_model(args=args, model_name=model_name)
Expand Down Expand Up @@ -121,17 +139,23 @@ def predict_image(
output_image = outputs[0]["sem_seg"]

predict_gen_page_wrapper.gen_page.generate_single_page(
output_image, output_path, old_height=outputs[1], old_width=outputs[2]
output_image, output_path, old_height=outputs[1],
old_width=outputs[2]
)
images_processed_counter.inc()

logger.info(f"Prediction complete for {image_path}")
return input_args
except Exception as exception:
# Catch CUDA out of memory errors
if isinstance(exception, torch.cuda.OutOfMemoryError) or (
isinstance(exception, RuntimeError) and "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str(
exception)
isinstance(exception, RuntimeError)
and "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str(exception)
):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# HACK remove traceback to prevent complete halt of program, not
# sure why this happens
exception = exception.with_traceback(None)

return input_args | {"exception": exception}
11 changes: 11 additions & 0 deletions api/setup/environment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# Imports

# > Standard Libraries
import logging
import os
from pathlib import Path

# > External Libraries
from utils.logging_utils import get_logger_name


def read_environment_variables():
logger = logging.getLogger(get_logger_name())

try:
max_queue_size_string: str = os.environ["LAYPA_MAX_QUEUE_SIZE"]
model_base_path_string: str = os.environ["LAYPA_MODEL_BASE_PATH"]
Expand All @@ -26,4 +32,9 @@ def read_environment_variables():
raise FileNotFoundError(
f"LAYPA_OUTPUT_BASE_PATH: {output_base_path} is not found in the current filesystem")

logger.debug("Running with the following environment variables:")
logger.debug(f"LAYPA_MAX_QUEUE_SIZE: {max_queue_size}")
logger.debug(f"LAYPA_MODEL_BASE_PATH: {model_base_path}")
logger.debug(f"LAYPA_OUTPUT_BASE_PATH: {output_base_path}")

return max_queue_size, model_base_path, output_base_path

0 comments on commit ce2c6b3

Please sign in to comment.