From eed6e92c8efa7412c4976c139f931561a80d9366 Mon Sep 17 00:00:00 2001 From: Joe Licata Date: Mon, 12 Jan 2026 20:20:05 +0000 Subject: [PATCH 1/8] refactor: Simplify configuration and authentication middleware - Refactored `Settings` class to delegate Redis URL and API key retrieval to dedicated services. - Removed deprecated authentication middleware and security headers middleware for cleaner codebase. - Introduced utility functions for container operations to reduce code duplication in container management. - Updated container execution logic to enhance readability and maintainability. --- src/config/__init__.py | 13 +-- src/dependencies/__init__.py | 6 -- src/dependencies/auth.py | 23 ------ src/middleware/__init__.py | 7 -- src/middleware/auth.py | 122 ----------------------------- src/middleware/headers.py | 55 ------------- src/services/container/__init__.py | 11 ++- src/services/container/executor.py | 83 ++++---------------- src/services/container/manager.py | 31 +------- src/services/container/utils.py | 102 ++++++++++++++++++++++++ src/services/execution.py | 14 ---- src/utils/__init__.py | 2 +- src/utils/containers.py | 27 ------- tests/conftest.py | 2 +- 14 files changed, 133 insertions(+), 365 deletions(-) delete mode 100644 src/middleware/auth.py delete mode 100644 src/middleware/headers.py create mode 100644 src/services/container/utils.py delete mode 100644 src/services/execution.py delete mode 100644 src/utils/containers.py diff --git a/src/config/__init__.py b/src/config/__init__.py index 168f084..61e6969 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -590,20 +590,11 @@ def validate_ssl_files(self) -> bool: def get_redis_url(self) -> str: """Get Redis connection URL.""" - if self.redis_url: - return self.redis_url - password_part = f":{self.redis_password}@" if self.redis_password else "" - return f"redis://{password_part}{self.redis_host}:{self.redis_port}/{self.redis_db}" + return self.redis.get_url() def get_valid_api_keys(self) -> List[str]: """Get all valid API keys including the primary key.""" - keys = [self.api_key] - if self.api_keys: - if isinstance(self.api_keys, list): - keys.extend(self.api_keys) - elif isinstance(self.api_keys, str): - keys.extend([k.strip() for k in self.api_keys.split(",") if k.strip()]) - return list(set(keys)) + return self.security.get_valid_api_keys() def get_language_config(self, language: str) -> Dict[str, Any]: """Get configuration for a specific language.""" diff --git a/src/dependencies/__init__.py b/src/dependencies/__init__.py index f35a838..40794f5 100644 --- a/src/dependencies/__init__.py +++ b/src/dependencies/__init__.py @@ -3,9 +3,6 @@ from .auth import ( verify_api_key, verify_api_key_optional, - get_current_user, - get_current_user_optional, - AuthenticatedUser, ) from .services import ( get_file_service, @@ -21,9 +18,6 @@ __all__ = [ "verify_api_key", "verify_api_key_optional", - "get_current_user", - "get_current_user_optional", - "AuthenticatedUser", "get_file_service", "get_session_service", "get_state_service", diff --git a/src/dependencies/auth.py b/src/dependencies/auth.py index d044abb..657e74d 100644 --- a/src/dependencies/auth.py +++ b/src/dependencies/auth.py @@ -70,26 +70,3 @@ async def verify_api_key_optional( raise # Invalid API key provided, which is not OK -class AuthenticatedUser: - """Represents an authenticated API user.""" - - def __init__(self, api_key: str): - self.api_key = api_key - self.key_prefix = api_key[:8] + "..." if len(api_key) > 8 else api_key - - def __str__(self): - return f"AuthenticatedUser(key={self.key_prefix})" - - -async def get_current_user(api_key: str = Depends(verify_api_key)) -> AuthenticatedUser: - """Get the current authenticated user.""" - return AuthenticatedUser(api_key) - - -async def get_current_user_optional( - api_key: Optional[str] = Depends(verify_api_key_optional), -) -> Optional[AuthenticatedUser]: - """Get the current authenticated user (optional).""" - if api_key: - return AuthenticatedUser(api_key) - return None diff --git a/src/middleware/__init__.py b/src/middleware/__init__.py index 994608d..4ce9c1c 100644 --- a/src/middleware/__init__.py +++ b/src/middleware/__init__.py @@ -1,17 +1,10 @@ """Middleware package for the Code Interpreter API.""" from .security import SecurityMiddleware, RequestLoggingMiddleware -from .auth import AuthenticationMiddleware -from .headers import SecurityHeadersMiddleware from .metrics import MetricsMiddleware __all__ = [ - # Consolidated (backward compatible) "SecurityMiddleware", "RequestLoggingMiddleware", - # Separated (new) - "AuthenticationMiddleware", - "SecurityHeadersMiddleware", - # Existing "MetricsMiddleware", ] diff --git a/src/middleware/auth.py b/src/middleware/auth.py deleted file mode 100644 index 1767945..0000000 --- a/src/middleware/auth.py +++ /dev/null @@ -1,122 +0,0 @@ -"""Authentication middleware for API key validation.""" - -import time -from typing import Callable, Optional - -import structlog -from fastapi import Request, HTTPException -from fastapi.responses import JSONResponse - -from ..services.auth import get_auth_service - -logger = structlog.get_logger(__name__) - - -class AuthenticationMiddleware: - """Middleware for API key authentication. - - This middleware handles: - - API key extraction from headers - - API key validation - - Rate limiting on authentication failures - - Setting authenticated state on request - """ - - def __init__(self, app: Callable): - self.app = app - self.excluded_paths = {"/health", "/docs", "/redoc", "/openapi.json"} - - async def __call__(self, scope: dict, receive: Callable, send: Callable): - """Process request through authentication middleware.""" - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - request = Request(scope, receive) - - # Skip auth for excluded paths and OPTIONS - if self._should_skip_auth(request): - await self.app(scope, receive, send) - return - - try: - await self._authenticate_request(request, scope) - except HTTPException as e: - response = JSONResponse( - status_code=e.status_code, - content={"error": e.detail, "timestamp": time.time()}, - ) - await response(scope, receive, send) - return - except Exception as e: - logger.error("Authentication middleware error", error=str(e)) - response = JSONResponse( - status_code=500, - content={ - "error": "Internal authentication error", - "timestamp": time.time(), - }, - ) - await response(scope, receive, send) - return - - await self.app(scope, receive, send) - - def _should_skip_auth(self, request: Request) -> bool: - """Check if authentication should be skipped.""" - return request.url.path in self.excluded_paths or request.method == "OPTIONS" - - async def _authenticate_request(self, request: Request, scope: dict): - """Handle API key authentication.""" - # Extract API key - api_key = self._extract_api_key(request) - - # Get authentication service - auth_service = await get_auth_service() - - # Check rate limiting - client_ip = self._get_client_ip(request) - if not await auth_service.check_rate_limit(client_ip): - raise HTTPException( - status_code=429, - detail="Too many authentication failures. Please try again later.", - ) - - # Validate API key - if not await auth_service.validate_api_key(api_key): - raise HTTPException(status_code=401, detail="Invalid or missing API key") - - # Add authenticated state - scope["state"] = scope.get("state", {}) - scope["state"]["authenticated"] = True - scope["state"]["api_key"] = api_key - - def _extract_api_key(self, request: Request) -> Optional[str]: - """Extract API key from request headers.""" - # Check x-api-key header first - api_key = request.headers.get("x-api-key") - if api_key: - return api_key - - # Check Authorization header - auth_header = request.headers.get("authorization") - if auth_header: - if auth_header.startswith("Bearer "): - return auth_header[7:] - elif auth_header.startswith("ApiKey "): - return auth_header[7:] - - return None - - def _get_client_ip(self, request: Request) -> str: - """Get client IP address.""" - # Check forwarded headers - forwarded_for = request.headers.get("x-forwarded-for") - if forwarded_for: - return forwarded_for.split(",")[0].strip() - - real_ip = request.headers.get("x-real-ip") - if real_ip: - return real_ip - - return request.client.host if request.client else "unknown" diff --git a/src/middleware/headers.py b/src/middleware/headers.py deleted file mode 100644 index 7f413ac..0000000 --- a/src/middleware/headers.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Security headers middleware.""" - -from typing import Callable - -import structlog - -logger = structlog.get_logger(__name__) - - -class SecurityHeadersMiddleware: - """Middleware for adding security headers to responses. - - This middleware adds standard security headers to all responses: - - X-Content-Type-Options: nosniff - - X-Frame-Options: DENY - - X-XSS-Protection: 1; mode=block - - Strict-Transport-Security: max-age=31536000; includeSubDomains - - Content-Security-Policy: default-src 'self' - - Referrer-Policy: strict-origin-when-cross-origin - - Permissions-Policy: geolocation=(), microphone=(), camera=() - """ - - # Default security headers - SECURITY_HEADERS = { - b"x-content-type-options": b"nosniff", - b"x-frame-options": b"DENY", - b"x-xss-protection": b"1; mode=block", - b"strict-transport-security": b"max-age=31536000; includeSubDomains", - b"content-security-policy": b"default-src 'self'", - b"referrer-policy": b"strict-origin-when-cross-origin", - b"permissions-policy": b"geolocation=(), microphone=(), camera=()", - } - - def __init__(self, app: Callable): - self.app = app - - async def __call__(self, scope: dict, receive: Callable, send: Callable): - """Process request and add security headers to response.""" - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - async def send_wrapper(message): - if message["type"] == "http.response.start": - headers = dict(message.get("headers", [])) - - # Add security headers - for key, value in self.SECURITY_HEADERS.items(): - headers[key] = value - - message["headers"] = list(headers.items()) - - await send(message) - - await self.app(scope, receive, send_wrapper) diff --git a/src/services/container/__init__.py b/src/services/container/__init__.py index 93a2a3b..d306154 100644 --- a/src/services/container/__init__.py +++ b/src/services/container/__init__.py @@ -4,10 +4,19 @@ - client.py: Docker client factory and initialization - executor.py: Command execution in containers - manager.py: Container lifecycle management +- utils.py: Shared utilities for container operations """ from .manager import ContainerManager from .client import DockerClientFactory from .executor import ContainerExecutor +from .utils import wait_for_container_ready, receive_socket_output, run_in_executor -__all__ = ["ContainerManager", "DockerClientFactory", "ContainerExecutor"] +__all__ = [ + "ContainerManager", + "DockerClientFactory", + "ContainerExecutor", + "wait_for_container_ready", + "receive_socket_output", + "run_in_executor", +] diff --git a/src/services/container/executor.py b/src/services/container/executor.py index 8bf563d..d755ecc 100644 --- a/src/services/container/executor.py +++ b/src/services/container/executor.py @@ -10,6 +10,7 @@ from docker.models.containers import Container from ...config import settings +from .utils import wait_for_container_ready, receive_socket_output, run_in_executor logger = structlog.get_logger(__name__) @@ -74,35 +75,7 @@ async def execute_command( exec_config["workdir"] = working_dir try: - exec_instance = self.client.api.exec_create(container.id, **exec_config) - exec_id = exec_instance["Id"] - - sock = self.client.api.exec_start(exec_id, socket=True) - raw_sock = sock._sock - raw_sock.settimeout(timeout) - - if stdin_payload: - raw_sock.sendall(stdin_payload.encode("utf-8")) - raw_sock.shutdown(1) - - output_chunks = [] - while True: - try: - chunk = raw_sock.recv(4096) - if not chunk: - break - output_chunks.append(chunk) - except (TimeoutError, OSError): - break - - output = b"".join(output_chunks) - exec_info = self.client.api.exec_inspect(exec_id) - exit_code = exec_info["ExitCode"] - - output_str = self._sanitize_output(output) if output else "" - stdout, stderr = self._separate_output_streams(output_str, exit_code) - - return exit_code, stdout, stderr + return self._execute_via_socket(container, exec_config, stdin_payload, timeout) except DockerException as e: error_text = str(e) @@ -111,7 +84,7 @@ async def execute_command( if "is not running" in error_text.lower(): try: await self._start_container(container) - return await self._retry_execution( + return self._execute_via_socket( container, exec_config, stdin_payload, timeout ) except Exception as retry_err: @@ -122,42 +95,17 @@ async def execute_command( logger.error(f"Unexpected error during command execution: {e}") return 1, "", f"Unexpected execution error: {str(e)}" - async def _start_container(self, container: Container) -> bool: - """Start a container and wait for running state.""" - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, container.start) - - stable_checks = 0 - max_wait = 2.0 - interval = 0.05 - total_wait = 0.0 - - while total_wait < max_wait: - try: - container.reload() - if getattr(container, "status", "") == "running": - stable_checks += 1 - if stable_checks >= 3: - return True - else: - stable_checks = 0 - except Exception: - stable_checks = 0 - await asyncio.sleep(interval) - total_wait += interval - - return getattr(container, "status", "") == "running" - - async def _retry_execution( + def _execute_via_socket( self, container: Container, exec_config: Dict[str, Any], stdin_payload: Optional[str], timeout: int, ) -> Tuple[int, str, str]: - """Retry execution after container start.""" + """Execute command and collect output via socket.""" exec_instance = self.client.api.exec_create(container.id, **exec_config) exec_id = exec_instance["Id"] + sock = self.client.api.exec_start(exec_id, socket=True) raw_sock = sock._sock raw_sock.settimeout(timeout) @@ -166,23 +114,20 @@ async def _retry_execution( raw_sock.sendall(stdin_payload.encode("utf-8")) raw_sock.shutdown(1) - output_chunks = [] - while True: - try: - chunk = raw_sock.recv(4096) - if not chunk: - break - output_chunks.append(chunk) - except (TimeoutError, OSError): - break - - output = b"".join(output_chunks) + output = receive_socket_output(raw_sock) exec_info = self.client.api.exec_inspect(exec_id) exit_code = exec_info["ExitCode"] + output_str = self._sanitize_output(output) if output else "" stdout, stderr = self._separate_output_streams(output_str, exit_code) + return exit_code, stdout, stderr + async def _start_container(self, container: Container) -> bool: + """Start a container and wait for running state.""" + await run_in_executor(container.start) + return await wait_for_container_ready(container) + def _build_sanitized_env(self, language: Optional[str]) -> Dict[str, str]: """Build environment whitelist for execution.""" normalized_lang = (language or "").lower().strip() diff --git a/src/services/container/manager.py b/src/services/container/manager.py index be91ed1..5f0fd47 100644 --- a/src/services/container/manager.py +++ b/src/services/container/manager.py @@ -20,6 +20,7 @@ ) from .client import DockerClientFactory from .executor import ContainerExecutor +from .utils import wait_for_container_ready, run_in_executor logger = structlog.get_logger(__name__) @@ -331,34 +332,8 @@ def create_container( async def start_container(self, container: Container) -> bool: """Start a Docker container.""" try: - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, container.start) - - stable_checks = 0 - max_wait = 2.0 - interval = 0.05 - total_wait = 0.0 - - while total_wait < max_wait: - try: - container.reload() - if getattr(container, "status", "") == "running": - stable_checks += 1 - if stable_checks >= 3: - return True - else: - stable_checks = 0 - except Exception: - stable_checks = 0 - await asyncio.sleep(interval) - total_wait += interval - - try: - container.reload() - return getattr(container, "status", "") == "running" - except Exception: - return False - + await run_in_executor(container.start) + return await wait_for_container_ready(container) except DockerException as e: logger.error(f"Failed to start container {container.id[:12]}: {e}") return False diff --git a/src/services/container/utils.py b/src/services/container/utils.py new file mode 100644 index 0000000..1bbd607 --- /dev/null +++ b/src/services/container/utils.py @@ -0,0 +1,102 @@ +"""Shared utilities for container operations. + +This module contains common patterns extracted from container services +to reduce code duplication. +""" + +import asyncio +from typing import List, Optional + +import structlog +from docker.models.containers import Container + +logger = structlog.get_logger(__name__) + + +async def wait_for_container_ready( + container: Container, + max_wait: float = 2.0, + interval: float = 0.05, + stable_checks_required: int = 3, +) -> bool: + """ + Wait for a container to reach a stable running state. + + Uses polling with stability checks to ensure the container + is truly running before returning. + + Args: + container: Docker container to wait for + max_wait: Maximum time to wait in seconds + interval: Polling interval in seconds + stable_checks_required: Number of consecutive running checks required + + Returns: + True if container is running, False otherwise + """ + stable_checks = 0 + total_wait = 0.0 + + while total_wait < max_wait: + try: + container.reload() + if getattr(container, "status", "") == "running": + stable_checks += 1 + if stable_checks >= stable_checks_required: + return True + else: + stable_checks = 0 + except Exception: + stable_checks = 0 + await asyncio.sleep(interval) + total_wait += interval + + # Final check + try: + container.reload() + return getattr(container, "status", "") == "running" + except Exception: + return False + + +def receive_socket_output( + sock, + chunk_size: int = 4096, + timeout_exceptions: tuple = (TimeoutError, OSError), +) -> bytes: + """ + Receive all output from a socket until closed or timeout. + + Args: + sock: Raw socket to receive from + chunk_size: Size of chunks to receive + timeout_exceptions: Exception types that indicate timeout + + Returns: + All received bytes concatenated + """ + output_chunks: List[bytes] = [] + while True: + try: + chunk = sock.recv(chunk_size) + if not chunk: + break + output_chunks.append(chunk) + except timeout_exceptions: + break + return b"".join(output_chunks) + + +async def run_in_executor(func, *args): + """ + Run a blocking function in the default thread pool executor. + + Args: + func: Blocking function to run + *args: Arguments to pass to the function + + Returns: + Result of the function + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, func, *args) diff --git a/src/services/execution.py b/src/services/execution.py deleted file mode 100644 index 606fc0b..0000000 --- a/src/services/execution.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Code execution service implementation. - -DEPRECATED: This module is maintained for backward compatibility. -New code should import from src.services.execution package instead. - -The CodeExecutionService has been split into: -- src/services/execution/runner.py: Core execution logic -- src/services/execution/output.py: Output processing and validation -""" - -# Re-export from new package for backward compatibility -from .execution import CodeExecutionService, CodeExecutionRunner, OutputProcessor - -__all__ = ["CodeExecutionService", "CodeExecutionRunner", "OutputProcessor"] diff --git a/src/utils/__init__.py b/src/utils/__init__.py index bb160e3..fb7b680 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -2,7 +2,7 @@ from .logging import setup_logging, get_logger from .security import SecurityValidator, RateLimiter, SecurityAudit, get_rate_limiter -from .containers import ContainerManager +from ..services.container import ContainerManager __all__ = [ "setup_logging", diff --git a/src/utils/containers.py b/src/utils/containers.py deleted file mode 100644 index cef0be0..0000000 --- a/src/utils/containers.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Container utilities for Docker operations. - -DEPRECATED: This module is maintained for backward compatibility. -New code should import from src.services.container instead. - -The ContainerManager class has been split into: -- src/services/container/client.py: Docker client factory -- src/services/container/executor.py: Command execution -- src/services/container/manager.py: Container lifecycle management -""" - -# Re-export from new location for backward compatibility -from ..services.container import ( - ContainerManager, - DockerClientFactory, - ContainerExecutor, -) - -# Also re-export error handler for existing imports -from .error_handlers import handle_docker_error - -__all__ = [ - "ContainerManager", - "DockerClientFactory", - "ContainerExecutor", - "handle_docker_error", -] diff --git a/tests/conftest.py b/tests/conftest.py index df2774a..16a9be5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -110,7 +110,7 @@ async def session_service(mock_redis): @pytest.fixture def execution_service(): """Create CodeExecutionService instance with mocked dependencies.""" - with patch('src.services.execution.ContainerManager') as mock_container_manager: + with patch('src.services.execution.runner.ContainerManager') as mock_container_manager: mock_manager = MagicMock() mock_container_manager.return_value = mock_manager From 78a753c4d2345f2d815503ca31857ebe3cc88f59 Mon Sep 17 00:00:00 2001 From: Joe Licata Date: Tue, 13 Jan 2026 16:35:22 +0000 Subject: [PATCH 2/8] feat: Implement file type restrictions for uploads - Added validation to restrict file uploads based on allowed file types, raising a 415 error for disallowed types. - Expanded the list of allowed file extensions in the `Settings` class to include various document, image, and script formats. - Introduced unit tests to verify the functionality of file type validation and ensure blocked extensions are correctly handled. --- src/api/files.py | 8 ++ src/config/__init__.py | 61 ++++++++++++-- src/config/security.py | 83 ------------------- src/middleware/security.py | 1 - tests/integration/test_file_api.py | 49 ++++++++++++ tests/unit/test_file_upload_validation.py | 98 +++++++++++++++++++++++ 6 files changed, 209 insertions(+), 91 deletions(-) create mode 100644 tests/unit/test_file_upload_validation.py diff --git a/src/api/files.py b/src/api/files.py index 6f436a0..048beb9 100644 --- a/src/api/files.py +++ b/src/api/files.py @@ -103,6 +103,14 @@ async def upload_file( detail=f"Too many files. Maximum {settings.max_files_per_session} files allowed", ) + # Check file type restrictions + for file in upload_files: + if not settings.is_file_allowed(file.filename or ""): + raise HTTPException( + status_code=415, + detail=f"File type not allowed: {file.filename}", + ) + uploaded_files = [] # Create a session ID for this upload diff --git a/src/config/__init__.py b/src/config/__init__.py index 61e6969..e235ef2 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -365,7 +365,48 @@ class Settings(BaseSettings): # Security Configuration allowed_file_extensions: List[str] = Field( default_factory=lambda: [ + # Text and documentation ".txt", + ".md", + ".rtf", + ".pdf", + # Microsoft Office + ".doc", + ".docx", + ".dotx", + ".xls", + ".xlsx", + ".xltx", + ".ppt", + ".pptx", + ".potx", + ".ppsx", + # OpenDocument formats + ".odt", + ".ods", + ".odp", + ".odg", + # Data formats + ".json", + ".csv", + ".xml", + ".yaml", + ".yml", + ".sql", + # Images + ".png", + ".jpg", + ".jpeg", + ".gif", + ".svg", + ".bmp", + ".webp", + ".ico", + # Web + ".html", + ".htm", + ".css", + # Code files ".py", ".js", ".ts", @@ -380,18 +421,24 @@ class Settings(BaseSettings): ".r", ".f90", ".d", - ".json", - ".csv", - ".xml", - ".yaml", - ".yml", - ".md", - ".sql", + # Scripts and config ".sh", ".bat", ".ps1", ".dockerfile", ".makefile", + ".ini", + ".cfg", + ".conf", + ".log", + # Archives + ".zip", + # Email and calendar + ".eml", + ".msg", + ".mbox", + ".ics", + ".vcf", ] ) blocked_file_patterns: List[str] = Field( diff --git a/src/config/security.py b/src/config/security.py index 0430582..64f3805 100644 --- a/src/config/security.py +++ b/src/config/security.py @@ -14,89 +14,6 @@ class SecurityConfig(BaseSettings): api_key_header: str = Field(default="x-api-key") api_key_cache_ttl: int = Field(default=300, ge=60) - # File Security - allowed_file_extensions: List[str] = Field( - default_factory=lambda: [ - # Text and documentation - ".txt", - ".md", - ".rtf", - ".pdf", - # Microsoft Office - ".doc", - ".docx", - ".dotx", - ".xls", - ".xlsx", - ".xltx", - ".ppt", - ".pptx", - ".potx", - ".ppsx", - # OpenDocument formats - ".odt", - ".ods", - ".odp", - ".odg", - # Data formats - ".json", - ".csv", - ".xml", - ".yaml", - ".yml", - ".sql", - # Images - ".png", - ".jpg", - ".jpeg", - ".gif", - ".svg", - ".bmp", - ".webp", - ".ico", - # Web - ".html", - ".htm", - ".css", - # Code files - ".py", - ".js", - ".ts", - ".go", - ".java", - ".c", - ".cpp", - ".h", - ".hpp", - ".php", - ".rs", - ".r", - ".f90", - ".d", - # Scripts and config - ".sh", - ".bat", - ".ps1", - ".dockerfile", - ".makefile", - ".ini", - ".cfg", - ".conf", - ".log", - # Archives - ".zip", - # Email and calendar - ".eml", - ".msg", - ".mbox", - ".ics", - ".vcf", - ] - ) - blocked_file_patterns: List[str] = Field( - default_factory=lambda: ["*.exe", "*.dll", "*.so", "*.dylib", "*.bin"] - ) - # Container Isolation enable_network_isolation: bool = Field(default=True) enable_filesystem_isolation: bool = Field(default=True) diff --git a/src/middleware/security.py b/src/middleware/security.py index cb5a113..6384481 100644 --- a/src/middleware/security.py +++ b/src/middleware/security.py @@ -22,7 +22,6 @@ class SecurityMiddleware: def __init__(self, app: Callable): self.app = app - self.max_request_size = settings.max_file_size_mb * 1024 * 1024 self.excluded_paths = { "/health", "/docs", diff --git a/tests/integration/test_file_api.py b/tests/integration/test_file_api.py index 4e2ba16..65f1434 100644 --- a/tests/integration/test_file_api.py +++ b/tests/integration/test_file_api.py @@ -209,6 +209,55 @@ def test_delete_nonexistent_file(self, client, auth_headers, unique_session_id): assert response.status_code == 404 +class TestFileTypeRestrictions: + """Test file type upload restrictions.""" + + def test_upload_blocked_exe_file(self, client, auth_headers): + """Test that .exe files are blocked with 415 status.""" + files = {"files": ("malware.exe", io.BytesIO(b"MZ...fake exe"), "application/octet-stream")} + + response = client.post("/upload", files=files, headers=auth_headers) + + assert response.status_code == 415 + assert "File type not allowed" in response.json()["detail"] + + def test_upload_blocked_dll_file(self, client, auth_headers): + """Test that .dll files are blocked with 415 status.""" + files = {"files": ("library.dll", io.BytesIO(b"fake dll content"), "application/octet-stream")} + + response = client.post("/upload", files=files, headers=auth_headers) + + assert response.status_code == 415 + assert "File type not allowed" in response.json()["detail"] + + def test_upload_blocked_bin_file(self, client, auth_headers): + """Test that .bin files are blocked with 415 status.""" + files = {"files": ("binary.bin", io.BytesIO(b"binary content"), "application/octet-stream")} + + response = client.post("/upload", files=files, headers=auth_headers) + + assert response.status_code == 415 + assert "File type not allowed" in response.json()["detail"] + + def test_upload_allowed_txt_file(self, client, auth_headers): + """Test that allowed file types still work.""" + files = {"files": ("readme.txt", io.BytesIO(b"Hello world"), "text/plain")} + + response = client.post("/upload", files=files, headers=auth_headers) + + assert response.status_code == 200 + assert response.json()["message"] == "success" + + def test_upload_allowed_python_file(self, client, auth_headers): + """Test that Python files are allowed.""" + files = {"files": ("script.py", io.BytesIO(b"print('hello')"), "text/x-python")} + + response = client.post("/upload", files=files, headers=auth_headers) + + assert response.status_code == 200 + assert response.json()["message"] == "success" + + class TestFileAuthentication: """Test authentication for file endpoints.""" diff --git a/tests/unit/test_file_upload_validation.py b/tests/unit/test_file_upload_validation.py new file mode 100644 index 0000000..e6ad112 --- /dev/null +++ b/tests/unit/test_file_upload_validation.py @@ -0,0 +1,98 @@ +"""Unit tests for file upload type validation.""" + +import pytest +from unittest.mock import patch, MagicMock + + +class TestIsFileAllowed: + """Test the is_file_allowed settings method.""" + + def test_allowed_extension_passes(self): + """Test that allowed file extensions pass validation.""" + from src.config import settings + + # Text and code files + assert settings.is_file_allowed("test.txt") is True + assert settings.is_file_allowed("script.py") is True + assert settings.is_file_allowed("data.json") is True + assert settings.is_file_allowed("code.js") is True + assert settings.is_file_allowed("notes.md") is True + + # Documents + assert settings.is_file_allowed("document.pdf") is True + assert settings.is_file_allowed("report.docx") is True + assert settings.is_file_allowed("spreadsheet.xlsx") is True + + # Images + assert settings.is_file_allowed("image.png") is True + assert settings.is_file_allowed("photo.jpg") is True + assert settings.is_file_allowed("icon.svg") is True + + # Archives + assert settings.is_file_allowed("archive.zip") is True + + def test_blocked_extension_fails(self): + """Test that blocked file extensions fail validation.""" + from src.config import settings + + # These are not in allowed_file_extensions + assert settings.is_file_allowed("malware.exe") is False + assert settings.is_file_allowed("library.dll") is False + assert settings.is_file_allowed("binary.bin") is False + assert settings.is_file_allowed("shared.so") is False + assert settings.is_file_allowed("dynamic.dylib") is False + + def test_blocked_pattern_matches(self): + """Test that blocked patterns are enforced.""" + from src.config import settings + + # Test blocked_file_patterns (*.exe, *.dll, *.so, *.dylib, *.bin) + assert settings.is_file_allowed("anything.exe") is False + assert settings.is_file_allowed("anything.dll") is False + assert settings.is_file_allowed("anything.bin") is False + + def test_case_insensitive_extension(self): + """Test that extension checking is case insensitive.""" + from src.config import settings + + # Allowed extensions should work regardless of case + assert settings.is_file_allowed("test.TXT") is True + assert settings.is_file_allowed("test.Txt") is True + assert settings.is_file_allowed("script.PY") is True + + # Blocked extensions should be blocked regardless of case + assert settings.is_file_allowed("malware.EXE") is False + assert settings.is_file_allowed("malware.Exe") is False + + def test_file_without_extension(self): + """Test handling of files without extensions.""" + from src.config import settings + + # Files without extensions should be allowed (no extension to block) + # The is_file_allowed method returns True if extension is empty + assert settings.is_file_allowed("Makefile") is True + assert settings.is_file_allowed("Dockerfile") is True + assert settings.is_file_allowed("README") is True + + def test_empty_filename(self): + """Test handling of empty filename.""" + from src.config import settings + + # Empty filename should be allowed (no extension to check) + assert settings.is_file_allowed("") is True + + def test_double_extension(self): + """Test files with double extensions.""" + from src.config import settings + + # Only the last extension matters + assert settings.is_file_allowed("archive.tar.gz") is False # .gz not in allowed + assert settings.is_file_allowed("script.test.py") is True # .py is allowed + + def test_hidden_files(self): + """Test hidden files (starting with dot).""" + from src.config import settings + + assert settings.is_file_allowed(".gitignore") is True # No extension + assert settings.is_file_allowed(".env") is True # No extension + assert settings.is_file_allowed(".config.json") is True # .json allowed From caa2f51f61ecc7c084ad6ffdef9e15f67eae9c23 Mon Sep 17 00:00:00 2001 From: Joe Licata Date: Tue, 13 Jan 2026 16:51:47 +0000 Subject: [PATCH 3/8] refactor: Remove deprecated settings and streamline request handling - Removed deprecated `max_cpu_quota` and session container reuse settings from the `Settings` class for cleaner configuration. - Consolidated API key extraction and client IP retrieval into shared utility functions for improved code reuse and maintainability. - Eliminated unused event classes and methods to reduce code clutter and enhance readability. - Updated output processing to utilize the new filename sanitization method, ensuring consistent handling of filenames across the application. --- src/config/__init__.py | 14 - src/config/resources.py | 3 - src/core/events.py | 52 -- src/dependencies/auth.py | 14 +- src/middleware/security.py | 37 +- src/services/container/executor.py | 4 +- src/services/container/network.py | 48 -- src/services/execution/__init__.py | 2 +- src/services/execution/output.py | 20 +- src/services/execution/runner.py | 38 +- src/utils/error_handlers.py | 69 -- src/utils/request_helpers.py | 70 ++ src/utils/security.py | 70 +- tests/__init__.py | 2 +- tests/conftest.py | 90 ++- tests/integration/__init__.py | 2 +- tests/integration/test_api_contracts.py | 217 +++--- tests/integration/test_auth_integration.py | 219 +++--- tests/integration/test_container_behavior.py | 683 +++++++++++------- tests/integration/test_container_hardening.py | 5 +- tests/integration/test_exec_api.py | 353 ++++----- tests/integration/test_file_api.py | 64 +- tests/integration/test_file_handling.py | 202 +++--- tests/integration/test_librechat_compat.py | 205 +++--- .../integration/test_security_integration.py | 106 +-- tests/integration/test_session_behavior.py | 305 +++++--- tests/integration/test_state_api.py | 56 +- tests/unit/__init__.py | 2 +- tests/unit/test_output_processor.py | 15 - tests/unit/test_session_service.py | 168 +++-- tests/unit/test_state_service.py | 60 +- 31 files changed, 1707 insertions(+), 1488 deletions(-) create mode 100644 src/utils/request_helpers.py diff --git a/src/config/__init__.py b/src/config/__init__.py index e235ef2..8c1266e 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -143,9 +143,6 @@ class Settings(BaseSettings): le=16.0, description="Maximum CPU cores available to execution containers", ) - max_cpu_quota: int = Field( - default=50000, ge=10000, le=100000 - ) # Deprecated, use max_cpus max_pids: int = Field( default=512, ge=64, @@ -179,16 +176,6 @@ class Settings(BaseSettings): container_pool_enabled: bool = Field(default=True) container_pool_warmup_on_startup: bool = Field(default=True) - # Session Container Reuse - DEPRECATED - # These settings are no longer used. Containers are now stateless: - # - Each execution gets a fresh container from pool - # - Containers are destroyed immediately after execution - # Kept for backward compatibility with existing configs - session_container_reuse_enabled: bool = Field(default=False) # Deprecated, not used - session_container_ttl_seconds: int = Field( - default=0, ge=0, le=1800 - ) # Deprecated, not used - # Per-language pool sizes (0 = on-demand only, no pre-warming) container_pool_py: int = Field( default=5, ge=0, le=50, description="Python pool size" @@ -593,7 +580,6 @@ def resources(self) -> ResourcesConfig: max_execution_time=self.max_execution_time, max_memory_mb=self.max_memory_mb, max_cpus=self.max_cpus, - max_cpu_quota=self.max_cpu_quota, max_pids=self.max_pids, max_open_files=self.max_open_files, max_file_size_mb=self.max_file_size_mb, diff --git a/src/config/resources.py b/src/config/resources.py index 3b4479e..8fd2380 100644 --- a/src/config/resources.py +++ b/src/config/resources.py @@ -16,9 +16,6 @@ class ResourcesConfig(BaseSettings): le=16.0, description="Maximum CPU cores available to execution containers", ) - max_cpu_quota: int = Field( - default=50000, ge=10000, le=100000 - ) # Deprecated, use max_cpus max_pids: int = Field( default=512, ge=64, diff --git a/src/core/events.py b/src/core/events.py index 554bb08..fecbb6f 100644 --- a/src/core/events.py +++ b/src/core/events.py @@ -164,15 +164,6 @@ def clear_handlers(self, event_type: Type[Event] = None) -> None: # Predefined events for service communication -@dataclass -class SessionCreated(Event): - """Emitted when a new session is created.""" - - session_id: str - entity_id: str | None = None - user_id: str | None = None - - @dataclass class SessionDeleted(Event): """Emitted when a session is deleted or expired.""" @@ -180,15 +171,6 @@ class SessionDeleted(Event): session_id: str -@dataclass -class ExecutionStarted(Event): - """Emitted when code execution starts.""" - - execution_id: str - session_id: str - language: str - - @dataclass class ExecutionCompleted(Event): """Emitted when code execution completes.""" @@ -199,40 +181,6 @@ class ExecutionCompleted(Event): execution_time_ms: int | None = None -@dataclass -class FileUploaded(Event): - """Emitted when a file is uploaded.""" - - file_id: str - session_id: str - filename: str - - -@dataclass -class FileDeleted(Event): - """Emitted when a file is deleted.""" - - file_id: str - session_id: str - - -@dataclass -class ContainerCreated(Event): - """Emitted when a container is created.""" - - container_id: str - session_id: str - language: str - - -@dataclass -class ContainerDestroyed(Event): - """Emitted when a container is destroyed.""" - - container_id: str - session_id: str - - # Container Pool Events @dataclass class ContainerAcquiredFromPool(Event): diff --git a/src/dependencies/auth.py b/src/dependencies/auth.py index 657e74d..ee69f57 100644 --- a/src/dependencies/auth.py +++ b/src/dependencies/auth.py @@ -10,6 +10,7 @@ # Local application imports from ..services.auth import get_auth_service +from ..utils.request_helpers import extract_api_key logger = structlog.get_logger(__name__) @@ -28,15 +29,8 @@ async def verify_api_key( if hasattr(request.state, "authenticated") and request.state.authenticated: return getattr(request.state, "api_key", "") - # Extract API key from various sources - api_key = None - - # Check x-api-key header (preferred method) - api_key = request.headers.get("x-api-key") - - # Check Authorization header as fallback - if not api_key and credentials: - api_key = credentials.credentials + # Extract API key using shared utility + api_key = extract_api_key(request) if not api_key: logger.warning("No API key provided in request") @@ -68,5 +62,3 @@ async def verify_api_key_optional( if "required" in e.detail: return None # No API key provided, which is OK for optional endpoints raise # Invalid API key provided, which is not OK - - diff --git a/src/middleware/security.py b/src/middleware/security.py index 6384481..ee43a7c 100644 --- a/src/middleware/security.py +++ b/src/middleware/security.py @@ -12,6 +12,7 @@ # Local application imports from ..config import settings from ..services.auth import get_auth_service +from ..utils.request_helpers import extract_api_key, get_client_ip logger = structlog.get_logger(__name__) @@ -151,14 +152,14 @@ def _should_skip_auth(self, request: Request) -> bool: async def _authenticate_request(self, request: Request, scope: dict): """Handle API key authentication with rate limiting.""" - # Extract API key - api_key = self._extract_api_key(request) + # Extract API key using shared utility + api_key = extract_api_key(request) # Get authentication service auth_service = await get_auth_service() # Check IP-based rate limiting for auth failures - client_ip = self._get_client_ip(request) + client_ip = get_client_ip(request) if not await auth_service.check_rate_limit(client_ip): raise HTTPException( status_code=429, @@ -220,36 +221,6 @@ async def _authenticate_request(self, request: Request, scope: dict): result.key_hash, is_env_key=result.is_env_key ) - def _extract_api_key(self, request: Request) -> Optional[str]: - """Extract API key from request headers.""" - # Check x-api-key header first - api_key = request.headers.get("x-api-key") - if api_key: - return api_key - - # Check Authorization header - auth_header = request.headers.get("authorization") - if auth_header: - if auth_header.startswith("Bearer "): - return auth_header[7:] - elif auth_header.startswith("ApiKey "): - return auth_header[7:] - - return None - - def _get_client_ip(self, request: Request) -> str: - """Get client IP address.""" - # Check forwarded headers - forwarded_for = request.headers.get("x-forwarded-for") - if forwarded_for: - return forwarded_for.split(",")[0].strip() - - real_ip = request.headers.get("x-real-ip") - if real_ip: - return real_ip - - return request.client.host if request.client else "unknown" - class RequestLoggingMiddleware: """Simplified request logging middleware.""" diff --git a/src/services/container/executor.py b/src/services/container/executor.py index d755ecc..47024d2 100644 --- a/src/services/container/executor.py +++ b/src/services/container/executor.py @@ -75,7 +75,9 @@ async def execute_command( exec_config["workdir"] = working_dir try: - return self._execute_via_socket(container, exec_config, stdin_payload, timeout) + return self._execute_via_socket( + container, exec_config, stdin_payload, timeout + ) except DockerException as e: error_text = str(e) diff --git a/src/services/container/network.py b/src/services/container/network.py index 51147b1..4f4e265 100644 --- a/src/services/container/network.py +++ b/src/services/container/network.py @@ -321,51 +321,3 @@ async def cleanup(self) -> None: pass # Ignore cleanup errors logger.info("Cleaned up WAN network iptables rules") - - def get_network_id(self) -> Optional[str]: - """Get the WAN network ID for container attachment. - - Returns: - Network ID string or None if not initialized - """ - if self._network: - return self._network.id - return None - - def is_ready(self) -> bool: - """Check if WAN network is ready for use. - - Returns: - True if network is initialized and ready - """ - return self._initialized and self._network is not None - - async def remove_network(self) -> bool: - """Remove the WAN network entirely. - - This is typically only called during testing or explicit cleanup. - - Returns: - True if network was removed successfully - """ - if not self._network: - return True - - try: - # First cleanup iptables - await self.cleanup() - - # Then remove the network - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, self._network.remove) - - logger.info("Removed WAN network", network_name=self.network_name) - self._network = None - self._initialized = False - return True - except NotFound: - # Already removed - return True - except Exception as e: - logger.error("Failed to remove WAN network", error=str(e)) - return False diff --git a/src/services/execution/__init__.py b/src/services/execution/__init__.py index 8d2a561..11a811f 100644 --- a/src/services/execution/__init__.py +++ b/src/services/execution/__init__.py @@ -45,7 +45,7 @@ async def execute_code( def _normalize_container_filename(self, filename): """Backward compatibility alias.""" - return OutputProcessor.normalize_filename(filename) + return OutputProcessor.sanitize_filename(filename) def _sanitize_execution_output(self, output): """Backward compatibility alias.""" diff --git a/src/services/execution/output.py b/src/services/execution/output.py index 8bcd108..a99f17b 100644 --- a/src/services/execution/output.py +++ b/src/services/execution/output.py @@ -32,9 +32,6 @@ class OutputProcessor: ".zip": "application/zip", } - # Dangerous extensions that should be blocked - DANGEROUS_EXTENSIONS = [".exe", ".bat", ".cmd", ".sh", ".ps1", ".scr", ".com"] - @classmethod def sanitize_output(cls, output: str, max_size: int = 64 * 1024) -> str: """Sanitize execution output for security and display. @@ -96,10 +93,10 @@ def validate_generated_file(cls, file_info: Dict[str, Any]) -> bool: logger.warning(f"Generated file {file_path} has suspicious path") return False - # Check for dangerous file extensions - file_extension = Path(file_path).suffix.lower() - if file_extension in cls.DANGEROUS_EXTENSIONS: - logger.warning(f"Generated file {file_path} has dangerous extension") + # Check file using centralized settings validation + filename = Path(file_path).name + if not settings.is_file_allowed(filename): + logger.warning(f"Generated file {file_path} has blocked extension") return False return True @@ -266,12 +263,3 @@ def sanitize_filename(cls, input_name: str) -> str: except Exception as e: logger.error(f"Failed to sanitize filename: {e}") return "_" - - @classmethod - def normalize_filename(cls, filename: str) -> str: - """Deprecated: Use sanitize_filename instead. - - This method is kept for backward compatibility but delegates to - sanitize_filename which matches LibreChat's sanitization logic. - """ - return cls.sanitize_filename(filename) diff --git a/src/services/execution/runner.py b/src/services/execution/runner.py index e98f8f4..efb1544 100644 --- a/src/services/execution/runner.py +++ b/src/services/execution/runner.py @@ -337,7 +337,7 @@ def _get_mounted_filenames(self, files: Optional[List[Dict[str, Any]]]) -> set: name = f.get("filename") or f.get("name") if name: mounted.add(name) - mounted.add(OutputProcessor.normalize_filename(name)) + mounted.add(OutputProcessor.sanitize_filename(name)) except Exception: pass return mounted @@ -578,7 +578,7 @@ async def _mount_files_to_container( if file_content is not None: # Direct memory-to-container transfer (no tempfiles) - normalized_filename = OutputProcessor.normalize_filename( + normalized_filename = OutputProcessor.sanitize_filename( filename ) dest_path = f"/mnt/data/{normalized_filename}" @@ -612,7 +612,7 @@ async def _create_placeholder_file( ) -> None: """Create a placeholder file when content cannot be retrieved.""" try: - normalized_filename = OutputProcessor.normalize_filename(filename) + normalized_filename = OutputProcessor.sanitize_filename(filename) create_command = f"""cat > /mnt/data/{normalized_filename} << 'EOF' # File: {filename} # This is a placeholder - original file could not be retrieved @@ -665,38 +665,6 @@ async def _detect_generated_files( logger.error(f"Failed to detect generated files: {e}") return [] - def get_container_by_session(self, session_id: str) -> Optional[Container]: - """Get container for a session. - - DEPRECATED: Container is now returned directly from execute() method. - This method is kept for backward compatibility only. - """ - # First check the pool if available - if self.container_pool and settings.container_pool_enabled: - try: - # Use synchronous wrapper since this may be called from sync context - import asyncio - - loop = asyncio.get_event_loop() - if loop.is_running(): - # We're in an async context, use the pool's method directly - # The pool stores containers in _session_containers - if session_id in self.container_pool._session_containers: - sc = self.container_pool._session_containers[session_id] - try: - container = self.container_pool._container_manager.client.containers.get( - sc.container_id - ) - if container.status == "running": - return container - except Exception: - pass - except Exception as e: - logger.debug("Error getting container from pool", error=str(e)) - - # Fall back to runner's local container dict - return self.session_containers.get(session_id) - async def get_execution(self, execution_id: str) -> Optional[CodeExecution]: """Retrieve an execution by ID.""" return self.active_executions.get(execution_id) diff --git a/src/utils/error_handlers.py b/src/utils/error_handlers.py index 7902f1b..1ba3867 100644 --- a/src/utils/error_handlers.py +++ b/src/utils/error_handlers.py @@ -17,9 +17,6 @@ ErrorResponse, ErrorType, ErrorDetail, - ValidationError, - AuthenticationError, - ServiceUnavailableError, ) logger = structlog.get_logger(__name__) @@ -190,69 +187,3 @@ async def general_exception_handler(request: Request, exc: Exception) -> JSONRes ) return JSONResponse(status_code=500, content=error_response.model_dump()) - - -# Utility functions for common error scenarios - - -def create_validation_error( - field: str, message: str, code: str = None -) -> ValidationError: - """Create a validation error with details.""" - details = [ErrorDetail(field=field, message=message, code=code)] - return ValidationError( - message=f"Validation failed for field '{field}'", details=details - ) - - -def create_resource_error( - resource_type: str, resource_id: str = None, operation: str = "access" -): - """Create a resource not found error.""" - from ..models.errors import ResourceNotFoundError - - return ResourceNotFoundError(resource=resource_type, resource_id=resource_id) - - -def create_service_error(service_name: str, original_error: Exception = None): - """Create a service unavailable error.""" - message = f"{service_name} service is currently unavailable" - if original_error: - message += f": {str(original_error)}" - - return ServiceUnavailableError(service=service_name, message=message) - - -def handle_docker_error(error: Exception, operation: str = "container operation"): - """Convert Docker errors to appropriate CodeInterpreter exceptions.""" - from docker.errors import DockerException, APIError, ContainerError, ImageNotFound - from ..models.errors import ( - ExecutionError, - ResourceNotFoundError, - ServiceUnavailableError, - ) - - if isinstance(error, ImageNotFound): - return ResourceNotFoundError(resource="Docker image", resource_id=str(error)) - elif isinstance(error, ContainerError): - return ExecutionError(message=f"Container execution failed: {str(error)}") - elif isinstance(error, APIError): - if error.status_code == 409: - from ..models.errors import ResourceConflictError - - return ResourceConflictError( - message=f"Docker API conflict: {error.explanation}" - ) - else: - return ServiceUnavailableError( - service="Docker", message=f"Docker API error: {error.explanation}" - ) - elif isinstance(error, DockerException): - return ServiceUnavailableError( - service="Docker", message=f"Docker service error: {str(error)}" - ) - else: - return ServiceUnavailableError( - service="Docker", - message=f"Unknown Docker error during {operation}: {str(error)}", - ) diff --git a/src/utils/request_helpers.py b/src/utils/request_helpers.py new file mode 100644 index 0000000..98c2c6b --- /dev/null +++ b/src/utils/request_helpers.py @@ -0,0 +1,70 @@ +"""Shared request helper utilities. + +These utilities consolidate common request handling patterns used across +the middleware and dependencies layers. +""" + +from typing import Optional +from fastapi import Request + + +def extract_api_key(request: Request) -> Optional[str]: + """Extract API key from request headers. + + Checks in order: + 1. x-api-key header (preferred) + 2. Authorization header with Bearer token + 3. Authorization header with ApiKey token + + Args: + request: FastAPI Request object + + Returns: + API key string or None if not found + """ + # Check x-api-key header first (preferred method) + api_key = request.headers.get("x-api-key") + if api_key: + return api_key + + # Check Authorization header as fallback + auth_header = request.headers.get("authorization") + if auth_header: + if auth_header.startswith("Bearer "): + return auth_header[7:] + elif auth_header.startswith("ApiKey "): + return auth_header[7:] + + return None + + +def get_client_ip(request: Request) -> str: + """Get client IP address from request. + + Checks in order: + 1. X-Forwarded-For header (first IP in list) + 2. X-Real-IP header + 3. Direct client host + + Args: + request: FastAPI Request object + + Returns: + Client IP address string, or "unknown" if not determinable + """ + # Check X-Forwarded-For header (common in reverse proxy setups) + forwarded_for = request.headers.get("x-forwarded-for") + if forwarded_for: + # Take the first IP in the chain (client IP) + return forwarded_for.split(",")[0].strip() + + # Check X-Real-IP header + real_ip = request.headers.get("x-real-ip") + if real_ip: + return real_ip + + # Fall back to direct client connection + if request.client: + return request.client.host + + return "unknown" diff --git a/src/utils/security.py b/src/utils/security.py index 06051cd..5ba2ba3 100644 --- a/src/utils/security.py +++ b/src/utils/security.py @@ -7,6 +7,8 @@ from datetime import datetime, timedelta import structlog +from ..config import settings + logger = structlog.get_logger(__name__) @@ -30,56 +32,18 @@ class SecurityValidator: r"raw_input\s*\(", ] - # File extensions that are allowed for upload - ALLOWED_FILE_EXTENSIONS = { - ".txt", - ".csv", - ".json", - ".xml", - ".yaml", - ".yml", - ".py", - ".js", - ".ts", - ".go", - ".java", - ".c", - ".cpp", - ".h", - ".hpp", - ".rs", - ".php", - ".rb", - ".r", - ".f90", - ".d", - ".md", - ".rst", - ".html", - ".css", - ".png", - ".jpg", - ".jpeg", - ".gif", - ".svg", - ".pdf", - ".doc", - ".docx", - ".xls", - ".xlsx", - } - - # Maximum filename length - MAX_FILENAME_LENGTH = 255 - @classmethod def validate_filename(cls, filename: str) -> bool: - """Validate uploaded filename for security.""" + """Validate uploaded filename for security. + + Uses settings.is_file_allowed() for extension checking and + settings.max_filename_length for length validation. + """ if not filename: return False - # Check length - if len(filename) > cls.MAX_FILENAME_LENGTH: + # Check length using settings + if len(filename) > settings.max_filename_length: logger.warning("Filename too long", filename=filename, length=len(filename)) return False @@ -93,12 +57,9 @@ def validate_filename(cls, filename: str) -> bool: logger.warning("Null byte in filename", filename=filename) return False - # Check file extension - file_ext = cls._get_file_extension(filename) - if file_ext not in cls.ALLOWED_FILE_EXTENSIONS: - logger.warning( - "Disallowed file extension", filename=filename, extension=file_ext - ) + # Check file extension using settings (consolidated validation) + if not settings.is_file_allowed(filename): + logger.warning("Disallowed file", filename=filename) return False # Check for suspicious characters @@ -172,13 +133,6 @@ def hash_sensitive_data(cls, data: str) -> str: """Hash sensitive data for logging/storage.""" return hashlib.sha256(data.encode()).hexdigest()[:16] - @classmethod - def _get_file_extension(cls, filename: str) -> str: - """Get file extension in lowercase.""" - if "." not in filename: - return "" - return "." + filename.split(".")[-1].lower() - class RateLimiter: """Simple in-memory rate limiter for additional protection.""" diff --git a/tests/__init__.py b/tests/__init__.py index 739954c..d4839a6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -# Tests package \ No newline at end of file +# Tests package diff --git a/tests/conftest.py b/tests/conftest.py index 16a9be5..d62051b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,7 +41,7 @@ def event_loop(): def mock_redis(): """Mock Redis client for testing.""" mock_client = AsyncMock(spec=redis.Redis) - + # Mock common Redis operations mock_client.hset = AsyncMock(return_value=1) mock_client.hgetall = AsyncMock(return_value={}) @@ -57,7 +57,7 @@ def mock_redis(): mock_client.ping = AsyncMock(return_value=True) mock_client.close = AsyncMock() mock_client.scan_iter = AsyncMock(return_value=iter([])) - + return mock_client @@ -65,7 +65,7 @@ def mock_redis(): def mock_minio(): """Mock MinIO client for testing.""" mock_client = MagicMock(spec=Minio) - + # Mock common MinIO operations mock_client.bucket_exists.return_value = True mock_client.make_bucket.return_value = None @@ -75,7 +75,7 @@ def mock_minio(): mock_client.put_object.return_value = None mock_client.get_object.return_value = MagicMock() mock_client.remove_object.return_value = None - + return mock_client @@ -84,18 +84,18 @@ def mock_docker(): """Mock Docker client for testing.""" mock_client = MagicMock(spec=DockerClient) mock_container = MagicMock() - + # Mock container operations mock_container.id = "test_container_id" mock_container.status = "running" mock_container.reload.return_value = None mock_container.exec_run.return_value = MagicMock(exit_code=0, output=b"test output") - + mock_client.containers.create.return_value = mock_container mock_client.containers.get.return_value = mock_container mock_client.images.pull.return_value = None mock_client.images.get.return_value = MagicMock() - + return mock_client @@ -110,21 +110,25 @@ async def session_service(mock_redis): @pytest.fixture def execution_service(): """Create CodeExecutionService instance with mocked dependencies.""" - with patch('src.services.execution.runner.ContainerManager') as mock_container_manager: + with patch( + "src.services.execution.runner.ContainerManager" + ) as mock_container_manager: mock_manager = MagicMock() mock_container_manager.return_value = mock_manager - + # Mock container manager methods mock_manager.get_image_for_language.return_value = "python:3.11" mock_manager.pull_image_if_needed = AsyncMock() mock_manager.create_container.return_value = MagicMock(id="test_container") mock_manager.start_container = AsyncMock() mock_manager.execute_command = AsyncMock(return_value=(0, "output", "")) - mock_manager.get_container_stats = AsyncMock(return_value={"memory_usage_mb": 50}) + mock_manager.get_container_stats = AsyncMock( + return_value={"memory_usage_mb": 50} + ) mock_manager.stop_container = AsyncMock() mock_manager.remove_container = AsyncMock() mock_manager.close.return_value = None - + service = CodeExecutionService() yield service @@ -132,8 +136,9 @@ def execution_service(): @pytest.fixture def file_service(mock_minio, mock_redis): """Create FileService instance with mocked dependencies.""" - with patch('src.services.file.Minio', return_value=mock_minio), \ - patch('src.services.file.redis.Redis', return_value=mock_redis): + with patch("src.services.file.Minio", return_value=mock_minio), patch( + "src.services.file.redis.Redis", return_value=mock_redis + ): service = FileService() yield service @@ -154,22 +159,20 @@ def sample_session(): created_at=datetime.now(timezone.utc), last_activity=datetime.now(timezone.utc), expires_at=datetime.now(timezone.utc), - metadata={"entity_id": "test-entity"} + metadata={"entity_id": "test-entity"}, ) @pytest.fixture def sample_session_create(): """Create a sample session creation request.""" - return SessionCreate( - metadata={"entity_id": "test-entity", "user_id": "test-user"} - ) + return SessionCreate(metadata={"entity_id": "test-entity", "user_id": "test-user"}) @pytest.fixture def mock_settings(): """Mock settings for testing.""" - with patch('src.config.settings') as mock_settings: + with patch("src.config.settings") as mock_settings: mock_settings.redis_host = "localhost" mock_settings.redis_port = 6379 mock_settings.redis_password = None @@ -188,11 +191,15 @@ def mock_settings(): mock_settings.max_execution_time = 30 mock_settings.max_file_size_mb = 10 mock_settings.max_output_files = 10 - + # Add helper methods for backward compatibility - mock_settings.get_session_ttl_minutes = lambda: mock_settings.session_ttl_hours * 60 - mock_settings.get_container_ttl_minutes = lambda: mock_settings.container_ttl_minutes - + mock_settings.get_session_ttl_minutes = ( + lambda: mock_settings.session_ttl_hours * 60 + ) + mock_settings.get_container_ttl_minutes = ( + lambda: mock_settings.container_ttl_minutes + ) + yield mock_settings @@ -202,7 +209,7 @@ def mock_container_stats(): return { "memory_usage_mb": 128.5, "cpu_usage_percent": 15.2, - "network_io": {"rx_bytes": 1024, "tx_bytes": 512} + "network_io": {"rx_bytes": 1024, "tx_bytes": 512}, } @@ -214,7 +221,7 @@ def mock_execution_result(): "stdout": "Hello, World!", "stderr": "", "execution_time_ms": 150, - "memory_peak_mb": 64.2 + "memory_peak_mb": 64.2, } @@ -230,8 +237,9 @@ async def async_session_service(mock_redis): @pytest_asyncio.fixture async def async_file_service(mock_minio, mock_redis): """Async fixture for FileService.""" - with patch('src.services.file.Minio', return_value=mock_minio), \ - patch('src.services.file.redis.Redis', return_value=mock_redis): + with patch("src.services.file.Minio", return_value=mock_minio), patch( + "src.services.file.redis.Redis", return_value=mock_redis + ): service = FileService() yield service await service.close() @@ -241,4 +249,32 @@ async def async_file_service(mock_minio, mock_redis): async def async_auth_service(mock_redis): """Async fixture for AuthenticationService.""" service = AuthenticationService(redis_client=mock_redis) - yield service \ No newline at end of file + yield service + + +# ============================================================================ +# Integration Test Fixtures +# ============================================================================ + + +@pytest.fixture +def client(): + """Create FastAPI test client for integration tests.""" + from fastapi.testclient import TestClient + from src.main import app + + return TestClient(app) + + +@pytest.fixture +def auth_headers(): + """Provide authentication headers for integration tests.""" + return {"x-api-key": "test-api-key-for-testing-12345"} + + +@pytest.fixture +def unique_session_id(): + """Generate unique session ID for test isolation.""" + import uuid + + return f"test-session-{uuid.uuid4().hex[:8]}" diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index e27cd7a..a265048 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -1 +1 @@ -# Integration tests package \ No newline at end of file +# Integration tests package diff --git a/tests/integration/test_api_contracts.py b/tests/integration/test_api_contracts.py index 83da3bb..0154a6e 100644 --- a/tests/integration/test_api_contracts.py +++ b/tests/integration/test_api_contracts.py @@ -19,7 +19,20 @@ # All 12 supported languages -SUPPORTED_LANGUAGES = ["py", "js", "ts", "go", "java", "c", "cpp", "php", "rs", "r", "f90", "d"] +SUPPORTED_LANGUAGES = [ + "py", + "js", + "ts", + "go", + "java", + "c", + "cpp", + "php", + "rs", + "r", + "f90", + "d", +] @pytest.fixture @@ -43,7 +56,7 @@ def mock_session(): created_at=datetime.now(timezone.utc), last_activity=datetime.now(timezone.utc), expires_at=datetime.now(timezone.utc) + timedelta(hours=24), - metadata={"entity_id": "test-entity"} + metadata={"entity_id": "test-entity"}, ) @@ -57,21 +70,27 @@ def mock_session_service(mock_session): return service -def create_mock_execution(language: str, stdout: str = "output", stderr: str = "") -> CodeExecution: +def create_mock_execution( + language: str, stdout: str = "output", stderr: str = "" +) -> CodeExecution: """Helper to create mock execution for any language.""" outputs = [] if stdout: - outputs.append(ExecutionOutput( - type=OutputType.STDOUT, - content=stdout, - timestamp=datetime.now(timezone.utc) - )) + outputs.append( + ExecutionOutput( + type=OutputType.STDOUT, + content=stdout, + timestamp=datetime.now(timezone.utc), + ) + ) if stderr: - outputs.append(ExecutionOutput( - type=OutputType.STDERR, - content=stderr, - timestamp=datetime.now(timezone.utc) - )) + outputs.append( + ExecutionOutput( + type=OutputType.STDERR, + content=stderr, + timestamp=datetime.now(timezone.utc), + ) + ) return CodeExecution( execution_id=f"exec-{language}-123", @@ -81,7 +100,7 @@ def create_mock_execution(language: str, stdout: str = "output", stderr: str = " status=ExecutionStatus.COMPLETED, exit_code=0, execution_time_ms=100, - outputs=outputs + outputs=outputs, ) @@ -90,7 +109,13 @@ def mock_execution_service(): """Mock execution service.""" service = AsyncMock() # Return tuple: (execution, container, new_state, state_errors, container_source) - service.execute_code.return_value = (create_mock_execution("py", "Hello, World!"), None, None, [], "pool_hit") + service.execute_code.return_value = ( + create_mock_execution("py", "Hello, World!"), + None, + None, + [], + "pool_hit", + ) return service @@ -106,7 +131,7 @@ def mock_file_service(): size=1024, content_type="text/plain", created_at=datetime.utcnow(), - path="/test.txt" + path="/test.txt", ) service.download_file.return_value = "https://minio.example.com/download-url" service.delete_file.return_value = True @@ -116,7 +141,11 @@ def mock_file_service(): @pytest.fixture(autouse=True) def mock_dependencies(mock_session_service, mock_execution_service, mock_file_service): """Mock all dependencies for testing.""" - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service @@ -131,15 +160,13 @@ def mock_dependencies(mock_session_service, mock_execution_service, mock_file_se # EXEC ENDPOINT - REQUEST FORMAT # ============================================================================= + class TestExecRequestFormat: """Test /exec request format validation.""" def test_exec_minimal_request(self, client, auth_headers): """Test minimal valid request with just code and lang.""" - request_data = { - "code": "print('hello')", - "lang": "py" - } + request_data = {"code": "print('hello')", "lang": "py"} response = client.post("/exec", json=request_data, headers=auth_headers) assert response.status_code == 200 @@ -152,24 +179,22 @@ def test_exec_full_request(self, client, auth_headers, mock_session_service): "args": "arg1 arg2", "user_id": "user-123", "entity_id": "entity-456", - "files": [] + "files": [], } response = client.post("/exec", json=request_data, headers=auth_headers) assert response.status_code == 200 - def test_exec_with_file_references(self, client, auth_headers, mock_execution_service): + def test_exec_with_file_references( + self, client, auth_headers, mock_execution_service + ): """Test request with file references.""" request_data = { "code": "with open('data.txt') as f: print(f.read())", "lang": "py", "files": [ - { - "id": "file-123", - "session_id": "session-456", - "name": "data.txt" - } - ] + {"id": "file-123", "session_id": "session-456", "name": "data.txt"} + ], } response = client.post("/exec", json=request_data, headers=auth_headers) @@ -178,21 +203,27 @@ def test_exec_with_file_references(self, client, auth_headers, mock_execution_se def test_exec_args_accepts_any_json(self, client, auth_headers): """Test that args field accepts any JSON type (string, object, array).""" # String args - response = client.post("/exec", json={ - "code": "print('test')", "lang": "py", "args": "string args" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('test')", "lang": "py", "args": "string args"}, + headers=auth_headers, + ) assert response.status_code == 200 # Object args - response = client.post("/exec", json={ - "code": "print('test')", "lang": "py", "args": {"key": "value"} - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('test')", "lang": "py", "args": {"key": "value"}}, + headers=auth_headers, + ) assert response.status_code == 200 # Array args - response = client.post("/exec", json={ - "code": "print('test')", "lang": "py", "args": ["arg1", "arg2"] - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('test')", "lang": "py", "args": ["arg1", "arg2"]}, + headers=auth_headers, + ) assert response.status_code == 200 def test_exec_missing_code_rejected(self, client, auth_headers): @@ -202,12 +233,16 @@ def test_exec_missing_code_rejected(self, client, auth_headers): def test_exec_missing_lang_rejected(self, client, auth_headers): """Test that missing lang field is rejected.""" - response = client.post("/exec", json={"code": "print('test')"}, headers=auth_headers) + response = client.post( + "/exec", json={"code": "print('test')"}, headers=auth_headers + ) assert response.status_code == 422 def test_exec_empty_code_rejected(self, client, auth_headers): """Test that empty code is rejected.""" - response = client.post("/exec", json={"code": "", "lang": "py"}, headers=auth_headers) + response = client.post( + "/exec", json={"code": "", "lang": "py"}, headers=auth_headers + ) # API returns 400 for empty code (application-level validation) assert response.status_code == 400 @@ -217,9 +252,9 @@ class TestExecResponseFormat: def test_response_has_required_fields(self, client, auth_headers): """Test that response has all required LibreChat fields.""" - response = client.post("/exec", json={ - "code": "print('test')", "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", json={"code": "print('test')", "lang": "py"}, headers=auth_headers + ) assert response.status_code == 200 data = response.json() @@ -236,22 +271,31 @@ def test_response_has_required_fields(self, client, auth_headers): assert isinstance(data["stdout"], str) assert isinstance(data["stderr"], str) - def test_response_stdout_ends_with_newline(self, client, auth_headers, mock_execution_service): + def test_response_stdout_ends_with_newline( + self, client, auth_headers, mock_execution_service + ): """Test that stdout ends with newline for LibreChat compatibility.""" mock_execution_service.execute_code.return_value = ( - create_mock_execution("py", "Hello, World!"), # No trailing newline in mock - None, None, [], "pool_hit" + create_mock_execution("py", "Hello, World!"), # No trailing newline in mock + None, + None, + [], + "pool_hit", ) - response = client.post("/exec", json={ - "code": "print('Hello, World!')", "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('Hello, World!')", "lang": "py"}, + headers=auth_headers, + ) data = response.json() # LibreChat expects stdout to end with newline assert data["stdout"].endswith("\n") - def test_response_files_format(self, client, auth_headers, mock_execution_service, mock_file_service): + def test_response_files_format( + self, client, auth_headers, mock_execution_service, mock_file_service + ): """Test that generated files have correct format.""" # Mock execution with file output execution_with_file = CodeExecution( @@ -267,11 +311,17 @@ def test_response_files_format(self, client, auth_headers, mock_execution_servic content="/workspace/output.txt", mime_type="text/plain", size=100, - timestamp=datetime.now(timezone.utc) + timestamp=datetime.now(timezone.utc), ) - ] + ], + ) + mock_execution_service.execute_code.return_value = ( + execution_with_file, + None, + None, + [], + "pool_hit", ) - mock_execution_service.execute_code.return_value = (execution_with_file, None, None, [], "pool_hit") # Mock store_execution_output_file to return a file_id string mock_file_service.store_execution_output_file.return_value = "gen-file-123" @@ -284,13 +334,13 @@ def test_response_files_format(self, client, auth_headers, mock_execution_servic size=100, content_type="text/plain", created_at=datetime.utcnow(), - path="/output.txt" + path="/output.txt", ) ] - response = client.post("/exec", json={ - "code": "write file", "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", json={"code": "write file", "lang": "py"}, headers=auth_headers + ) data = response.json() assert len(data["files"]) >= 1 @@ -306,6 +356,7 @@ def test_response_files_format(self, client, auth_headers, mock_execution_servic # FILE ENDPOINTS # ============================================================================= + class TestFileUploadContract: """Test file upload endpoint contract.""" @@ -334,7 +385,7 @@ def test_upload_multiple_files(self, client, auth_headers, mock_file_service): files = [ ("files", ("test1.txt", io.BytesIO(b"content 1"), "text/plain")), - ("files", ("test2.txt", io.BytesIO(b"content 2"), "text/plain")) + ("files", ("test2.txt", io.BytesIO(b"content 2"), "text/plain")), ] response = client.post("/upload", files=files, headers=auth_headers) @@ -397,14 +448,15 @@ def test_list_files_full_detail(self, client, auth_headers, mock_file_service): class TestFileDownloadContract: """Test file download endpoint contract.""" - def test_download_returns_streaming_response(self, client, auth_headers, mock_file_service): + def test_download_returns_streaming_response( + self, client, auth_headers, mock_file_service + ): """Test that download returns streaming response with file content.""" # Mock the file service to return file content mock_file_service.get_file_content.return_value = b"test file content" response = client.get( - "/download/test-session/test-file-id-123", - headers=auth_headers + "/download/test-session/test-file-id-123", headers=auth_headers ) # API returns streaming response (200), not redirect @@ -416,8 +468,7 @@ def test_download_not_found(self, client, auth_headers, mock_file_service): mock_file_service.get_file_info.return_value = None response = client.get( - "/download/test-session/nonexistent", - headers=auth_headers + "/download/test-session/nonexistent", headers=auth_headers ) assert response.status_code == 404 @@ -429,8 +480,7 @@ class TestFileDeleteContract: def test_delete_success(self, client, auth_headers, mock_file_service): """Test successful file deletion.""" response = client.delete( - "/files/test-session/test-file-id-123", - headers=auth_headers + "/files/test-session/test-file-id-123", headers=auth_headers ) # API returns 200 with empty body for LibreChat compatibility @@ -441,8 +491,7 @@ def test_delete_not_found(self, client, auth_headers, mock_file_service): mock_file_service.get_file_info.return_value = None response = client.delete( - "/files/test-session/nonexistent", - headers=auth_headers + "/files/test-session/nonexistent", headers=auth_headers ) assert response.status_code == 404 @@ -452,6 +501,7 @@ def test_delete_not_found(self, client, auth_headers, mock_file_service): # HEALTH ENDPOINTS # ============================================================================= + class TestHealthContract: """Test health endpoint contracts.""" @@ -480,6 +530,7 @@ def test_health_services(self, client, auth_headers): # ERROR RESPONSE FORMAT # ============================================================================= + class TestErrorResponseFormat: """Test error response format consistency.""" @@ -508,8 +559,7 @@ def test_not_found_error_format(self, client, auth_headers, mock_file_service): mock_file_service.get_file_info.return_value = None response = client.get( - "/download/test-session/nonexistent", - headers=auth_headers + "/download/test-session/nonexistent", headers=auth_headers ) assert response.status_code == 404 @@ -523,49 +573,48 @@ def test_not_found_error_format(self, client, auth_headers, mock_file_service): # AUTHENTICATION METHODS # ============================================================================= + class TestAuthenticationMethods: """Test all authentication methods work.""" def test_x_api_key_header(self, client): """Test x-api-key header authentication.""" headers = {"x-api-key": "test-api-key-for-testing-12345"} - response = client.post("/exec", json={ - "code": "print('test')", "lang": "py" - }, headers=headers) + response = client.post( + "/exec", json={"code": "print('test')", "lang": "py"}, headers=headers + ) assert response.status_code != 401 def test_authorization_bearer(self, client): """Test Authorization Bearer authentication.""" headers = {"Authorization": "Bearer test-api-key-for-testing-12345"} - response = client.post("/exec", json={ - "code": "print('test')", "lang": "py" - }, headers=headers) + response = client.post( + "/exec", json={"code": "print('test')", "lang": "py"}, headers=headers + ) assert response.status_code != 401 def test_authorization_apikey(self, client): """Test Authorization ApiKey authentication.""" headers = {"Authorization": "ApiKey test-api-key-for-testing-12345"} - response = client.post("/exec", json={ - "code": "print('test')", "lang": "py" - }, headers=headers) + response = client.post( + "/exec", json={"code": "print('test')", "lang": "py"}, headers=headers + ) assert response.status_code != 401 def test_no_auth_rejected(self, client): """Test requests without auth are rejected.""" - response = client.post("/exec", json={ - "code": "print('test')", "lang": "py" - }) + response = client.post("/exec", json={"code": "print('test')", "lang": "py"}) assert response.status_code == 401 def test_invalid_auth_rejected(self, client): """Test requests with invalid auth are rejected.""" headers = {"x-api-key": "invalid-key"} - response = client.post("/exec", json={ - "code": "print('test')", "lang": "py" - }, headers=headers) + response = client.post( + "/exec", json={"code": "print('test')", "lang": "py"}, headers=headers + ) assert response.status_code == 401 diff --git a/tests/integration/test_auth_integration.py b/tests/integration/test_auth_integration.py index 4c7e7f5..f445c84 100644 --- a/tests/integration/test_auth_integration.py +++ b/tests/integration/test_auth_integration.py @@ -17,93 +17,103 @@ def client(): @pytest.fixture def mock_services(): """Mock all services for testing.""" - from src.dependencies.services import get_session_service, get_execution_service, get_file_service - + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + mock_session_service = AsyncMock() mock_execution_service = AsyncMock() mock_file_service = AsyncMock() - + # Override the dependencies in the FastAPI app app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service - + yield { "session": mock_session_service, "execution": mock_execution_service, - "file": mock_file_service + "file": mock_file_service, } - + # Clean up after test app.dependency_overrides.clear() class TestAPIKeyAuthentication: """Test API key authentication workflows.""" - + def test_valid_api_key_x_api_key_header(self, client, mock_services): """Test authentication with valid API key in x-api-key header.""" headers = {"x-api-key": "test-api-key-for-testing-12345"} - - with patch('src.services.auth.settings') as mock_settings: + + with patch("src.services.auth.settings") as mock_settings: mock_settings.api_key = "test-api-key-for-testing-12345" - + response = client.get("/sessions", headers=headers) - + # Should not fail with authentication error assert response.status_code != 401 - + def test_valid_api_key_authorization_bearer(self, client, mock_services): """Test authentication with valid API key in Authorization Bearer header.""" headers = {"Authorization": "Bearer test-api-key-for-testing-12345"} - - with patch('src.services.auth.settings') as mock_settings: + + with patch("src.services.auth.settings") as mock_settings: mock_settings.api_key = "test-api-key-for-testing-12345" - + response = client.get("/sessions", headers=headers) - + # Should not fail with authentication error assert response.status_code != 401 - + def test_valid_api_key_authorization_apikey(self, client, mock_services): """Test authentication with valid API key in Authorization ApiKey header.""" headers = {"Authorization": "ApiKey test-api-key-for-testing-12345"} - - with patch('src.services.auth.settings') as mock_settings: + + with patch("src.services.auth.settings") as mock_settings: mock_settings.api_key = "test-api-key-for-testing-12345" - + response = client.get("/sessions", headers=headers) - + # Should not fail with authentication error assert response.status_code != 401 - + def test_invalid_api_key(self, client, mock_services): """Test authentication with invalid API key.""" headers = {"x-api-key": "invalid-key"} - - with patch('src.services.auth.settings') as mock_settings: + + with patch("src.services.auth.settings") as mock_settings: mock_settings.api_key = "test-api-key-for-testing-12345" - + response = client.get("/sessions", headers=headers) - + assert response.status_code == 401 # Update assertion to match actual error message - assert "Invalid API key" in response.json()["error"] or "Invalid or missing API key" in response.json()["error"] - + assert ( + "Invalid API key" in response.json()["error"] + or "Invalid or missing API key" in response.json()["error"] + ) + def test_missing_api_key(self, client, mock_services): """Test authentication without API key.""" response = client.get("/sessions") - + assert response.status_code == 401 # Update assertion to match actual error message - assert "API key is required" in response.json()["error"] or "Invalid or missing API key" in response.json()["error"] - + assert ( + "API key is required" in response.json()["error"] + or "Invalid or missing API key" in response.json()["error"] + ) + def test_empty_api_key(self, client, mock_services): """Test authentication with empty API key.""" headers = {"x-api-key": ""} - + response = client.get("/sessions", headers=headers) - + assert response.status_code == 401 # ... (skipping some methods) ... @@ -111,77 +121,79 @@ def test_empty_api_key(self, client, mock_services): def test_root_endpoint_no_auth_required(self, client): """Test root endpoint auth requirements.""" response = client.get("/") - + # Root endpoint is not excluded from auth, so it should return 401 assert response.status_code == 401 # ... - @patch('src.services.auth.settings') + @patch("src.services.auth.settings") def test_complete_exec_flow_with_auth(self, mock_settings, client, mock_services): """Test complete execution flow with authentication.""" mock_settings.api_key = "test-api-key-for-testing-12345" headers = {"x-api-key": "test-api-key-for-testing-12345"} - + # Mock successful execution from src.models import CodeExecution, ExecutionStatus from datetime import datetime, timezone - + mock_execution = CodeExecution( execution_id="test-exec", session_id="test-session", code="print('Hello')", language="py", status=ExecutionStatus.COMPLETED, - exit_code=0 + exit_code=0, + ) + mock_services["execution"].execute_code.return_value = ( + mock_execution, + None, + None, + [], + "pool_hit", ) - mock_services["execution"].execute_code.return_value = (mock_execution, None, None, [], "pool_hit") - + # Mock session creation from src.models.session import Session, SessionStatus + mock_session = Session( session_id="test-session", status=SessionStatus.ACTIVE, created_at=datetime.now(timezone.utc), last_activity=datetime.now(timezone.utc), expires_at=datetime.now(timezone.utc), - metadata={} + metadata={}, ) mock_services["session"].create_session.return_value = mock_session - + # Execute code - request_data = { - "code": "print('Hello, World!')", - "lang": "py" - } - + request_data = {"code": "print('Hello, World!')", "lang": "py"} + response = client.post("/exec", json=request_data, headers=headers) - + assert response.status_code == 200 assert "session_id" in response.json() - + def test_exec_flow_without_auth(self, client, mock_services): """Test execution flow without authentication.""" - request_data = { - "code": "print('Hello, World!')", - "lang": "py" - } - + request_data = {"code": "print('Hello, World!')", "lang": "py"} + response = client.post("/exec", json=request_data) - + assert response.status_code == 401 - - @patch('src.services.auth.settings') + + @patch("src.services.auth.settings") def test_file_upload_flow_with_auth(self, mock_settings, client, mock_services): """Test file upload flow with authentication.""" mock_settings.api_key = "test-api-key-for-testing-12345" headers = {"x-api-key": "test-api-key-for-testing-12345"} - + # Mock file upload mock_services["file"].store_uploaded_file.return_value = "file-123" # Mock get_file_info needed for upload response from src.models.files import FileInfo from datetime import datetime, timezone + mock_services["file"].get_file_info.return_value = FileInfo( file_id="file-123", filename="test.txt", @@ -189,138 +201,139 @@ def test_file_upload_flow_with_auth(self, mock_settings, client, mock_services): size=12, created_at=datetime.now(timezone.utc), modified_at=datetime.now(timezone.utc), - content_type="text/plain" + content_type="text/plain", ) - + import io + files = {"files": ("test.txt", io.BytesIO(b"test content"), "text/plain")} - + # Use /upload instead of /files/upload as per src/main.py response = client.post("/upload", files=files, headers=headers) - + assert response.status_code == 200 assert "files" in response.json() - + def test_file_upload_flow_without_auth(self, client, mock_services): """Test file upload flow without authentication.""" import io + files = {"files": ("test.txt", io.BytesIO(b"test content"), "text/plain")} - + response = client.post("/upload", files=files) - + assert response.status_code == 401 class TestAuthenticationEdgeCases: """Test edge cases in authentication.""" - + def test_auth_with_special_characters_in_key(self, client, mock_services): """Test authentication with special characters in API key.""" special_key = "test-key-with-special-chars!@#$%^&*()" - - with patch('src.services.auth.settings') as mock_settings: + + with patch("src.services.auth.settings") as mock_settings: mock_settings.api_key = special_key headers = {"x-api-key": special_key} - + response = client.get("/sessions", headers=headers) - + # Should handle special characters correctly # If 401, it means auth failed, but we want to ensure no 500 error assert response.status_code in [200, 401] - - + def test_auth_with_very_long_key(self, client, mock_services): """Test authentication with very long API key.""" long_key = "a" * 1000 # 1000 character key - - with patch('src.services.auth.settings') as mock_settings: + + with patch("src.services.auth.settings") as mock_settings: mock_settings.api_key = long_key headers = {"x-api-key": long_key} - + response = client.get("/sessions", headers=headers) - + # Should handle long keys (within reason) assert response.status_code in [200, 401] - + def test_auth_with_whitespace_in_key(self, client, mock_services): """Test authentication with whitespace in API key.""" # Test leading/trailing whitespace key_with_whitespace = " test-api-key-for-testing-12345 " clean_key = "test-api-key-for-testing-12345" - - with patch('src.services.auth.settings') as mock_settings: + + with patch("src.services.auth.settings") as mock_settings: mock_settings.api_key = clean_key headers = {"x-api-key": key_with_whitespace} - + response = client.get("/sessions", headers=headers) - + # Should either trim whitespace or reject assert response.status_code in [401, 200] # Depends on implementation - + def test_multiple_auth_headers(self, client, mock_services): """Test request with multiple authentication headers.""" - with patch('src.services.auth.settings') as mock_settings: + with patch("src.services.auth.settings") as mock_settings: mock_settings.api_key = "test-api-key-for-testing-12345" - + headers = { "x-api-key": "test-api-key-for-testing-12345", - "Authorization": "Bearer different-key" + "Authorization": "Bearer different-key", } - + response = client.get("/sessions", headers=headers) - + # Should use one of the headers (typically x-api-key takes precedence) assert response.status_code != 401 - + def test_auth_header_injection_attempt(self, client, mock_services): """Test authentication with header injection attempt.""" malicious_key = "test-key\r\nX-Injected-Header: malicious" - + headers = {"x-api-key": malicious_key} - + response = client.get("/sessions", headers=headers) - + # Should reject malicious header assert response.status_code == 401 - + # Verify no injected headers in response assert "X-Injected-Header" not in response.headers class TestAuthenticationPerformance: """Test authentication performance characteristics.""" - - @patch('src.services.auth.settings') + + @patch("src.services.auth.settings") def test_auth_response_time(self, mock_settings, client, mock_services): """Test that authentication doesn't add excessive latency.""" mock_settings.api_key = "test-api-key-for-testing-12345" headers = {"x-api-key": "test-api-key-for-testing-12345"} - + start_time = time.time() response = client.get("/sessions", headers=headers) end_time = time.time() - + # Authentication should be fast (< 1 second for this simple test) auth_time = end_time - start_time assert auth_time < 1.0 - + # Should not fail with auth error assert response.status_code != 401 - + def test_concurrent_auth_requests(self, client, mock_services): """Test handling of concurrent authentication requests.""" # This would require actual concurrency testing # For now, just verify that multiple sequential requests work - - with patch('src.services.auth.settings') as mock_settings: + + with patch("src.services.auth.settings") as mock_settings: mock_settings.api_key = "test-api-key-for-testing-12345" headers = {"x-api-key": "test-api-key-for-testing-12345"} - + responses = [] for i in range(10): response = client.get("/sessions", headers=headers) responses.append(response) - + # All should have consistent auth results auth_results = [r.status_code != 401 for r in responses] - assert all(auth_results) # All should pass auth \ No newline at end of file + assert all(auth_results) # All should pass auth diff --git a/tests/integration/test_container_behavior.py b/tests/integration/test_container_behavior.py index 81484a8..935cc21 100644 --- a/tests/integration/test_container_behavior.py +++ b/tests/integration/test_container_behavior.py @@ -36,7 +36,7 @@ def create_session(session_id: str) -> Session: created_at=datetime.now(timezone.utc), last_activity=datetime.now(timezone.utc), expires_at=datetime.now(timezone.utc) + timedelta(hours=24), - metadata={} + metadata={}, ) @@ -44,6 +44,7 @@ def create_session(session_id: str) -> Session: # CONTAINER LIFECYCLE BEHAVIOR # ============================================================================= + class TestContainerLifecycle: """Test container lifecycle behavior.""" @@ -63,9 +64,9 @@ def test_container_created_for_execution(self, client, auth_headers): ExecutionOutput( type=OutputType.STDOUT, content="test", - timestamp=datetime.now(timezone.utc) + timestamp=datetime.now(timezone.utc), ) - ] + ], ) mock_session_service = AsyncMock() @@ -73,21 +74,33 @@ def test_container_created_for_execution(self, client, auth_headers): mock_session_service.get_session.return_value = mock_session mock_execution_service = AsyncMock() - mock_execution_service.execute_code.return_value = (mock_execution, None, None, [], "pool_hit") + mock_execution_service.execute_code.return_value = ( + mock_execution, + None, + None, + [], + "pool_hit", + ) mock_file_service = AsyncMock() mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service try: - response = client.post("/exec", json={ - "code": "print('test')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('test')", "lang": "py"}, + headers=auth_headers, + ) assert response.status_code == 200 @@ -108,7 +121,7 @@ def test_container_cleaned_up_after_execution(self, client, auth_headers): language="py", status=ExecutionStatus.COMPLETED, exit_code=0, - outputs=[] + outputs=[], ) mock_session_service = AsyncMock() @@ -116,22 +129,34 @@ def test_container_cleaned_up_after_execution(self, client, auth_headers): mock_session_service.get_session.return_value = mock_session mock_execution_service = AsyncMock() - mock_execution_service.execute_code.return_value = (mock_execution, None, None, [], "pool_hit") + mock_execution_service.execute_code.return_value = ( + mock_execution, + None, + None, + [], + "pool_hit", + ) mock_execution_service.cleanup_session = AsyncMock() mock_file_service = AsyncMock() mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service try: - response = client.post("/exec", json={ - "code": "print('done')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('done')", "lang": "py"}, + headers=auth_headers, + ) assert response.status_code == 200 @@ -145,6 +170,7 @@ def test_container_cleaned_up_after_execution(self, client, auth_headers): # LANGUAGE-SPECIFIC EXECUTION BEHAVIOR # ============================================================================= + class TestLanguageExecution: """Test language-specific execution patterns.""" @@ -167,9 +193,18 @@ def setup_mocks(self): self.mock_file_service = AsyncMock() self.mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service - app.dependency_overrides[get_session_service] = lambda: self.mock_session_service - app.dependency_overrides[get_execution_service] = lambda: self.mock_execution_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + + app.dependency_overrides[get_session_service] = ( + lambda: self.mock_session_service + ) + app.dependency_overrides[get_execution_service] = ( + lambda: self.mock_execution_service + ) app.dependency_overrides[get_file_service] = lambda: self.mock_file_service yield @@ -179,54 +214,67 @@ def setup_mocks(self): @pytest.mark.parametrize("language", STDIN_LANGUAGES) def test_stdin_language_execution(self, client, auth_headers, language): """Test stdin-based language execution (interpreted languages).""" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id=f"exec-{language}", - session_id="lang-test-session", - code=f"{language} code", - language=language, - status=ExecutionStatus.COMPLETED, - exit_code=0, - outputs=[ - ExecutionOutput( - type=OutputType.STDOUT, - content=f"Hello {language}", - timestamp=datetime.now(timezone.utc) - ) - ] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id=f"exec-{language}", + session_id="lang-test-session", + code=f"{language} code", + language=language, + status=ExecutionStatus.COMPLETED, + exit_code=0, + outputs=[ + ExecutionOutput( + type=OutputType.STDOUT, + content=f"Hello {language}", + timestamp=datetime.now(timezone.utc), + ) + ], + ), + None, + None, + [], + "pool_hit", + ) code_samples = { "py": "print('Hello py')", "js": "console.log('Hello js')", "php": "", - "r": "print('Hello r')" + "r": "print('Hello r')", } - response = client.post("/exec", json={ - "code": code_samples.get(language, ""), - "lang": language - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": code_samples.get(language, ""), "lang": language}, + headers=auth_headers, + ) assert response.status_code == 200 @pytest.mark.parametrize("language", FILE_LANGUAGES) def test_file_language_execution(self, client, auth_headers, language): """Test file-based language execution (compiled languages).""" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id=f"exec-{language}", - session_id="lang-test-session", - code=f"{language} code", - language=language, - status=ExecutionStatus.COMPLETED, - exit_code=0, - outputs=[ - ExecutionOutput( - type=OutputType.STDOUT, - content=f"Hello {language}", - timestamp=datetime.now(timezone.utc) - ) - ] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id=f"exec-{language}", + session_id="lang-test-session", + code=f"{language} code", + language=language, + status=ExecutionStatus.COMPLETED, + exit_code=0, + outputs=[ + ExecutionOutput( + type=OutputType.STDOUT, + content=f"Hello {language}", + timestamp=datetime.now(timezone.utc), + ) + ], + ), + None, + None, + [], + "pool_hit", + ) code_samples = { "go": 'package main\nimport "fmt"\nfunc main() { fmt.Println("Hello go") }', @@ -236,13 +284,14 @@ def test_file_language_execution(self, client, auth_headers, language): "rs": 'fn main() { println!("Hello rs"); }', "f90": 'program hello\n print *, "Hello f90"\nend program hello', "d": 'import std.stdio; void main() { writeln("Hello d"); }', - "ts": 'console.log("Hello ts");' + "ts": 'console.log("Hello ts");', } - response = client.post("/exec", json={ - "code": code_samples.get(language, ""), - "lang": language - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": code_samples.get(language, ""), "lang": language}, + headers=auth_headers, + ) assert response.status_code == 200 @@ -251,6 +300,7 @@ def test_file_language_execution(self, client, auth_headers, language): # EXECUTION STATUS BEHAVIOR # ============================================================================= + class TestExecutionStatus: """Test execution status handling.""" @@ -267,9 +317,18 @@ def setup_mocks(self): self.mock_file_service = AsyncMock() self.mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service - app.dependency_overrides[get_session_service] = lambda: self.mock_session_service - app.dependency_overrides[get_execution_service] = lambda: self.mock_execution_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + + app.dependency_overrides[get_session_service] = ( + lambda: self.mock_session_service + ) + app.dependency_overrides[get_execution_service] = ( + lambda: self.mock_execution_service + ) app.dependency_overrides[get_file_service] = lambda: self.mock_file_service yield @@ -278,26 +337,31 @@ def setup_mocks(self): def test_completed_status(self, client, auth_headers): """Test successful execution status.""" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id="exec-completed", - session_id="status-test-session", - code="print('ok')", - language="py", - status=ExecutionStatus.COMPLETED, - exit_code=0, - outputs=[ - ExecutionOutput( - type=OutputType.STDOUT, - content="ok", - timestamp=datetime.now(timezone.utc) - ) - ] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id="exec-completed", + session_id="status-test-session", + code="print('ok')", + language="py", + status=ExecutionStatus.COMPLETED, + exit_code=0, + outputs=[ + ExecutionOutput( + type=OutputType.STDOUT, + content="ok", + timestamp=datetime.now(timezone.utc), + ) + ], + ), + None, + None, + [], + "pool_hit", + ) - response = client.post("/exec", json={ - "code": "print('ok')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", json={"code": "print('ok')", "lang": "py"}, headers=auth_headers + ) assert response.status_code == 200 data = response.json() @@ -305,66 +369,87 @@ def test_completed_status(self, client, auth_headers): def test_failed_status(self, client, auth_headers): """Test failed execution status.""" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id="exec-failed", - session_id="status-test-session", - code="raise Exception('fail')", - language="py", - status=ExecutionStatus.FAILED, - exit_code=1, - error_message="Exception: fail", - outputs=[ - ExecutionOutput( - type=OutputType.STDERR, - content="Exception: fail", - timestamp=datetime.now(timezone.utc) - ) - ] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id="exec-failed", + session_id="status-test-session", + code="raise Exception('fail')", + language="py", + status=ExecutionStatus.FAILED, + exit_code=1, + error_message="Exception: fail", + outputs=[ + ExecutionOutput( + type=OutputType.STDERR, + content="Exception: fail", + timestamp=datetime.now(timezone.utc), + ) + ], + ), + None, + None, + [], + "pool_hit", + ) - response = client.post("/exec", json={ - "code": "raise Exception('fail')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "raise Exception('fail')", "lang": "py"}, + headers=auth_headers, + ) # Still returns 200 with error in output assert response.status_code == 200 def test_timeout_status(self, client, auth_headers): """Test timeout execution status.""" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id="exec-timeout", - session_id="status-test-session", - code="import time; time.sleep(999)", - language="py", - status=ExecutionStatus.TIMEOUT, - error_message="Execution timed out after 30 seconds", - outputs=[] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id="exec-timeout", + session_id="status-test-session", + code="import time; time.sleep(999)", + language="py", + status=ExecutionStatus.TIMEOUT, + error_message="Execution timed out after 30 seconds", + outputs=[], + ), + None, + None, + [], + "pool_hit", + ) - response = client.post("/exec", json={ - "code": "import time; time.sleep(999)", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "import time; time.sleep(999)", "lang": "py"}, + headers=auth_headers, + ) # Still returns 200 with timeout info assert response.status_code == 200 def test_cancelled_status(self, client, auth_headers): """Test cancelled execution status.""" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id="exec-cancelled", - session_id="status-test-session", - code="cancelled code", - language="py", - status=ExecutionStatus.CANCELLED, - outputs=[] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id="exec-cancelled", + session_id="status-test-session", + code="cancelled code", + language="py", + status=ExecutionStatus.CANCELLED, + outputs=[], + ), + None, + None, + [], + "pool_hit", + ) - response = client.post("/exec", json={ - "code": "long running code", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "long running code", "lang": "py"}, + headers=auth_headers, + ) assert response.status_code == 200 @@ -373,6 +458,7 @@ def test_cancelled_status(self, client, auth_headers): # FILE GENERATION BEHAVIOR # ============================================================================= + class TestFileGeneration: """Test file generation during execution.""" @@ -388,9 +474,18 @@ def setup_mocks(self): self.mock_execution_service = AsyncMock() self.mock_file_service = AsyncMock() - from src.dependencies.services import get_session_service, get_execution_service, get_file_service - app.dependency_overrides[get_session_service] = lambda: self.mock_session_service - app.dependency_overrides[get_execution_service] = lambda: self.mock_execution_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + + app.dependency_overrides[get_session_service] = ( + lambda: self.mock_session_service + ) + app.dependency_overrides[get_execution_service] = ( + lambda: self.mock_execution_service + ) app.dependency_overrides[get_file_service] = lambda: self.mock_file_service yield @@ -399,23 +494,29 @@ def setup_mocks(self): def test_generated_file_detected(self, client, auth_headers): """Test that files generated during execution are detected.""" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id="exec-genfile", - session_id="filegen-test-session", - code="write file", - language="py", - status=ExecutionStatus.COMPLETED, - exit_code=0, - outputs=[ - ExecutionOutput( - type=OutputType.FILE, - content="/mnt/data/output.txt", - mime_type="text/plain", - size=100, - timestamp=datetime.now(timezone.utc) - ) - ] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id="exec-genfile", + session_id="filegen-test-session", + code="write file", + language="py", + status=ExecutionStatus.COMPLETED, + exit_code=0, + outputs=[ + ExecutionOutput( + type=OutputType.FILE, + content="/mnt/data/output.txt", + mime_type="text/plain", + size=100, + timestamp=datetime.now(timezone.utc), + ) + ], + ), + None, + None, + [], + "pool_hit", + ) # Mock store_execution_output_file to return a file_id string self.mock_file_service.store_execution_output_file.return_value = "gen-file-1" @@ -427,14 +528,18 @@ def test_generated_file_detected(self, client, auth_headers): size=100, content_type="text/plain", created_at=datetime.utcnow(), - path="/output.txt" + path="/output.txt", ) ] - response = client.post("/exec", json={ - "code": "with open('output.txt', 'w') as f: f.write('hello')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={ + "code": "with open('output.txt', 'w') as f: f.write('hello')", + "lang": "py", + }, + headers=auth_headers, + ) assert response.status_code == 200 data = response.json() @@ -442,33 +547,42 @@ def test_generated_file_detected(self, client, auth_headers): def test_multiple_files_generated(self, client, auth_headers): """Test that multiple generated files are detected.""" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id="exec-multifile", - session_id="filegen-test-session", - code="write files", - language="py", - status=ExecutionStatus.COMPLETED, - exit_code=0, - outputs=[ - ExecutionOutput( - type=OutputType.FILE, - content="/mnt/data/file1.txt", - mime_type="text/plain", - size=50, - timestamp=datetime.now(timezone.utc) - ), - ExecutionOutput( - type=OutputType.FILE, - content="/mnt/data/file2.csv", - mime_type="text/csv", - size=100, - timestamp=datetime.now(timezone.utc) - ) - ] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id="exec-multifile", + session_id="filegen-test-session", + code="write files", + language="py", + status=ExecutionStatus.COMPLETED, + exit_code=0, + outputs=[ + ExecutionOutput( + type=OutputType.FILE, + content="/mnt/data/file1.txt", + mime_type="text/plain", + size=50, + timestamp=datetime.now(timezone.utc), + ), + ExecutionOutput( + type=OutputType.FILE, + content="/mnt/data/file2.csv", + mime_type="text/csv", + size=100, + timestamp=datetime.now(timezone.utc), + ), + ], + ), + None, + None, + [], + "pool_hit", + ) # Mock store_execution_output_file to return file IDs (called multiple times) - self.mock_file_service.store_execution_output_file.side_effect = ["gen-1", "gen-2"] + self.mock_file_service.store_execution_output_file.side_effect = [ + "gen-1", + "gen-2", + ] self.mock_file_service.list_files.return_value = [ FileInfo( @@ -477,7 +591,7 @@ def test_multiple_files_generated(self, client, auth_headers): size=50, content_type="text/plain", created_at=datetime.utcnow(), - path="/file1.txt" + path="/file1.txt", ), FileInfo( file_id="gen-2", @@ -485,14 +599,15 @@ def test_multiple_files_generated(self, client, auth_headers): size=100, content_type="text/csv", created_at=datetime.utcnow(), - path="/file2.csv" - ) + path="/file2.csv", + ), ] - response = client.post("/exec", json={ - "code": "generate multiple files", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "generate multiple files", "lang": "py"}, + headers=auth_headers, + ) assert response.status_code == 200 data = response.json() @@ -500,28 +615,33 @@ def test_multiple_files_generated(self, client, auth_headers): def test_no_files_generated(self, client, auth_headers): """Test execution with no file generation.""" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id="exec-nofile", - session_id="filegen-test-session", - code="print only", - language="py", - status=ExecutionStatus.COMPLETED, - exit_code=0, - outputs=[ - ExecutionOutput( - type=OutputType.STDOUT, - content="output", - timestamp=datetime.now(timezone.utc) - ) - ] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id="exec-nofile", + session_id="filegen-test-session", + code="print only", + language="py", + status=ExecutionStatus.COMPLETED, + exit_code=0, + outputs=[ + ExecutionOutput( + type=OutputType.STDOUT, + content="output", + timestamp=datetime.now(timezone.utc), + ) + ], + ), + None, + None, + [], + "pool_hit", + ) self.mock_file_service.list_files.return_value = [] - response = client.post("/exec", json={ - "code": "print('hello')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", json={"code": "print('hello')", "lang": "py"}, headers=auth_headers + ) assert response.status_code == 200 data = response.json() @@ -532,6 +652,7 @@ def test_no_files_generated(self, client, auth_headers): # OUTPUT HANDLING BEHAVIOR # ============================================================================= + class TestOutputHandling: """Test output handling behavior.""" @@ -548,9 +669,18 @@ def setup_mocks(self): self.mock_file_service = AsyncMock() self.mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service - app.dependency_overrides[get_session_service] = lambda: self.mock_session_service - app.dependency_overrides[get_execution_service] = lambda: self.mock_execution_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + + app.dependency_overrides[get_session_service] = ( + lambda: self.mock_session_service + ) + app.dependency_overrides[get_execution_service] = ( + lambda: self.mock_execution_service + ) app.dependency_overrides[get_file_service] = lambda: self.mock_file_service yield @@ -561,26 +691,33 @@ def test_large_output_handling(self, client, auth_headers): """Test handling of large output.""" large_output = "A" * 100000 # 100KB - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id="exec-large", - session_id="output-test-session", - code="print large", - language="py", - status=ExecutionStatus.COMPLETED, - exit_code=0, - outputs=[ - ExecutionOutput( - type=OutputType.STDOUT, - content=large_output, - timestamp=datetime.now(timezone.utc) - ) - ] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id="exec-large", + session_id="output-test-session", + code="print large", + language="py", + status=ExecutionStatus.COMPLETED, + exit_code=0, + outputs=[ + ExecutionOutput( + type=OutputType.STDOUT, + content=large_output, + timestamp=datetime.now(timezone.utc), + ) + ], + ), + None, + None, + [], + "pool_hit", + ) - response = client.post("/exec", json={ - "code": "print('A' * 100000)", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('A' * 100000)", "lang": "py"}, + headers=auth_headers, + ) assert response.status_code == 200 data = response.json() @@ -588,31 +725,36 @@ def test_large_output_handling(self, client, auth_headers): def test_mixed_stdout_stderr(self, client, auth_headers): """Test handling of mixed stdout and stderr.""" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id="exec-mixed", - session_id="output-test-session", - code="mixed output", - language="py", - status=ExecutionStatus.COMPLETED, - exit_code=0, - outputs=[ - ExecutionOutput( - type=OutputType.STDOUT, - content="stdout content", - timestamp=datetime.now(timezone.utc) - ), - ExecutionOutput( - type=OutputType.STDERR, - content="stderr content", - timestamp=datetime.now(timezone.utc) - ) - ] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id="exec-mixed", + session_id="output-test-session", + code="mixed output", + language="py", + status=ExecutionStatus.COMPLETED, + exit_code=0, + outputs=[ + ExecutionOutput( + type=OutputType.STDOUT, + content="stdout content", + timestamp=datetime.now(timezone.utc), + ), + ExecutionOutput( + type=OutputType.STDERR, + content="stderr content", + timestamp=datetime.now(timezone.utc), + ), + ], + ), + None, + None, + [], + "pool_hit", + ) - response = client.post("/exec", json={ - "code": "print and warn", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", json={"code": "print and warn", "lang": "py"}, headers=auth_headers + ) assert response.status_code == 200 data = response.json() @@ -624,26 +766,33 @@ def test_unicode_output(self, client, auth_headers): """Test handling of Unicode output.""" unicode_output = "Hello 世界 🌍 مرحبا" - self.mock_execution_service.execute_code.return_value = (CodeExecution( - execution_id="exec-unicode", - session_id="output-test-session", - code="print unicode", - language="py", - status=ExecutionStatus.COMPLETED, - exit_code=0, - outputs=[ - ExecutionOutput( - type=OutputType.STDOUT, - content=unicode_output, - timestamp=datetime.now(timezone.utc) - ) - ] - ), None, None, [], "pool_hit") + self.mock_execution_service.execute_code.return_value = ( + CodeExecution( + execution_id="exec-unicode", + session_id="output-test-session", + code="print unicode", + language="py", + status=ExecutionStatus.COMPLETED, + exit_code=0, + outputs=[ + ExecutionOutput( + type=OutputType.STDOUT, + content=unicode_output, + timestamp=datetime.now(timezone.utc), + ) + ], + ), + None, + None, + [], + "pool_hit", + ) - response = client.post("/exec", json={ - "code": "print('Hello 世界 🌍 مرحبا')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('Hello 世界 🌍 مرحبا')", "lang": "py"}, + headers=auth_headers, + ) assert response.status_code == 200 data = response.json() diff --git a/tests/integration/test_container_hardening.py b/tests/integration/test_container_hardening.py index 5ddbaed..d9bbc31 100644 --- a/tests/integration/test_container_hardening.py +++ b/tests/integration/test_container_hardening.py @@ -110,7 +110,10 @@ def test_hostname_is_generic(self, client, auth_headers): assert response.status_code == 200 data = response.json() # Hostname should be 'sandbox', not contain Azure or host info - assert "sandbox" in data.get("stdout", "").lower() or response.status_code == 200 + assert ( + "sandbox" in data.get("stdout", "").lower() + or response.status_code == 200 + ) finally: app.dependency_overrides.clear() diff --git a/tests/integration/test_exec_api.py b/tests/integration/test_exec_api.py index f1fd6f7..1640764 100644 --- a/tests/integration/test_exec_api.py +++ b/tests/integration/test_exec_api.py @@ -27,24 +27,24 @@ def auth_headers(): def mock_session_service(): """Mock session service for testing.""" service = AsyncMock() - + # Mock session creation from src.models.session import Session, SessionStatus from datetime import datetime, timezone, timedelta - + mock_session = Session( session_id="test-session-123", status=SessionStatus.ACTIVE, created_at=datetime.now(timezone.utc), last_activity=datetime.now(timezone.utc), expires_at=datetime.now(timezone.utc) + timedelta(hours=24), - metadata={"entity_id": "test-entity"} + metadata={"entity_id": "test-entity"}, ) - + service.create_session.return_value = mock_session service.get_session.return_value = mock_session service.validate_session_access.return_value = True - + return service @@ -52,7 +52,7 @@ def mock_session_service(): def mock_execution_service(): """Mock execution service for testing.""" service = AsyncMock() - + # Mock successful execution mock_execution = CodeExecution( execution_id="exec-123", @@ -66,13 +66,13 @@ def mock_execution_service(): ExecutionOutput( type=OutputType.STDOUT, content="Hello, World!", - timestamp=datetime.now(timezone.utc) + timestamp=datetime.now(timezone.utc), ) - ] + ], ) - + service.execute_code.return_value = (mock_execution, None, None, [], "pool_hit") - + return service @@ -87,111 +87,124 @@ def mock_file_service(): @pytest.fixture(autouse=True) def mock_dependencies(mock_session_service, mock_execution_service, mock_file_service): """Mock all dependencies for testing.""" - from src.dependencies.services import get_session_service, get_execution_service, get_file_service - + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + # Override the dependencies in the FastAPI app app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service - + yield - + # Clean up after test app.dependency_overrides.clear() class TestExecEndpoint: """Test the /exec endpoint functionality.""" - - def test_exec_simple_python_code(self, client, auth_headers, mock_execution_service): + + def test_exec_simple_python_code( + self, client, auth_headers, mock_execution_service + ): """Test executing simple Python code.""" - request_data = { - "code": "print('Hello, World!')", - "lang": "py" - } - + request_data = {"code": "print('Hello, World!')", "lang": "py"} + response = client.post("/exec", json=request_data, headers=auth_headers) - + assert response.status_code == 200 response_data = response.json() - + # Check LibreChat-compatible response structure assert "session_id" in response_data assert "files" in response_data assert "stdout" in response_data assert "stderr" in response_data - + # Check stdout content (should end with newline for LibreChat compatibility) assert response_data["stdout"] == "Hello, World!\n" - + # Check files array assert isinstance(response_data["files"], list) - + # Verify service was called mock_execution_service.execute_code.assert_called_once() - - def test_exec_with_entity_id(self, client, auth_headers, mock_session_service, mock_execution_service): + + def test_exec_with_entity_id( + self, client, auth_headers, mock_session_service, mock_execution_service + ): """Test executing code with entity_id for session sharing.""" request_data = { "code": "print('Hello from entity!')", "lang": "py", - "entity_id": "test-entity-123" + "entity_id": "test-entity-123", } - + response = client.post("/exec", json=request_data, headers=auth_headers) - + assert response.status_code == 200 response_data = response.json() - + # Should create session with entity metadata mock_session_service.create_session.assert_called_once() create_call = mock_session_service.create_session.call_args[0][0] assert create_call.metadata["entity_id"] == "test-entity-123" - - @pytest.mark.skip(reason="Mock file service returns AsyncMock instead of proper values") + + @pytest.mark.skip( + reason="Mock file service returns AsyncMock instead of proper values" + ) def test_exec_with_files(self, client, auth_headers, mock_execution_service): """Test executing code with file references.""" request_data = { "code": "with open('data.txt', 'r') as f: print(f.read())", "lang": "py", "files": [ - { - "id": "file-123", - "session_id": "test-session", - "name": "data.txt" - } - ] + {"id": "file-123", "session_id": "test-session", "name": "data.txt"} + ], } - + response = client.post("/exec", json=request_data, headers=auth_headers) - + assert response.status_code == 200 - + # Verify execution was called with files mock_execution_service.execute_code.assert_called_once() call_args = mock_execution_service.execute_code.call_args files_arg = call_args[1]["files"] # keyword argument assert len(files_arg) == 1 assert files_arg[0]["id"] == "file-123" - - def test_exec_different_languages(self, client, auth_headers, mock_execution_service): + + def test_exec_different_languages( + self, client, auth_headers, mock_execution_service + ): """Test executing code in different languages.""" test_cases = [ {"lang": "py", "code": "print('Hello Python')"}, {"lang": "js", "code": "console.log('Hello JavaScript')"}, - {"lang": "go", "code": "package main\nimport \"fmt\"\nfunc main() { fmt.Println(\"Hello Go\") }"}, - {"lang": "java", "code": "public class Main { public static void main(String[] args) { System.out.println(\"Hello Java\"); } }"} + { + "lang": "go", + "code": 'package main\nimport "fmt"\nfunc main() { fmt.Println("Hello Go") }', + }, + { + "lang": "java", + "code": 'public class Main { public static void main(String[] args) { System.out.println("Hello Java"); } }', + }, ] - + for test_case in test_cases: response = client.post("/exec", json=test_case, headers=auth_headers) assert response.status_code == 200 - + response_data = response.json() # Language is no longer returned in the response for LibreChat compatibility assert "session_id" in response_data - - def test_exec_with_execution_error(self, client, auth_headers, mock_execution_service): + + def test_exec_with_execution_error( + self, client, auth_headers, mock_execution_service + ): """Test handling execution errors.""" # Mock failed execution failed_execution = CodeExecution( @@ -206,27 +219,30 @@ def test_exec_with_execution_error(self, client, auth_headers, mock_execution_se ExecutionOutput( type=OutputType.STDERR, content="NameError: name 'undefined_variable' is not defined", - timestamp=datetime.now(timezone.utc) + timestamp=datetime.now(timezone.utc), ) - ] + ], ) - - mock_execution_service.execute_code.return_value = (failed_execution, None, None, [], "pool_hit") - - request_data = { - "code": "print(undefined_variable)", - "lang": "py" - } - + + mock_execution_service.execute_code.return_value = ( + failed_execution, + None, + None, + [], + "pool_hit", + ) + + request_data = {"code": "print(undefined_variable)", "lang": "py"} + response = client.post("/exec", json=request_data, headers=auth_headers) - + assert response.status_code == 200 # Still 200, but with error in response response_data = response.json() - + # For failed executions, content may be empty or contain error info # In LibreChat format, errors would typically be in stderr which isn't directly exposed # but the test shows the execution completed and returned a response - + def test_exec_with_timeout(self, client, auth_headers, mock_execution_service): """Test handling execution timeout.""" # Mock timeout execution @@ -236,169 +252,160 @@ def test_exec_with_timeout(self, client, auth_headers, mock_execution_service): code="import time; time.sleep(100)", language="py", status=ExecutionStatus.TIMEOUT, - error_message="Execution timed out after 30 seconds" + error_message="Execution timed out after 30 seconds", ) - - mock_execution_service.execute_code.return_value = (timeout_execution, None, None, [], "pool_hit") - - request_data = { - "code": "import time; time.sleep(100)", - "lang": "py" - } - + + mock_execution_service.execute_code.return_value = ( + timeout_execution, + None, + None, + [], + "pool_hit", + ) + + request_data = {"code": "import time; time.sleep(100)", "lang": "py"} + response = client.post("/exec", json=request_data, headers=auth_headers) - + assert response.status_code == 200 response_data = response.json() - + # For timeout, we expect LibreChat format but stdout may be empty or contain timeout message assert "session_id" in response_data assert "files" in response_data assert "stdout" in response_data assert "stderr" in response_data - + def test_exec_invalid_language(self, client, auth_headers): """Test executing code with invalid language.""" - request_data = { - "code": "print('Hello')", - "lang": "invalid_language" - } - + request_data = {"code": "print('Hello')", "lang": "invalid_language"} + response = client.post("/exec", json=request_data, headers=auth_headers) - + # Should either return error or handle gracefully assert response.status_code in [200, 400, 422] - + def test_exec_empty_code(self, client, auth_headers): """Test executing empty code.""" - request_data = { - "code": "", - "lang": "py" - } + request_data = {"code": "", "lang": "py"} response = client.post("/exec", json=request_data, headers=auth_headers) # Should return validation error (400 for business logic validation) assert response.status_code == 400 - + def test_exec_missing_required_fields(self, client, auth_headers): """Test request with missing required fields.""" # Missing code response = client.post("/exec", json={"lang": "py"}, headers=auth_headers) assert response.status_code == 422 - + # Missing lang - response = client.post("/exec", json={"code": "print('test')"}, headers=auth_headers) + response = client.post( + "/exec", json={"code": "print('test')"}, headers=auth_headers + ) assert response.status_code == 422 - + def test_exec_with_args(self, client, auth_headers, mock_execution_service): """Test executing code with command line arguments.""" request_data = { "code": "import sys; print(' '.join(sys.argv[1:]))", "lang": "py", - "args": "arg1 arg2 arg3" + "args": "arg1 arg2 arg3", } - + response = client.post("/exec", json=request_data, headers=auth_headers) - + assert response.status_code == 200 # Args handling would be implementation-specific - + def test_exec_with_user_id(self, client, auth_headers, mock_execution_service): """Test executing code with user_id for tracking.""" request_data = { "code": "print('Hello User')", "lang": "py", - "user_id": "user-123" + "user_id": "user-123", } - + response = client.post("/exec", json=request_data, headers=auth_headers) - + assert response.status_code == 200 # User ID would be used for logging/tracking - - def test_exec_session_reuse(self, client, auth_headers, mock_session_service, mock_execution_service): + + def test_exec_session_reuse( + self, client, auth_headers, mock_session_service, mock_execution_service + ): """Test that sessions are reused for the same entity.""" - request_data = { - "code": "x = 1", - "lang": "py", - "entity_id": "test-entity" - } - + request_data = {"code": "x = 1", "lang": "py", "entity_id": "test-entity"} + # First execution response1 = client.post("/exec", json=request_data, headers=auth_headers) assert response1.status_code == 200 session_id_1 = response1.json()["session_id"] - + # Second execution with same entity request_data["code"] = "print(x)" response2 = client.post("/exec", json=request_data, headers=auth_headers) assert response2.status_code == 200 session_id_2 = response2.json()["session_id"] - + # Should reuse the same session assert session_id_1 == session_id_2 - + def test_exec_without_authentication(self, client): """Test executing code without authentication.""" - request_data = { - "code": "print('Hello')", - "lang": "py" - } - + request_data = {"code": "print('Hello')", "lang": "py"} + response = client.post("/exec", json=request_data) - + assert response.status_code == 401 - + def test_exec_with_invalid_api_key(self, client): """Test executing code with invalid API key.""" - request_data = { - "code": "print('Hello')", - "lang": "py" - } - + request_data = {"code": "print('Hello')", "lang": "py"} + headers = {"x-api-key": "invalid-key"} response = client.post("/exec", json=request_data, headers=headers) - + assert response.status_code == 401 - + def test_exec_service_error(self, client, auth_headers, mock_execution_service): """Test handling service errors during execution.""" mock_execution_service.execute_code.side_effect = Exception("Service error") - request_data = { - "code": "print('Hello')", - "lang": "py" - } + request_data = {"code": "print('Hello')", "lang": "py"} response = client.post("/exec", json=request_data, headers=auth_headers) # 503 Service Unavailable for backend service errors assert response.status_code == 503 assert "error" in response.json() - - def test_exec_response_format_compatibility(self, client, auth_headers, mock_execution_service): + + def test_exec_response_format_compatibility( + self, client, auth_headers, mock_execution_service + ): """Test that response format is compatible with LibreChat API.""" - request_data = { - "code": "print('Hello, World!')", - "lang": "py" - } - + request_data = {"code": "print('Hello, World!')", "lang": "py"} + response = client.post("/exec", json=request_data, headers=auth_headers) - + assert response.status_code == 200 response_data = response.json() - + # Check LibreChat-compatible structure required_fields = ["session_id", "files", "stdout", "stderr"] for field in required_fields: assert field in response_data - + # Check that files is a list assert isinstance(response_data["files"], list) - - @pytest.mark.skip(reason="Mock file service returns AsyncMock instead of proper values") - def test_exec_with_generated_files(self, client, auth_headers, mock_execution_service, mock_file_service): + + @pytest.mark.skip( + reason="Mock file service returns AsyncMock instead of proper values" + ) + def test_exec_with_generated_files( + self, client, auth_headers, mock_execution_service, mock_file_service + ): """Test execution that generates files.""" # Mock execution with file output execution_with_files = CodeExecution( @@ -414,45 +421,54 @@ def test_exec_with_generated_files(self, client, auth_headers, mock_execution_se content="/workspace/output.txt", mime_type="text/plain", size=17, - timestamp=datetime.now(timezone.utc) + timestamp=datetime.now(timezone.utc), ) - ] + ], + ) + + mock_execution_service.execute_code.return_value = ( + execution_with_files, + None, + None, + [], + "pool_hit", ) - - mock_execution_service.execute_code.return_value = (execution_with_files, None, None, [], "pool_hit") - + # Mock file service to return generated file from src.models.files import FileInfo + mock_file_info = FileInfo( file_id="generated-file-123", filename="output.txt", size=17, content_type="text/plain", created_at=datetime.now(timezone.utc), - path="/output.txt" + path="/output.txt", ) mock_file_service.list_files.return_value = [mock_file_info] - + request_data = { "code": "with open('output.txt', 'w') as f: f.write('Generated content')", - "lang": "py" + "lang": "py", } - + response = client.post("/exec", json=request_data, headers=auth_headers) - + assert response.status_code == 200 response_data = response.json() - + # Should include generated files in files array assert len(response_data["files"]) == 1 assert response_data["files"][0]["id"] == "generated-file-123" assert response_data["files"][0]["name"] == "output.txt" - - def test_exec_large_output_handling(self, client, auth_headers, mock_execution_service): + + def test_exec_large_output_handling( + self, client, auth_headers, mock_execution_service + ): """Test handling of large execution output.""" # Mock execution with large output large_output = "A" * 100000 # 100KB output - + large_execution = CodeExecution( execution_id="exec-large", session_id="test-session-123", @@ -464,23 +480,26 @@ def test_exec_large_output_handling(self, client, auth_headers, mock_execution_s ExecutionOutput( type=OutputType.STDOUT, content=large_output, - timestamp=datetime.now(timezone.utc) + timestamp=datetime.now(timezone.utc), ) - ] + ], ) - - mock_execution_service.execute_code.return_value = (large_execution, None, None, [], "pool_hit") - - request_data = { - "code": "print('A' * 100000)", - "lang": "py" - } - + + mock_execution_service.execute_code.return_value = ( + large_execution, + None, + None, + [], + "pool_hit", + ) + + request_data = {"code": "print('A' * 100000)", "lang": "py"} + response = client.post("/exec", json=request_data, headers=auth_headers) - + assert response.status_code == 200 response_data = response.json() - + # Output should be present in stdout field (may be truncated) assert "stdout" in response_data - assert len(response_data["stdout"]) > 0 \ No newline at end of file + assert len(response_data["stdout"]) > 0 diff --git a/tests/integration/test_file_api.py b/tests/integration/test_file_api.py index 65f1434..6419f4f 100644 --- a/tests/integration/test_file_api.py +++ b/tests/integration/test_file_api.py @@ -1,32 +1,11 @@ """Integration tests for file management API endpoints. These tests use real infrastructure (MinIO, Redis) - requires docker-compose up. +Fixtures (client, auth_headers, unique_session_id) are defined in conftest.py. """ import pytest -from fastapi.testclient import TestClient import io -import uuid - -from src.main import app - - -@pytest.fixture -def client(): - """Create test client.""" - return TestClient(app) - - -@pytest.fixture -def auth_headers(): - """Provide authentication headers for tests.""" - return {"x-api-key": "test-api-key-for-testing-12345"} - - -@pytest.fixture -def unique_session_id(): - """Generate unique session ID for test isolation.""" - return f"test-session-{uuid.uuid4().hex[:8]}" class TestFileUpload: @@ -56,7 +35,7 @@ def test_upload_multiple_files(self, client, auth_headers, unique_session_id): """Test uploading multiple files.""" files = [ ("files", ("file1.txt", io.BytesIO(b"Content 1"), "text/plain")), - ("files", ("file2.txt", io.BytesIO(b"Content 2"), "text/plain")) + ("files", ("file2.txt", io.BytesIO(b"Content 2"), "text/plain")), ] data = {"entity_id": unique_session_id} @@ -125,7 +104,9 @@ def test_list_files_simple_detail(self, client, auth_headers): session_id = upload_response.json()["session_id"] # List with simple detail - response = client.get(f"/files/{session_id}?detail=simple", headers=auth_headers) + response = client.get( + f"/files/{session_id}?detail=simple", headers=auth_headers + ) assert response.status_code == 200 files_list = response.json() @@ -155,7 +136,7 @@ def test_download_uploaded_file(self, client, auth_headers): download_response = client.get( f"/download/{session_id}/{file_id}", headers=auth_headers, - follow_redirects=False + follow_redirects=False, ) # Should redirect to MinIO presigned URL @@ -165,8 +146,7 @@ def test_download_uploaded_file(self, client, auth_headers): def test_download_nonexistent_file(self, client, auth_headers, unique_session_id): """Test downloading a file that doesn't exist.""" response = client.get( - f"/download/{unique_session_id}/nonexistent-file-id", - headers=auth_headers + f"/download/{unique_session_id}/nonexistent-file-id", headers=auth_headers ) assert response.status_code == 404 @@ -188,8 +168,7 @@ def test_delete_uploaded_file(self, client, auth_headers): # Delete the file delete_response = client.delete( - f"/files/{session_id}/{file_id}", - headers=auth_headers + f"/files/{session_id}/{file_id}", headers=auth_headers ) assert delete_response.status_code == 200 @@ -202,8 +181,7 @@ def test_delete_uploaded_file(self, client, auth_headers): def test_delete_nonexistent_file(self, client, auth_headers, unique_session_id): """Test deleting a file that doesn't exist.""" response = client.delete( - f"/files/{unique_session_id}/nonexistent-file-id", - headers=auth_headers + f"/files/{unique_session_id}/nonexistent-file-id", headers=auth_headers ) assert response.status_code == 404 @@ -214,7 +192,13 @@ class TestFileTypeRestrictions: def test_upload_blocked_exe_file(self, client, auth_headers): """Test that .exe files are blocked with 415 status.""" - files = {"files": ("malware.exe", io.BytesIO(b"MZ...fake exe"), "application/octet-stream")} + files = { + "files": ( + "malware.exe", + io.BytesIO(b"MZ...fake exe"), + "application/octet-stream", + ) + } response = client.post("/upload", files=files, headers=auth_headers) @@ -223,7 +207,13 @@ def test_upload_blocked_exe_file(self, client, auth_headers): def test_upload_blocked_dll_file(self, client, auth_headers): """Test that .dll files are blocked with 415 status.""" - files = {"files": ("library.dll", io.BytesIO(b"fake dll content"), "application/octet-stream")} + files = { + "files": ( + "library.dll", + io.BytesIO(b"fake dll content"), + "application/octet-stream", + ) + } response = client.post("/upload", files=files, headers=auth_headers) @@ -232,7 +222,13 @@ def test_upload_blocked_dll_file(self, client, auth_headers): def test_upload_blocked_bin_file(self, client, auth_headers): """Test that .bin files are blocked with 415 status.""" - files = {"files": ("binary.bin", io.BytesIO(b"binary content"), "application/octet-stream")} + files = { + "files": ( + "binary.bin", + io.BytesIO(b"binary content"), + "application/octet-stream", + ) + } response = client.post("/upload", files=files, headers=auth_headers) diff --git a/tests/integration/test_file_handling.py b/tests/integration/test_file_handling.py index 224f198..e9f4ab1 100644 --- a/tests/integration/test_file_handling.py +++ b/tests/integration/test_file_handling.py @@ -48,7 +48,7 @@ async def test_generated_image_is_valid_png(self, ssl_context, headers): plt.savefig('/mnt/data/test_chart.png', dpi=100) print('Chart saved') """, - "entity_id": "test-file-gen-png" + "entity_id": "test-file-gen-png", } async with session.post( @@ -80,7 +80,7 @@ async def test_generated_image_is_valid_png(self, ssl_context, headers): assert len(content) > 1000, f"File too small: {len(content)} bytes" # Check PNG magic bytes - assert content[:8] == b'\x89PNG\r\n\x1a\n', "Not a valid PNG file" + assert content[:8] == b"\x89PNG\r\n\x1a\n", "Not a valid PNG file" @pytest.mark.asyncio async def test_multiple_generated_files(self, ssl_context, headers): @@ -101,7 +101,7 @@ async def test_multiple_generated_files(self, ssl_context, headers): plt.savefig(f'/mnt/data/{name}.png') print(f'Created {name}.png') """, - "entity_id": "test-multi-files" + "entity_id": "test-multi-files", } async with session.post( @@ -130,12 +130,12 @@ async def test_multiple_generated_files(self, ssl_context, headers): assert dl_resp.status == 200 content = await dl_resp.read() - assert len(content) > 1000, ( - f"File {file_info['name']} too small: {len(content)} bytes" - ) - assert content[:4] == b'\x89PNG', ( - f"File {file_info['name']} is not a valid PNG" - ) + assert ( + len(content) > 1000 + ), f"File {file_info['name']} too small: {len(content)} bytes" + assert ( + content[:4] == b"\x89PNG" + ), f"File {file_info['name']} is not a valid PNG" @pytest.mark.asyncio async def test_text_file_generation(self, ssl_context, headers): @@ -151,7 +151,7 @@ async def test_text_file_generation(self, ssl_context, headers): f.write('This is a test file.\\n') print('Text file created') """, - "entity_id": "test-text-file" + "entity_id": "test-text-file", } async with session.post( @@ -200,7 +200,7 @@ async def test_csv_file_generation(self, ssl_context, headers): df.to_csv('/mnt/data/people.csv', index=False) print(f'Created CSV with {len(df)} rows') """, - "entity_id": "test-csv-file" + "entity_id": "test-csv-file", } async with session.post( @@ -241,6 +241,7 @@ async def test_file_generation_after_pool_acquisition(self, ssl_context, headers async with aiohttp.ClientSession(connector=connector) as session: # Use unique entity_id to get a fresh session/container from pool import time + entity_id = f"test-pool-file-{int(time.time())}" payload = { @@ -252,7 +253,7 @@ async def test_file_generation_after_pool_acquisition(self, ssl_context, headers plt.savefig('/mnt/data/pie.png') print('Pie chart created') """, - "entity_id": entity_id + "entity_id": entity_id, } async with session.post( @@ -278,10 +279,10 @@ async def test_file_generation_after_pool_acquisition(self, ssl_context, headers content = await dl_resp.read() # Should be a substantial PNG file, not a stub - assert len(content) > 5000, ( - f"File appears truncated: {len(content)} bytes" - ) - assert content[:8] == b'\x89PNG\r\n\x1a\n', "Invalid PNG" + assert ( + len(content) > 5000 + ), f"File appears truncated: {len(content)} bytes" + assert content[:8] == b"\x89PNG\r\n\x1a\n", "Invalid PNG" @pytest.mark.asyncio async def test_large_file_generation(self, ssl_context, headers): @@ -306,7 +307,7 @@ async def test_large_file_generation(self, ssl_context, headers): plt.savefig('/mnt/data/large_plot.png') print('Large plot created') """, - "entity_id": "test-large-file" + "entity_id": "test-large-file", } async with session.post( @@ -332,10 +333,10 @@ async def test_large_file_generation(self, ssl_context, headers): content = await dl_resp.read() # Large detailed plot should be > 50KB - assert len(content) > 50000, ( - f"Large file too small: {len(content)} bytes" - ) - assert content[:8] == b'\x89PNG\r\n\x1a\n', "Invalid PNG" + assert ( + len(content) > 50000 + ), f"Large file too small: {len(content)} bytes" + assert content[:8] == b"\x89PNG\r\n\x1a\n", "Invalid PNG" class TestUploadAnalyzeDownload: @@ -347,16 +348,20 @@ async def test_upload_csv_analyze_download_results(self, ssl_context, headers): connector = aiohttp.TCPConnector(ssl=ssl_context) async with aiohttp.ClientSession(connector=connector) as session: import time + entity_id = f"test-upload-analyze-{int(time.time())}" # Step 1: Upload a CSV file csv_content = "product,quantity,price\nWidget A,100,9.99\nWidget B,250,14.99\nWidget C,75,24.99\nWidget D,300,4.99\nWidget E,150,19.99" form_data = aiohttp.FormData() - form_data.add_field('files', csv_content.encode(), - filename='sales_data.csv', - content_type='text/csv') - form_data.add_field('entity_id', entity_id) + form_data.add_field( + "files", + csv_content.encode(), + filename="sales_data.csv", + content_type="text/csv", + ) + form_data.add_field("entity_id", entity_id) upload_headers = {"X-API-Key": API_KEY} @@ -364,7 +369,7 @@ async def test_upload_csv_analyze_download_results(self, ssl_context, headers): f"{API_URL}/upload", data=form_data, headers=upload_headers, - ssl=ssl_context + ssl=ssl_context, ) as resp: assert resp.status == 200, f"Upload failed: {await resp.text()}" upload_result = await resp.json() @@ -379,7 +384,9 @@ async def test_upload_csv_analyze_download_results(self, ssl_context, headers): # Step 2: Execute analysis code that reads the uploaded file and creates a report from textwrap import dedent - analysis_code = dedent(""" + + analysis_code = dedent( + """ import pandas as pd # Read the uploaded CSV (files are placed in /mnt/data/) @@ -406,24 +413,20 @@ async def test_upload_csv_analyze_download_results(self, ssl_context, headers): f.write(report) print(report) - """).strip() + """ + ).strip() exec_payload = { "lang": "py", "code": analysis_code, "entity_id": entity_id, - "files": [{ - "id": file_id, - "session_id": session_id, - "name": "sales_data.csv" - }] + "files": [ + {"id": file_id, "session_id": session_id, "name": "sales_data.csv"} + ], } async with session.post( - f"{API_URL}/exec", - json=exec_payload, - headers=headers, - ssl=ssl_context + f"{API_URL}/exec", json=exec_payload, headers=headers, ssl=ssl_context ) as resp: assert resp.status == 200, f"Exec failed: {await resp.text()}" exec_result = await resp.json() @@ -431,14 +434,21 @@ async def test_upload_csv_analyze_download_results(self, ssl_context, headers): # Verify execution succeeded stdout = exec_result.get("stdout", "") stderr = exec_result.get("stderr", "") - assert "Sales Analysis Report" in stdout, f"Analysis failed. stdout: {stdout}, stderr: {stderr}" + assert ( + "Sales Analysis Report" in stdout + ), f"Analysis failed. stdout: {stdout}, stderr: {stderr}" # Find generated files files = exec_result.get("files", []) assert len(files) >= 2, f"Expected 2 output files, got {len(files)}" - csv_output = next((f for f in files if "analyzed_sales.csv" in f.get("name", "")), None) - txt_output = next((f for f in files if "sales_report.txt" in f.get("name", "")), None) + csv_output = next( + (f for f in files if "analyzed_sales.csv" in f.get("name", "")), + None, + ) + txt_output = next( + (f for f in files if "sales_report.txt" in f.get("name", "")), None + ) assert csv_output is not None, "analyzed_sales.csv not found in output" assert txt_output is not None, "sales_report.txt not found in output" @@ -448,19 +458,25 @@ async def test_upload_csv_analyze_download_results(self, ssl_context, headers): # Step 3: Download and verify the analyzed CSV download_url = f"{API_URL}/download/{exec_session_id}/{csv_output['id']}" - async with session.get(download_url, headers=upload_headers, ssl=ssl_context) as resp: + async with session.get( + download_url, headers=upload_headers, ssl=ssl_context + ) as resp: assert resp.status == 200, f"CSV download failed: {resp.status}" csv_result = await resp.text() # Verify the analysis added the total_value column - assert "total_value" in csv_result, "Analysis column not found in output CSV" + assert ( + "total_value" in csv_result + ), "Analysis column not found in output CSV" assert "Widget A" in csv_result # Widget A: 100 * 9.99 = 999.0 assert "999" in csv_result, "Calculated total_value not found" # Step 4: Download and verify the text report download_url = f"{API_URL}/download/{exec_session_id}/{txt_output['id']}" - async with session.get(download_url, headers=upload_headers, ssl=ssl_context) as resp: + async with session.get( + download_url, headers=upload_headers, ssl=ssl_context + ) as resp: assert resp.status == 200, f"Report download failed: {resp.status}" report_content = await resp.text() @@ -475,6 +491,7 @@ async def test_upload_image_process_download(self, ssl_context, headers): connector = aiohttp.TCPConnector(ssl=ssl_context) async with aiohttp.ClientSession(connector=connector) as session: import time + entity_id = f"test-image-process-{int(time.time())}" # Step 1: Create and upload a simple PNG image (100x100 red square) @@ -501,7 +518,7 @@ async def test_upload_image_process_download(self, ssl_context, headers): f"{API_URL}/exec", json={"lang": "py", "code": create_image_code, "entity_id": entity_id}, headers=headers, - ssl=ssl_context + ssl=ssl_context, ) as resp: assert resp.status == 200 result = await resp.json() @@ -509,7 +526,9 @@ async def test_upload_image_process_download(self, ssl_context, headers): # Get the created image file files = result.get("files", []) - input_image = next((f for f in files if "test_input.png" in f.get("name", "")), None) + input_image = next( + (f for f in files if "test_input.png" in f.get("name", "")), None + ) assert input_image is not None, "Test image not created" # Step 2: Process the image (apply blur and edge detection) @@ -539,29 +558,34 @@ async def test_upload_image_process_download(self, ssl_context, headers): "lang": "py", "code": process_code, "entity_id": entity_id, - "files": [{ - "id": input_image['id'], - "session_id": session_id, - "name": "test_input.png" - }] + "files": [ + { + "id": input_image["id"], + "session_id": session_id, + "name": "test_input.png", + } + ], } async with session.post( - f"{API_URL}/exec", - json=exec_payload, - headers=headers, - ssl=ssl_context + f"{API_URL}/exec", json=exec_payload, headers=headers, ssl=ssl_context ) as resp: assert resp.status == 200, f"Processing failed: {await resp.text()}" result = await resp.json() stdout = result.get("stdout", "") stderr = result.get("stderr", "") - assert "Processed images saved" in stdout, f"Processing failed. stderr: {stderr}" + assert ( + "Processed images saved" in stdout + ), f"Processing failed. stderr: {stderr}" files = result.get("files", []) - blurred_file = next((f for f in files if "blurred.png" in f.get("name", "")), None) - edges_file = next((f for f in files if "edges.png" in f.get("name", "")), None) + blurred_file = next( + (f for f in files if "blurred.png" in f.get("name", "")), None + ) + edges_file = next( + (f for f in files if "edges.png" in f.get("name", "")), None + ) assert blurred_file is not None, "blurred.png not found" assert edges_file is not None, "edges.png not found" @@ -571,19 +595,23 @@ async def test_upload_image_process_download(self, ssl_context, headers): # Download blurred image download_url = f"{API_URL}/download/{session_id}/{blurred_file['id']}" - async with session.get(download_url, headers=upload_headers, ssl=ssl_context) as resp: + async with session.get( + download_url, headers=upload_headers, ssl=ssl_context + ) as resp: assert resp.status == 200 content = await resp.read() assert len(content) > 100, f"Blurred image too small: {len(content)}" - assert content[:4] == b'\x89PNG', "Blurred output is not a valid PNG" + assert content[:4] == b"\x89PNG", "Blurred output is not a valid PNG" # Download edges image download_url = f"{API_URL}/download/{session_id}/{edges_file['id']}" - async with session.get(download_url, headers=upload_headers, ssl=ssl_context) as resp: + async with session.get( + download_url, headers=upload_headers, ssl=ssl_context + ) as resp: assert resp.status == 200 content = await resp.read() assert len(content) > 100, f"Edges image too small: {len(content)}" - assert content[:4] == b'\x89PNG', "Edges output is not a valid PNG" + assert content[:4] == b"\x89PNG", "Edges output is not a valid PNG" @pytest.mark.asyncio async def test_upload_json_transform_download(self, ssl_context, headers): @@ -592,6 +620,7 @@ async def test_upload_json_transform_download(self, ssl_context, headers): async with aiohttp.ClientSession(connector=connector) as session: import time import json + entity_id = f"test-json-transform-{int(time.time())}" # Step 1: Upload JSON data @@ -601,15 +630,18 @@ async def test_upload_json_transform_download(self, ssl_context, headers): {"name": "Bob", "age": 25, "department": "Marketing"}, {"name": "Charlie", "age": 35, "department": "Engineering"}, {"name": "Diana", "age": 28, "department": "Sales"}, - {"name": "Eve", "age": 32, "department": "Engineering"} + {"name": "Eve", "age": 32, "department": "Engineering"}, ] } form_data = aiohttp.FormData() - form_data.add_field('files', json.dumps(json_data).encode(), - filename='users.json', - content_type='application/json') - form_data.add_field('entity_id', entity_id) + form_data.add_field( + "files", + json.dumps(json_data).encode(), + filename="users.json", + content_type="application/json", + ) + form_data.add_field("entity_id", entity_id) upload_headers = {"X-API-Key": API_KEY} @@ -617,7 +649,7 @@ async def test_upload_json_transform_download(self, ssl_context, headers): f"{API_URL}/upload", data=form_data, headers=upload_headers, - ssl=ssl_context + ssl=ssl_context, ) as resp: assert resp.status == 200, f"Upload failed: {await resp.text()}" upload_result = await resp.json() @@ -627,7 +659,9 @@ async def test_upload_json_transform_download(self, ssl_context, headers): # Step 2: Transform the data from textwrap import dedent - transform_code = dedent(""" + + transform_code = dedent( + """ import json import pandas as pd @@ -656,30 +690,28 @@ async def test_upload_json_transform_download(self, ssl_context, headers): json.dump(output, f, indent=2) print(json.dumps(output, indent=2)) - """).strip() + """ + ).strip() exec_payload = { "lang": "py", "code": transform_code, "entity_id": entity_id, - "files": [{ - "id": file_id, - "session_id": session_id, - "name": "users.json" - }] + "files": [ + {"id": file_id, "session_id": session_id, "name": "users.json"} + ], } async with session.post( - f"{API_URL}/exec", - json=exec_payload, - headers=headers, - ssl=ssl_context + f"{API_URL}/exec", json=exec_payload, headers=headers, ssl=ssl_context ) as resp: assert resp.status == 200 result = await resp.json() files = result.get("files", []) - json_output = next((f for f in files if "analysis.json" in f.get("name", "")), None) + json_output = next( + (f for f in files if "analysis.json" in f.get("name", "")), None + ) assert json_output is not None, "analysis.json not found" # Use session_id from exec result for downloading @@ -687,12 +719,14 @@ async def test_upload_json_transform_download(self, ssl_context, headers): # Step 3: Download and verify the result download_url = f"{API_URL}/download/{exec_session_id}/{json_output['id']}" - async with session.get(download_url, headers=upload_headers, ssl=ssl_context) as resp: + async with session.get( + download_url, headers=upload_headers, ssl=ssl_context + ) as resp: assert resp.status == 200 content = await resp.text() result_data = json.loads(content) - assert result_data['total_users'] == 5 - assert 'department_breakdown' in result_data - assert 'Engineering' in result_data['department_breakdown'] - assert result_data['department_breakdown']['Engineering']['count'] == 3 + assert result_data["total_users"] == 5 + assert "department_breakdown" in result_data + assert "Engineering" in result_data["department_breakdown"] + assert result_data["department_breakdown"]["Engineering"]["count"] == 3 diff --git a/tests/integration/test_librechat_compat.py b/tests/integration/test_librechat_compat.py index 78df156..3a474a0 100644 --- a/tests/integration/test_librechat_compat.py +++ b/tests/integration/test_librechat_compat.py @@ -43,10 +43,7 @@ def auth_headers(): def mock_exec_response(): """Standard successful execution response.""" return ExecResponse( - session_id="test-session-123", - stdout="output\n", - stderr="", - files=[] + session_id="test-session-123", stdout="output\n", stderr="", files=[] ) @@ -54,6 +51,7 @@ def mock_exec_response(): # LIBRECHAT EXEC REQUEST FORMAT # ============================================================================= + class TestLibreChatExecRequest: """Test /exec request format exactly as LibreChat sends it. @@ -66,8 +64,10 @@ class TestLibreChatExecRequest: - files?: Array<{id, session_id, name}> """ - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') - def test_librechat_minimal_request(self, mock_execute, client, auth_headers, mock_exec_response): + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_librechat_minimal_request( + self, mock_execute, client, auth_headers, mock_exec_response + ): """ Test LibreChat minimal request format. @@ -75,17 +75,16 @@ def test_librechat_minimal_request(self, mock_execute, client, auth_headers, moc """ mock_execute.return_value = mock_exec_response - request = { - "code": "print('hello')", - "lang": "py" - } + request = {"code": "print('hello')", "lang": "py"} response = client.post("/exec", json=request, headers=auth_headers) assert response.status_code == 200 mock_execute.assert_called_once() - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') - def test_librechat_request_with_user_id(self, mock_execute, client, auth_headers, mock_exec_response): + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_librechat_request_with_user_id( + self, mock_execute, client, auth_headers, mock_exec_response + ): """ Test LibreChat request with user_id for tracking. @@ -93,17 +92,15 @@ def test_librechat_request_with_user_id(self, mock_execute, client, auth_headers """ mock_execute.return_value = mock_exec_response - request = { - "code": "print('hello')", - "lang": "py", - "user_id": "user_xyz789" - } + request = {"code": "print('hello')", "lang": "py", "user_id": "user_xyz789"} response = client.post("/exec", json=request, headers=auth_headers) assert response.status_code == 200 - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') - def test_librechat_request_with_files(self, mock_execute, client, auth_headers, mock_exec_response): + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_librechat_request_with_files( + self, mock_execute, client, auth_headers, mock_exec_response + ): """ Test LibreChat request with file references. @@ -119,16 +116,18 @@ def test_librechat_request_with_files(self, mock_execute, client, auth_headers, { "id": "file-svc-abc123", "session_id": "sess_xyz789", - "name": "data.csv" + "name": "data.csv", } - ] + ], } response = client.post("/exec", json=request, headers=auth_headers) assert response.status_code == 200 - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') - def test_librechat_request_with_multiple_files(self, mock_execute, client, auth_headers, mock_exec_response): + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_librechat_request_with_multiple_files( + self, mock_execute, client, auth_headers, mock_exec_response + ): """Test LibreChat request with multiple file references.""" mock_execute.return_value = mock_exec_response @@ -138,15 +137,17 @@ def test_librechat_request_with_multiple_files(self, mock_execute, client, auth_ "files": [ {"id": "file-1", "session_id": "sess-1", "name": "file1.txt"}, {"id": "file-2", "session_id": "sess-2", "name": "file2.txt"}, - {"id": "file-3", "session_id": "sess-3", "name": "file3.csv"} - ] + {"id": "file-3", "session_id": "sess-3", "name": "file3.csv"}, + ], } response = client.post("/exec", json=request, headers=auth_headers) assert response.status_code == 200 - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') - def test_librechat_args_as_array(self, mock_execute, client, auth_headers, mock_exec_response): + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_librechat_args_as_array( + self, mock_execute, client, auth_headers, mock_exec_response + ): """ Test LibreChat args field format. @@ -158,14 +159,16 @@ def test_librechat_args_as_array(self, mock_execute, client, auth_headers, mock_ request = { "code": "print('test')", "lang": "py", - "args": ["arg1", "arg2", "arg3"] + "args": ["arg1", "arg2", "arg3"], } response = client.post("/exec", json=request, headers=auth_headers) assert response.status_code == 200 - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') - def test_librechat_request_with_session_id(self, mock_execute, client, auth_headers, mock_exec_response): + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_librechat_request_with_session_id( + self, mock_execute, client, auth_headers, mock_exec_response + ): """ Test LibreChat request with session_id for file access. @@ -178,7 +181,7 @@ def test_librechat_request_with_session_id(self, mock_execute, client, auth_head request = { "code": "import os; print(os.listdir('/mnt/data'))", "lang": "py", - "session_id": "prev-session-abc123" + "session_id": "prev-session-abc123", } response = client.post("/exec", json=request, headers=auth_headers) @@ -189,6 +192,7 @@ def test_librechat_request_with_session_id(self, mock_execute, client, auth_head # LIBRECHAT EXEC RESPONSE FORMAT # ============================================================================= + class TestLibreChatExecResponse: """Test /exec response format exactly as LibreChat expects it. @@ -201,7 +205,7 @@ class TestLibreChatExecResponse: Additional fields (has_state, state_size, state_hash) are allowed and ignored. """ - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") def test_response_has_required_fields(self, mock_execute, client, auth_headers): """ Test LibreChat response has required fields: session_id, files, stdout, stderr. @@ -211,16 +215,12 @@ def test_response_has_required_fields(self, mock_execute, client, auth_headers): and will be ignored by LibreChat. """ mock_execute.return_value = ExecResponse( - session_id="resp-session-123", - stdout="test output\n", - stderr="", - files=[] + session_id="resp-session-123", stdout="test output\n", stderr="", files=[] ) - response = client.post("/exec", json={ - "code": "print('test')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", json={"code": "print('test')", "lang": "py"}, headers=auth_headers + ) data = response.json() @@ -236,7 +236,7 @@ def test_response_has_required_fields(self, mock_execute, client, auth_headers): assert isinstance(data["stdout"], str) assert isinstance(data["stderr"], str) - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") def test_stdout_ends_with_newline(self, mock_execute, client, auth_headers): """ Test that stdout ends with newline. @@ -244,21 +244,19 @@ def test_stdout_ends_with_newline(self, mock_execute, client, auth_headers): LibreChat UI expects this for proper display. """ mock_execute.return_value = ExecResponse( - session_id="resp-session-123", - stdout="hello\n", - stderr="", - files=[] + session_id="resp-session-123", stdout="hello\n", stderr="", files=[] ) - response = client.post("/exec", json={ - "code": "print('hello')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", json={"code": "print('hello')", "lang": "py"}, headers=auth_headers + ) data = response.json() - assert data["stdout"].endswith("\n"), "stdout must end with newline for LibreChat" + assert data["stdout"].endswith( + "\n" + ), "stdout must end with newline for LibreChat" - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") def test_files_array_format(self, mock_execute, client, auth_headers): """ Test generated files format: {id, name, path?} @@ -269,15 +267,12 @@ def test_files_array_format(self, mock_execute, client, auth_headers): session_id="resp-session-123", stdout="", stderr="", - files=[ - FileRef(id="gen-file-abc", name="output.png", path="/output.png") - ] + files=[FileRef(id="gen-file-abc", name="output.png", path="/output.png")], ) - response = client.post("/exec", json={ - "code": "generate image", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", json={"code": "generate image", "lang": "py"}, headers=auth_headers + ) data = response.json() assert len(data["files"]) == 1 @@ -288,56 +283,49 @@ def test_files_array_format(self, mock_execute, client, auth_headers): assert "name" in file_ref, "File must have 'name' field" # path is optional but typically included - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") def test_empty_stderr_on_success(self, mock_execute, client, auth_headers): """Test stderr is empty string on successful execution.""" mock_execute.return_value = ExecResponse( - session_id="resp-session-123", - stdout="ok\n", - stderr="", - files=[] + session_id="resp-session-123", stdout="ok\n", stderr="", files=[] ) - response = client.post("/exec", json={ - "code": "print('ok')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", json={"code": "print('ok')", "lang": "py"}, headers=auth_headers + ) data = response.json() assert data["stderr"] == "", "stderr should be empty on success" - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") def test_stderr_populated_on_error(self, mock_execute, client, auth_headers): """Test stderr contains error message on failure.""" mock_execute.return_value = ExecResponse( session_id="resp-session-123", stdout="", stderr="Traceback: Exception: error\n", - files=[] + files=[], ) - response = client.post("/exec", json={ - "code": "raise Exception('error')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "raise Exception('error')", "lang": "py"}, + headers=auth_headers, + ) data = response.json() assert len(data["stderr"]) > 0, "stderr should contain the error" - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") def test_session_id_is_string(self, mock_execute, client, auth_headers): """Test session_id is always a non-empty string.""" mock_execute.return_value = ExecResponse( - session_id="resp-session-123", - stdout="", - stderr="", - files=[] + session_id="resp-session-123", stdout="", stderr="", files=[] ) - response = client.post("/exec", json={ - "code": "pass", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", json={"code": "pass", "lang": "py"}, headers=auth_headers + ) data = response.json() assert isinstance(data["session_id"], str) @@ -348,6 +336,7 @@ def test_session_id_is_string(self, mock_execute, client, auth_headers): # LIBRECHAT FILE UPLOAD FORMAT # ============================================================================= + class TestLibreChatFileUpload: """Test /upload format exactly as LibreChat sends it. @@ -366,6 +355,7 @@ def setup_mocks(self): mock_file_service.store_uploaded_file.return_value = "lc-file-123" from src.dependencies.services import get_file_service + app.dependency_overrides[get_file_service] = lambda: mock_file_service yield @@ -380,7 +370,9 @@ def test_multipart_upload_format(self, client, auth_headers): From crud.js: form.append('file', stream, filename) """ # LibreChat uses 'file' (singular), not 'files' - files = {"file": ("document.pdf", io.BytesIO(b"PDF content"), "application/pdf")} + files = { + "file": ("document.pdf", io.BytesIO(b"PDF content"), "application/pdf") + } data = {"entity_id": "asst_librechat"} response = client.post("/upload", files=files, data=data, headers=auth_headers) @@ -427,7 +419,7 @@ def test_librechat_upload_with_user_id_header(self, client, auth_headers): headers = { **auth_headers, "User-Id": "user_abc123", - "User-Agent": "LibreChat/1.0" + "User-Agent": "LibreChat/1.0", } response = client.post("/upload", files=files, data=data, headers=headers) @@ -440,6 +432,7 @@ def test_librechat_upload_with_user_id_header(self, client, auth_headers): # LIBRECHAT FILE RETRIEVAL # ============================================================================= + class TestLibreChatFileRetrieval: """Test file retrieval endpoints as LibreChat uses them. @@ -456,6 +449,7 @@ def setup_mocks(self): self.mock_file_service = AsyncMock() from src.dependencies.services import get_file_service + app.dependency_overrides[get_file_service] = lambda: self.mock_file_service yield @@ -476,13 +470,12 @@ def test_files_endpoint_with_detail_summary(self, client, auth_headers): size=1024, content_type="image/png", created_at=datetime.now(timezone.utc), - path="/output.png" + path="/output.png", ) ] response = client.get( - "/files/test-session-123?detail=summary", - headers=auth_headers + "/files/test-session-123?detail=summary", headers=auth_headers ) assert response.status_code == 200 @@ -503,13 +496,12 @@ def test_files_endpoint_with_detail_full(self, client, auth_headers): size=2048, content_type="text/csv", created_at=datetime.now(timezone.utc), - path="/data.csv" + path="/data.csv", ) ] response = client.get( - "/files/test-session-456?detail=full", - headers=auth_headers + "/files/test-session-456?detail=full", headers=auth_headers ) assert response.status_code == 200 @@ -527,12 +519,11 @@ def test_download_endpoint(self, client, auth_headers): self.mock_file_service.get_file.return_value = ( io.BytesIO(b"file content here"), "output.txt", - "text/plain" + "text/plain", ) response = client.get( - "/download/test-session-789/file-abc", - headers=auth_headers + "/download/test-session-789/file-abc", headers=auth_headers ) # Should return file content or appropriate response @@ -544,6 +535,7 @@ def test_download_endpoint(self, client, auth_headers): # LIBRECHAT AUTHENTICATION # ============================================================================= + class TestLibreChatAuthentication: """Test authentication exactly as LibreChat uses it. @@ -568,6 +560,7 @@ def test_x_api_key_header(self, client): # LIBRECHAT ERROR HANDLING # ============================================================================= + class TestLibreChatErrors: """Test error handling as LibreChat expects. @@ -594,7 +587,7 @@ def test_auth_error_format(self, client): data = response.json() assert "error" in data - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") def test_execution_error_returns_200(self, mock_execute, client, auth_headers): """ Test that code execution errors still return 200. @@ -605,13 +598,14 @@ def test_execution_error_returns_200(self, mock_execute, client, auth_headers): session_id="err-session", stdout="", stderr="SyntaxError: invalid syntax\n", - files=[] + files=[], ) - response = client.post("/exec", json={ - "code": "this is not valid python [[[", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "this is not valid python [[[", "lang": "py"}, + headers=auth_headers, + ) # CRITICAL: Should return 200, not 4xx or 5xx assert response.status_code == 200 @@ -623,20 +617,21 @@ def test_execution_error_returns_200(self, mock_execute, client, auth_headers): assert "stdout" in data assert "stderr" in data - @patch('src.services.orchestrator.ExecutionOrchestrator.execute') + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") def test_timeout_returns_200(self, mock_execute, client, auth_headers): """Test that timeout still returns 200 with appropriate message.""" mock_execute.return_value = ExecResponse( session_id="timeout-session", stdout="", stderr="Execution timed out after 30 seconds\n", - files=[] + files=[], ) - response = client.post("/exec", json={ - "code": "import time; time.sleep(9999)", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "import time; time.sleep(9999)", "lang": "py"}, + headers=auth_headers, + ) # Should return 200 even for timeout assert response.status_code == 200 diff --git a/tests/integration/test_security_integration.py b/tests/integration/test_security_integration.py index 0daacb0..eccbe21 100644 --- a/tests/integration/test_security_integration.py +++ b/tests/integration/test_security_integration.py @@ -11,73 +11,76 @@ class TestSecurityIntegration: """Test security middleware integration with the main application.""" - + @pytest.fixture def client(self): """Create test client.""" return TestClient(app) - + @pytest.fixture def valid_headers(self): """Valid API key headers for testing.""" return {"x-api-key": "test-api-key-for-testing-12345"} - + def test_health_endpoint_no_auth(self, client): """Test that health endpoint doesn't require authentication.""" response = client.get("/health") assert response.status_code == 200 assert response.json()["status"] == "healthy" - + def test_docs_endpoint_no_auth(self, client): """Test that docs endpoint doesn't require authentication.""" response = client.get("/docs") assert response.status_code == 200 - + def test_protected_endpoint_no_auth(self, client): """Test that protected endpoints require authentication.""" # Try to access a protected endpoint without API key response = client.get("/sessions") assert response.status_code == 401 assert "API key" in response.json()["error"] - - @patch('src.services.auth.settings') + + @patch("src.services.auth.settings") def test_protected_endpoint_invalid_auth(self, mock_settings, client): """Test protected endpoint with invalid API key.""" mock_settings.api_key = "correct-key" - + headers = {"x-api-key": "wrong-key"} response = client.get("/sessions", headers=headers) assert response.status_code == 401 - + def test_protected_endpoint_valid_auth(self, client, valid_headers): """Test protected endpoint with valid API key.""" # Use the test API key from conftest (test-api-key-for-testing-12345) response = client.get("/sessions", headers=valid_headers) # Should not be 401 (auth failure) assert response.status_code != 401 - + def test_security_headers_present(self, client): """Test that security headers are added to responses.""" response = client.get("/health") - + # Check for security headers expected_headers = [ - 'x-content-type-options', - 'x-frame-options', - 'x-xss-protection', - 'strict-transport-security', - 'content-security-policy' + "x-content-type-options", + "x-frame-options", + "x-xss-protection", + "strict-transport-security", + "content-security-policy", ] - + for header in expected_headers: assert header in response.headers - + def test_cors_headers_present(self, client): """Test that CORS headers are properly configured.""" response = client.options("/health") # CORS headers should be present for OPTIONS requests - assert response.status_code in [200, 405] # Either allowed or method not allowed - + assert response.status_code in [ + 200, + 405, + ] # Either allowed or method not allowed + def test_authorization_header_fallback(self, client): """Test that Authorization header works as fallback for API key.""" # Use the test API key from conftest @@ -85,31 +88,29 @@ def test_authorization_header_fallback(self, client): response = client.get("/sessions", headers=headers) # Should not be 401 (auth failure) assert response.status_code != 401 - + def test_request_size_limit(self, client): """Test request size limiting.""" # Create a large payload (this is a basic test) large_data = {"data": "x" * 1000} - + response = client.post( - "/sessions", - json=large_data, - headers={"x-api-key": "test-key"} + "/sessions", json=large_data, headers={"x-api-key": "test-key"} ) - + # Should either process or fail with auth, not with size limit for this small payload assert response.status_code != 413 - + def test_invalid_content_type(self, client): """Test content type validation.""" headers = { "x-api-key": "test-key", - "content-type": "application/xml" # Not allowed + "content-type": "application/xml", # Not allowed } - + response = client.post("/sessions", data="", headers=headers) assert response.status_code == 415 # Unsupported Media Type - + def test_multiple_auth_methods(self, client): """Test that multiple authentication methods work.""" test_key = "test-api-key-for-testing-12345" @@ -118,12 +119,14 @@ def test_multiple_auth_methods(self, client): response1 = client.get("/sessions", headers={"x-api-key": test_key}) # Test Authorization Bearer header - response2 = client.get("/sessions", headers={"Authorization": f"Bearer {test_key}"}) + response2 = client.get( + "/sessions", headers={"Authorization": f"Bearer {test_key}"} + ) # Both should have same result (not 401) assert response1.status_code == response2.status_code assert response1.status_code != 401 - + def test_case_insensitive_headers(self, client): """Test that header names are case insensitive.""" test_key = "test-api-key-for-testing-12345" @@ -132,7 +135,7 @@ def test_case_insensitive_headers(self, client): headers_variations = [ {"X-API-KEY": test_key}, {"x-api-key": test_key}, - {"X-Api-Key": test_key} + {"X-Api-Key": test_key}, ] for headers in headers_variations: @@ -142,53 +145,53 @@ def test_case_insensitive_headers(self, client): class TestRateLimitingIntegration: """Test rate limiting integration.""" - + @pytest.fixture def client(self): """Create test client.""" return TestClient(app) - - @patch('src.services.auth.settings') + + @patch("src.services.auth.settings") def test_rate_limiting_basic(self, mock_settings, client): """Test basic rate limiting functionality.""" mock_settings.api_key = "test-api-key" - + # This test would need Redis to be properly mocked # For now, just verify the endpoint responds headers = {"x-api-key": "test-api-key"} response = client.get("/sessions", headers=headers) - + # Should not fail with rate limiting initially assert response.status_code != 429 class TestSecurityValidationIntegration: """Test security validation in real requests.""" - + @pytest.fixture def client(self): """Create test client.""" return TestClient(app) - + def test_path_traversal_protection(self, client): """Test protection against path traversal attacks.""" - with patch('src.services.auth.settings') as mock_settings: + with patch("src.services.auth.settings") as mock_settings: mock_settings.api_key = "test-api-key" - + # Try path traversal in URL malicious_paths = [ "/sessions/../../../etc/passwd", "/sessions/%2e%2e%2f%2e%2e%2fetc%2fpasswd", - "/sessions/..\\..\\windows\\system32" + "/sessions/..\\..\\windows\\system32", ] - + headers = {"x-api-key": "test-api-key"} - + for path in malicious_paths: response = client.get(path, headers=headers) # Should either be 404 (not found) or other error, not 200 assert response.status_code != 200 - + @pytest.mark.skip(reason="httpx test client doesn't support null bytes in URLs") def test_null_byte_injection(self, client): """Test protection against null byte injection.""" @@ -198,13 +201,18 @@ def test_null_byte_injection(self, client): response = client.get("/sessions/test\x00", headers=headers) # Should handle gracefully assert response.status_code in [400, 404, 422] # Bad request or not found - + def test_oversized_headers(self, client): """Test handling of oversized headers.""" # Create very large header value large_value = "x" * 10000 headers = {"x-api-key": large_value} - + response = client.get("/sessions", headers=headers) # Should either reject or handle gracefully - assert response.status_code in [400, 401, 413, 431] # Various possible error codes \ No newline at end of file + assert response.status_code in [ + 400, + 401, + 413, + 431, + ] # Various possible error codes diff --git a/tests/integration/test_session_behavior.py b/tests/integration/test_session_behavior.py index abec240..7881f00 100644 --- a/tests/integration/test_session_behavior.py +++ b/tests/integration/test_session_behavior.py @@ -29,7 +29,9 @@ def auth_headers(): return {"x-api-key": "test-api-key-for-testing-12345"} -def create_session(session_id: str, entity_id: str = None, metadata: dict = None) -> Session: +def create_session( + session_id: str, entity_id: str = None, metadata: dict = None +) -> Session: """Helper to create a session with specific properties.""" meta = metadata or {} if entity_id: @@ -41,7 +43,7 @@ def create_session(session_id: str, entity_id: str = None, metadata: dict = None created_at=datetime.now(timezone.utc), last_activity=datetime.now(timezone.utc), expires_at=datetime.now(timezone.utc) + timedelta(hours=24), - metadata=meta + metadata=meta, ) @@ -59,9 +61,9 @@ def create_execution(session_id: str, stdout: str = "output") -> CodeExecution: ExecutionOutput( type=OutputType.STDOUT, content=stdout, - timestamp=datetime.now(timezone.utc) + timestamp=datetime.now(timezone.utc), ) - ] + ], ) @@ -69,6 +71,7 @@ def create_execution(session_id: str, stdout: str = "output") -> CodeExecution: # SESSION CREATION BEHAVIOR # ============================================================================= + class TestSessionCreation: """Test session creation behavior.""" @@ -83,21 +86,33 @@ def test_session_created_on_first_exec(self, client, auth_headers): mock_execution_service = AsyncMock() # execute_code returns (execution, container, new_state, state_errors, container_source) - mock_execution_service.execute_code.return_value = (mock_execution, None, None, [], "pool_hit") + mock_execution_service.execute_code.return_value = ( + mock_execution, + None, + None, + [], + "pool_hit", + ) mock_file_service = AsyncMock() mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service try: - response = client.post("/exec", json={ - "code": "print('hello')", - "lang": "py" - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('hello')", "lang": "py"}, + headers=auth_headers, + ) assert response.status_code == 200 data = response.json() @@ -123,27 +138,41 @@ def test_session_created_with_entity_id(self, client, auth_headers): mock_execution_service = AsyncMock() # execute_code returns (execution, container, new_state, state_errors, container_source) - mock_execution_service.execute_code.return_value = (mock_execution, None, None, [], "pool_hit") + mock_execution_service.execute_code.return_value = ( + mock_execution, + None, + None, + [], + "pool_hit", + ) mock_file_service = AsyncMock() mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service try: - response = client.post("/exec", json={ - "code": "print('hello')", - "lang": "py", - "entity_id": entity_id - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('hello')", "lang": "py", "entity_id": entity_id}, + headers=auth_headers, + ) assert response.status_code == 200 # Verify session was created (entity_id is used for lookup, not stored in metadata) - assert mock_session_service.create_session.called or mock_session_service.get_session.called + assert ( + mock_session_service.create_session.called + or mock_session_service.get_session.called + ) # Response should contain a session_id assert "session_id" in response.json() finally: @@ -161,22 +190,33 @@ def test_session_created_with_user_id(self, client, auth_headers): mock_execution_service = AsyncMock() # execute_code returns (execution, container, new_state, state_errors, container_source) - mock_execution_service.execute_code.return_value = (mock_execution, None, None, [], "pool_hit") + mock_execution_service.execute_code.return_value = ( + mock_execution, + None, + None, + [], + "pool_hit", + ) mock_file_service = AsyncMock() mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service try: - response = client.post("/exec", json={ - "code": "print('hello')", - "lang": "py", - "user_id": user_id - }, headers=auth_headers) + response = client.post( + "/exec", + json={"code": "print('hello')", "lang": "py", "user_id": user_id}, + headers=auth_headers, + ) assert response.status_code == 200 finally: @@ -187,6 +227,7 @@ def test_session_created_with_user_id(self, client, auth_headers): # SESSION REUSE BEHAVIOR # ============================================================================= + class TestSessionReuse: """Test session reuse behavior.""" @@ -203,30 +244,41 @@ def test_session_reused_with_same_entity_id(self, client, auth_headers): mock_execution_service = AsyncMock() # execute_code returns (execution, container, new_state, state_errors, container_source) - mock_execution_service.execute_code.return_value = (mock_execution, None, None, [], "pool_hit") + mock_execution_service.execute_code.return_value = ( + mock_execution, + None, + None, + [], + "pool_hit", + ) mock_file_service = AsyncMock() mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service try: # First execution - response1 = client.post("/exec", json={ - "code": "x = 1", - "lang": "py", - "entity_id": entity_id - }, headers=auth_headers) + response1 = client.post( + "/exec", + json={"code": "x = 1", "lang": "py", "entity_id": entity_id}, + headers=auth_headers, + ) # Second execution with same entity - response2 = client.post("/exec", json={ - "code": "print(x)", - "lang": "py", - "entity_id": entity_id - }, headers=auth_headers) + response2 = client.post( + "/exec", + json={"code": "print(x)", "lang": "py", "entity_id": entity_id}, + headers=auth_headers, + ) assert response1.status_code == 200 assert response2.status_code == 200 @@ -251,29 +303,34 @@ def test_different_entity_gets_different_session(self, client, auth_headers): # execute_code returns (execution, container, new_state, state_errors, container_source) mock_execution_service.execute_code.side_effect = [ (create_execution("session-1"), None, None, [], "pool_hit"), - (create_execution("session-2"), None, None, [], "pool_hit") + (create_execution("session-2"), None, None, [], "pool_hit"), ] mock_file_service = AsyncMock() mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service try: - response1 = client.post("/exec", json={ - "code": "print('1')", - "lang": "py", - "entity_id": "entity-1" - }, headers=auth_headers) - - response2 = client.post("/exec", json={ - "code": "print('2')", - "lang": "py", - "entity_id": "entity-2" - }, headers=auth_headers) + response1 = client.post( + "/exec", + json={"code": "print('1')", "lang": "py", "entity_id": "entity-1"}, + headers=auth_headers, + ) + + response2 = client.post( + "/exec", + json={"code": "print('2')", "lang": "py", "entity_id": "entity-2"}, + headers=auth_headers, + ) assert response1.status_code == 200 assert response2.status_code == 200 @@ -288,10 +345,13 @@ def test_different_entity_gets_different_session(self, client, auth_headers): # FILE PERSISTENCE BEHAVIOR # ============================================================================= + class TestFilePersistence: """Test file persistence across executions.""" - @pytest.mark.skip(reason="Requires full integration testing with real services - complex multi-step file flow") + @pytest.mark.skip( + reason="Requires full integration testing with real services - complex multi-step file flow" + ) def test_uploaded_file_available_in_execution(self, client, auth_headers): """Test that uploaded files are available during execution.""" session_id = "file-test-session" @@ -306,7 +366,13 @@ def test_uploaded_file_available_in_execution(self, client, auth_headers): mock_execution_service = AsyncMock() # execute_code returns (execution, container, new_state, state_errors, container_source) - mock_execution_service.execute_code.return_value = (mock_execution, None, None, [], "pool_hit") + mock_execution_service.execute_code.return_value = ( + mock_execution, + None, + None, + [], + "pool_hit", + ) mock_file_service = AsyncMock() mock_file_service.store_uploaded_file.return_value = file_id @@ -317,7 +383,7 @@ def test_uploaded_file_available_in_execution(self, client, auth_headers): size=100, content_type="text/plain", created_at=datetime.utcnow(), - path="/data.txt" + path="/data.txt", ) ] mock_file_service.get_file_info.return_value = FileInfo( @@ -326,11 +392,16 @@ def test_uploaded_file_available_in_execution(self, client, auth_headers): size=100, content_type="text/plain", created_at=datetime.utcnow(), - path="/data.txt" + path="/data.txt", ) mock_file_service.get_file_content.return_value = b"test content" - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service @@ -340,21 +411,29 @@ def test_uploaded_file_available_in_execution(self, client, auth_headers): files = {"files": ("data.txt", io.BytesIO(b"test content"), "text/plain")} data = {"entity_id": "file-test-entity"} - upload_response = client.post("/files/upload", files=files, data=data, headers=auth_headers) + upload_response = client.post( + "/files/upload", files=files, data=data, headers=auth_headers + ) assert upload_response.status_code == 200 uploaded_file = upload_response.json()["files"][0] # Execute code that references the file - exec_response = client.post("/exec", json={ - "code": "with open('data.txt') as f: print(f.read())", - "lang": "py", - "entity_id": "file-test-entity", - "files": [{ - "id": uploaded_file["id"], - "session_id": uploaded_file["session_id"], - "name": "data.txt" - }] - }, headers=auth_headers) + exec_response = client.post( + "/exec", + json={ + "code": "with open('data.txt') as f: print(f.read())", + "lang": "py", + "entity_id": "file-test-entity", + "files": [ + { + "id": uploaded_file["id"], + "session_id": uploaded_file["session_id"], + "name": "data.txt", + } + ], + }, + headers=auth_headers, + ) assert exec_response.status_code == 200 @@ -365,7 +444,9 @@ def test_uploaded_file_available_in_execution(self, client, auth_headers): finally: app.dependency_overrides.clear() - @pytest.mark.skip(reason="Requires full integration testing with real services - complex multi-step file flow") + @pytest.mark.skip( + reason="Requires full integration testing with real services - complex multi-step file flow" + ) def test_generated_file_downloadable(self, client, auth_headers): """Test that files generated during execution can be downloaded.""" session_id = "gen-file-session" @@ -387,9 +468,9 @@ def test_generated_file_downloadable(self, client, auth_headers): content="/workspace/output.txt", mime_type="text/plain", size=50, - timestamp=datetime.now(timezone.utc) + timestamp=datetime.now(timezone.utc), ) - ] + ], ) mock_session_service = AsyncMock() @@ -398,7 +479,13 @@ def test_generated_file_downloadable(self, client, auth_headers): mock_execution_service = AsyncMock() # execute_code returns (execution, container, new_state, state_errors, container_source) - mock_execution_service.execute_code.return_value = (execution_with_file, None, None, [], "pool_hit") + mock_execution_service.execute_code.return_value = ( + execution_with_file, + None, + None, + [], + "pool_hit", + ) mock_file_service = AsyncMock() mock_file_service.list_files.return_value = [ @@ -408,22 +495,31 @@ def test_generated_file_downloadable(self, client, auth_headers): size=50, content_type="text/plain", created_at=datetime.utcnow(), - path="/output.txt" + path="/output.txt", ) ] mock_file_service.download_file.return_value = "https://minio.test/download" - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service try: # Execute code that generates a file - exec_response = client.post("/exec", json={ - "code": "with open('output.txt', 'w') as f: f.write('generated')", - "lang": "py" - }, headers=auth_headers) + exec_response = client.post( + "/exec", + json={ + "code": "with open('output.txt', 'w') as f: f.write('generated')", + "lang": "py", + }, + headers=auth_headers, + ) assert exec_response.status_code == 200 generated_files = exec_response.json()["files"] @@ -434,7 +530,7 @@ def test_generated_file_downloadable(self, client, auth_headers): download_response = client.get( f"/files/download/{session_id}/{file_ref['id']}", headers=auth_headers, - follow_redirects=False + follow_redirects=False, ) assert download_response.status_code == 302 @@ -446,6 +542,7 @@ def test_generated_file_downloadable(self, client, auth_headers): # SESSION ISOLATION BEHAVIOR # ============================================================================= + class TestSessionIsolation: """Test session isolation between different users/entities.""" @@ -458,10 +555,11 @@ def test_sessions_isolated_between_users(self, client, auth_headers): mock_session_service = AsyncMock() mock_session_service.list_sessions_by_entity.side_effect = [ [session1], # For entity-1 - [session2] # For entity-2 + [session2], # For entity-2 ] from src.dependencies.services import get_session_service + app.dependency_overrides[get_session_service] = lambda: mock_session_service try: @@ -477,13 +575,13 @@ def test_files_not_accessible_cross_session(self, client, auth_headers): mock_file_service.get_file_info.return_value = None # Not found from src.dependencies.services import get_file_service + app.dependency_overrides[get_file_service] = lambda: mock_file_service try: # Try to access file from different session response = client.get( - "/files/download/other-session/some-file-id", - headers=auth_headers + "/files/download/other-session/some-file-id", headers=auth_headers ) # Should not find the file @@ -496,6 +594,7 @@ def test_files_not_accessible_cross_session(self, client, auth_headers): # SESSION ID STABILITY # ============================================================================= + class TestSessionIdStability: """Test that session IDs remain stable across requests.""" @@ -512,12 +611,23 @@ def test_session_id_consistent_in_response(self, client, auth_headers): mock_execution_service = AsyncMock() # execute_code returns (execution, container, new_state, state_errors, container_source) - mock_execution_service.execute_code.return_value = (execution, None, None, [], "pool_hit") + mock_execution_service.execute_code.return_value = ( + execution, + None, + None, + [], + "pool_hit", + ) mock_file_service = AsyncMock() mock_file_service.list_files.return_value = [] - from src.dependencies.services import get_session_service, get_execution_service, get_file_service + from src.dependencies.services import ( + get_session_service, + get_execution_service, + get_file_service, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_service app.dependency_overrides[get_execution_service] = lambda: mock_execution_service app.dependency_overrides[get_file_service] = lambda: mock_file_service @@ -526,11 +636,15 @@ def test_session_id_consistent_in_response(self, client, auth_headers): # Multiple executions responses = [] for i in range(3): - response = client.post("/exec", json={ - "code": f"print({i})", - "lang": "py", - "entity_id": "stable-entity" - }, headers=auth_headers) + response = client.post( + "/exec", + json={ + "code": f"print({i})", + "lang": "py", + "entity_id": "stable-entity", + }, + headers=auth_headers, + ) responses.append(response) # All should return the same session_id @@ -539,7 +653,9 @@ def test_session_id_consistent_in_response(self, client, auth_headers): finally: app.dependency_overrides.clear() - @pytest.mark.skip(reason="Requires full integration testing with real services - complex file ID stability verification") + @pytest.mark.skip( + reason="Requires full integration testing with real services - complex file ID stability verification" + ) def test_file_ids_stable_across_requests(self, client, auth_headers): """Test that file IDs remain stable.""" stable_file_id = "stable-file-xyz789" @@ -553,7 +669,7 @@ def test_file_ids_stable_across_requests(self, client, auth_headers): size=100, content_type="text/plain", created_at=datetime.utcnow(), - path="/stable.txt" + path="/stable.txt", ) ] mock_file_service.get_file_info.return_value = FileInfo( @@ -562,21 +678,26 @@ def test_file_ids_stable_across_requests(self, client, auth_headers): size=100, content_type="text/plain", created_at=datetime.utcnow(), - path="/stable.txt" + path="/stable.txt", ) from src.dependencies.services import get_file_service + app.dependency_overrides[get_file_service] = lambda: mock_file_service try: # Upload a file files = {"files": ("stable.txt", io.BytesIO(b"content"), "text/plain")} - upload_response = client.post("/files/upload", files=files, headers=auth_headers) + upload_response = client.post( + "/files/upload", files=files, headers=auth_headers + ) uploaded_id = upload_response.json()["files"][0]["id"] # List files - should show same ID - list_response = client.get("/files/files/temp-session", headers=auth_headers) + list_response = client.get( + "/files/files/temp-session", headers=auth_headers + ) listed_id = list_response.json()[0]["id"] assert uploaded_id == listed_id diff --git a/tests/integration/test_state_api.py b/tests/integration/test_state_api.py index 70127a1..74c0af8 100644 --- a/tests/integration/test_state_api.py +++ b/tests/integration/test_state_api.py @@ -42,7 +42,9 @@ def client(mock_state_service, mock_state_archival_service): """Create test client with mocked services.""" # Override dependencies app.dependency_overrides[get_state_service] = lambda: mock_state_service - app.dependency_overrides[get_state_archival_service] = lambda: mock_state_archival_service + app.dependency_overrides[get_state_archival_service] = ( + lambda: mock_state_archival_service + ) client = TestClient(app) yield client @@ -60,7 +62,9 @@ def auth_headers(): class TestDownloadState: """Tests for GET /state/{session_id}.""" - def test_download_nonexistent_state_returns_404(self, client, auth_headers, mock_state_service): + def test_download_nonexistent_state_returns_404( + self, client, auth_headers, mock_state_service + ): """Test that downloading nonexistent state returns 404.""" # mock_state_service already returns None by default response = client.get("/state/nonexistent-session", headers=auth_headers) @@ -70,7 +74,9 @@ def test_download_nonexistent_state_returns_404(self, client, auth_headers, mock # Error handler stringifies the detail dict assert "state_not_found" in data["error"] - def test_download_state_returns_etag(self, client, auth_headers, mock_state_service): + def test_download_state_returns_etag( + self, client, auth_headers, mock_state_service + ): """Test that downloading state returns ETag header.""" # Setup mock with state raw_bytes = b"\x02test state data" @@ -101,7 +107,9 @@ def test_if_none_match_returns_304(self, client, auth_headers, mock_state_servic class TestUploadState: """Tests for POST /state/{session_id}.""" - def test_upload_valid_state_returns_201(self, client, auth_headers, mock_state_service): + def test_upload_valid_state_returns_201( + self, client, auth_headers, mock_state_service + ): """Test that uploading valid state returns 201.""" # Create valid state blob (version 2 + some data) raw_bytes = b"\x02fake lz4 compressed data here" @@ -109,7 +117,7 @@ def test_upload_valid_state_returns_201(self, client, auth_headers, mock_state_s response = client.post( "/state/test-session", content=raw_bytes, - headers={**auth_headers, "Content-Type": "application/octet-stream"} + headers={**auth_headers, "Content-Type": "application/octet-stream"}, ) assert response.status_code == 201 @@ -117,7 +125,9 @@ def test_upload_valid_state_returns_201(self, client, auth_headers, mock_state_s assert data["message"] == "state_uploaded" assert data["size"] == len(raw_bytes) - def test_upload_invalid_version_returns_400(self, client, auth_headers, mock_state_service): + def test_upload_invalid_version_returns_400( + self, client, auth_headers, mock_state_service + ): """Test that invalid version byte returns 400.""" # Version 99 is invalid raw_bytes = b"\x63invalid version data" @@ -125,7 +135,7 @@ def test_upload_invalid_version_returns_400(self, client, auth_headers, mock_sta response = client.post( "/state/test-session", content=raw_bytes, - headers={**auth_headers, "Content-Type": "application/octet-stream"} + headers={**auth_headers, "Content-Type": "application/octet-stream"}, ) assert response.status_code == 400 @@ -133,14 +143,16 @@ def test_upload_invalid_version_returns_400(self, client, auth_headers, mock_sta # Error handler stringifies the detail dict assert "invalid_state" in data["error"] - def test_upload_too_short_returns_400(self, client, auth_headers, mock_state_service): + def test_upload_too_short_returns_400( + self, client, auth_headers, mock_state_service + ): """Test that state < 2 bytes returns 400.""" raw_bytes = b"\x02" # Only 1 byte response = client.post( "/state/test-session", content=raw_bytes, - headers={**auth_headers, "Content-Type": "application/octet-stream"} + headers={**auth_headers, "Content-Type": "application/octet-stream"}, ) assert response.status_code == 400 @@ -149,7 +161,9 @@ def test_upload_too_short_returns_400(self, client, auth_headers, mock_state_ser class TestGetStateInfo: """Tests for GET /state/{session_id}/info.""" - def test_info_nonexistent_returns_exists_false(self, client, auth_headers, mock_state_service): + def test_info_nonexistent_returns_exists_false( + self, client, auth_headers, mock_state_service + ): """Test that info for nonexistent state returns exists=false.""" response = client.get("/state/nonexistent/info", headers=auth_headers) @@ -157,13 +171,15 @@ def test_info_nonexistent_returns_exists_false(self, client, auth_headers, mock_ data = response.json() assert data["exists"] is False - def test_info_existing_state_returns_metadata(self, client, auth_headers, mock_state_service): + def test_info_existing_state_returns_metadata( + self, client, auth_headers, mock_state_service + ): """Test that info for existing state returns metadata.""" mock_state_service.get_full_state_info.return_value = { "size_bytes": 1024, "hash": "abc123", "created_at": "2025-12-21T10:00:00+00:00", - "expires_at": "2025-12-21T12:00:00+00:00" + "expires_at": "2025-12-21T12:00:00+00:00", } response = client.get("/state/test-session/info", headers=auth_headers) @@ -174,7 +190,9 @@ def test_info_existing_state_returns_metadata(self, client, auth_headers, mock_s assert data["source"] == "redis" assert data["size_bytes"] == 1024 - def test_info_archived_state_returns_archive_source(self, client, auth_headers, mock_state_service, mock_state_archival_service): + def test_info_archived_state_returns_archive_source( + self, client, auth_headers, mock_state_service, mock_state_archival_service + ): """Test that archived state shows source='archive'.""" mock_state_archival_service.has_archived_state.return_value = True @@ -195,7 +213,9 @@ def test_delete_returns_204(self, client, auth_headers, mock_state_service): assert response.status_code == 204 - def test_delete_nonexistent_still_returns_204(self, client, auth_headers, mock_state_service): + def test_delete_nonexistent_still_returns_204( + self, client, auth_headers, mock_state_service + ): """Test that deleting nonexistent state still returns 204.""" response = client.delete("/state/nonexistent", headers=auth_headers) @@ -217,7 +237,7 @@ def test_exec_response_includes_state_fields_for_python(self, client, auth_heade stderr="", has_state=True, state_size=1024, - state_hash="abc123" + state_hash="abc123", ) assert response.has_state is True @@ -228,11 +248,7 @@ def test_exec_response_defaults_state_fields(self): """Test that state fields have correct defaults.""" from src.models.exec import ExecResponse - response = ExecResponse( - session_id="test-session", - stdout="", - stderr="" - ) + response = ExecResponse(session_id="test-session", stdout="", stderr="") assert response.has_state is False assert response.state_size is None diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 07c9273..4a5d263 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1 +1 @@ -# Unit tests package \ No newline at end of file +# Unit tests package diff --git a/tests/unit/test_output_processor.py b/tests/unit/test_output_processor.py index 6b78c60..fffcf2d 100644 --- a/tests/unit/test_output_processor.py +++ b/tests/unit/test_output_processor.py @@ -102,18 +102,3 @@ def test_long_filename_truncated(self): assert result.endswith(".txt") # Should have a random suffix before extension assert "-" in result - - -class TestNormalizeFilename: - """Tests for the deprecated normalize_filename method.""" - - def test_delegates_to_sanitize_filename(self): - """Test that normalize_filename delegates to sanitize_filename.""" - result = OutputProcessor.normalize_filename("file with spaces.txt") - expected = OutputProcessor.sanitize_filename("file with spaces.txt") - assert result == expected - - def test_parentheses_now_replaced(self): - """Test that normalize_filename now also replaces parentheses.""" - result = OutputProcessor.normalize_filename("file (v2).xlsx") - assert result == "file__v2_.xlsx" diff --git a/tests/unit/test_session_service.py b/tests/unit/test_session_service.py index d0534c5..3d961e9 100644 --- a/tests/unit/test_session_service.py +++ b/tests/unit/test_session_service.py @@ -14,7 +14,7 @@ def mock_redis(): """Create a mock Redis client.""" redis_mock = AsyncMock() - + # Mock pipeline pipeline_mock = AsyncMock() pipeline_mock.hset = MagicMock() @@ -24,7 +24,7 @@ def mock_redis(): pipeline_mock.srem = MagicMock() pipeline_mock.execute = AsyncMock(return_value=[True, True, True]) pipeline_mock.reset = AsyncMock() - + # Make pipeline() return the pipeline mock when awaited redis_mock.pipeline = AsyncMock(return_value=pipeline_mock) return redis_mock @@ -40,16 +40,16 @@ def session_service(mock_redis): async def test_create_session(session_service, mock_redis): """Test session creation.""" request = SessionCreate(metadata={"test": "value"}) - + session = await session_service.create_session(request) - + assert session.session_id is not None assert session.status == SessionStatus.ACTIVE assert session.metadata == {"test": "value"} assert isinstance(session.created_at, datetime) assert isinstance(session.expires_at, datetime) assert session.expires_at > session.created_at - + # Verify Redis operations were called mock_redis.pipeline.assert_called_once() @@ -66,20 +66,20 @@ async def test_get_session_exists(session_service, mock_redis): "expires_at": "2023-01-02T00:00:00", "files": "{}", "metadata": '{"test": "value"}', - "working_directory": "/workspace" + "working_directory": "/workspace", } - + mock_redis.hgetall.return_value = session_data mock_redis.hset = AsyncMock() - + session = await session_service.get_session(session_id) - + assert session is not None assert session.session_id == session_id assert session.status == SessionStatus.ACTIVE assert session.metadata == {"test": "value"} assert session.files == {} - + # Verify last activity was updated mock_redis.hset.assert_called_once() @@ -88,9 +88,9 @@ async def test_get_session_exists(session_service, mock_redis): async def test_get_session_not_exists(session_service, mock_redis): """Test retrieving a non-existent session.""" mock_redis.hgetall.return_value = {} - + session = await session_service.get_session("non-existent") - + assert session is None @@ -98,11 +98,11 @@ async def test_get_session_not_exists(session_service, mock_redis): async def test_update_session(session_service, mock_redis): """Test updating a session.""" session_id = "test-session-id" - + # Mock session exists mock_redis.exists.return_value = True mock_redis.hset = AsyncMock() - + # Mock get_session to return updated session updated_session_data = { "session_id": session_id, @@ -112,15 +112,17 @@ async def test_update_session(session_service, mock_redis): "expires_at": "2023-01-02T00:00:00", "files": "{}", "metadata": "{}", - "working_directory": "/workspace" + "working_directory": "/workspace", } mock_redis.hgetall.return_value = updated_session_data - - session = await session_service.update_session(session_id, status=SessionStatus.IDLE) - + + session = await session_service.update_session( + session_id, status=SessionStatus.IDLE + ) + assert session is not None assert session.session_id == session_id - + # Verify Redis update was called mock_redis.hset.assert_called() @@ -129,9 +131,11 @@ async def test_update_session(session_service, mock_redis): async def test_update_session_not_exists(session_service, mock_redis): """Test updating a non-existent session.""" mock_redis.exists.return_value = False - - session = await session_service.update_session("non-existent", status=SessionStatus.IDLE) - + + session = await session_service.update_session( + "non-existent", status=SessionStatus.IDLE + ) + assert session is None @@ -139,13 +143,13 @@ async def test_update_session_not_exists(session_service, mock_redis): async def test_delete_session(session_service, mock_redis): """Test deleting a session.""" session_id = "test-session-id" - + # The pipeline mock is already set up in the fixture pipeline_mock = mock_redis.pipeline.return_value pipeline_mock.execute.return_value = [1, 1] # Both operations successful - + result = await session_service.delete_session(session_id) - + assert result is True pipeline_mock.delete.assert_called_once() pipeline_mock.srem.assert_called_once() @@ -156,7 +160,7 @@ async def test_list_sessions(session_service, mock_redis): """Test listing sessions.""" session_ids = ["session1", "session2", "session3"] mock_redis.smembers.return_value = session_ids - + # Mock get_session to return valid sessions def mock_hgetall(key): session_id = key.split(":")[-1] # Extract session ID from key @@ -168,14 +172,14 @@ def mock_hgetall(key): "expires_at": "2023-01-02T00:00:00+00:00", "files": "{}", "metadata": "{}", - "working_directory": "/workspace" + "working_directory": "/workspace", } - + mock_redis.hgetall.side_effect = mock_hgetall mock_redis.hset = AsyncMock() - + sessions = await session_service.list_sessions(limit=2) - + assert len(sessions) == 2 # Limited to 2 assert all(isinstance(s, Session) for s in sessions) @@ -185,7 +189,7 @@ async def test_cleanup_expired_sessions(session_service, mock_redis): """Test cleaning up expired sessions.""" session_ids = ["expired1", "expired2", "active1"] mock_redis.smembers.return_value = session_ids - + # Mock sessions - some expired, some active def mock_get_session(session_id): if session_id.startswith("expired"): @@ -194,7 +198,7 @@ def mock_get_session(session_id): status=SessionStatus.ACTIVE, created_at=datetime.now(timezone.utc) - timedelta(days=2), last_activity=datetime.now(timezone.utc) - timedelta(days=2), - expires_at=datetime.now(timezone.utc) - timedelta(hours=1) # Expired + expires_at=datetime.now(timezone.utc) - timedelta(hours=1), # Expired ) else: return Session( @@ -202,16 +206,16 @@ def mock_get_session(session_id): status=SessionStatus.ACTIVE, created_at=datetime.now(timezone.utc), last_activity=datetime.now(timezone.utc), - expires_at=datetime.now(timezone.utc) + timedelta(hours=1) # Active + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), # Active ) - + # The pipeline mock is already set up in the fixture pipeline_mock = mock_redis.pipeline.return_value pipeline_mock.execute.return_value = [1, 1] - - with patch.object(session_service, 'get_session', side_effect=mock_get_session): + + with patch.object(session_service, "get_session", side_effect=mock_get_session): cleaned_count = await session_service.cleanup_expired_sessions() - + assert cleaned_count == 2 # Two expired sessions cleaned @@ -229,7 +233,7 @@ async def test_generate_session_id(session_service): session_id = session_service._generate_session_id() assert isinstance(session_id, str) assert len(session_id) > 0 - + # Generate another to ensure uniqueness session_id2 = session_service._generate_session_id() assert session_id != session_id2 @@ -242,7 +246,7 @@ async def test_cleanup_task_lifecycle(session_service, mock_redis): await session_service.start_cleanup_task() assert session_service._cleanup_task is not None assert not session_service._cleanup_task.done() - + # Stop cleanup task await session_service.stop_cleanup_task() assert session_service._cleanup_task.done() @@ -252,13 +256,13 @@ async def test_cleanup_task_lifecycle(session_service, mock_redis): async def test_create_session_with_entity_id(session_service, mock_redis): """Test session creation with entity_id.""" request = SessionCreate(metadata={"entity_id": "test-entity", "test": "value"}) - + session = await session_service.create_session(request) - + assert session.session_id is not None assert session.status == SessionStatus.ACTIVE assert session.metadata == {"entity_id": "test-entity", "test": "value"} - + # Verify Redis operations were called including entity grouping mock_redis.pipeline.assert_called_once() @@ -268,9 +272,9 @@ async def test_list_sessions_by_entity(session_service, mock_redis): """Test listing sessions by entity ID.""" entity_id = "test-entity" session_ids = ["session1", "session2"] - + mock_redis.smembers.return_value = session_ids - + # Mock get_session to return valid sessions def mock_hgetall(key): session_id = key.split(":")[-1] @@ -282,17 +286,17 @@ def mock_hgetall(key): "expires_at": "2023-01-02T00:00:00+00:00", "files": "{}", "metadata": '{"entity_id": "test-entity"}', - "working_directory": "/workspace" + "working_directory": "/workspace", } - + mock_redis.hgetall.side_effect = mock_hgetall mock_redis.hset = AsyncMock() - + sessions = await session_service.list_sessions_by_entity(entity_id) - + assert len(sessions) == 2 assert all(isinstance(s, Session) for s in sessions) - assert all(s.metadata.get('entity_id') == entity_id for s in sessions) + assert all(s.metadata.get("entity_id") == entity_id for s in sessions) @pytest.mark.asyncio @@ -300,7 +304,7 @@ async def test_validate_session_access_success(session_service, mock_redis): """Test successful session access validation.""" session_id = "test-session" entity_id = "test-entity" - + session_data = { "session_id": session_id, "status": "active", @@ -309,14 +313,14 @@ async def test_validate_session_access_success(session_service, mock_redis): "expires_at": "2023-01-02T00:00:00+00:00", "files": "{}", "metadata": f'{{"entity_id": "{entity_id}"}}', - "working_directory": "/workspace" + "working_directory": "/workspace", } - + mock_redis.hgetall.return_value = session_data mock_redis.hset = AsyncMock() - + result = await session_service.validate_session_access(session_id, entity_id) - + assert result is True @@ -325,7 +329,7 @@ async def test_validate_session_access_wrong_entity(session_service, mock_redis) """Test session access validation with wrong entity ID.""" session_id = "test-session" entity_id = "wrong-entity" - + session_data = { "session_id": session_id, "status": "active", @@ -334,14 +338,14 @@ async def test_validate_session_access_wrong_entity(session_service, mock_redis) "expires_at": "2023-01-02T00:00:00+00:00", "files": "{}", "metadata": '{"entity_id": "test-entity"}', - "working_directory": "/workspace" + "working_directory": "/workspace", } - + mock_redis.hgetall.return_value = session_data mock_redis.hset = AsyncMock() - + result = await session_service.validate_session_access(session_id, entity_id) - + assert result is False @@ -349,9 +353,11 @@ async def test_validate_session_access_wrong_entity(session_service, mock_redis) async def test_validate_session_access_no_session(session_service, mock_redis): """Test session access validation when session doesn't exist.""" mock_redis.hgetall.return_value = {} - - result = await session_service.validate_session_access("non-existent", "test-entity") - + + result = await session_service.validate_session_access( + "non-existent", "test-entity" + ) + assert result is False @@ -360,9 +366,9 @@ async def test_get_session_files_access_success(session_service, mock_redis): """Test successful session files access validation.""" session_id = "test-session" entity_id = "test-entity" - + # Mock validate_session_access to return True - with patch.object(session_service, 'validate_session_access', return_value=True): + with patch.object(session_service, "validate_session_access", return_value=True): # Mock list_sessions_by_entity to return sessions including the target session mock_sessions = [ Session( @@ -371,12 +377,16 @@ async def test_get_session_files_access_success(session_service, mock_redis): created_at=datetime.now(timezone.utc), last_activity=datetime.now(timezone.utc), expires_at=datetime.now(timezone.utc) + timedelta(hours=1), - metadata={"entity_id": entity_id} + metadata={"entity_id": entity_id}, ) ] - with patch.object(session_service, 'list_sessions_by_entity', return_value=mock_sessions): - result = await session_service.get_session_files_access(session_id, entity_id) - + with patch.object( + session_service, "list_sessions_by_entity", return_value=mock_sessions + ): + result = await session_service.get_session_files_access( + session_id, entity_id + ) + assert result is True @@ -384,9 +394,11 @@ async def test_get_session_files_access_success(session_service, mock_redis): async def test_get_session_files_access_invalid_session(session_service, mock_redis): """Test session files access validation with invalid session.""" # Mock validate_session_access to return False - with patch.object(session_service, 'validate_session_access', return_value=False): - result = await session_service.get_session_files_access("invalid-session", "test-entity") - + with patch.object(session_service, "validate_session_access", return_value=False): + result = await session_service.get_session_files_access( + "invalid-session", "test-entity" + ) + assert result is False @@ -395,7 +407,7 @@ async def test_delete_session_with_entity_cleanup(session_service, mock_redis): """Test deleting a session with entity cleanup.""" session_id = "test-session-id" entity_id = "test-entity" - + # Mock get_session to return a session with entity_id session_data = { "session_id": session_id, @@ -405,17 +417,17 @@ async def test_delete_session_with_entity_cleanup(session_service, mock_redis): "expires_at": "2023-01-02T00:00:00+00:00", "files": "{}", "metadata": f'{{"entity_id": "{entity_id}"}}', - "working_directory": "/workspace" + "working_directory": "/workspace", } mock_redis.hgetall.return_value = session_data mock_redis.hset = AsyncMock() - + # The pipeline mock is already set up in the fixture pipeline_mock = mock_redis.pipeline.return_value pipeline_mock.execute.return_value = [1, 1, 1] # Three operations successful - + result = await session_service.delete_session(session_id) - + assert result is True pipeline_mock.delete.assert_called_once() pipeline_mock.srem.assert_called() # Called twice - once for session index, once for entity @@ -426,6 +438,6 @@ async def test_close(session_service, mock_redis): """Test service cleanup.""" await session_service.start_cleanup_task() await session_service.close() - + # Verify cleanup task was stopped and Redis connection closed - mock_redis.close.assert_called_once() \ No newline at end of file + mock_redis.close.assert_called_once() diff --git a/tests/unit/test_state_service.py b/tests/unit/test_state_service.py index 3dfa850..01190ee 100644 --- a/tests/unit/test_state_service.py +++ b/tests/unit/test_state_service.py @@ -27,7 +27,7 @@ def mock_redis_client(): @pytest.fixture def state_service(mock_redis_client): """Create StateService with mocked Redis.""" - with patch('src.services.state.redis_pool') as mock_pool: + with patch("src.services.state.redis_pool") as mock_pool: mock_pool.get_client.return_value = mock_redis_client service = StateService(redis_client=mock_redis_client) return service @@ -69,11 +69,13 @@ class TestSaveState: """Tests for save_state method.""" @pytest.mark.asyncio - async def test_save_state_stores_hash_and_metadata(self, state_service, mock_redis_client): + async def test_save_state_stores_hash_and_metadata( + self, state_service, mock_redis_client + ): """Test that save_state stores state, hash, and metadata.""" session_id = "test-session-123" raw_bytes = b"\x02test state data" # Version 2 prefix - state_b64 = base64.b64encode(raw_bytes).decode('utf-8') + state_b64 = base64.b64encode(raw_bytes).decode("utf-8") # Setup mock pipeline mock_pipe = AsyncMock() @@ -88,11 +90,13 @@ async def test_save_state_stores_hash_and_metadata(self, state_service, mock_red assert mock_pipe.setex.call_count == 3 @pytest.mark.asyncio - async def test_save_state_with_upload_marker(self, state_service, mock_redis_client): + async def test_save_state_with_upload_marker( + self, state_service, mock_redis_client + ): """Test that from_upload=True sets upload marker.""" session_id = "test-session-upload" raw_bytes = b"\x02uploaded state" - state_b64 = base64.b64encode(raw_bytes).decode('utf-8') + state_b64 = base64.b64encode(raw_bytes).decode("utf-8") mock_pipe = AsyncMock() mock_pipe.setex = MagicMock() @@ -121,7 +125,7 @@ async def test_get_state_raw_decodes_base64(self, state_service, mock_redis_clie """Test that get_state_raw returns decoded bytes.""" session_id = "test-session" raw_bytes = b"\x02raw binary state data" - state_b64 = base64.b64encode(raw_bytes).decode('utf-8') + state_b64 = base64.b64encode(raw_bytes).decode("utf-8") mock_redis_client.get.return_value = state_b64 @@ -130,7 +134,9 @@ async def test_get_state_raw_decodes_base64(self, state_service, mock_redis_clie assert result == raw_bytes @pytest.mark.asyncio - async def test_get_state_raw_returns_none_when_no_state(self, state_service, mock_redis_client): + async def test_get_state_raw_returns_none_when_no_state( + self, state_service, mock_redis_client + ): """Test that get_state_raw returns None when no state exists.""" mock_redis_client.get.return_value = None @@ -143,7 +149,9 @@ class TestSaveStateRaw: """Tests for save_state_raw method.""" @pytest.mark.asyncio - async def test_save_state_raw_encodes_to_base64(self, state_service, mock_redis_client): + async def test_save_state_raw_encodes_to_base64( + self, state_service, mock_redis_client + ): """Test that save_state_raw encodes bytes to base64.""" session_id = "test-session" raw_bytes = b"\x02raw data to save" @@ -162,17 +170,21 @@ class TestGetStateHash: """Tests for get_state_hash method.""" @pytest.mark.asyncio - async def test_get_state_hash_returns_string(self, state_service, mock_redis_client): + async def test_get_state_hash_returns_string( + self, state_service, mock_redis_client + ): """Test that get_state_hash returns hash string.""" expected_hash = "abc123def456" - mock_redis_client.get.return_value = expected_hash.encode('utf-8') + mock_redis_client.get.return_value = expected_hash.encode("utf-8") result = await state_service.get_state_hash("session-id") assert result == expected_hash @pytest.mark.asyncio - async def test_get_state_hash_returns_none_when_missing(self, state_service, mock_redis_client): + async def test_get_state_hash_returns_none_when_missing( + self, state_service, mock_redis_client + ): """Test that get_state_hash returns None when no hash.""" mock_redis_client.get.return_value = None @@ -185,7 +197,9 @@ class TestUploadMarker: """Tests for upload marker methods.""" @pytest.mark.asyncio - async def test_has_recent_upload_true_when_marker_exists(self, state_service, mock_redis_client): + async def test_has_recent_upload_true_when_marker_exists( + self, state_service, mock_redis_client + ): """Test that has_recent_upload returns True when marker exists.""" mock_redis_client.get.return_value = "1" @@ -194,7 +208,9 @@ async def test_has_recent_upload_true_when_marker_exists(self, state_service, mo assert result is True @pytest.mark.asyncio - async def test_has_recent_upload_false_when_no_marker(self, state_service, mock_redis_client): + async def test_has_recent_upload_false_when_no_marker( + self, state_service, mock_redis_client + ): """Test that has_recent_upload returns False when no marker.""" mock_redis_client.get.return_value = None @@ -203,7 +219,9 @@ async def test_has_recent_upload_false_when_no_marker(self, state_service, mock_ assert result is False @pytest.mark.asyncio - async def test_clear_upload_marker_deletes_key(self, state_service, mock_redis_client): + async def test_clear_upload_marker_deletes_key( + self, state_service, mock_redis_client + ): """Test that clear_upload_marker deletes the marker key.""" await state_service.clear_upload_marker("session-id") @@ -214,7 +232,9 @@ class TestDeleteState: """Tests for delete_state method.""" @pytest.mark.asyncio - async def test_delete_state_removes_all_keys(self, state_service, mock_redis_client): + async def test_delete_state_removes_all_keys( + self, state_service, mock_redis_client + ): """Test that delete_state removes state, hash, meta, and marker keys.""" session_id = "session-to-delete" @@ -231,14 +251,16 @@ class TestGetFullStateInfo: """Tests for get_full_state_info method.""" @pytest.mark.asyncio - async def test_get_full_state_info_returns_metadata(self, state_service, mock_redis_client): + async def test_get_full_state_info_returns_metadata( + self, state_service, mock_redis_client + ): """Test that get_full_state_info returns complete metadata.""" session_id = "session-with-state" meta = { "size_bytes": 1024, "hash": "abc123", "created_at": "2025-12-21T10:00:00+00:00", - "from_upload": False + "from_upload": False, } mock_pipe = AsyncMock() @@ -256,7 +278,9 @@ async def test_get_full_state_info_returns_metadata(self, state_service, mock_re assert result["expires_at"] is not None @pytest.mark.asyncio - async def test_get_full_state_info_returns_none_when_no_state(self, state_service, mock_redis_client): + async def test_get_full_state_info_returns_none_when_no_state( + self, state_service, mock_redis_client + ): """Test that get_full_state_info returns None when no state.""" mock_pipe = AsyncMock() mock_pipe.execute = AsyncMock(return_value=[0, -1, None]) From b9fe7f167f47b7094afb553cea12b4426d8a4646 Mon Sep 17 00:00:00 2001 From: Joe Licata Date: Tue, 13 Jan 2026 17:38:00 +0000 Subject: [PATCH 4/8] chore: Remove deprecated MAX_CPU_QUOTA and update configuration documentation - Removed the deprecated `MAX_CPU_QUOTA` setting from the `.env.example` file. - Updated the `CONFIGURATION.md` documentation to reflect the change, replacing `MAX_CPU_QUOTA` with `MAX_CPUS` and adjusting its default value. --- .env.example | 1 - docs/CONFIGURATION.md | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.env.example b/.env.example index dd360bf..d477502 100644 --- a/.env.example +++ b/.env.example @@ -61,7 +61,6 @@ DOCKER_READ_ONLY=true MAX_EXECUTION_TIME=120 MAX_MEMORY_MB=512 MAX_CPUS=1 -MAX_CPU_QUOTA=50000 #Deprecated MAX_PIDS=512 MAX_OPEN_FILES=1024 diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 8c10c31..4cde9c2 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -188,7 +188,7 @@ Docker is used for secure code execution in containers. | -------------------- | ------- | ---------------------------------------------------------------- | | `MAX_EXECUTION_TIME` | `30` | Maximum code execution time (seconds) | | `MAX_MEMORY_MB` | `512` | Maximum memory per execution (MB) | -| `MAX_CPU_QUOTA` | `50000` | CPU quota (100000 = 1 CPU) | +| `MAX_CPUS` | `4.0` | Maximum CPU cores available to execution containers | | `MAX_PIDS` | `512` | Per-container process limit (cgroup pids_limit, prevents fork bombs) | | `MAX_OPEN_FILES` | `1024` | Maximum open files per container | From 65ace53cf3100070cb10ae2d6416eebf97ed0674 Mon Sep 17 00:00:00 2001 From: Joe Licata Date: Tue, 13 Jan 2026 19:18:34 +0000 Subject: [PATCH 5/8] chore: Update project configuration and dependencies - Added `pyproject.toml` for pytest configuration with asyncio settings. - Updated dependencies in `requirements.txt` to newer versions for `uvicorn`, `pydantic-settings`, `httpx`, `redis`, `minio`, `pytest-asyncio`, `structlog`, `Unidecode`, and `locust`. - Updated GitHub Actions workflow to use the latest version of `setup-python`. - Updated Dockerfiles for C/C++, Go, and R to use newer base images. --- .github/workflows/lint.yml | 2 +- docker/c-cpp.Dockerfile | 2 +- docker/go.Dockerfile | 2 +- docker/r.Dockerfile | 2 +- pyproject.toml | 3 +++ requirements.txt | 18 +++++++++--------- tests/conftest.py | 8 -------- 7 files changed, 16 insertions(+), 21 deletions(-) create mode 100644 pyproject.toml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e8b1216..521afb0 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -11,7 +11,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.11" cache: "pip" diff --git a/docker/c-cpp.Dockerfile b/docker/c-cpp.Dockerfile index a0ef371..4939f4e 100644 --- a/docker/c-cpp.Dockerfile +++ b/docker/c-cpp.Dockerfile @@ -1,7 +1,7 @@ # syntax=docker/dockerfile:1.4 # C/C++ execution environment with BuildKit optimizations # Pin to specific version for reproducibility -FROM gcc:13-bookworm +FROM gcc:15-bookworm # Install essential development tools and libraries RUN apt-get update && apt-get install -y --no-install-recommends \ diff --git a/docker/go.Dockerfile b/docker/go.Dockerfile index 5b95775..73549da 100644 --- a/docker/go.Dockerfile +++ b/docker/go.Dockerfile @@ -1,6 +1,6 @@ # syntax=docker/dockerfile:1.4 # Go execution environment with BuildKit optimizations -FROM golang:1.23-alpine +FROM golang:1.25-alpine # Install common tools RUN apk add --no-cache \ diff --git a/docker/r.Dockerfile b/docker/r.Dockerfile index 34e4cb9..116c358 100644 --- a/docker/r.Dockerfile +++ b/docker/r.Dockerfile @@ -1,6 +1,6 @@ # syntax=docker/dockerfile:1.4 # R execution environment with BuildKit optimizations -FROM r-base:4.3.0 +FROM r-base:4.4.3 # Install system dependencies for R packages (including Cairo) RUN apt-get update && apt-get install -y --no-install-recommends \ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d7ecf10 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.pytest.ini_options] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "session" diff --git a/requirements.txt b/requirements.txt index 4c45b72..4ffaca7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,23 +1,23 @@ requests>=2.31.0,<3 # Core API framework fastapi==0.127.1 -uvicorn[standard]==0.30.6 +uvicorn[standard]==0.40.0 # Data validation and serialization pydantic==2.12.5 -pydantic-settings==2.5.0 +pydantic-settings==2.12.0 # HTTP client for external requests -httpx==0.27.2 +httpx==0.28.1 # Redis for session management -redis==5.1.0 +redis==7.1.0 # SQLite async support for metrics aiosqlite>=0.19.0 # MinIO/S3 client -minio==7.2.15 +minio==7.2.20 # Docker client for container management docker==7.1.0 @@ -29,7 +29,7 @@ python-dateutil==2.9.0.post0 # Testing framework pytest==9.0.2 -pytest-asyncio==0.21.1 +pytest-asyncio==1.3.0 pytest-cov==4.1.0 pytest-mock==3.12.0 @@ -42,14 +42,14 @@ mypy==1.7.1 python-dotenv==1.0.0 # Logging -structlog==23.2.0 +structlog==25.5.0 # File handling python-multipart>=0.0.18 -Unidecode==1.3.8 +Unidecode==1.4.0 # System monitoring for performance tests psutil==5.9.6 # Stress testing -locust==2.42.6 +locust==2.43.0 diff --git a/tests/conftest.py b/tests/conftest.py index d62051b..250e252 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,14 +29,6 @@ from src.models import Session, SessionCreate, SessionStatus -@pytest.fixture(scope="session") -def event_loop(): - """Create an instance of the default event loop for the test session.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - - @pytest.fixture def mock_redis(): """Mock Redis client for testing.""" From e7b48885e13d8834e920eb6420152850b8b06caf Mon Sep 17 00:00:00 2001 From: Joe Licata Date: Tue, 13 Jan 2026 19:50:02 +0000 Subject: [PATCH 6/8] fix: Update Dockerfile and command execution logic - Changed the Dockerfile ENTRYPOINT to use a temporary directory for GOCACHE instead of a mounted data directory. - Modified the command in runner.py to limit the search depth when finding files, improving performance and accuracy in file listing. --- docker/go.Dockerfile | 2 +- src/services/execution/runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/go.Dockerfile b/docker/go.Dockerfile index 73549da..3f41382 100644 --- a/docker/go.Dockerfile +++ b/docker/go.Dockerfile @@ -37,5 +37,5 @@ ENV GO111MODULE=on \ GOSUMDB=sum.golang.org # Default command with sanitized environment -ENTRYPOINT ["/usr/bin/env","-i","PATH=/usr/local/go/bin:/usr/local/bin:/usr/bin:/bin","HOME=/tmp","TMPDIR=/tmp","GO111MODULE=on","GOPROXY=https://proxy.golang.org,direct","GOSUMDB=sum.golang.org","GOCACHE=/mnt/data/go-build"] +ENTRYPOINT ["/usr/bin/env","-i","PATH=/usr/local/go/bin:/usr/local/bin:/usr/bin:/bin","HOME=/tmp","TMPDIR=/tmp","GO111MODULE=on","GOPROXY=https://proxy.golang.org,direct","GOSUMDB=sum.golang.org","GOCACHE=/tmp/go-build"] CMD ["go"] diff --git a/src/services/execution/runner.py b/src/services/execution/runner.py index efb1544..e69b519 100644 --- a/src/services/execution/runner.py +++ b/src/services/execution/runner.py @@ -630,7 +630,7 @@ async def _detect_generated_files( try: exit_code, stdout, stderr = await self.container_manager.execute_command( container, - "find /mnt/data -type f -name '*' ! -name 'code.*' ! -name 'Code.*' -exec ls -la {} \\;", + "find /mnt/data -maxdepth 1 -type f -name '*' ! -name 'code' ! -name 'code.*' ! -name 'Code.*' -exec ls -la {} \\;", timeout=5, ) From 82b25df50d8c757de8200ba0ad8920e453be8e52 Mon Sep 17 00:00:00 2001 From: Joe Licata Date: Tue, 13 Jan 2026 19:59:28 +0000 Subject: [PATCH 7/8] fix: Remove unused SecurityConfig kwargs for file extensions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit File extension validation was moved to Settings class but the security property was still passing these fields to SecurityConfig. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/config/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/config/__init__.py b/src/config/__init__.py index 8c1266e..00f8bda 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -566,8 +566,6 @@ def security(self) -> SecurityConfig: api_keys=self.api_keys if isinstance(self.api_keys, str) else None, api_key_header=self.api_key_header, api_key_cache_ttl=self.api_key_cache_ttl, - allowed_file_extensions=self.allowed_file_extensions, - blocked_file_patterns=self.blocked_file_patterns, enable_network_isolation=self.enable_network_isolation, enable_filesystem_isolation=self.enable_filesystem_isolation, enable_security_logs=self.enable_security_logs, From fa9be0474e944031ed1c40ce656555b6c18735cf Mon Sep 17 00:00:00 2001 From: Joe Licata Date: Tue, 13 Jan 2026 20:01:56 +0000 Subject: [PATCH 8/8] fix: Add type ignore comments for redis ping() type ambiguity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The redis library's ping() method has an ambiguous return type (Awaitable[bool] | bool) that mypy can't resolve. Added type ignore comments to suppress these false positives. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/services/auth.py | 2 +- src/services/metrics.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/services/auth.py b/src/services/auth.py index be9d22e..5ea0330 100644 --- a/src/services/auth.py +++ b/src/services/auth.py @@ -294,7 +294,7 @@ async def get_auth_service() -> AuthenticationService: redis_client = redis_pool.get_client() # Test connection - await redis_client.ping() + await redis_client.ping() # type: ignore[misc] logger.info("Redis connection established for authentication service") except Exception as e: logger.warning( diff --git a/src/services/metrics.py b/src/services/metrics.py index a766244..96de408 100644 --- a/src/services/metrics.py +++ b/src/services/metrics.py @@ -118,7 +118,7 @@ async def start(self) -> None: self._redis_client = redis_pool.get_client() # Test Redis connection with timeout - await asyncio.wait_for(self._redis_client.ping(), timeout=3.0) + await asyncio.wait_for(self._redis_client.ping(), timeout=3.0) # type: ignore[arg-type] # Load existing metrics from Redis await self._load_metrics_from_redis()