diff --git a/requirements.txt b/requirements.txt index 333667c..a277378 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ -flask==3.0.2 -gunicorn==22.0.0 numpy==1.26.4 editdistance==0.8.1 tensorflow==2.14.1 @@ -15,3 +13,6 @@ xlsxwriter==3.2.0 six Pillow==10.3.0 h5py==3.10.0 +fastapi==0.111.0 +uvicorn==0.30.1 +typing-extensions==4.12.2 \ No newline at end of file diff --git a/src/api/app.py b/src/api/app.py index f18a324..6512187 100644 --- a/src/api/app.py +++ b/src/api/app.py @@ -1,79 +1,142 @@ # Imports # > Standard library +import asyncio +from contextlib import asynccontextmanager +import socket +import multiprocessing as mp + +# > Third-party dependencies +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from uvicorn.config import Config +from uvicorn.server import Server # > Local dependencies -import errors -from routes import main -from app_utils import setup_logging, get_env_variable, start_workers -from simple_security import SimpleSecurity +from app_utils import (setup_logging, get_env_variable, + start_workers, stop_workers) +from routes import create_router -# > Third-party dependencies -from flask import Flask +# Set up logging +logging_level = get_env_variable("LOGGING_LEVEL", "INFO") +logger = setup_logging(logging_level) +# Get Loghi-HTR options from environment variables +logger.info("Getting Loghi-HTR options from environment variables") +batch_size = int(get_env_variable("LOGHI_BATCH_SIZE", "256")) +model_path = get_env_variable("LOGHI_MODEL_PATH") +output_path = get_env_variable("LOGHI_OUTPUT_PATH") +max_queue_size = int(get_env_variable("LOGHI_MAX_QUEUE_SIZE", "10000")) +patience = float(get_env_variable("LOGHI_PATIENCE", "0.5")) -def create_app() -> Flask: - """ - Create and configure a Flask app for image prediction. +# Get GPU options from environment variables +logger.info("Getting GPU options from environment variables") +gpus = get_env_variable("LOGHI_GPUS", "0") - This function initializes a Flask app, sets up necessary configurations, - starts image preparation and batch prediction processes, and returns the - configured app instance. - Returns - ------- - Flask - Configured Flask app instance ready for serving. +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Manage the lifespan of the FastAPI application. + + Parameters + ---------- + app : FastAPI + The FastAPI application instance. - Side Effects - ------------ - - Initializes and starts preparation, prediction, and decoding processes. - - Logs various messages regarding the app and process initialization. + Yields + ------ + None """ + # Create a stop event + stop_event = mp.Event() - # Set up logging - logging_level = get_env_variable("LOGGING_LEVEL", "INFO") - logger = setup_logging(logging_level) - - # Get Loghi-HTR options from environment variables - logger.info("Getting Loghi-HTR options from environment variables") - batch_size = int(get_env_variable("LOGHI_BATCH_SIZE", "256")) - model_path = get_env_variable("LOGHI_MODEL_PATH") - output_path = get_env_variable("LOGHI_OUTPUT_PATH") - max_queue_size = int(get_env_variable("LOGHI_MAX_QUEUE_SIZE", "10000")) - patience = float(get_env_variable("LOGHI_PATIENCE", "0.5")) - - # Get GPU options from environment variables - logger.info("Getting GPU options from environment variables") - gpus = get_env_variable("LOGHI_GPUS", "0") - - # Create Flask app - logger.info("Creating Flask app") - app = Flask(__name__) - - # Register error handler - app.register_error_handler(ValueError, errors.handle_invalid_usage) - app.register_error_handler(405, errors.method_not_allowed) - - # Add security to app - security_config = \ - {"enabled": get_env_variable("SECURITY_ENABLED", "False"), - "key_user_json": get_env_variable("API_KEY_USER_JSON_STRING", "{}")} - security = SimpleSecurity(app, security_config) - logger.info(f"Security enabled: {security.enabled}") - - # Start the worker processes + # Startup: Start the worker processes logger.info("Starting worker processes") workers, queues = start_workers(batch_size, max_queue_size, output_path, - gpus, model_path, patience) + gpus, model_path, patience, stop_event) + # Add request queue and stop event to the app + app.state.request_queue = queues["Request"] + app.state.stop_event = stop_event + app.state.workers = workers + + yield - # Add request queue to the app - app.request_queue = queues["Request"] + # Shutdown: Stop all workers and join them + logger.info("Shutting down worker processes") + stop_workers(app.state.workers, app.state.stop_event) + logger.info("All workers have been stopped and joined") - # Add the workers to the app - app.workers = workers - # Register blueprints - app.register_blueprint(main) +def create_app() -> FastAPI: + """ + Create and configure the FastAPI application. + + Returns + ------- + FastAPI + The configured FastAPI application instance. + """ + app = FastAPI( + title="Loghi-HTR API", + description="API for Loghi-HTR", + lifespan=lifespan + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allows all origins + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers + ) + + # Include the router + router = create_router(app) + app.include_router(router) return app + + +app = create_app() + + +async def run_server(): + """ + Run the FastAPI server. + + Returns + ------- + None + """ + host = get_env_variable("UVICORN_HOST", "127.0.0.1") + port = int(get_env_variable("UVICORN_PORT", "5000")) + + # Attempt to resolve the hostname + try: + socket.gethostbyname(host) + except socket.gaierror: + logger.error( + f"Unable to resolve hostname: {host}. Falling back to localhost.") + host = "127.0.0.1" + + config = Config("app:app", host=host, port=port, workers=1) + server = Server(config=config) + + try: + await server.serve() + except OSError as e: + logger.error(f"Error starting server: {e}") + if e.errno == 98: # Address already in use + logger.error( + f"Port {port} is already in use. Try a different port.") + elif e.errno == 13: # Permission denied + logger.error( + f"Permission denied when trying to bind to port {port}. Try a " + "port number > 1024 or run with sudo.") + except Exception as e: + logger.error(f"Unexpected error occurred: {e}") + +if __name__ == "__main__": + asyncio.run(run_server()) diff --git a/src/api/app_utils.py b/src/api/app_utils.py index e165e0b..be2bc5c 100644 --- a/src/api/app_utils.py +++ b/src/api/app_utils.py @@ -1,10 +1,11 @@ # Imports # > Standard library +from typing import List, Optional, Dict +from fastapi import UploadFile, Form, File, HTTPException import logging import multiprocessing as mp import os -from typing import Tuple # > Local dependencies from batch_predictor import batch_prediction_worker @@ -12,7 +13,6 @@ from batch_decoder import batch_decoding_worker # > Third-party dependencies -from flask import request from prometheus_client import Gauge @@ -71,13 +71,32 @@ def setup_logging(level: str = "INFO") -> logging.Logger: return logging.getLogger(__name__) -def extract_request_data() -> Tuple[bytes, str, str, str, list]: +async def extract_request_data( + image: UploadFile = File(...), + group_id: str = Form(...), + identifier: str = Form(...), + model: Optional[str] = Form(None), + whitelist: List[str] = Form([]) +) -> tuple[bytes, str, str, str, list]: """ Extract image and other form data from the current request. + Parameters + ---------- + image : UploadFile + The uploaded image file. + group_id : str + ID of the group from form data. + identifier : str + Identifier from form data. + model : Optional[str] + Location of the model to use for prediction. + whitelist : List[str] + List of classes to whitelist for output. + Returns ------- - tuple of (bytes, str, str, str) + tuple of (bytes, str, str, str, list) image_content : bytes Content of the uploaded image. group_id : str @@ -91,45 +110,32 @@ def extract_request_data() -> Tuple[bytes, str, str, str, list]: Raises ------ - ValueError - If required data (image, group_id, identifier) is missing or if - the image format is invalid. + HTTPException + If required data is missing or if the image format is invalid. """ - - # Extract the uploaded image - image_file = request.files.get('image') - if not image_file: - raise ValueError("No image provided.") - # Validate image format allowed_extensions = {'png', 'jpg', 'jpeg', 'gif'} - if '.' not in image_file.filename or image_file.filename.rsplit('.', 1)[1]\ - .lower() not in allowed_extensions: - raise ValueError( - "Invalid image format. Allowed formats: png, jpg, jpeg, gif") - - image_content = image_file.read() - - # Check if the image content is empty or None - if image_content is None or len(image_content) == 0: - raise ValueError( - "The uploaded image is empty. Please upload a valid image file.") - - # Extract other form data - group_id = request.form.get('group_id') - if not group_id: - raise ValueError("No group_id provided.") - - identifier = request.form.get('identifier') - if not identifier: - raise ValueError("No identifier provided.") - - model = request.form.get('model') - if model: - if not os.path.exists(model): - raise ValueError(f"Model directory {model} does not exist.") - - whitelist = request.form.getlist('whitelist') + file_extension = image.filename.split('.')[-1].lower() + if file_extension not in allowed_extensions: + raise HTTPException( + status_code=400, + detail="Invalid image format. Allowed formats: " + f"{', '.join(allowed_extensions)}") + + # Read image content + image_content = await image.read() + + # Check if the image content is empty + if not image_content: + raise HTTPException( + status_code=400, + detail="The uploaded image is empty. Please upload a valid image " + "file.") + + # Validate model path if provided + if model and not os.path.exists(model): + raise HTTPException( + status_code=400, detail=f"Model directory {model} does not exist.") return image_content, group_id, identifier, model, whitelist @@ -177,7 +183,7 @@ def get_env_variable(var_name: str, default_value: str = None) -> str: def start_workers(batch_size: int, max_queue_size: int, output_path: str, gpus: str, model_path: str, - patience: int): + patience: int, stop_event: mp.Event): """ Initializes and starts multiple multiprocessing workers for image processing and prediction. @@ -246,7 +252,7 @@ def start_workers(batch_size: int, max_queue_size: int, target=image_preparation_worker, args=(batch_size, request_queue, prepared_queue, model_path, - patience), + patience, stop_event), name="Image Preparation Process", daemon=True) preparation_process.start() @@ -256,7 +262,7 @@ def start_workers(batch_size: int, max_queue_size: int, prediction_process = mp.Process( target=batch_prediction_worker, args=(prepared_queue, predicted_queue, - output_path, model_path, gpus), + output_path, model_path, stop_event, gpus), name="Batch Prediction Process", daemon=True) prediction_process.start() @@ -265,7 +271,8 @@ def start_workers(batch_size: int, max_queue_size: int, logger.info("Starting batch decoding process") decoding_process = mp.Process( target=batch_decoding_worker, - args=(predicted_queue, model_path, output_path), + args=(predicted_queue, model_path, output_path, + stop_event), name="Batch Decoding Process", daemon=True) decoding_process.start() @@ -283,3 +290,24 @@ def start_workers(batch_size: int, max_queue_size: int, } return workers, queues + + +def stop_workers(workers: Dict[str, mp.Process], stop_event: mp.Event): + """ + Stop all worker processes gracefully. + + Parameters + ---------- + workers : Dict[str, mp.Process] + A dictionary of worker processes with worker names as keys. + stop_event : mp.Event + An event to signal workers to stop. + """ + # Signal all workers to stop + stop_event.set() + + # Wait for all workers to finish + for worker in workers.values(): + logger = logging.getLogger(__name__) + logger.info("Waiting for worker process %s to finish", worker.name) + worker.join() diff --git a/src/api/batch_decoder.py b/src/api/batch_decoder.py index a0a85b8..b6b2db3 100644 --- a/src/api/batch_decoder.py +++ b/src/api/batch_decoder.py @@ -23,7 +23,8 @@ def batch_decoding_worker(predicted_queue: multiprocessing.Queue, model_path: str, - output_path: str) -> None: + output_path: str, + stop_event: multiprocessing.Event) -> None: """ Worker function for batch decoding process. @@ -48,9 +49,12 @@ def batch_decoding_worker(predicted_queue: multiprocessing.Queue, total_outputs = 0 try: - while True: - encoded_predictions, batch_groups, batch_identifiers, model, \ - batch_id, batch_metadata = predicted_queue.get() + while not stop_event.is_set(): + try: + encoded_predictions, batch_groups, batch_identifiers, model, \ + batch_id, batch_metadata = predicted_queue.get(timeout=0.1) + except multiprocessing.queues.Empty: + continue # Re-initialize utilities if model has changed if model != model_path: diff --git a/src/api/batch_predictor.py b/src/api/batch_predictor.py index 08b2d02..36d8072 100644 --- a/src/api/batch_predictor.py +++ b/src/api/batch_predictor.py @@ -109,6 +109,7 @@ def batch_prediction_worker(prepared_queue: multiprocessing.Queue, predicted_queue: multiprocessing.Queue, output_path: str, model_path: str, + stop_event: multiprocessing.Event, gpus: str = '0'): """ Worker process for batch prediction on images. @@ -126,6 +127,8 @@ def batch_prediction_worker(prepared_queue: multiprocessing.Queue, Path where predictions should be saved. model_path : str Path to the initial model file. + stop_event : multiprocessing.Event + Event to signal the worker to stop processing. gpus : str, optional IDs of GPUs to be used (comma-separated). Default is '0'. @@ -149,8 +152,11 @@ def batch_prediction_worker(prepared_queue: multiprocessing.Queue, old_model_path = model_path try: - while True: - batch_data = prepared_queue.get() + while not stop_event.is_set(): + try: + batch_data = prepared_queue.get(timeout=0.1) + except multiprocessing.queues.Empty: + continue model_path = batch_data[4] batch_id = batch_data[5] logging.debug("Received batch %s from prepared_queue", batch_id) diff --git a/src/api/image_preparator.py b/src/api/image_preparator.py index ddda421..659682f 100644 --- a/src/api/image_preparator.py +++ b/src/api/image_preparator.py @@ -20,7 +20,8 @@ def image_preparation_worker(batch_size: int, request_queue: multiprocessing.Queue, prepared_queue: multiprocessing.Queue, model_path: str, - patience: float): + patience: float, + stop_event: multiprocessing.Event): """ Worker process to prepare images for batch processing. @@ -60,11 +61,12 @@ def image_preparation_worker(batch_size: int, metadata, whitelist = {}, [] try: - while True: + while not stop_event.is_set(): num_channels, model, metadata, whitelist = \ fetch_and_prepare_images(request_queue, prepared_queue, batch_size, patience, num_channels, - model, metadata, whitelist) + model, metadata, whitelist, + stop_event) except Exception as e: logging.error("Exception in image preparation worker: %s", e) @@ -207,7 +209,8 @@ def fetch_and_prepare_images(request_queue: multiprocessing.Queue, num_channels: int, current_model: str, metadata: dict, - old_whitelist: list) -> (int, str): + old_whitelist: list, + stop_event: multiprocessing.Event) -> (int, str): """ Fetches and prepares images for processing. We pass the current model, the current number of channels, the current metadata, and the current whitelist @@ -232,6 +235,8 @@ def fetch_and_prepare_images(request_queue: multiprocessing.Queue, Metadata for the images. old_whitelist : list Whitelist for the images. + stop_event : multiprocessing.Event + Event to signal the worker process to stop. Returns ------- @@ -245,7 +250,7 @@ def fetch_and_prepare_images(request_queue: multiprocessing.Queue, batch_images, batch_groups, batch_identifiers, batch_metadata \ = [], [], [], [] - while True: + while not stop_event.is_set(): try: image, group, identifier, new_model, whitelist = \ request_queue.get(timeout=0.1) @@ -302,6 +307,9 @@ def fetch_and_prepare_images(request_queue: multiprocessing.Queue, if last_image_time is not None and \ (time.time() - last_image_time) >= patience: break + else: + # If the stop event is set, break the loop + return num_channels, current_model, metadata, old_whitelist # Pad and queue the batch pad_and_queue_batch(current_model, batch_images, batch_groups, diff --git a/src/api/routes.py b/src/api/routes.py index 69e6ea4..897c122 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -2,171 +2,157 @@ # > Standard library import datetime -import logging +from typing import List from multiprocessing.queues import Full +# > Third-party dependencies +from fastapi import APIRouter, HTTPException, File, UploadFile, Form, FastAPI +from fastapi.responses import JSONResponse, Response +from prometheus_client import generate_latest, CONTENT_TYPE_LATEST + # > Local dependencies from app_utils import extract_request_data -from simple_security import session_key_required - -# > Third party dependencies -import flask -from flask import Blueprint, jsonify, current_app as app -from prometheus_client import generate_latest - -logger = logging.getLogger(__name__) -main = Blueprint('main', __name__) - -@main.route('/predict', methods=['POST']) -@session_key_required -def predict() -> flask.Response: +def create_router(app: FastAPI) -> APIRouter: """ - Endpoint to receive image data and queue it for prediction. - - Receives a POST request containing an image, group_id, and identifier. - The data is then queued for further processing and prediction. + Create an API router with endpoints for prediction, health check, and + readiness check. - Expected POST data - ------------------ - image : file - The image file to be processed. - group_id : str - The group ID associated with the image. - identifier : str - An identifier for the image. + Parameters + ---------- + app : FastAPI + The FastAPI application instance. Returns ------- - Flask.Response - A JSON response containing a status message, timestamp, group_id, - and identifier. The HTTP status code is 202 (Accepted). - - Side Effects - ------------ - - Logs debug messages regarding the received data and queuing status. - - Adds the received data to the global request queue. - """ - - # Add incoming request to queue - # Here, we're just queuing the raw data. - try: - image_file, group_id, identifier, model, whitelist = extract_request_data() - except ValueError as e: - response = jsonify({ - "status": "error", - "code": 400, - "message": str(e), - "timestamp": datetime.datetime.now().isoformat() - }) - - response.status_code = 400 - logger.error("Error processing request: %s", str(e)) - return response - - logger.debug("Data received: %s, %s", group_id, identifier) - logger.debug("Adding %s to queue", identifier) - logger.debug("Using model %s", model) - logger.debug("Using whitelist %s", whitelist) - - try: - app.request_queue.put((image_file, group_id, identifier, - model, whitelist), - block=False) - except Full: - response = jsonify({ - "status": "error", - "code": 429, - "message": "The server is currently processing a high volume of " - "requests. Please try again later.", - "timestamp": datetime.datetime.now().isoformat(), - "group_id": group_id, - "identifier": identifier, - }) - - response.status_code = 429 - - return response - - response = jsonify({ - "status": "Request received", - "code": 202, - "message": "Your request is being processed", - "timestamp": datetime.datetime.now().isoformat(), - "group_id": group_id, - "identifier": identifier, - }) - - response.status_code = 202 - - return response - - -@main.route("/prometheus", methods=["GET"]) -@session_key_required -def prometheus() -> bytes: - """ - Endpoint for getting prometheus statistics - """ - return generate_latest() - - -@main.route("/health", methods=["GET"]) -@session_key_required -def health() -> flask.Response: - """ - Endpoint for getting health status + APIRouter + The API router with the defined endpoints. """ - - for name, worker in app.workers.items(): - if not worker.is_alive(): - logger.error("%s worker is not alive", name) - response = jsonify({ - "status": "unhealthy", - "code": 500, - "message": f"{name} worker is not alive", + router = APIRouter() + + @router.post("/predict") + async def predict( + image: UploadFile = File(...), + group_id: str = Form(...), + identifier: str = Form(...), + model: str = Form(None), + whitelist: List[str] = Form([]), + ): + """ + Handle image prediction requests. + + Parameters + ---------- + image : UploadFile + The image file to be processed. + group_id : str + The group identifier. + identifier : str + The request identifier. + model : str, optional + The model to be used for prediction (default is None). + whitelist : List[str], optional + A list of whitelisted items (default is an empty list). + + Returns + ------- + JSONResponse + A JSON response indicating the status of the request. + """ + try: + data = await extract_request_data( + image, group_id, identifier, model, whitelist) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + try: + app.state.request_queue.put(data, block=False) + except Full: + raise HTTPException( + status_code=429, + detail="The server is currently processing a high volume of " + "requests. Please try again later." + ) + + return JSONResponse( + status_code=202, + content={ + "status": "Request received", + "code": 202, + "message": "Your request is being processed", + "timestamp": datetime.datetime.now().isoformat(), + "group_id": group_id, + "identifier": identifier + } + ) + + @router.get("/health") + async def health(): + """ + Check the health of the application workers. + + Returns + ------- + JSONResponse + A JSON response indicating the health status of the application. + """ + for name, worker in app.state.workers.items(): + if not worker.is_alive(): + return JSONResponse( + status_code=500, + content={ + "status": "unhealthy", + "code": 500, + "message": f"{name} worker is not alive", + "timestamp": datetime.datetime.now().isoformat() + } + ) + + return JSONResponse( + status_code=200, + content={ + "status": "healthy", + "code": 200, + "message": "All workers are alive", "timestamp": datetime.datetime.now().isoformat() - }) - response.status_code = 500 - - return response - - response = jsonify({ - "status": "healthy", - "code": 200, - "message": "All workers are alive", - "timestamp": datetime.datetime.now().isoformat() - }) - response.status_code = 200 - - return response - + } + ) + + @router.get("/ready") + async def ready(): + """ + Check if the request queue is ready to accept new requests. + + Returns + ------- + JSONResponse + A JSON response indicating the readiness status of the request queue. + """ + if app.state.request_queue.full(): + return JSONResponse( + status_code=503, + content={ + "status": "unready", + "code": 503, + "message": "Request queue is full", + "timestamp": datetime.datetime.now().isoformat() + } + ) + + return JSONResponse( + status_code=200, + content={ + "status": "ready", + "code": 200, + "message": "Request queue is not full", + "timestamp": datetime.datetime.now().isoformat() + } + ) -@main.route("/ready", methods=["GET"]) -@session_key_required -def ready() -> flask.Response: - """ - Endpoint for getting readiness status - """ + @router.get("/prometheus") + async def prometheus(): + metrics = generate_latest() + return Response(content=metrics, media_type=CONTENT_TYPE_LATEST) - if app.request_queue.full(): - response = jsonify({ - "status": "unready", - "code": 503, - "message": "Request queue is full", - "timestamp": datetime.datetime.now().isoformat() - }) - response.status_code = 503 - - return response - - response = jsonify({ - "status": "ready", - "code": 200, - "message": "Request queue is not full", - "timestamp": datetime.datetime.now().isoformat() - }) - response.status_code = 200 - - return response + return router diff --git a/src/api/simple_security.py b/src/api/simple_security.py index 0e454b2..4e3bc11 100644 --- a/src/api/simple_security.py +++ b/src/api/simple_security.py @@ -47,7 +47,7 @@ def __init__(self, app: Flask, config: dict): if not app or not config: raise ValueError("App and config must be provided") - app.extensions["security"] = self + # app.extensions["security"] = self self.app = app self.config = config self.enabled = self._security_enabled(config.get("enabled", "false")) diff --git a/src/api/start_local_app.sh b/src/api/start_local_app.sh index 3b1eb0d..9b72049 100755 --- a/src/api/start_local_app.sh +++ b/src/api/start_local_app.sh @@ -1,6 +1,3 @@ -export GUNICORN_RUN_HOST='0.0.0.0:5000' -export GUNICORN_ACCESSLOG='-' - export LOGHI_BATCH_SIZE=300 export LOGHI_MODEL_PATH="/home/tim/Downloads/new_model" export LOGHI_OUTPUT_PATH="/home/tim/Documents/development/loghi-htr/output/" @@ -10,5 +7,4 @@ export LOGHI_PATIENCE=0.5 export LOGGING_LEVEL="INFO" export LOGHI_GPUS="0" -gunicorn -w 1 -b $GUNICORN_RUN_HOST \ - --access-logfile $GUNICORN_ACCESSLOG 'app:create_app()' +uvicorn app:app --host 0.0.0.0 --port 5000