diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index cf064d63..f6640176 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -15,16 +15,16 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - + - name: Set up Python uses: actions/setup-python@v4 with: python-version: '3.10' - + - name: Install dependencies run: | pip install ruff - + - name: Run Ruff check run: | - ruff check . --exclude examples/ + ruff check . --exclude examples/ --line-length 120 diff --git a/pyproject.toml b/pyproject.toml index 0c66e0bc..20d4fa16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,4 +58,4 @@ markers = [ [tool.ruff] line-length = 120 -target-version = "py310" \ No newline at end of file +target-version = "py310" diff --git a/requirements.txt b/requirements.txt index 8569b10f..ad593132 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ better_profanity PyYAML fastapi[all] uvicorn -jinja2 \ No newline at end of file +jinja2 +pytest-asyncio==1.1.0 diff --git a/src/kernelbot/api/api_utils.py b/src/kernelbot/api/api_utils.py index b5cda480..4082108e 100644 --- a/src/kernelbot/api/api_utils.py +++ b/src/kernelbot/api/api_utils.py @@ -1,9 +1,12 @@ +from typing import Any + import requests -from fastapi import HTTPException +from fastapi import HTTPException, UploadFile from kernelbot.env import env from libkernelbot.backend import KernelBackend from libkernelbot.consts import SubmissionMode +from libkernelbot.leaderboard_db import LeaderboardDB from libkernelbot.report import ( Log, MultiProgressReporter, @@ -11,7 +14,10 @@ RunResultReport, Text, ) -from libkernelbot.submission import SubmissionRequest, prepare_submission +from libkernelbot.submission import ( + SubmissionRequest, + prepare_submission, +) async def _handle_discord_oauth(code: str, redirect_uri: str) -> tuple[str, str]: @@ -183,3 +189,103 @@ async def display_report(self, title: str, report: RunResultReport): elif isinstance(part, Log): self.long_report += f"\n\n## {part.header}:\n" self.long_report += f"```\n{part.content}```" +# ruff: noqa: C901 +async def to_submit_info( + user_info: Any, + submission_mode: str, + file: UploadFile, + leaderboard_name: str, + gpu_type: str, + db_context: LeaderboardDB, +) -> tuple[SubmissionRequest, SubmissionMode]: # noqa: C901 + user_name = user_info["user_name"] + user_id = user_info["user_id"] + + try: + submission_mode_enum: SubmissionMode = SubmissionMode( + submission_mode.lower() + ) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Invalid submission mode value: '{submission_mode}'", + ) from None + + if submission_mode_enum in [SubmissionMode.PROFILE]: + raise HTTPException( + status_code=400, + detail="Profile submissions are not currently supported via API", + ) + + allowed_modes = [ + SubmissionMode.TEST, + SubmissionMode.BENCHMARK, + SubmissionMode.LEADERBOARD, + ] + if submission_mode_enum not in allowed_modes: + raise HTTPException( + status_code=400, + detail=f"Submission mode '{submission_mode}' is not supported for this endpoint", + ) + + try: + with db_context as db: + leaderboard_item = db.get_leaderboard(leaderboard_name) + gpus = leaderboard_item.get("gpu_types", []) + if gpu_type not in gpus: + supported_gpus = ", ".join(gpus) if gpus else "None" + raise HTTPException( + status_code=400, + detail=f"GPU type '{gpu_type}' is not supported for " + f"leaderboard '{leaderboard_name}'. Supported GPUs: {supported_gpus}", + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Internal server error while validating leaderboard/GPU: {e}", + ) from e + + try: + submission_content = await file.read() + if not submission_content: + raise HTTPException( + status_code=400, + detail="Empty file submitted. Please provide a file with code.", + ) + if len(submission_content) > 1_000_000: + raise HTTPException( + status_code=413, + detail="Submission file is too large (limit: 1MB).", + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Error reading submission file: {e}" + ) from e + + try: + submission_code = submission_content.decode("utf-8") + submission_request = SubmissionRequest( + code=submission_code, + file_name=file.filename or "submission.py", + user_id=user_id, + user_name=user_name, + gpus=[gpu_type], + leaderboard=leaderboard_name, + ) + except UnicodeDecodeError: + raise HTTPException( + status_code=400, + detail="Failed to decode submission file content as UTF-8.", + ) from None + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Internal server error creating submission request: {e}", + ) from e + + return submission_request, submission_mode_enum diff --git a/src/kernelbot/api/main.py b/src/kernelbot/api/main.py index 2848d383..d9d0ae9b 100644 --- a/src/kernelbot/api/main.py +++ b/src/kernelbot/api/main.py @@ -5,25 +5,37 @@ import os import time from dataclasses import asdict -from typing import Annotated, Optional +from typing import Annotated, Any, Optional from fastapi import Depends, FastAPI, Header, HTTPException, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse from libkernelbot.backend import KernelBackend +from libkernelbot.background_submission_manager import BackgroundSubmissionManager from libkernelbot.consts import SubmissionMode -from libkernelbot.leaderboard_db import LeaderboardRankedEntry -from libkernelbot.submission import SubmissionRequest -from libkernelbot.utils import KernelBotError - -from .api_utils import _handle_discord_oauth, _handle_github_oauth, _run_submission +from libkernelbot.db_types import IdentityType +from libkernelbot.leaderboard_db import LeaderboardDB, LeaderboardRankedEntry +from libkernelbot.submission import ( + ProcessedSubmissionRequest, + SubmissionRequest, + prepare_submission, +) +from libkernelbot.utils import KernelBotError, setup_logging + +from .api_utils import ( + _handle_discord_oauth, + _handle_github_oauth, + _run_submission, + to_submit_info, +) + +logger = setup_logging(__name__) # yes, we do want ... = Depends() in function signatures # ruff: noqa: B008 app = FastAPI() - def json_serializer(obj): """JSON serializer for objects not serializable by default json code""" if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): @@ -32,6 +44,7 @@ def json_serializer(obj): backend_instance: KernelBackend = None +background_submission_manager: BackgroundSubmissionManager = None _last_action = time.time() _submit_limiter = asyncio.Semaphore(3) @@ -61,6 +74,12 @@ def init_api(_backend_instance: KernelBackend): backend_instance = _backend_instance +def init_background_submission_manager(_manager: BackgroundSubmissionManager): + global background_submission_manager + background_submission_manager = _manager + return background_submission_manager + + @app.exception_handler(KernelBotError) async def kernel_bot_error_handler(req: Request, exc: KernelBotError): return JSONResponse(status_code=exc.http_code, content={"message": str(exc)}) @@ -102,6 +121,50 @@ async def validate_cli_header( return user_info +async def validate_user_header( + x_web_auth_id: Optional[str] = Header(None, alias="X-Web-Auth-Id"), + x_popcorn_cli_id: Optional[str] = Header(None, alias="X-Popcorn-Cli-Id"), + db_context: LeaderboardDB = Depends(get_db), +) -> Any: + """ + Validate either X-Web-Auth-Id or X-Popcorn-Cli-Id and return the associated user id. + Prefers X-Web-Auth-Id if both are provided. + """ + token = x_web_auth_id or x_popcorn_cli_id + if not token: + raise HTTPException( + status_code=400, + detail="Missing X-Web-Auth-Id or X-Popcorn-Cli-Id header", + ) + + if x_web_auth_id: + token = x_web_auth_id + id_type = IdentityType.WEB + elif x_popcorn_cli_id: + token = x_popcorn_cli_id + id_type = IdentityType.CLI + else: + raise HTTPException( + status_code=400, + detail="Missing header must be eother X-Web-Auth-Id or X-Popcorn-Cli-Id header", + ) + try: + with db_context as db: + user_info = db.validate_identity(token, id_type) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Database error during validation: {e}", + ) from e + + if not user_info: + raise HTTPException( + status_code=401, + detail="Invalid or unauthorized auth header elaine", + ) + return user_info + + @app.get("/auth/init") async def auth_init(provider: str, db_context=Depends(get_db)) -> dict: if provider not in ["discord", "github"]: @@ -192,14 +255,10 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends raise e except Exception as e: # Catch unexpected errors during OAuth handling - raise HTTPException( - status_code=500, detail=f"Error during {auth_provider} OAuth flow: {e}" - ) from e + raise HTTPException(status_code=500, detail=f"Error during {auth_provider} OAuth flow: {e}") from e if not user_id or not user_name: - raise HTTPException( - status_code=500, detail="Failed to retrieve user ID or username from provider." - ) + raise HTTPException(status_code=500,detail="Failed to retrieve user ID or username from provider.",) try: with db_context as db: @@ -209,9 +268,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends db.create_user_from_cli(user_id, user_name, cli_id, auth_provider) except AttributeError as e: - raise HTTPException( - status_code=500, detail=f"Database interface error during update: {e}" - ) from e + raise HTTPException(status_code=500, detail=f"Database interface error during update: {e}") from e except Exception as e: raise HTTPException(status_code=400, detail=f"Database update failed: {e}") from e @@ -223,7 +280,6 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends "is_reset": is_reset, } - async def _stream_submission_response( submission_request: SubmissionRequest, submission_mode_enum: SubmissionMode, @@ -287,7 +343,6 @@ async def _stream_submission_response( except asyncio.CancelledError: pass - @app.post("/{leaderboard_name}/{gpu_type}/{submission_mode}") async def run_submission( # noqa: C901 leaderboard_name: str, @@ -316,94 +371,104 @@ async def run_submission( # noqa: C901 StreamingResponse: A streaming response containing the status and results of the submission. """ await simple_rate_limit() - user_name = user_info["user_name"] - user_id = user_info["user_id"] + submission_request, submission_mode_enum = await to_submit_info( + user_info, submission_mode, file, leaderboard_name, gpu_type, db_context + ) + generator = _stream_submission_response( + submission_request=submission_request, + submission_mode_enum=submission_mode_enum, + backend=backend_instance, + ) + return StreamingResponse(generator, media_type="text/event-stream") - try: - submission_mode_enum: SubmissionMode = SubmissionMode(submission_mode.lower()) - except ValueError: - raise HTTPException( - status_code=400, detail=f"Invalid submission mode value: '{submission_mode}'" - ) from None +async def enqueue_background_job( + req: ProcessedSubmissionRequest, + mode: SubmissionMode, + backend: KernelBackend, + manager: BackgroundSubmissionManager, +): - if submission_mode_enum in [SubmissionMode.PROFILE]: - raise HTTPException( - status_code=400, detail="Profile submissions are not currently supported via API" + # pre-create the submission for api returns + with backend.db as db: + sub_id = db.create_submission( + leaderboard=req.leaderboard, + file_name=req.file_name, + code=req.code, + user_id=req.user_id, + time=datetime.datetime.now(), + user_name=req.user_name, ) + job_id = db.upsert_submission_job_status(sub_id, "initial", None) + # put submission request in queue + await manager.enqueue(req, mode, sub_id) + return sub_id,job_id - allowed_modes = [ - SubmissionMode.TEST, - SubmissionMode.BENCHMARK, - SubmissionMode.LEADERBOARD, - ] - if submission_mode_enum not in allowed_modes: - raise HTTPException( - status_code=400, - detail=f"Submission mode '{submission_mode}' is not supported for this endpoint", - ) +@app.post("/submission/{leaderboard_name}/{gpu_type}/{submission_mode}") +async def run_submission_async( + leaderboard_name: str, + gpu_type: str, + submission_mode: str, + file: UploadFile, + user_info: Annotated[dict, Depends(validate_user_header)], + db_context=Depends(get_db), +) -> Any: + """An endpoint that runs a submission on a given leaderboard, runner, and GPU type. - try: - with db_context as db: - leaderboard_item = db.get_leaderboard(leaderboard_name) - gpus = leaderboard_item.get("gpu_types", []) - if gpu_type not in gpus: - supported_gpus = ", ".join(gpus) if gpus else "None" - raise HTTPException( - status_code=400, - detail=f"GPU type '{gpu_type}' is not supported for " - f"leaderboard '{leaderboard_name}'. Supported GPUs: {supported_gpus}", - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Internal server error while validating leaderboard/GPU: {e}", - ) from e + Requires a valid X-Popcorn-Cli-Id or X-Web-Auth-Id header. + Args: + leaderboard_name (str): The name of the leaderboard to run the submission on. + gpu_type (str): The type of GPU to run the submission on. + submission_mode (str): The mode for the submission (test, benchmark, etc.). + file (UploadFile): The file to run the submission on. + user_id (str): The validated user ID obtained from the X-Popcorn-Cli-Id header. + Raises: + HTTPException: If the kernelbot is not initialized, or header/input is invalid. + Returns: + JSONResponse: A JSON response containing job_id and and submission_id for the client to poll for status. + """ try: - submission_content = await file.read() - if not submission_content: - raise HTTPException( - status_code=400, detail="Empty file submitted. Please provide a file with code." - ) - if len(submission_content) > 1_000_000: - raise HTTPException( - status_code=413, detail="Submission file is too large (limit: 1MB)." + + await simple_rate_limit() + logger.info(f"Received submission request for {leaderboard_name} {gpu_type} {submission_mode}") + + + # throw error if submission request is invalid + try: + submission_request, submission_mode_enum = await to_submit_info( + user_info, submission_mode, file, leaderboard_name, gpu_type, db_context ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=400, detail=f"Error reading submission file: {e}") from e + req = prepare_submission(submission_request, backend_instance) - try: - submission_code = submission_content.decode("utf-8") - submission_request = SubmissionRequest( - code=submission_code, - file_name=file.filename or "submission.py", - user_id=user_id, - user_name=user_name, - gpus=[gpu_type], - leaderboard=leaderboard_name, - ) - except UnicodeDecodeError: - raise HTTPException( - status_code=400, detail="Failed to decode submission file content as UTF-8." - ) from None - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Internal server error creating submission request: {e}" - ) from e + except Exception as e: + raise HTTPException(status_code=400, detail=f"failed to prepare submission request: {str(e)}") from e - generator = _stream_submission_response( - submission_request=submission_request, - submission_mode_enum=submission_mode_enum, - backend=backend_instance, - ) + # prepare submission request before the submission is started + if not req.gpus or len(req.gpus) != 1: + raise HTTPException(status_code=400, detail="Invalid GPU type") - return StreamingResponse(generator, media_type="text/event-stream") + # put submission request to background manager to run in background + sub_id,job_status_id = await enqueue_background_job( + req, submission_mode_enum, backend_instance, background_submission_manager + ) + + return JSONResponse( + status_code=202, + content={"details":{"id": sub_id, "job_status_id": job_status_id}, "status": "accepted"}, + ) + # Preserve FastAPI HTTPException as-is + except HTTPException: + raise + # Your custom sanitized error + except KernelBotError as e: + raise HTTPException(status_code=getattr(e, "http_code", 400), detail=str(e)) from e + # All other unexpected errors → 500 + except Exception as e: + # logger.exception("Unexpected error in run_submission_v2") + logger.error(f"Unexpected error in api submissoin: {e}") + raise HTTPException(status_code=500, detail="Internal server error") from e @app.get("/leaderboards") async def get_leaderboards(db_context=Depends(get_db)): diff --git a/src/kernelbot/main.py b/src/kernelbot/main.py index 20e40248..e0411096 100644 --- a/src/kernelbot/main.py +++ b/src/kernelbot/main.py @@ -4,7 +4,7 @@ import discord import uvicorn -from api.main import app, init_api +from api.main import app, init_api, init_background_submission_manager from cogs.admin_cog import AdminCog from cogs.leaderboard_cog import LeaderboardCog from cogs.misc_cog import BotManagerCog @@ -15,6 +15,7 @@ from libkernelbot import consts from libkernelbot.backend import KernelBackend +from libkernelbot.background_submission_manager import BackgroundSubmissionManager from libkernelbot.launchers import GitHubLauncher, ModalLauncher from libkernelbot.utils import setup_logging @@ -216,6 +217,9 @@ async def start_bot_and_api(debug_mode: bool): bot_instance = ClusterBot(debug_mode=debug_mode) init_api(bot_instance.backend) + m = init_background_submission_manager(BackgroundSubmissionManager(bot_instance.backend)) + # Start manager queue BEFORE serving requests + await m.start() config = uvicorn.Config( app, @@ -225,10 +229,14 @@ async def start_bot_and_api(debug_mode: bool): limit_concurrency=10, ) server = uvicorn.Server(config) - - # we need this as discord and fastapi both run on the same event loop - await asyncio.gather(bot_instance.start_bot(token), server.serve()) - + try: + await asyncio.gather( + bot_instance.start_bot(token), + server.serve(), + ) + finally: + # graceful shutdown + await m.stop() def on_unhandled_exception(loop, context): logger.exception("Unhandled exception: %s", context["message"], exc_info=context["exception"]) diff --git a/src/libkernelbot/backend.py b/src/libkernelbot/backend.py index ec1b8769..3d014ed1 100644 --- a/src/libkernelbot/backend.py +++ b/src/libkernelbot/backend.py @@ -48,20 +48,26 @@ def register_launcher(self, launcher: Launcher): self.launcher_map[gpu.value] = launcher async def submit_full( - self, req: ProcessedSubmissionRequest, mode: SubmissionMode, reporter: MultiProgressReporter + self, req: ProcessedSubmissionRequest, mode: SubmissionMode, reporter: MultiProgressReporter, + pre_sub_id: Optional[int] = None ): - with self.db as db: - sub_id = db.create_submission( - leaderboard=req.leaderboard, - file_name=req.file_name, - code=req.code, - user_id=req.user_id, - time=datetime.now(), - user_name=req.user_name, - ) + """ + pre_sub_id is used to pass the submission id which is created beforehand. + """ + if pre_sub_id is not None: + sub_id = pre_sub_id + else: + with self.db as db: + sub_id = db.create_submission( + leaderboard=req.leaderboard, + file_name=req.file_name, + code=req.code, + user_id=req.user_id, + time=datetime.now(), + user_name=req.user_name, + ) selected_gpus = [get_gpu_by_name(gpu) for gpu in req.gpus] - try: tasks = [ self.submit_leaderboard( @@ -98,7 +104,6 @@ async def submit_full( finally: with self.db as db: db.mark_submission_done(sub_id) - return sub_id, results async def submit_leaderboard( # noqa: C901 diff --git a/src/libkernelbot/background_submission_manager.py b/src/libkernelbot/background_submission_manager.py new file mode 100644 index 00000000..0f50b9fa --- /dev/null +++ b/src/libkernelbot/background_submission_manager.py @@ -0,0 +1,287 @@ +import asyncio +import contextlib +import datetime as dt +import logging +from dataclasses import dataclass + +from kernelbot.api.api_utils import MultiProgressReporterAPI +from libkernelbot.backend import KernelBackend +from libkernelbot.consts import SubmissionMode +from libkernelbot.submission import ProcessedSubmissionRequest +from libkernelbot.utils import setup_logging + +logger = setup_logging(__name__) + + +@dataclass +class JobItem: + job_id: int + sub_id: int + req: ProcessedSubmissionRequest + mode: SubmissionMode + + +# Periodicaly update the last heartbeat time for the submission job in submission_job_status table +HEARTBEAT_SEC = 15 # heartbeat interval 15s +# HARD_TIMEOUT_SEC [3hours]:if a submission is not completed within the hard timeout, +# it will be marked as failed in submission_job_status table +HARD_TIMEOUT_SEC = 60 * 30 # hard timeout 30 mins + + +class BackgroundSubmissionManager: + """ + This class manages submission in the backeground. It is responsible for + submitting jobs to the backend, monitoring their progress, and updating the + submissoin status in the database. + + It is also responsible for managing the workers, starting and stopping them + as needed. By default, the api can maximum support 24 submission processes, + + Scale up: scale up up to max_workers based on the queue size and the + number of running workers + Scale down: each worker scale down automatically after hitting idle_seconds(hot) + if there is no job in the queue. + """ + + def __init__( + self, + backend: KernelBackend, + min_workers: int = 2, + max_workers: int = 24, + idle_seconds: int = 120, + ): + self.backend = backend + self.queue: asyncio.Queue[JobItem] = asyncio.Queue() + self._workers: list[asyncio.Task] = [] # workers currently running + self._live_tasks: set[ + asyncio.Task + ] = set() # tasks currently processing + self.idle_seconds = ( + idle_seconds # idle_seconds for each worker before scale down + ) + # state variables + + self._state_lock = asyncio.Lock() + self._accepting: bool = False + self.min_workers = min_workers + self.max_workers = max_workers + + async def start(self): + logger.info("[Background Job] starting background submission manager") + async with self._state_lock: + self._accepting = True + need = max(0, self.min_workers - len(self._workers)) + for _ in range(need): + t = asyncio.create_task(self._worker_loop(), name="bg-worker") + async with self._state_lock: + self._workers.append(t) + + async def stop(self): + logger.info( + "[Background Job] stopping background submission manager..." + ) + async with self._state_lock: + self._accepting = False + workers = list(self._workers) + self._workers.clear() + for t in workers: + t.cancel() + for t in workers: + with contextlib.suppress(asyncio.CancelledError): + await t + logger.info( + "[Background Job] ...stopped all background submission manager" + ) + + async def enqueue( + self, + req: ProcessedSubmissionRequest, + mode: SubmissionMode, + sub_id: int, + ) -> tuple[int, int]: + async with self._state_lock: + if not self._accepting: + raise RuntimeError( + "[Background Job] Background Submission Manager is not" + "accepting new jobs right now" + ) + logger.info("enqueueing submission %s", sub_id) + now = dt.datetime.now(dt.timezone.utc) + with self.backend.db as db: + job_id = db.upsert_submission_job_status( + sub_id, + status="pending", + last_heartbeat=now, + ) + await self.queue.put( + JobItem(job_id=job_id, sub_id=sub_id, req=req, mode=mode) + ) + # if we have no workers and it does not hit maximum, start one + await self._autoscale_up() + return job_id, sub_id + + async def _worker_loop(self): + """ + A worker will keep listening to the queue, and process the job in the queue. + If the queue is empty, it will exit after idle_seconds. + Each worker only handles one submission job at a time + """ + try: + while True: + try: + item = await asyncio.wait_for( + self.queue.get(), timeout=self.idle_seconds + ) + logger.info( + "[Background Job][worker %r] pick the submission job `%s`", + id(asyncio.current_task()), + item.sub_id, + ) + except asyncio.TimeoutError: + async with self._state_lock: + me = asyncio.current_task() + if ( + len(self._workers) > self.min_workers + and me in self._workers + ): + try: + self._workers.remove(me) + logger.info( + "[Background Job][worker %r] idle too long," + "scale down; existing workers=%d", + me.get_name() + if hasattr(me, "get_name") + else id(me), + len(self._workers), + ) + except ValueError: + pass + return # scale down: exit + + continue + + t = asyncio.create_task( + self._run_one(item), name=f"submision-job-{item.sub_id}" + ) + + async with self._state_lock: + self._live_tasks.add(t) + try: + await t # wait submission job to finish + finally: + logger.info( + "[Background Job][worker %r] finishes the submission job `%s`", + id(asyncio.current_task()), + item.sub_id, + ) + async with self._state_lock: + self._live_tasks.discard(t) + self.queue.task_done() + except asyncio.CancelledError: + return + + async def _task_done_async(self, tt: asyncio.Task, item: JobItem): + async with self._state_lock: + self._live_tasks.discard(tt) + self.queue.task_done() + await self._autoscale_up() + + async def _run_one(self, item: JobItem): + sub_id = item.sub_id + now = dt.datetime.now(dt.timezone.utc) + + logger.info("[Background Job] start processing submission %s", sub_id) + with self.backend.db as db: + db.upsert_submission_job_status( + sub_id, status="running", last_heartbeat=now + ) + + # heartbeat loop continuously update the last heartbeat time for the submission status + stop_heartbeat = asyncio.Event() + + async def heartbeat(): + while not stop_heartbeat.is_set(): + await asyncio.sleep(HEARTBEAT_SEC) + ts = dt.datetime.now(dt.timezone.utc) + try: + with self.backend.db as db: + db.update_heartbeat_if_active(sub_id, ts) + except Exception: + pass + + hb_task = asyncio.create_task(heartbeat(), name=f"hb-{sub_id}") + try: + reporter = MultiProgressReporterAPI() + await asyncio.wait_for( + self.backend.submit_full( + item.req, item.mode, reporter, sub_id + ), + timeout=HARD_TIMEOUT_SEC, + ) + ts = dt.datetime.now(dt.timezone.utc) + logger.info("[Background Job] submission %s succeeded", sub_id) + with self.backend.db as db: + db.upsert_submission_job_status( + sub_id, status="succeeded", last_heartbeat=ts + ) + except asyncio.TimeoutError: + ts = dt.datetime.now(dt.timezone.utc) + with self.backend.db as db: + db.upsert_submission_job_status( + sub_id, + status="timed_out", + last_heartbeat=ts, + error="hard timeout reached", + ) + except Exception as e: + ts = dt.datetime.now(dt.timezone.utc) + logger.error( + "[Background Job] submission job %s failed", + sub_id, + exc_info=True, + ) + try: + with self.backend.db as db: + db.upsert_submission_job_status( + sub_id, + status="failed", + last_heartbeat=ts, + error=str(e), + ) + except Exception: + logger.error( + "[Background Job] Failed to write failed status for submission %s", + sub_id, + ) + finally: + stop_heartbeat.set() + hb_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await hb_task + + async def _autoscale_up(self): + async with self._state_lock: + running = len(self._live_tasks) + workers = len(self._workers) + qsize = self.queue.qsize() + + desired = min( + self.max_workers, max(self.min_workers, running + qsize) + ) + need = desired - workers + to_add = max(0, need) + logger.info( + "[Background Job] autoscale plan: add %d workers " + "(max=%d, busy=%d, active=%s, enqueue=%d)", + to_add, + self.max_workers, + running, + workers, + qsize, + ) + for _ in range(to_add): + logging.info( + "[Background Job] scale up: starting a new worker" + ) + t = asyncio.create_task(self._worker_loop(), name="bg-worker") + self._workers.append(t) diff --git a/src/libkernelbot/db_types.py b/src/libkernelbot/db_types.py index e8715fa2..0a03ec52 100644 --- a/src/libkernelbot/db_types.py +++ b/src/libkernelbot/db_types.py @@ -1,10 +1,15 @@ # This file provides TypeDict definitions for the return types we get from database queries import datetime +from enum import Enum from typing import TYPE_CHECKING, List, NotRequired, Optional, TypedDict if TYPE_CHECKING: from libkernelbot.task import LeaderboardTask +class IdentityType(str, Enum): + CLI = "cli" + WEB = "web" + UNKNOWN = "unknown" class LeaderboardItem(TypedDict): id: int diff --git a/src/libkernelbot/leaderboard_db.py b/src/libkernelbot/leaderboard_db.py index 93284fbc..5c9552cc 100644 --- a/src/libkernelbot/leaderboard_db.py +++ b/src/libkernelbot/leaderboard_db.py @@ -5,7 +5,13 @@ import psycopg2 -from libkernelbot.db_types import LeaderboardItem, LeaderboardRankedEntry, RunItem, SubmissionItem +from libkernelbot.db_types import ( + IdentityType, + LeaderboardItem, + LeaderboardRankedEntry, + RunItem, + SubmissionItem, +) from libkernelbot.run_eval import CompileResult, RunResult, SystemInfo from libkernelbot.task import LeaderboardDefinition, LeaderboardTask from libkernelbot.utils import ( @@ -182,7 +188,7 @@ def delete_leaderboard(self, leaderboard_name: str, force: bool = False): WHERE leaderboard.leaderboard.name = %s ) ); -""", + """, (leaderboard_name,), ) self.cursor.execute( @@ -222,6 +228,49 @@ def delete_leaderboard(self, leaderboard_name: str, force: bool = False): logger.exception("Could not delete leaderboard %s.", leaderboard_name, exc_info=e) raise KernelBotError(f"Could not delete leaderboard `{leaderboard_name}`.") from e + def validate_identity( + self, + identifier: str, + id_type: IdentityType, + ) -> Optional[dict[str, str]]: + """ + Validate an identity (CLI or Web) and return {user_id, user_name} if found. + + Args: + identifier: The identifier value (CLI ID or Web Auth ID). + id_type: IdentityType enum (IdentityType.CLI or IdentityType.WEB). + + Returns: + Optional[dict[str, str]]: {"user_id": ..., "user_name": ...} if valid; else None. + """ + where_by_type = { + IdentityType.CLI: ("cli_id = %s AND cli_valid = TRUE", "CLI ID"), + IdentityType.WEB: ("web_auth_id = %s", "WEB AUTH ID"), + } + + where_clause, human_label = where_by_type[id_type] + + try: + self.cursor.execute( + f""" + SELECT id, user_name + FROM leaderboard.user_info + WHERE {where_clause} + """, + (identifier,), + ) + row = self.cursor.fetchone() + return { + "user_id": row[0], + "user_name": row[1], + "id_type":id_type.value + } if row else None + except psycopg2.Error as e: + self.connection.rollback() + logger.exception("Error validating %s %s", human_label, identifier, exc_info=e) + raise KernelBotError(f"Error validating {human_label}") from e + + def create_submission( self, leaderboard: str, @@ -324,6 +373,56 @@ def mark_submission_done( self.connection.rollback() # Ensure rollback if error occurs raise KernelBotError("Error while finalizing submission") from e + def update_heartbeat_if_active(self, sub_id: int, ts: datetime.datetime) -> None: + try: + self.cursor.execute( + """ + UPDATE leaderboard.submission_job_status + SET last_heartbeat = %s, + updated_at = %s + WHERE submission_id = %s + AND status IN ('pending','running') + """, + (ts, ts, sub_id), + ) + self.connection.commit() + except psycopg2.Error as e: + self.connection.rollback() + logger.error("Failed to upsert submission job status. sub_id: '%s'", sub_id, exc_info=e) + raise KernelBotError("Error updating job status") from e + + + def upsert_submission_job_status( + self, + sub_id: int, + status: str | None = None, + error: str | None = None, + last_heartbeat: datetime.datetime | None = None, + ) -> int: + try: + self.cursor.execute( + """ + INSERT INTO leaderboard.submission_job_status AS s + (submission_id, status, error, last_heartbeat) + VALUES + (%s, %s, %s, %s) + ON CONFLICT (submission_id) DO UPDATE + SET + status = COALESCE(EXCLUDED.status, s.status), + error = COALESCE(EXCLUDED.error, s.error), + last_heartbeat = COALESCE(EXCLUDED.last_heartbeat, s.last_heartbeat) + RETURNING id; + """, + (sub_id, status, error, last_heartbeat), + ) + job_id = self.cursor.fetchone()[0] + self.connection.commit() + return int(job_id) + except psycopg2.Error as e: + self.connection.rollback() + logger.error("Failed to upsert submission job status. sub_id: '%s'", sub_id, exc_info=e) + raise KernelBotError("Error updating job status") from e + def create_submission_run( self, submission: int, diff --git a/src/migrations/20250822_01_UtXzl-website-submission.py b/src/migrations/20250822_01_UtXzl-website-submission.py new file mode 100644 index 00000000..a4f3c4eb --- /dev/null +++ b/src/migrations/20250822_01_UtXzl-website-submission.py @@ -0,0 +1,29 @@ +""" +website_submission +""" + +from yoyo import step + +__depends__ = {'20250728_01_Q3jso-fix-code-table'} + +# noqa: C901 +steps = [ + step( + "ALTER TABLE leaderboard.user_info " + "ADD COLUMN IF NOT EXISTS web_auth_id VARCHAR(255) DEFAULT NULL;" + ), + step(""" + CREATE TABLE IF NOT EXISTS leaderboard.submission_job_status ( + id SERIAL PRIMARY KEY, + submission_id INTEGER NOT NULL + REFERENCES leaderboard.submission(id) + ON DELETE CASCADE, + status VARCHAR(255) DEFAULT NULL, -- status of the job + error TEXT DEFAULT NULL, -- error details if failed + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), -- creation timestamp + last_heartbeat TIMESTAMPTZ DEFAULT NULL, -- updated periodically by worker + CONSTRAINT uq_submission_job_status_submission_id + UNIQUE (submission_id) -- one-to-one with submission + ); + """), +] diff --git a/tests/conftest.py b/tests/conftest.py index 1a049250..f296bbc9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,31 @@ +import os import subprocess import time from pathlib import Path import pytest +REQUIRED = { + "DISCORD_TOKEN": "dummy", + "GITHUB_TOKEN": "dummy", + "GITHUB_REPO": "dummy", +} + +for k, v in REQUIRED.items(): + os.environ.setdefault(k, v) + +@pytest.fixture(scope="session", autouse=True) +def _restore_env(): + old = {k: os.environ.get(k) for k in REQUIRED} + try: + yield + finally: + for k, v in old.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + DATABASE_URL = "postgresql://postgres:postgres@localhost:5433/clusterdev" diff --git a/tests/test_background_submission_manager.py b/tests/test_background_submission_manager.py new file mode 100644 index 00000000..ac038972 --- /dev/null +++ b/tests/test_background_submission_manager.py @@ -0,0 +1,148 @@ +import asyncio +import datetime +from unittest import mock + +import pytest + +from libkernelbot.background_submission_manager import BackgroundSubmissionManager +from libkernelbot.consts import SubmissionMode +from libkernelbot.submission import ProcessedSubmissionRequest + + +@pytest.fixture +def mock_backend(): + backend = mock.Mock() + backend.accepts_jobs = True + + # Mock database context manager + db_context = mock.Mock() + backend.db = db_context + db_context.__enter__ = mock.Mock(return_value=db_context) + db_context.__exit__ = mock.Mock(return_value=None) + + # Default mock responses + mock_task = mock.Mock() + db_context.get_leaderboard.return_value = { + "task": mock_task, + "secret_seed": 12345, + "deadline": datetime.datetime.now() + datetime.timedelta(days=1), + "name": "test_board", + } + db_context.get_leaderboard_gpu_types.return_value = ["A100", "V100"] + + return backend + + +def get_req(i: int) -> ProcessedSubmissionRequest: + return ProcessedSubmissionRequest( + leaderboard="lb", + task="dummy_task", + secret_seed=12345, + task_gpus=["A100"], + file_name=f"f{i}.py", + code="print('hi')", + user_id=1, + user_name="tester", + gpus=None, + ) + + +@pytest.mark.asyncio +async def test_enqueue_and_run_job(mock_backend): + # mock upsert/update + db_context = mock_backend.db + db_context.upsert_submission_job_status = mock.Mock( + side_effect=lambda *a, **k: a[0] + ) + db_context.update_heartbeat_if_active = mock.Mock() + + # mock submit_full + async def fake_submit_full(req, mode, reporter, sub_id): + await asyncio.sleep(0.01) # simulate a long-running job + return None, None + + mock_backend.submit_full = fake_submit_full + + manager = BackgroundSubmissionManager( + mock_backend, min_workers=1, max_workers=2, idle_seconds=0.1 + ) + await manager.start() + + # create a fake submission request + job_id, sub_id = await manager.enqueue(get_req(1), SubmissionMode.TEST, sub_id=42) + assert job_id == 42 + + # wait for the queue is clear + await manager.queue.join() + await asyncio.sleep(0.05) + + # check db status + assert ( + mock.call(42, status="pending", last_heartbeat=mock.ANY) + in db_context.upsert_submission_job_status.call_args_list + ) + assert ( + mock.call(42, status="running", last_heartbeat=mock.ANY) + in db_context.upsert_submission_job_status.call_args_list + ) + assert ( + mock.call(42, status="succeeded", last_heartbeat=mock.ANY) + in db_context.upsert_submission_job_status.call_args_list + ) + + await manager.stop() + + +@pytest.mark.asyncio +async def test_stop_rejects_new_jobs(mock_backend): + db_context = mock_backend.db + db_context.upsert_submission_job_status = mock.Mock(return_value=1) + db_context.update_heartbeat_if_active = mock.Mock() + mock_backend.submit_full = mock.AsyncMock() + + manager = BackgroundSubmissionManager( + mock_backend, min_workers=1, max_workers=1, idle_seconds=0.1 + ) + await manager.start() + await manager.stop() + + req = get_req(1) + with pytest.raises(RuntimeError): + await manager.enqueue(req, SubmissionMode.TEST, 99) + + +@pytest.mark.asyncio +async def test_scale_up_and_down(mock_backend): + db_context = mock_backend.db + db_context.upsert_submission_job_status = mock.Mock( + side_effect=lambda *a, **k: a[0] + ) + db_context.update_heartbeat_if_active = mock.Mock() + + async def fake_submit_full(req, mode, reporter, sub_id): + await asyncio.sleep(0.05) + return None, None + + mock_backend.submit_full = fake_submit_full + + manager = BackgroundSubmissionManager( + mock_backend, min_workers=1, max_workers=3, idle_seconds=0.2 + ) + await manager.start() + + # send multiple request to scale up + for i in range(6): + await manager.enqueue( + get_req(i), + SubmissionMode.TEST, + sub_id=i + 1, + ) + + await manager.queue.join() + + # idle timeout + await asyncio.sleep(manager.idle_seconds + 0.1) + + async with manager._state_lock: + assert len(manager._workers) == manager.min_workers + await manager.stop() diff --git a/tests/test_leaderboard_db.py b/tests/test_leaderboard_db.py index 741515a8..753c88f0 100644 --- a/tests/test_leaderboard_db.py +++ b/tests/test_leaderboard_db.py @@ -6,6 +6,7 @@ from test_report import sample_compile_result, sample_run_result, sample_system_info from libkernelbot import leaderboard_db +from libkernelbot.db_types import IdentityType from libkernelbot.utils import KernelBotError @@ -367,6 +368,44 @@ def test_leaderboard_submission_ranked(database, submit_leaderboard): }, ] +def test_validate_identity_web_auth_happy_path(database, submit_leaderboard): + with database as db: + db.cursor.execute( + """ + INSERT INTO leaderboard.user_info (id, user_name, web_auth_id) + VALUES (%s, %s, %s) + """, + ("1234", "sara_jojo","2345" ), + ) + user_info = db.validate_identity("2345",IdentityType.WEB) + assert user_info["user_id"] =="1234" + assert user_info["user_name"] =="sara_jojo" + assert user_info["id_type"] ==IdentityType.WEB.value + +def test_validate_identity_web_auth_not_found(database, submit_leaderboard): + with database as db: + db.cursor.execute( + """ + INSERT INTO leaderboard.user_info (id, user_name) + VALUES (%s, %s) + """, + ("1234", "sara_jojo"), + ) + user_info = db.validate_identity("2345",IdentityType.WEB) + assert user_info is None + +def test_validate_identity_web_auth_missing(database, submit_leaderboard): + with database as db: + db.cursor.execute( + """ + INSERT INTO leaderboard.user_info (id, user_name) + VALUES (%s, %s) + """, + ("1234", "sara_jojo"), + ) + res = db.validate_identity("2345",IdentityType.WEB) + assert res is None + def test_leaderboard_submission_deduplication(database, submit_leaderboard): """validate that identical submission codes are added just once""" diff --git a/tests/test_submission.py b/tests/test_submission.py index 5dc1fd11..a7cdfb74 100644 --- a/tests/test_submission.py +++ b/tests/test_submission.py @@ -106,7 +106,7 @@ def test_get_popcorn_directives_valid(): assert result == {"gpus": ["a100", "v100"], "leaderboard": "My_Board"} # Extra whitespace - code_whitespace = """#!POPCORN gpu A100 V100 + code_whitespace = """#!POPCORN gpu A100 V100 #!POPCORN leaderboard my_board """ # noqa: W291 result = submission._get_popcorn_directives(code_whitespace) assert result == {"gpus": ["A100", "V100"], "leaderboard": "my_board"} diff --git a/tests/test_validate_user_header.py b/tests/test_validate_user_header.py new file mode 100644 index 00000000..54ab484d --- /dev/null +++ b/tests/test_validate_user_header.py @@ -0,0 +1,70 @@ +from typing import Optional + +import pytest +from fastapi import HTTPException + +from kernelbot.api.main import validate_user_header +from libkernelbot.db_types import IdentityType + + +class DummyDBCtx: + def __init__(self, to_return=None, to_raise: Optional[Exception] = None): + self.to_return = to_return + self.to_raise = to_raise + self.seen = {} + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def validate_identity(self, token, id_type): + self.seen["token"] = token + self.seen["id_type"] = id_type + if self.to_raise: + raise self.to_raise + return self.to_return + +@pytest.mark.asyncio +async def test_cli_header_success(): + test_db = DummyDBCtx(to_return={"user_id": "u2", "user_name": "bob"}) + res = await validate_user_header( + x_web_auth_id=None, + x_popcorn_cli_id="clitok", + db_context=test_db, + ) + assert res["user_id"] == "u2" + assert test_db.seen["id_type"] == IdentityType.CLI + +@pytest.mark.asyncio +async def test_both_headers_prefers_web(): + test_db = DummyDBCtx(to_return={"user_id": "u3", "user_name": "c"}) + _ = await validate_user_header( + x_web_auth_id="webtok", + x_popcorn_cli_id="clitok", + db_context=test_db, + ) + assert test_db.seen["token"] == "webtok" + assert test_db.seen["id_type"] == IdentityType.WEB + +@pytest.mark.asyncio +async def test_missing_header_400(): + test_db = DummyDBCtx() + with pytest.raises(HTTPException) as ei: + await validate_user_header(None, None, test_db) + assert ei.value.status_code == 400 + +@pytest.mark.asyncio +async def test_db_error_500(): + test_db = DummyDBCtx(to_raise=RuntimeError("boom")) + with pytest.raises(HTTPException) as ei: + await validate_user_header("webtok", None, test_db) + assert ei.value.status_code == 500 + +@pytest.mark.asyncio +async def test_unauthorized_401(): + test_db = DummyDBCtx(to_return=None) + with pytest.raises(HTTPException) as ei: + await validate_user_header("webtok", None, test_db) + assert ei.value.status_code == 401