From 85c47c6d41785f410cace743d65e64f99c5644e8 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 4 Dec 2025 23:01:27 +0530 Subject: [PATCH 01/27] Add WebSocket support for environment interactions and enhance HTTP server capabilities - Introduced WebSocketEnvClient for persistent sessions with multi-step interactions. - Updated HTTPEnvServer to support WebSocket connections and manage multiple concurrent environments. - Added WebSocket message types and responses for better communication. - Enhanced Environment interface with concurrency safety attributes. --- src/openenv/core/__init__.py | 2 + src/openenv/core/env_server/__init__.py | 26 +- src/openenv/core/env_server/http_server.py | 293 +++++++++++++++++++- src/openenv/core/env_server/interfaces.py | 14 + src/openenv/core/env_server/types.py | 68 +++++ src/openenv/core/ws_env_client.py | 305 +++++++++++++++++++++ 6 files changed, 692 insertions(+), 16 deletions(-) create mode 100644 src/openenv/core/ws_env_client.py diff --git a/src/openenv/core/__init__.py b/src/openenv/core/__init__.py index 99507ab55..3592ead53 100644 --- a/src/openenv/core/__init__.py +++ b/src/openenv/core/__init__.py @@ -10,10 +10,12 @@ from .env_server import * from .client_types import StepResult from .http_env_client import HTTPEnvClient +from .ws_env_client import WebSocketEnvClient # Note: MCP module doesn't export anything yet __all__ = [ "HTTPEnvClient", + "WebSocketEnvClient", "StepResult", ] diff --git a/src/openenv/core/env_server/__init__.py b/src/openenv/core/env_server/__init__.py index 4e1c2d7ac..92ebbeb2d 100644 --- a/src/openenv/core/env_server/__init__.py +++ b/src/openenv/core/env_server/__init__.py @@ -15,7 +15,22 @@ deserialize_action_with_preprocessing, serialize_observation, ) -from .types import Action, Observation, State, SchemaResponse, HealthResponse +from .types import ( + Action, + Observation, + State, + SchemaResponse, + HealthResponse, + # WebSocket message types + WSMessage, + WSResetMessage, + WSStepMessage, + WSStateMessage, + WSCloseMessage, + WSObservationResponse, + WSStateResponse, + WSErrorResponse, +) from .web_interface import create_web_interface_app, WebInterfaceManager __all__ = [ @@ -30,6 +45,15 @@ "State", "SchemaResponse", "HealthResponse", + # WebSocket message types + "WSMessage", + "WSResetMessage", + "WSStepMessage", + "WSStateMessage", + "WSCloseMessage", + "WSObservationResponse", + "WSStateResponse", + "WSErrorResponse", # Base transforms "CompositeTransform", "NullTransform", diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 7fa7c0f32..41cc32315 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -8,18 +8,21 @@ HTTP server wrapper for Environment instances. This module provides utilities to wrap any Environment subclass and expose it -over HTTP endpoints that HTTPEnvClient can consume. +over HTTP endpoints that HTTPEnvClient can consume. Also supports WebSocket +connections for persistent sessions with multi-environment concurrency. """ from __future__ import annotations import asyncio import inspect +import json import os +import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Type +from typing import Any, Callable, Dict, Optional, Type, Union -from fastapi import Body, FastAPI, HTTPException, status +from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect, status from pydantic import ValidationError from .interfaces import Environment @@ -39,6 +42,13 @@ EnvironmentMetadata, SchemaResponse, HealthResponse, + WSResetMessage, + WSStepMessage, + WSStateMessage, + WSCloseMessage, + WSObservationResponse, + WSStateResponse, + WSErrorResponse, ) @@ -47,7 +57,8 @@ class HTTPEnvServer: HTTP server wrapper for Environment instances. This class wraps an Environment and exposes its reset(), step(), and state - methods as HTTP endpoints compatible with HTTPEnvClient. + methods as HTTP endpoints compatible with HTTPEnvClient. Also supports + WebSocket connections for persistent sessions with multi-environment concurrency. The server expects: - Action deserialization: Converts JSON dict to Action subclass @@ -57,9 +68,16 @@ class HTTPEnvServer: >>> from core.env_server import HTTPEnvServer >>> from envs.coding_env.server import CodeExecutionEnvironment >>> + >>> # Single environment (backward compatible) >>> env = CodeExecutionEnvironment() >>> server = HTTPEnvServer(env) >>> + >>> # Factory pattern for concurrent sessions + >>> server = HTTPEnvServer( + ... env=CodeExecutionEnvironment, # Pass class, not instance + ... max_concurrent_envs=4, + ... ) + >>> >>> # Register routes with FastAPI >>> from fastapi import FastAPI >>> app = FastAPI() @@ -68,21 +86,50 @@ class HTTPEnvServer: def __init__( self, - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], + env: Union[Environment, Callable[[], Environment], Type[Environment]], + action_cls: Type[Action] = None, + observation_cls: Type[Observation] = None, + max_concurrent_envs: int = 1, ): """ Initialize HTTP server wrapper. Args: - env: The Environment instance to wrap + env: The Environment instance, factory callable, or class to wrap. + - If an instance is provided, it's used directly (single-env mode) + - If a callable/class is provided, it's called to create new + environments for each WebSocket session (factory mode) action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum number of concurrent WebSocket sessions. + Only applies when env is a factory. Default is 1. """ - self.env = env + self._env_factory: Optional[Callable[[], Environment]] = None + self._max_concurrent_envs = max_concurrent_envs + + # Determine if env is an instance or factory + if isinstance(env, Environment): + # Single instance mode (backward compatible) + self.env = env + self._env_factory = None + elif callable(env): + # Factory mode - env is a class or callable + self._env_factory = env + # Create a single instance for HTTP endpoints (backward compat) + self.env = env() + else: + raise TypeError( + f"env must be an Environment instance or callable, got {type(env)}" + ) + self.action_cls = action_cls self.observation_cls = observation_cls + + # Session management for WebSocket connections + self._sessions: Dict[str, Environment] = {} + self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_lock = asyncio.Lock() + # Create thread pool for running sync code in async context # This is needed for environments using sync libraries (e.g., Playwright sync API) self._executor = ThreadPoolExecutor(max_workers=1) @@ -110,6 +157,80 @@ def _get_valid_kwargs(self, sig, kwargs, skip_params=None): return valid_kwargs + async def _create_session(self) -> tuple[str, Environment]: + """ + Create a new WebSocket session with its own environment instance. + + Returns: + Tuple of (session_id, environment) + + Raises: + RuntimeError: If max concurrent sessions reached or no factory available + """ + async with self._session_lock: + if len(self._sessions) >= self._max_concurrent_envs: + raise RuntimeError( + f"Maximum concurrent environments ({self._max_concurrent_envs}) reached" + ) + + if self._env_factory is None: + # Single instance mode - use shared env (limited concurrency) + if self._sessions: + raise RuntimeError( + "Single instance mode: only one WebSocket session allowed" + ) + session_id = str(uuid.uuid4()) + self._sessions[session_id] = self.env + else: + # Factory mode - create new environment + session_id = str(uuid.uuid4()) + env = self._env_factory() + self._sessions[session_id] = env + + # Create dedicated executor for this session + self._session_executors[session_id] = ThreadPoolExecutor(max_workers=1) + + return session_id, self._sessions[session_id] + + async def _destroy_session(self, session_id: str) -> None: + """ + Destroy a WebSocket session and cleanup resources. + + Args: + session_id: The session ID to destroy + """ + async with self._session_lock: + if session_id in self._sessions: + env = self._sessions.pop(session_id) + # Call close() if environment has it + if hasattr(env, 'close') and callable(env.close): + try: + env.close() + except Exception: + pass # Best effort cleanup + + if session_id in self._session_executors: + executor = self._session_executors.pop(session_id) + executor.shutdown(wait=False) + + async def _run_in_session_executor( + self, session_id: str, func: Callable, *args, **kwargs + ) -> Any: + """Run a synchronous function in the session's thread pool executor.""" + executor = self._session_executors.get(session_id, self._executor) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) + + @property + def active_sessions(self) -> int: + """Return the number of active WebSocket sessions.""" + return len(self._sessions) + + @property + def max_concurrent_envs(self) -> int: + """Return the maximum number of concurrent environments.""" + return self._max_concurrent_envs + def register_routes(self, app: FastAPI) -> None: """ Register HTTP routes on a FastAPI application. @@ -339,12 +460,141 @@ async def get_schemas() -> SchemaResponse: state=State.model_json_schema(), ) + # Register WebSocket endpoint for persistent sessions + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for persistent environment sessions. + + Each WebSocket connection gets its own environment instance (when using + factory mode) or shares the single instance (backward compatible mode). + + Message Protocol: + - Client sends: {"type": "reset|step|state|close", "data": {...}} + - Server responds: {"type": "observation|state|error", "data": {...}} + """ + await websocket.accept() + + session_id = None + session_env = None + + try: + # Create session with dedicated environment + session_id, session_env = await self._create_session() + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + message = json.loads(raw_message) + except json.JSONDecodeError as e: + error_resp = WSErrorResponse( + data={"message": f"Invalid JSON: {e}", "code": "INVALID_JSON"} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + msg_type = message.get("type", "") + msg_data = message.get("data", {}) + + try: + if msg_type == "reset": + # Handle reset + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg_data) + + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + await websocket.send_text(response.model_dump_json()) + + elif msg_type == "step": + # Handle step + if not msg_data: + error_resp = WSErrorResponse( + data={"message": "Missing action data", "code": "MISSING_ACTION"} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + # Deserialize action with Pydantic validation + try: + action = deserialize_action(msg_data, self.action_cls) + except ValidationError as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": "VALIDATION_ERROR", "errors": e.errors()} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + await websocket.send_text(response.model_dump_json()) + + elif msg_type == "state": + # Handle state request + state = session_env.state + if hasattr(state, 'model_dump'): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + await websocket.send_text(response.model_dump_json()) + + elif msg_type == "close": + # Client requested close + break + + else: + error_resp = WSErrorResponse( + data={"message": f"Unknown message type: {msg_type}", "code": "UNKNOWN_TYPE"} + ) + await websocket.send_text(error_resp.model_dump_json()) + + except Exception as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": "EXECUTION_ERROR"} + ) + await websocket.send_text(error_resp.model_dump_json()) + + except WebSocketDisconnect: + pass # Client disconnected normally + except RuntimeError as e: + # Could not create session (max concurrent reached) + try: + error_resp = WSErrorResponse( + data={"message": str(e), "code": "SESSION_ERROR"} + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception: + pass + finally: + # Cleanup session + if session_id: + await self._destroy_session(session_id) + try: + await websocket.close() + except Exception: + pass + def create_app( - env: Environment, + env: Union[Environment, Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, + max_concurrent_envs: int = 1, ) -> FastAPI: """ Create a FastAPI application with or without web interface. @@ -353,10 +603,11 @@ def create_app( including README integration for better user experience. Args: - env: The Environment instance to serve + env: The Environment instance, factory callable, or class to serve action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) Returns: FastAPI application instance with or without web interface and README integration @@ -376,15 +627,27 @@ def create_app( return create_web_interface_app(env, action_cls, observation_cls, env_name) else: # Use standard FastAPI app without web interface - return create_fastapi_app(env, action_cls, observation_cls) + return create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs) def create_fastapi_app( - env: Environment, + env: Union[Environment, Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], + max_concurrent_envs: int = 1, ) -> FastAPI: - """Create a FastAPI application with comprehensive documentation.""" + """ + Create a FastAPI application with comprehensive documentation. + + Args: + env: The Environment instance, factory callable, or class to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) + + Returns: + FastAPI application instance + """ try: from fastapi import FastAPI except ImportError: @@ -452,6 +715,6 @@ def create_fastapi_app( }, ) - server = HTTPEnvServer(env, action_cls, observation_cls) + server = HTTPEnvServer(env, action_cls, observation_cls, max_concurrent_envs) server.register_routes(app) return app diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index b438cd667..196e7ac82 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -90,7 +90,21 @@ class Environment(ABC): Args: transform: Optional transform to apply to observations + + Class Attributes: + CONCURRENCY_SAFE: Whether this environment supports concurrent sessions. + When True, multiple WebSocket connections can each have their own + environment instance (up to max_concurrent_envs). When False (default), + the environment should only be used with a single session at a time. + + Set this to True in your Environment subclass if: + - The environment uses proper session isolation (e.g., unique working dirs) + - No shared mutable state exists between instances + - External resources (databases, APIs) can handle concurrent access """ + + # Class-level flag indicating whether this environment supports concurrent sessions + CONCURRENCY_SAFE: bool = False def __init__(self, transform: Transform | None = None): self.transform = transform diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index c3ee689c0..765d6382d 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -212,3 +212,71 @@ class HealthResponse(BaseModel): ) status: str = Field(description="Health status of the environment server") + +class WSMessage(BaseModel): + """Base class for WebSocket messages.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + type: str = Field(description="Message type identifier") + + +class WSResetMessage(WSMessage): + """WebSocket message to reset the environment.""" + + type: str = Field(default="reset", description="Message type") + data: Dict[str, Any] = Field( + default_factory=dict, + description="Optional reset parameters (seed, episode_id, etc.)", + ) + + +class WSStepMessage(WSMessage): + """WebSocket message to execute a step.""" + + type: str = Field(default="step", description="Message type") + data: Dict[str, Any] = Field( + ..., description="Action data conforming to environment's action schema" + ) + + +class WSStateMessage(WSMessage): + """WebSocket message to request current state.""" + + type: str = Field(default="state", description="Message type") + + +class WSCloseMessage(WSMessage): + """WebSocket message to close the session.""" + + type: str = Field(default="close", description="Message type") + + +class WSObservationResponse(BaseModel): + """WebSocket response containing an observation.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="observation", description="Response type") + data: Dict[str, Any] = Field(description="Observation data") + + +class WSStateResponse(BaseModel): + """WebSocket response containing environment state.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="state", description="Response type") + data: Dict[str, Any] = Field(description="State data") + + +class WSErrorResponse(BaseModel): + """WebSocket response for errors.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="error", description="Response type") + data: Dict[str, Any] = Field(description="Error details including message and code") diff --git a/src/openenv/core/ws_env_client.py b/src/openenv/core/ws_env_client.py new file mode 100644 index 000000000..c6f054e85 --- /dev/null +++ b/src/openenv/core/ws_env_client.py @@ -0,0 +1,305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +WebSocket-based environment client for persistent sessions. + +This module provides a WebSocket client that maintains a persistent connection +to an environment server, enabling efficient multi-step interactions without +the overhead of HTTP request/response cycles. +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar + +from .client_types import StepResult +from .containers.runtime import LocalDockerProvider + +if TYPE_CHECKING: + from .containers.runtime import ContainerProvider + from websockets.sync.client import ClientConnection + +try: + import websockets + from websockets.sync.client import connect as ws_connect +except ImportError: + websockets = None # type: ignore + ws_connect = None # type: ignore + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") +WSEnvClientT = TypeVar("WSEnvClientT", bound="WebSocketEnvClient") + + +class WebSocketEnvClient(ABC, Generic[ActT, ObsT]): + """ + WebSocket-based environment client for persistent sessions. + + This client maintains a persistent WebSocket connection to an environment + server, enabling efficient multi-step interactions. Each client instance + corresponds to a dedicated environment session on the server. + + Compared to HTTPEnvClient: + - Lower latency for sequential interactions + - Session state is maintained server-side + - Better suited for long-running episodes + + Example: + >>> from envs.coding_env.client import CodingEnvWS + >>> + >>> # Connect to a server via WebSocket + >>> with CodingEnvWS(base_url="ws://localhost:8000") as env: + ... result = env.reset(seed=42) + ... while not result.done: + ... action = agent.predict(result.observation) + ... result = env.step(action) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + provider: Optional["ContainerProvider"] = None, + ): + """ + Initialize WebSocket client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + Will be converted to ws:// if http:// is provided. + connect_timeout_s: Timeout for establishing WebSocket connection + message_timeout_s: Timeout for receiving responses to messages + provider: Optional container provider for lifecycle management + """ + if websockets is None: + raise ImportError( + "websockets library is required for WebSocketEnvClient. " + "Install with: pip install websockets" + ) + + # Convert HTTP URL to WebSocket URL + ws_url = base_url.rstrip("/") + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[7:] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[8:] + elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): + ws_url = "ws://" + ws_url + + self._ws_url = f"{ws_url}/ws" + self._connect_timeout = connect_timeout_s + self._message_timeout = message_timeout_s + self._provider = provider + self._ws: Optional[ClientConnection] = None + + def connect(self) -> "WebSocketEnvClient": + """ + Establish WebSocket connection to the server. + + Returns: + self for method chaining + + Raises: + ConnectionError: If connection cannot be established + """ + if self._ws is not None: + return self + + try: + self._ws = ws_connect( + self._ws_url, + open_timeout=self._connect_timeout, + ) + except Exception as e: + raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e + + return self + + def disconnect(self) -> None: + """Close the WebSocket connection.""" + if self._ws is not None: + try: + # Send close message + self._send({"type": "close"}) + except Exception: + pass # Best effort + try: + self._ws.close() + except Exception: + pass + self._ws = None + + def _ensure_connected(self) -> None: + """Ensure WebSocket connection is established.""" + if self._ws is None: + self.connect() + + def _send(self, message: Dict[str, Any]) -> None: + """Send a message over the WebSocket.""" + self._ensure_connected() + assert self._ws is not None + self._ws.send(json.dumps(message)) + + def _receive(self) -> Dict[str, Any]: + """Receive and parse a message from the WebSocket.""" + assert self._ws is not None + raw = self._ws.recv(timeout=self._message_timeout) + return json.loads(raw) + + def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Send a message and wait for response.""" + self._send(message) + response = self._receive() + + # Check for error response + if response.get("type") == "error": + error_data = response.get("data", {}) + raise RuntimeError( + f"Server error: {error_data.get('message', 'Unknown error')} " + f"(code: {error_data.get('code', 'UNKNOWN')})" + ) + + return response + + @classmethod + def from_docker_image( + cls: Type[WSEnvClientT], + image: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> WSEnvClientT: + """ + Create a WebSocket environment client by spinning up a Docker container. + + Args: + image: Docker image name to run (e.g., "coding-env:latest") + provider: Container provider to use (defaults to LocalDockerProvider) + **kwargs: Additional arguments to pass to provider.start_container() + + Returns: + Connected WebSocket client instance + """ + if provider is None: + provider = LocalDockerProvider() + + # Start container + base_url = provider.start_container(image, **kwargs) + + # Wait for server to be ready + provider.wait_for_ready(base_url) + + # Create and connect client + client = cls(base_url=base_url, provider=provider) + client.connect() + + return client + + @classmethod + def from_hub( + cls: Type[WSEnvClientT], + repo_id: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> WSEnvClientT: + """ + Create a WebSocket client by pulling from a Hugging Face model hub. + """ + if provider is None: + provider = LocalDockerProvider() + + tag = kwargs.pop("tag", "latest") + base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + + return cls.from_docker_image(image=base_url, provider=provider, **kwargs) + + @abstractmethod + def _step_payload(self, action: ActT) -> dict: + """Convert an Action object to the JSON data expected by the env server.""" + raise NotImplementedError + + @abstractmethod + def _parse_result(self, payload: dict) -> StepResult[ObsT]: + """Convert a JSON response from the env server to StepResult[ObsT].""" + raise NotImplementedError + + @abstractmethod + def _parse_state(self, payload: dict) -> Any: + """Convert a JSON response from the state endpoint to a State object.""" + raise NotImplementedError + + def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment with optional parameters. + + Args: + **kwargs: Optional parameters passed to the environment's reset method. + Common parameters include: + - seed: Random seed for reproducibility + - episode_id: Custom episode identifier + + Returns: + StepResult containing initial observation + """ + message = { + "type": "reset", + "data": kwargs, + } + response = self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters (currently ignored for WebSocket) + + Returns: + StepResult containing observation, reward, and done status + """ + message = { + "type": "step", + "data": self._step_payload(action), + } + response = self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + def state(self) -> Any: + """ + Get the current environment state from the server. + + Returns: + State object with environment state information + """ + message = {"type": "state"} + response = self._send_and_receive(message) + return self._parse_state(response.get("data", {})) + + def close(self) -> None: + """ + Close the WebSocket connection and clean up resources. + + If this client was created via from_docker_image(), this will also + stop and remove the associated container. + """ + self.disconnect() + + if self._provider is not None: + self._provider.stop_container() + + def __enter__(self) -> "WebSocketEnvClient": + """Enter context manager, ensuring connection is established.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager, closing connection.""" + self.close() From e0a063d5833c5ff421bdf4368539adb131ad8b55 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 4 Dec 2025 23:43:09 +0530 Subject: [PATCH 02/27] impl concurrency management and session handling --- src/openenv/core/env_server/__init__.py | 23 ++- src/openenv/core/env_server/exceptions.py | 105 ++++++++++++ src/openenv/core/env_server/http_server.py | 176 +++++++++++++++++++-- src/openenv/core/env_server/types.py | 96 +++++++++++ 4 files changed, 384 insertions(+), 16 deletions(-) create mode 100644 src/openenv/core/env_server/exceptions.py diff --git a/src/openenv/core/env_server/__init__.py b/src/openenv/core/env_server/__init__.py index 92ebbeb2d..e1014540e 100644 --- a/src/openenv/core/env_server/__init__.py +++ b/src/openenv/core/env_server/__init__.py @@ -21,7 +21,6 @@ State, SchemaResponse, HealthResponse, - # WebSocket message types WSMessage, WSResetMessage, WSStepMessage, @@ -30,6 +29,17 @@ WSObservationResponse, WSStateResponse, WSErrorResponse, + ConcurrencyConfig, + ServerCapacityStatus, + SessionInfo, +) +from .exceptions import ( + OpenEnvError, + ConcurrencyConfigurationError, + SessionCapacityError, + SessionNotFoundError, + SessionCreationError, + EnvironmentFactoryError, ) from .web_interface import create_web_interface_app, WebInterfaceManager @@ -54,6 +64,17 @@ "WSObservationResponse", "WSStateResponse", "WSErrorResponse", + # Concurrency types + "ConcurrencyConfig", + "ServerCapacityStatus", + "SessionInfo", + # Exceptions + "OpenEnvError", + "ConcurrencyConfigurationError", + "SessionCapacityError", + "SessionNotFoundError", + "SessionCreationError", + "EnvironmentFactoryError", # Base transforms "CompositeTransform", "NullTransform", diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py new file mode 100644 index 000000000..41a8235bb --- /dev/null +++ b/src/openenv/core/env_server/exceptions.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Custom exceptions for environment server operations.""" + +from typing import Optional + + +class OpenEnvError(Exception): + """Base exception for all OpenEnv errors.""" + + pass + + +class ConcurrencyConfigurationError(OpenEnvError): + """ + Raised when an environment is misconfigured for concurrent sessions. + + This error is raised during server startup when max_concurrent_envs > 1 + is specified for an environment that is not marked as CONCURRENCY_SAFE. + """ + + def __init__( + self, + environment_name: str, + max_concurrent_envs: int, + message: Optional[str] = None, + ): + self.environment_name = environment_name + self.max_concurrent_envs = max_concurrent_envs + + if message is None: + message = ( + f"Environment '{environment_name}' is not marked as CONCURRENCY_SAFE. " + f"Cannot run with max_concurrent_envs={max_concurrent_envs}. " + f"Either set max_concurrent_envs=1 or ensure the environment " + f"properly isolates session state and set CONCURRENCY_SAFE=True." + ) + + super().__init__(message) + + +class SessionCapacityError(OpenEnvError): + """ + Raised when the server cannot accept new sessions due to capacity limits. + + This error is raised when a new WebSocket connection is attempted but + the server has already reached max_concurrent_envs active sessions. + """ + + def __init__( + self, + active_sessions: int, + max_sessions: int, + message: Optional[str] = None, + ): + self.active_sessions = active_sessions + self.max_sessions = max_sessions + + if message is None: + message = ( + f"Server at capacity: {active_sessions}/{max_sessions} sessions active. " + f"Cannot accept new connections." + ) + + super().__init__(message) + + +class SessionNotFoundError(OpenEnvError): + """Raised when attempting to access a session that does not exist.""" + + def __init__(self, session_id: str, message: Optional[str] = None): + self.session_id = session_id + + if message is None: + message = f"Session '{session_id}' not found." + + super().__init__(message) + + +class SessionCreationError(OpenEnvError): + """Raised when a session cannot be created.""" + + def __init__(self, reason: str, message: Optional[str] = None): + self.reason = reason + + if message is None: + message = f"Failed to create session: {reason}" + + super().__init__(message) + + +class EnvironmentFactoryError(OpenEnvError): + """Raised when the environment factory fails to create an instance.""" + + def __init__(self, factory_name: str, cause: Exception): + self.factory_name = factory_name + self.cause = cause + + message = f"Environment factory '{factory_name}' failed to create instance: {cause}" + + super().__init__(message) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 41cc32315..50eaac13d 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -49,6 +49,14 @@ WSObservationResponse, WSStateResponse, WSErrorResponse, + ConcurrencyConfig, + ServerCapacityStatus, + SessionInfo, +) +from .exceptions import ( + ConcurrencyConfigurationError, + SessionCapacityError, + EnvironmentFactoryError, ) @@ -90,6 +98,7 @@ def __init__( action_cls: Type[Action] = None, observation_cls: Type[Observation] = None, max_concurrent_envs: int = 1, + skip_concurrency_check: bool = False, ): """ Initialize HTTP server wrapper. @@ -103,9 +112,19 @@ def __init__( observation_cls: The Observation subclass this environment returns max_concurrent_envs: Maximum number of concurrent WebSocket sessions. Only applies when env is a factory. Default is 1. + skip_concurrency_check: If True, skip concurrency safety validation. + Use with caution for advanced users who understand + the isolation requirements. + + Raises: + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as CONCURRENCY_SAFE. """ self._env_factory: Optional[Callable[[], Environment]] = None self._max_concurrent_envs = max_concurrent_envs + self._skip_concurrency_check = skip_concurrency_check or os.getenv( + "OPENENV_SKIP_CONCURRENCY_CHECK", "" + ).lower() in ("1", "true", "yes") # Determine if env is an instance or factory if isinstance(env, Environment): @@ -116,24 +135,67 @@ def __init__( # Factory mode - env is a class or callable self._env_factory = env # Create a single instance for HTTP endpoints (backward compat) - self.env = env() + try: + self.env = env() + except Exception as e: + factory_name = getattr(env, "__name__", str(env)) + raise EnvironmentFactoryError(factory_name, e) from e else: raise TypeError( f"env must be an Environment instance or callable, got {type(env)}" ) + # Validate concurrency configuration + self._validate_concurrency_safety() + self.action_cls = action_cls self.observation_cls = observation_cls # Session management for WebSocket connections self._sessions: Dict[str, Environment] = {} self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_info: Dict[str, SessionInfo] = {} self._session_lock = asyncio.Lock() # Create thread pool for running sync code in async context # This is needed for environments using sync libraries (e.g., Playwright sync API) self._executor = ThreadPoolExecutor(max_workers=1) + def _validate_concurrency_safety(self) -> None: + """ + Validate that the environment supports the configured concurrency level. + + Raises: + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as CONCURRENCY_SAFE. + """ + if self._max_concurrent_envs <= 1: + return + + if self._skip_concurrency_check: + return + + is_concurrency_safe = getattr(self.env, "CONCURRENCY_SAFE", False) + + if not is_concurrency_safe: + env_name = type(self.env).__name__ + raise ConcurrencyConfigurationError( + environment_name=env_name, + max_concurrent_envs=self._max_concurrent_envs, + ) + + def get_capacity_status(self) -> ServerCapacityStatus: + """ + Get the current capacity status of the server. + + Returns: + ServerCapacityStatus with current session counts and availability. + """ + return ServerCapacityStatus.from_counts( + active=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): """Run a synchronous function in the thread pool executor.""" loop = asyncio.get_event_loop() @@ -165,32 +227,53 @@ async def _create_session(self) -> tuple[str, Environment]: Tuple of (session_id, environment) Raises: - RuntimeError: If max concurrent sessions reached or no factory available + SessionCapacityError: If max concurrent sessions reached + EnvironmentFactoryError: If the factory fails to create an environment """ + import time + async with self._session_lock: if len(self._sessions) >= self._max_concurrent_envs: - raise RuntimeError( - f"Maximum concurrent environments ({self._max_concurrent_envs}) reached" + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, ) + session_id = str(uuid.uuid4()) + current_time = time.time() + if self._env_factory is None: # Single instance mode - use shared env (limited concurrency) if self._sessions: - raise RuntimeError( - "Single instance mode: only one WebSocket session allowed" + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=1, + message="Single instance mode: only one WebSocket session allowed", ) - session_id = str(uuid.uuid4()) - self._sessions[session_id] = self.env + env = self.env else: # Factory mode - create new environment - session_id = str(uuid.uuid4()) - env = self._env_factory() - self._sessions[session_id] = env + try: + env = self._env_factory() + except Exception as e: + factory_name = getattr(self._env_factory, "__name__", str(self._env_factory)) + raise EnvironmentFactoryError(factory_name, e) from e + + self._sessions[session_id] = env # Create dedicated executor for this session self._session_executors[session_id] = ThreadPoolExecutor(max_workers=1) - return session_id, self._sessions[session_id] + # Track session metadata + self._session_info[session_id] = SessionInfo( + session_id=session_id, + created_at=current_time, + last_activity_at=current_time, + step_count=0, + environment_type=type(env).__name__, + ) + + return session_id, env async def _destroy_session(self, session_id: str) -> None: """ @@ -212,7 +295,37 @@ async def _destroy_session(self, session_id: str) -> None: if session_id in self._session_executors: executor = self._session_executors.pop(session_id) executor.shutdown(wait=False) + + # Remove session metadata + self._session_info.pop(session_id, None) + def _update_session_activity(self, session_id: str, increment_step: bool = False) -> None: + """ + Update session activity timestamp and optionally increment step count. + + Args: + session_id: The session ID to update + increment_step: If True, increment the step count + """ + import time + + if session_id in self._session_info: + self._session_info[session_id].last_activity_at = time.time() + if increment_step: + self._session_info[session_id].step_count += 1 + + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: + """ + Get information about a specific session. + + Args: + session_id: The session ID to query + + Returns: + SessionInfo if the session exists, None otherwise + """ + return self._session_info.get(session_id) + async def _run_in_session_executor( self, session_id: str, func: Callable, *args, **kwargs ) -> Any: @@ -231,6 +344,11 @@ def max_concurrent_envs(self) -> int: """Return the maximum number of concurrent environments.""" return self._max_concurrent_envs + @property + def is_concurrency_safe(self) -> bool: + """Return whether the environment is marked as concurrency safe.""" + return getattr(self.env, "CONCURRENCY_SAFE", False) + def register_routes(self, app: FastAPI) -> None: """ Register HTTP routes on a FastAPI application. @@ -508,6 +626,8 @@ async def websocket_endpoint(websocket: WebSocket): session_id, session_env.reset, **valid_kwargs ) + self._update_session_activity(session_id) + response = WSObservationResponse( data=serialize_observation(observation) ) @@ -536,6 +656,8 @@ async def websocket_endpoint(websocket: WebSocket): session_id, session_env.step, action ) + self._update_session_activity(session_id, increment_step=True) + response = WSObservationResponse( data=serialize_observation(observation) ) @@ -569,9 +691,33 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(error_resp.model_dump_json()) except WebSocketDisconnect: - pass # Client disconnected normally - except RuntimeError as e: - # Could not create session (max concurrent reached) + pass + except SessionCapacityError as e: + try: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "CAPACITY_REACHED", + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception: + pass + except EnvironmentFactoryError as e: + try: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "FACTORY_ERROR", + "factory_name": e.factory_name, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception: + pass + except Exception as e: try: error_resp = WSErrorResponse( data={"message": str(e), "code": "SESSION_ERROR"} diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 765d6382d..39074595f 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -280,3 +280,99 @@ class WSErrorResponse(BaseModel): type: str = Field(default="error", description="Response type") data: Dict[str, Any] = Field(description="Error details including message and code") + + +class ConcurrencySafetyLevel(str): + """ + Classification of environment concurrency safety. + + Environments are classified based on their ability to safely handle + multiple concurrent sessions within a single container. + """ + + UNSAFE = "unsafe" + SAFE = "safe" + + +class ConcurrencyConfig(BaseModel): + """Configuration for concurrent environment sessions.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + max_concurrent_envs: int = Field( + default=1, + ge=1, + le=1000, + description="Maximum number of concurrent WebSocket sessions allowed", + ) + session_timeout_seconds: Optional[float] = Field( + default=None, + gt=0, + description="Timeout in seconds for inactive sessions. None means no timeout.", + ) + reject_on_capacity: bool = Field( + default=True, + description="If True, reject new connections when at capacity. If False, queue them.", + ) + + +class ServerCapacityStatus(BaseModel): + """Status of server capacity for concurrent sessions.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + active_sessions: int = Field( + ge=0, + description="Number of currently active sessions", + ) + max_sessions: int = Field( + ge=1, + description="Maximum number of allowed sessions", + ) + available_slots: int = Field( + ge=0, + description="Number of available session slots", + ) + is_at_capacity: bool = Field( + description="Whether the server has reached maximum capacity", + ) + + @classmethod + def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": + """Create status from active and max session counts.""" + available = max(0, max_sessions - active) + return cls( + active_sessions=active, + max_sessions=max_sessions, + available_slots=available, + is_at_capacity=active >= max_sessions, + ) + + +class SessionInfo(BaseModel): + """Information about an active session.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + session_id: str = Field(description="Unique identifier for the session") + created_at: float = Field(description="Unix timestamp when the session was created") + last_activity_at: float = Field( + description="Unix timestamp of the last activity in the session" + ) + step_count: int = Field( + default=0, + ge=0, + description="Number of steps executed in this session", + ) + environment_type: str = Field( + description="Type name of the environment class for this session" + ) From 95563b0afdeb8806d37ded906544ddc9f6aceaad Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Sat, 6 Dec 2025 09:57:43 +0100 Subject: [PATCH 03/27] add async to http server --- src/openenv/core/env_server/http_server.py | 50 +++++++++++++--------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 7fa7c0f32..d301fa7e9 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -84,8 +84,14 @@ def __init__( self.action_cls = action_cls self.observation_cls = observation_cls # Create thread pool for running sync code in async context - # This is needed for environments using sync libraries (e.g., Playwright sync API) - self._executor = ThreadPoolExecutor(max_workers=1) + # This is needed for environments using sync libraries (e.g., Playwright) + # Configurable via OPENENV_THREAD_POOL_SIZE (default: 32) + pool_size = int(os.getenv("OPENENV_THREAD_POOL_SIZE", "32")) + self._executor = ThreadPoolExecutor(max_workers=pool_size) + + # Check if environment has async methods for better concurrency + self._has_step_async = hasattr(env, "step_async") and asyncio.iscoroutinefunction(env.step_async) + self._has_reset_async = hasattr(env, "reset_async") and asyncio.iscoroutinefunction(env.reset_async) async def _run_sync_in_thread_pool(self, func, *args, **kwargs): """Run a synchronous function in the thread pool executor.""" @@ -99,9 +105,7 @@ def _get_valid_kwargs(self, sig, kwargs, skip_params=None): valid_kwargs = {} - has_kwargs = any( - p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() - ) + has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) for k, v in kwargs.items(): if k in sig.parameters or has_kwargs: @@ -128,13 +132,17 @@ async def reset_handler( kwargs = request.model_dump(exclude_unset=True) # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.reset) + if self._has_reset_async: + sig = inspect.signature(self.env.reset_async) + else: + sig = inspect.signature(self.env.reset) valid_kwargs = self._get_valid_kwargs(sig, kwargs) - # Run synchronous reset in thread pool to avoid blocking event loop - observation = await self._run_sync_in_thread_pool( - self.env.reset, **valid_kwargs - ) + # Use async method if available for better concurrency + if self._has_reset_async: + observation = await self.env.reset_async(**valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool(self.env.reset, **valid_kwargs) return ResetResponse(**serialize_observation(observation)) # Helper function to handle step endpoint @@ -147,22 +155,24 @@ async def step_handler(request: StepRequest) -> StepResponse: action = deserialize_action(action_data, self.action_cls) except ValidationError as e: # Return HTTP 422 with detailed validation errors - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() - ) + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors()) # Handle optional parameters # Start with all fields from the request, including extra ones, but exclude 'action' kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.step) + if self._has_step_async: + sig = inspect.signature(self.env.step_async) + else: + sig = inspect.signature(self.env.step) valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) - # Run synchronous step in thread pool to avoid blocking event loop - observation = await self._run_sync_in_thread_pool( - self.env.step, action, **valid_kwargs - ) + # Use async method if available for better concurrency + if self._has_step_async: + observation = await self.env.step_async(action, **valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool(self.env.step, action, **valid_kwargs) # Return serialized observation return StepResponse(**serialize_observation(observation)) @@ -388,9 +398,7 @@ def create_fastapi_app( try: from fastapi import FastAPI except ImportError: - raise ImportError( - "FastAPI is required. Install with: pip install fastapi uvicorn" - ) + raise ImportError("FastAPI is required. Install with: pip install fastapi uvicorn") app = FastAPI( title="OpenEnv Environment HTTP API", From 3601357a9727c75f7a805c6b1364118884ce7ae8 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Mon, 8 Dec 2025 01:40:40 +0530 Subject: [PATCH 04/27] concurrency config --- src/openenv/core/__init__.py | 13 +- src/openenv/core/env_server/http_server.py | 138 ++++++++++++++++--- src/openenv/core/env_server/serialization.py | 2 +- 3 files changed, 123 insertions(+), 30 deletions(-) diff --git a/src/openenv/core/__init__.py b/src/openenv/core/__init__.py index 3592ead53..93ae09786 100644 --- a/src/openenv/core/__init__.py +++ b/src/openenv/core/__init__.py @@ -7,15 +7,10 @@ """Core components for agentic environments.""" # Re-export main components from submodules for convenience -from .env_server import * -from .client_types import StepResult -from .http_env_client import HTTPEnvClient -from .ws_env_client import WebSocketEnvClient +from .env_server import * # noqa: F403 +from .env_server import __all__ as _env_server_all + # Note: MCP module doesn't export anything yet -__all__ = [ - "HTTPEnvClient", - "WebSocketEnvClient", - "StepResult", -] +__all__ = list(_env_server_all) \ No newline at end of file diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 517809655..8dd144987 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -99,6 +99,7 @@ def __init__( observation_cls: Type[Observation] = None, max_concurrent_envs: int = 1, skip_concurrency_check: bool = False, + concurrency_config: Optional[ConcurrencyConfig] = None, ): """ Initialize HTTP server wrapper. @@ -112,16 +113,33 @@ def __init__( observation_cls: The Observation subclass this environment returns max_concurrent_envs: Maximum number of concurrent WebSocket sessions. Only applies when env is a factory. Default is 1. + If concurrency_config is provided, this parameter is ignored. skip_concurrency_check: If True, skip concurrency safety validation. Use with caution for advanced users who understand the isolation requirements. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + If provided, overrides max_concurrent_envs and allows + configuration of session timeout and capacity behavior. Raises: ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an environment that is not marked as CONCURRENCY_SAFE. """ self._env_factory: Optional[Callable[[], Environment]] = None - self._max_concurrent_envs = max_concurrent_envs + + # Handle concurrency configuration + if concurrency_config is not None: + self._concurrency_config = concurrency_config + self._max_concurrent_envs = concurrency_config.max_concurrent_envs + else: + # Use legacy parameters + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=max_concurrent_envs, + session_timeout_seconds=None, + reject_on_capacity=True, + ) + self._max_concurrent_envs = max_concurrent_envs + self._skip_concurrency_check = skip_concurrency_check or os.getenv( "OPENENV_SKIP_CONCURRENCY_CHECK", "" ).lower() in ("1", "true", "yes") @@ -238,10 +256,18 @@ async def _create_session(self) -> tuple[str, Environment]: async with self._session_lock: if len(self._sessions) >= self._max_concurrent_envs: - raise SessionCapacityError( - active_sessions=len(self._sessions), - max_sessions=self._max_concurrent_envs, - ) + if self._concurrency_config.reject_on_capacity: + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + else: + # TODO: Implement queuing mechanism when reject_on_capacity=False + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + message="Session queuing not yet implemented", + ) session_id = str(uuid.uuid4()) current_time = time.time() @@ -353,6 +379,11 @@ def is_concurrency_safe(self) -> bool: """Return whether the environment is marked as concurrency safe.""" return getattr(self.env, "CONCURRENCY_SAFE", False) + @property + def concurrency_config(self) -> ConcurrencyConfig: + """Return the concurrency configuration.""" + return self._concurrency_config + def register_routes(self, app: FastAPI) -> None: """ Register HTTP routes on a FastAPI application. @@ -539,6 +570,25 @@ async def step(request: StepRequest) -> StepResponse: ] register_get_endpoints(app, get_endpoints) + # Register concurrency config endpoint + @app.get( + "/concurrency", + response_model=ConcurrencyConfig, + tags=["Environment Info"], + summary="Get concurrency configuration", + description=""" +Get the current concurrency configuration for this server. + +Returns information about: +- **max_concurrent_envs**: Maximum number of concurrent WebSocket sessions +- **session_timeout_seconds**: Timeout for inactive sessions (None if no timeout) +- **reject_on_capacity**: Whether to reject or queue connections at capacity + """, + ) + async def get_concurrency_config() -> ConcurrencyConfig: + """Return concurrency configuration.""" + return self._concurrency_config + # Register combined schema endpoint @app.get( "/schema", @@ -598,8 +648,8 @@ async def websocket_endpoint(websocket: WebSocket): factory mode) or shares the single instance (backward compatible mode). Message Protocol: - - Client sends: {"type": "reset|step|state|close", "data": {...}} - - Server responds: {"type": "observation|state|error", "data": {...}} + - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage + - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse """ await websocket.accept() @@ -615,7 +665,7 @@ async def websocket_endpoint(websocket: WebSocket): raw_message = await websocket.receive_text() try: - message = json.loads(raw_message) + message_dict = json.loads(raw_message) except json.JSONDecodeError as e: error_resp = WSErrorResponse( data={"message": f"Invalid JSON: {e}", "code": "INVALID_JSON"} @@ -623,14 +673,23 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(error_resp.model_dump_json()) continue - msg_type = message.get("type", "") - msg_data = message.get("data", {}) + msg_type = message_dict.get("type", "") try: if msg_type == "reset": + # Parse and validate reset message + try: + msg = WSResetMessage(**message_dict) + except ValidationError as e: + error_resp = WSErrorResponse( + data={"message": "Invalid reset message", "code": "VALIDATION_ERROR", "errors": e.errors()} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + # Handle reset sig = inspect.signature(session_env.reset) - valid_kwargs = self._get_valid_kwargs(sig, msg_data) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) observation = await self._run_in_session_executor( session_id, session_env.reset, **valid_kwargs @@ -644,17 +703,19 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(response.model_dump_json()) elif msg_type == "step": - # Handle step - if not msg_data: + # Parse and validate step message + try: + msg = WSStepMessage(**message_dict) + except ValidationError as e: error_resp = WSErrorResponse( - data={"message": "Missing action data", "code": "MISSING_ACTION"} + data={"message": "Invalid step message", "code": "VALIDATION_ERROR", "errors": e.errors()} ) await websocket.send_text(error_resp.model_dump_json()) continue # Deserialize action with Pydantic validation try: - action = deserialize_action(msg_data, self.action_cls) + action = deserialize_action(msg.data, self.action_cls) except ValidationError as e: error_resp = WSErrorResponse( data={"message": str(e), "code": "VALIDATION_ERROR", "errors": e.errors()} @@ -674,6 +735,16 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(response.model_dump_json()) elif msg_type == "state": + # Parse and validate state message + try: + msg = WSStateMessage(**message_dict) + except ValidationError as e: + error_resp = WSErrorResponse( + data={"message": "Invalid state message", "code": "VALIDATION_ERROR", "errors": e.errors()} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + # Handle state request state = session_env.state if hasattr(state, 'model_dump'): @@ -685,6 +756,16 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(response.model_dump_json()) elif msg_type == "close": + # Parse and validate close message + try: + msg = WSCloseMessage(**message_dict) + except ValidationError as e: + error_resp = WSErrorResponse( + data={"message": "Invalid close message", "code": "VALIDATION_ERROR", "errors": e.errors()} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + # Client requested close break @@ -751,6 +832,7 @@ def create_app( observation_cls: Type[Observation], env_name: Optional[str] = None, max_concurrent_envs: int = 1, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ Create a FastAPI application with or without web interface. @@ -763,7 +845,10 @@ def create_app( action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading - max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). + Ignored if concurrency_config is provided. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + If provided, overrides max_concurrent_envs. Returns: FastAPI application instance with or without web interface and README integration @@ -780,10 +865,16 @@ def create_app( # Import web interface only when needed from .web_interface import create_web_interface_app - return create_web_interface_app(env, action_cls, observation_cls, env_name) + return create_web_interface_app( + env, action_cls, observation_cls, env_name, + max_concurrent_envs, concurrency_config + ) else: # Use standard FastAPI app without web interface - return create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs) + return create_fastapi_app( + env, action_cls, observation_cls, + max_concurrent_envs, concurrency_config + ) def create_fastapi_app( @@ -791,6 +882,7 @@ def create_fastapi_app( action_cls: Type[Action], observation_cls: Type[Observation], max_concurrent_envs: int = 1, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ Create a FastAPI application with comprehensive documentation. @@ -799,7 +891,10 @@ def create_fastapi_app( env: The Environment instance, factory callable, or class to serve action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns - max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). + Ignored if concurrency_config is provided. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + If provided, overrides max_concurrent_envs. Returns: FastAPI application instance @@ -869,6 +964,9 @@ def create_fastapi_app( }, ) - server = HTTPEnvServer(env, action_cls, observation_cls, max_concurrent_envs) + server = HTTPEnvServer( + env, action_cls, observation_cls, + max_concurrent_envs, concurrency_config=concurrency_config + ) server.register_routes(app) return app diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index a97a05283..df06592f5 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -80,7 +80,7 @@ def deserialize_action_with_preprocessing( value = [] if isinstance(value, list): try: - import torch + import torch # type: ignore processed_data[key] = torch.tensor(value, dtype=torch.long) except ImportError: From 600acb41e952525bbb564ae3fbeb8559f3131694 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Mon, 8 Dec 2025 18:59:59 +0530 Subject: [PATCH 05/27] chore: add websockets to pyproject.toml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 811c068c9..edb6c1f17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ dependencies = [ "huggingface_hub>=0.20.0", "openai>=2.7.2", "tomli>=2.3.0", - "tomli-w>=1.2.0" + "tomli-w>=1.2.0", + "websockets>=15.0.1", ] [project.optional-dependencies] From a98851a2e5a1ce12b13595f95aa632f2c19f0fd4 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 14:45:35 +0100 Subject: [PATCH 06/27] add concurrency safe pram --- .../openenv_env/server/__ENV_NAME___environment.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py index e2a9ce0b7..72db6472f 100644 --- a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py +++ b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py @@ -36,6 +36,12 @@ class __ENV_CLASS_NAME__Environment(Environment): >>> print(obs.message_length) # 5 """ + # Enable concurrent WebSocket sessions. + # Set to True if your environment isolates state between instances. + # When True, multiple WebSocket clients can connect simultaneously, each + # getting their own environment instance (when using factory mode in app.py). + CONCURRENCY_SAFE: bool = True + def __init__(self): """Initialize the __ENV_NAME__ environment.""" self._state = State(episode_id=str(uuid4()), step_count=0) From 8197d6f29c1f3dd6a8b7abdc364c69cd33354429 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 14:45:54 +0100 Subject: [PATCH 07/27] use factory in template app --- .../cli/templates/openenv_env/server/app.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/server/app.py b/src/openenv/cli/templates/openenv_env/server/app.py index db216fb06..87e3db6dc 100644 --- a/src/openenv/cli/templates/openenv_env/server/app.py +++ b/src/openenv/cli/templates/openenv_env/server/app.py @@ -8,7 +8,14 @@ FastAPI application for the __ENV_TITLE_NAME__ Environment. This module creates an HTTP server that exposes the __ENV_CLASS_NAME__Environment -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with HTTPEnvClient and WebSocketEnvClient. + +Endpoints: + - POST /reset: Reset the environment + - POST /step: Execute an action + - GET /state: Get current environment state + - GET /schema: Get action/observation schemas + - WS /ws: WebSocket endpoint for persistent sessions Usage: # Development (with auto-reload): @@ -31,15 +38,14 @@ from __ENV_NAME__.models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation from .__ENV_NAME___environment import __ENV_CLASS_NAME__Environment -# Create the environment instance -env = __ENV_CLASS_NAME__Environment() # Create the app with web interface and README integration app = create_app( - env, + __ENV_CLASS_NAME__Environment, __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation, env_name="__ENV_NAME__", + max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions ) From f72b6dad63275127b536c6448daa7f9a4730d4c5 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 14:46:16 +0100 Subject: [PATCH 08/27] us WS in client --- .../cli/templates/openenv_env/client.py | 95 ++++++++++++++++++- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/client.py b/src/openenv/cli/templates/openenv_env/client.py index 703b28a85..0775f2536 100644 --- a/src/openenv/cli/templates/openenv_env/client.py +++ b/src/openenv/cli/templates/openenv_env/client.py @@ -5,10 +5,11 @@ # LICENSE file in the root directory of this source tree. """ -__ENV_TITLE_NAME__ Environment HTTP Client. +__ENV_TITLE_NAME__ Environment Clients. -This module provides the client for connecting to a __ENV_TITLE_NAME__ Environment server -over HTTP. +This module provides clients for connecting to a __ENV_TITLE_NAME__ Environment server: +- __ENV_CLASS_NAME__Env: HTTP client for request/response interactions +- __ENV_CLASS_NAME__EnvWS: WebSocket client for persistent sessions """ from typing import Any, Dict @@ -16,6 +17,7 @@ from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.ws_env_client import WebSocketEnvClient from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation @@ -98,3 +100,90 @@ def _parse_state(self, payload: Dict) -> State: episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), ) + + +class __ENV_CLASS_NAME__EnvWS(WebSocketEnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation]): + """ + WebSocket client for the __ENV_TITLE_NAME__ Environment. + + This client maintains a persistent WebSocket connection to the environment server, + enabling efficient multi-step interactions with lower latency than HTTP. + Each client instance has its own dedicated environment session on the server. + + Advantages over HTTP client: + - Lower latency for sequential interactions (no connection overhead per request) + - Session state is maintained server-side + - Better suited for long-running episodes + + Example: + >>> # Connect to a running server via WebSocket + >>> with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.echoed_message) + ... + ... result = client.step(__ENV_CLASS_NAME__Action(message="Hello!")) + ... print(result.observation.echoed_message) + + Example with Docker: + >>> # Automatically start container and connect via WebSocket + >>> client = __ENV_CLASS_NAME__EnvWS.from_docker_image("__ENV_NAME__-env:latest") + >>> try: + ... result = client.reset() + ... result = client.step(__ENV_CLASS_NAME__Action(message="Test")) + ... finally: + ... client.close() + """ + + def _step_payload(self, action: __ENV_CLASS_NAME__Action) -> Dict: + """ + Convert __ENV_CLASS_NAME__Action to JSON payload for step message. + + Args: + action: __ENV_CLASS_NAME__Action instance + + Returns: + Dictionary representation suitable for JSON encoding + """ + return { + "message": action.message, + } + + def _parse_result(self, payload: Dict) -> StepResult[__ENV_CLASS_NAME__Observation]: + """ + Parse WebSocket response into StepResult[__ENV_CLASS_NAME__Observation]. + + Args: + payload: JSON response data from server + + Returns: + StepResult with __ENV_CLASS_NAME__Observation + """ + obs_data = payload.get("observation", {}) + observation = __ENV_CLASS_NAME__Observation( + echoed_message=obs_data.get("echoed_message", ""), + message_length=obs_data.get("message_length", 0), + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict) -> State: + """ + Parse WebSocket state response into State object. + + Args: + payload: JSON response from state request + + Returns: + State object with episode_id and step_count + """ + return State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + ) From 26b1148eab604a566c00b29a651a2a0a7bed2fb5 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 14:46:22 +0100 Subject: [PATCH 09/27] expose ws classes --- src/openenv/cli/templates/openenv_env/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/__init__.py b/src/openenv/cli/templates/openenv_env/__init__.py index 656800a55..aed293ba8 100644 --- a/src/openenv/cli/templates/openenv_env/__init__.py +++ b/src/openenv/cli/templates/openenv_env/__init__.py @@ -6,8 +6,12 @@ """__ENV_TITLE_NAME__ Environment - A simple test environment for HTTP server.""" -from .client import __ENV_CLASS_NAME__Env +from .client import __ENV_CLASS_NAME__Env, __ENV_CLASS_NAME__EnvWS from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation -__all__ = ["__ENV_CLASS_NAME__Action", "__ENV_CLASS_NAME__Observation", "__ENV_CLASS_NAME__Env"] - +__all__ = [ + "__ENV_CLASS_NAME__Action", + "__ENV_CLASS_NAME__Observation", + "__ENV_CLASS_NAME__Env", + "__ENV_CLASS_NAME__EnvWS", +] From 1ddd8d8537f29c8360b255bcb0200c7a6395a0b7 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 14:46:49 +0100 Subject: [PATCH 10/27] add websocket examples to template readme --- .../cli/templates/openenv_env/README.md | 60 ++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/README.md b/src/openenv/cli/templates/openenv_env/README.md index ef238dfb7..f6a5c0292 100644 --- a/src/openenv/cli/templates/openenv_env/README.md +++ b/src/openenv/cli/templates/openenv_env/README.md @@ -114,6 +114,7 @@ The deployed space includes: - **Web Interface** at `/web` - Interactive UI for exploring the environment - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface - **Health Check** at `/health` - Container health monitoring +- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions ## Environment Details @@ -154,6 +155,61 @@ result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message="Hello!")) Note: When connecting to an existing server, `__ENV_NAME__env.close()` will NOT stop the server. +### WebSocket Client for Persistent Sessions + +For long-running episodes or when you need lower latency, use the WebSocket client: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__EnvWS + +# Connect via WebSocket (maintains persistent connection) +with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as env: + result = env.reset() + print(f"Reset: {result.observation.echoed_message}") + # Multiple steps with low latency + for msg in ["Hello", "World", "!"]: + result = env.step(__ENV_CLASS_NAME__Action(message=msg)) + print(f"Echoed: {result.observation.echoed_message}") +``` + +WebSocket advantages: +- **Lower latency**: No HTTP connection overhead per request +- **Persistent session**: Server maintains your environment state +- **Efficient for episodes**: Better for many sequential steps + +### Concurrent WebSocket Sessions + +The server supports multiple concurrent WebSocket connections. To enable this, +modify `server/app.py` to use factory mode: + +```python +# In server/app.py - use factory mode for concurrent sessions +app = create_app( + __ENV_CLASS_NAME__Environment, # Pass class, not instance + __ENV_CLASS_NAME__Action, + __ENV_CLASS_NAME__Observation, + max_concurrent_envs=4, # Allow 4 concurrent sessions +) +``` + +Then multiple clients can connect simultaneously: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__EnvWS +from concurrent.futures import ThreadPoolExecutor + +def run_episode(client_id: int): + with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as env: + result = env.reset() + for i in range(10): + result = env.step(__ENV_CLASS_NAME__Action(message=f"Client {client_id}, step {i}")) + return client_id, result.observation.message_length + +# Run 4 episodes concurrently +with ThreadPoolExecutor(max_workers=4) as executor: + results = list(executor.map(run_episode, range(4))) +``` + ## Development & Testing ### Direct Environment Testing @@ -189,11 +245,11 @@ __ENV_NAME__/ ├── openenv.yaml # OpenEnv manifest ├── pyproject.toml # Project metadata and dependencies ├── uv.lock # Locked dependencies (generated) -├── client.py # __ENV_CLASS_NAME__Env client implementation +├── client.py # __ENV_CLASS_NAME__Env (HTTP) and __ENV_CLASS_NAME__EnvWS (WebSocket) clients ├── models.py # Action and Observation models └── server/ ├── __init__.py # Server module exports ├── __ENV_NAME___environment.py # Core environment logic - ├── app.py # FastAPI application + ├── app.py # FastAPI application (HTTP + WebSocket endpoints) └── Dockerfile # Container image definition ``` From 7138716eef49164e612e637fff40576d850762de Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 15:23:18 +0100 Subject: [PATCH 11/27] add note to toml for github install --- src/openenv/cli/templates/openenv_env/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/openenv/cli/templates/openenv_env/pyproject.toml b/src/openenv/cli/templates/openenv_env/pyproject.toml index 55b90113f..4c6b948ff 100644 --- a/src/openenv/cli/templates/openenv_env/pyproject.toml +++ b/src/openenv/cli/templates/openenv_env/pyproject.toml @@ -15,6 +15,8 @@ description = "__ENV_TITLE_NAME__ environment for OpenEnv" requires-python = ">=3.10" dependencies = [ # Core OpenEnv runtime (provides FastAPI server + HTTP client types) + # install from github + # "openenv[core] @ git+https://github.com/meta-pytorch/OpenEnv.git", "openenv[core]>=0.2.0", # Environment-specific dependencies # Add all dependencies needed for your environment here From 438f96647c63c317f55abf3992fbcd9930209a83 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:34:18 +0530 Subject: [PATCH 12/27] refactor: enforce env factory usage and drop instance mode --- src/openenv/core/env_server/http_server.py | 76 ++++++-------------- src/openenv/core/env_server/web_interface.py | 21 ++++-- 2 files changed, 36 insertions(+), 61 deletions(-) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 8dd144987..fd25739b2 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -76,13 +76,9 @@ class HTTPEnvServer: >>> from core.env_server import HTTPEnvServer >>> from envs.coding_env.server import CodeExecutionEnvironment >>> - >>> # Single environment (backward compatible) - >>> env = CodeExecutionEnvironment() - >>> server = HTTPEnvServer(env) - >>> - >>> # Factory pattern for concurrent sessions + >>> # Pass environment class (factory pattern) >>> server = HTTPEnvServer( - ... env=CodeExecutionEnvironment, # Pass class, not instance + ... env=CodeExecutionEnvironment, ... max_concurrent_envs=4, ... ) >>> @@ -94,9 +90,9 @@ class HTTPEnvServer: def __init__( self, - env: Union[Environment, Callable[[], Environment], Type[Environment]], - action_cls: Type[Action] = None, - observation_cls: Type[Observation] = None, + env: Union[Callable[[], Environment], Type[Environment]], + action_cls: Type[Action], + observation_cls: Type[Observation], max_concurrent_envs: int = 1, skip_concurrency_check: bool = False, concurrency_config: Optional[ConcurrencyConfig] = None, @@ -105,14 +101,11 @@ def __init__( Initialize HTTP server wrapper. Args: - env: The Environment instance, factory callable, or class to wrap. - - If an instance is provided, it's used directly (single-env mode) - - If a callable/class is provided, it's called to create new - environments for each WebSocket session (factory mode) + env: Environment factory (callable or class) that creates new instances. + Will be called to create a new environment for each WebSocket session. action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns - max_concurrent_envs: Maximum number of concurrent WebSocket sessions. - Only applies when env is a factory. Default is 1. + max_concurrent_envs: Maximum number of concurrent WebSocket sessions (default: 1). If concurrency_config is provided, this parameter is ignored. skip_concurrency_check: If True, skip concurrency safety validation. Use with caution for advanced users who understand @@ -125,7 +118,14 @@ def __init__( ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an environment that is not marked as CONCURRENCY_SAFE. """ - self._env_factory: Optional[Callable[[], Environment]] = None + # Validate that env is callable + if not callable(env): + raise TypeError( + f"env must be a callable (class or factory function), got {type(env)}. " + f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())." + ) + + self._env_factory: Callable[[], Environment] = env # Handle concurrency configuration if concurrency_config is not None: @@ -144,24 +144,7 @@ def __init__( "OPENENV_SKIP_CONCURRENCY_CHECK", "" ).lower() in ("1", "true", "yes") - # Determine if env is an instance or factory - if isinstance(env, Environment): - # Single instance mode (backward compatible) - self.env = env - self._env_factory = None - elif callable(env): - # Factory mode - env is a class or callable - self._env_factory = env - # Create a single instance for HTTP endpoints (backward compat) - try: - self.env = env() - except Exception as e: - factory_name = getattr(env, "__name__", str(env)) - raise EnvironmentFactoryError(factory_name, e) from e - else: - raise TypeError( - f"env must be an Environment instance or callable, got {type(env)}" - ) + self.env = env() # Validate concurrency configuration self._validate_concurrency_safety() @@ -272,22 +255,7 @@ async def _create_session(self) -> tuple[str, Environment]: session_id = str(uuid.uuid4()) current_time = time.time() - if self._env_factory is None: - # Single instance mode - use shared env (limited concurrency) - if self._sessions: - raise SessionCapacityError( - active_sessions=len(self._sessions), - max_sessions=1, - message="Single instance mode: only one WebSocket session allowed", - ) - env = self.env - else: - # Factory mode - create new environment - try: - env = self._env_factory() - except Exception as e: - factory_name = getattr(self._env_factory, "__name__", str(self._env_factory)) - raise EnvironmentFactoryError(factory_name, e) from e + env = self._env_factory() self._sessions[session_id] = env @@ -827,7 +795,7 @@ async def websocket_endpoint(websocket: WebSocket): def create_app( - env: Union[Environment, Callable[[], Environment], Type[Environment]], + env: Union[Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, @@ -841,7 +809,7 @@ def create_app( including README integration for better user experience. Args: - env: The Environment instance, factory callable, or class to serve + env: Environment factory (callable or class) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading @@ -878,7 +846,7 @@ def create_app( def create_fastapi_app( - env: Union[Environment, Callable[[], Environment], Type[Environment]], + env: Union[Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], max_concurrent_envs: int = 1, @@ -888,7 +856,7 @@ def create_fastapi_app( Create a FastAPI application with comprehensive documentation. Args: - env: The Environment instance, factory callable, or class to serve + env: Environment factory (callable or class) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index b370cfa53..52ce4a113 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -14,7 +14,7 @@ from __future__ import annotations import json -from typing import Any, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, Union from datetime import datetime from fastapi import FastAPI, WebSocket, WebSocketDisconnect @@ -23,7 +23,7 @@ from .interfaces import Environment from .serialization import deserialize_action_with_preprocessing, serialize_observation -from .types import Action, Observation, State, EnvironmentMetadata +from .types import Action, Observation, State, EnvironmentMetadata, ConcurrencyConfig def load_environment_metadata( @@ -251,19 +251,23 @@ def get_state(self) -> Dict[str, Any]: def create_web_interface_app( - env: Environment, + env: Union[Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, + max_concurrent_envs: int = 1, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ Create a FastAPI application with web interface for the given environment. Args: - env: The Environment instance to serve + env: Environment factory (callable or class) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings Returns: FastAPI application instance with web interface @@ -271,13 +275,16 @@ def create_web_interface_app( from .http_server import create_fastapi_app # Create the base environment app - app = create_fastapi_app(env, action_cls, observation_cls) + app = create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs, concurrency_config) + + # Create a test instance for metadata + env_instance = env() # Load environment metadata - metadata = load_environment_metadata(env, env_name) + metadata = load_environment_metadata(env_instance, env_name) # Create web interface manager - web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + web_manager = WebInterfaceManager(env_instance, action_cls, observation_cls, metadata) # Add web interface routes @app.get("/web", response_class=HTMLResponse) From 7319be0aa2e3a382366fcc18601fdff259c02097 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:37:47 +0530 Subject: [PATCH 13/27] refactor(ws): replace WSMessage with typed BaseMessage + discriminated WSIncomingMessage --- src/openenv/core/env_server/__init__.py | 6 +++-- src/openenv/core/env_server/types.py | 32 +++++++++++++++---------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/openenv/core/env_server/__init__.py b/src/openenv/core/env_server/__init__.py index e1014540e..ed0d41278 100644 --- a/src/openenv/core/env_server/__init__.py +++ b/src/openenv/core/env_server/__init__.py @@ -21,7 +21,8 @@ State, SchemaResponse, HealthResponse, - WSMessage, + BaseMessage, + WSIncomingMessage, WSResetMessage, WSStepMessage, WSStateMessage, @@ -56,7 +57,8 @@ "SchemaResponse", "HealthResponse", # WebSocket message types - "WSMessage", + "BaseMessage", + "WSIncomingMessage", "WSResetMessage", "WSStepMessage", "WSStateMessage", diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 39074595f..279726f6d 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, Literal, Annotated from pydantic import BaseModel, Field, ConfigDict @@ -213,46 +213,52 @@ class HealthResponse(BaseModel): status: str = Field(description="Health status of the environment server") -class WSMessage(BaseModel): - """Base class for WebSocket messages.""" + +class BaseMessage(BaseModel): + """Base class for WebSocket messages with shared configuration.""" model_config = ConfigDict( extra="forbid", validate_assignment=True, ) - type: str = Field(description="Message type identifier") - -class WSResetMessage(WSMessage): +class WSResetMessage(BaseMessage): """WebSocket message to reset the environment.""" - type: str = Field(default="reset", description="Message type") + type: Literal["reset"] = Field(default="reset", description="Message type") data: Dict[str, Any] = Field( default_factory=dict, description="Optional reset parameters (seed, episode_id, etc.)", ) -class WSStepMessage(WSMessage): +class WSStepMessage(BaseMessage): """WebSocket message to execute a step.""" - type: str = Field(default="step", description="Message type") + type: Literal["step"] = Field(default="step", description="Message type") data: Dict[str, Any] = Field( ..., description="Action data conforming to environment's action schema" ) -class WSStateMessage(WSMessage): +class WSStateMessage(BaseMessage): """WebSocket message to request current state.""" - type: str = Field(default="state", description="Message type") + type: Literal["state"] = Field(default="state", description="Message type") -class WSCloseMessage(WSMessage): +class WSCloseMessage(BaseMessage): """WebSocket message to close the session.""" - type: str = Field(default="close", description="Message type") + type: Literal["close"] = Field(default="close", description="Message type") + + +# Discriminated union for incoming WebSocket messages +WSIncomingMessage = Annotated[ + WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage, + Field(discriminator="type") +] class WSObservationResponse(BaseModel): From 561f9023b73eda7bf303c65326c2468bf4562848 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:38:09 +0530 Subject: [PATCH 14/27] refactor: remove redundant ConcurrencySafetyLevel --- src/openenv/core/env_server/types.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 279726f6d..3c7d18b05 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -288,18 +288,6 @@ class WSErrorResponse(BaseModel): data: Dict[str, Any] = Field(description="Error details including message and code") -class ConcurrencySafetyLevel(str): - """ - Classification of environment concurrency safety. - - Environments are classified based on their ability to safely handle - multiple concurrent sessions within a single container. - """ - - UNSAFE = "unsafe" - SAFE = "safe" - - class ConcurrencyConfig(BaseModel): """Configuration for concurrent environment sessions.""" From c90cca06b614c242d770b2741044a03e093b6dc2 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 20:11:08 +0100 Subject: [PATCH 15/27] update web interface --- src/openenv/core/env_server/web_interface.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index b370cfa53..d1b527f14 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -255,6 +255,8 @@ def create_web_interface_app( action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, + max_concurrent_envs: int = 1, + concurrency_config: Optional[Any] = None, ) -> FastAPI: """ Create a FastAPI application with web interface for the given environment. @@ -264,14 +266,21 @@ def create_web_interface_app( action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). + Ignored if concurrency_config is provided. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + If provided, overrides max_concurrent_envs. Returns: FastAPI application instance with web interface """ from .http_server import create_fastapi_app - # Create the base environment app - app = create_fastapi_app(env, action_cls, observation_cls) + # Create the base environment app with concurrency settings + app = create_fastapi_app( + env, action_cls, observation_cls, + max_concurrent_envs, concurrency_config + ) # Load environment metadata metadata = load_environment_metadata(env, env_name) From f57b36f615061184374cafab290eaedf631d4a32 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 20:18:26 +0100 Subject: [PATCH 16/27] make web interface compatible with websockets --- src/openenv/core/env_server/web_interface.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index d1b527f14..404abba35 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -262,7 +262,7 @@ def create_web_interface_app( Create a FastAPI application with web interface for the given environment. Args: - env: The Environment instance to serve + env: The Environment instance, factory callable, or class to serve action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading @@ -282,11 +282,22 @@ def create_web_interface_app( max_concurrent_envs, concurrency_config ) + # If env is a class/factory, instantiate it for the web interface + # (the HTTPEnvServer in create_fastapi_app handles this separately) + if isinstance(env, Environment): + env_instance = env + elif callable(env): + env_instance = env() + else: + raise TypeError( + f"env must be an Environment instance or callable, got {type(env)}" + ) + # Load environment metadata - metadata = load_environment_metadata(env, env_name) + metadata = load_environment_metadata(env_instance, env_name) # Create web interface manager - web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + web_manager = WebInterfaceManager(env_instance, action_cls, observation_cls, metadata) # Add web interface routes @app.get("/web", response_class=HTMLResponse) From bd2a1636a1376ceccab0b38c8ae04ffee1650329 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 20:19:06 +0100 Subject: [PATCH 17/27] format --- src/openenv/core/env_server/web_interface.py | 69 +++++--------------- 1 file changed, 17 insertions(+), 52 deletions(-) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index 404abba35..119845177 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -26,9 +26,7 @@ from .types import Action, Observation, State, EnvironmentMetadata -def load_environment_metadata( - env: Environment, env_name: Optional[str] = None -) -> EnvironmentMetadata: +def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata: """ Load environment metadata including README content. @@ -106,9 +104,7 @@ class ActionLog(BaseModel): timestamp: str = Field(description="Timestamp when action was taken") action: Dict[str, Any] = Field(description="Action that was taken") observation: Dict[str, Any] = Field(description="Observation returned from action") - reward: Optional[float] = Field( - default=None, description="Reward received from action" - ) + reward: Optional[float] = Field(default=None, description="Reward received from action") done: bool = Field(description="Whether the episode is done after this action") step_count: int = Field(description="Step count when this action was taken") @@ -120,15 +116,9 @@ class EpisodeState(BaseModel): episode_id: Optional[str] = Field(default=None, description="Current episode ID") step_count: int = Field(description="Current step count in episode") - current_observation: Optional[Dict[str, Any]] = Field( - default=None, description="Current observation" - ) - action_logs: List[ActionLog] = Field( - default_factory=list, description="List of action logs" - ) - is_reset: bool = Field( - default=True, description="Whether the episode has been reset" - ) + current_observation: Optional[Dict[str, Any]] = Field(default=None, description="Current observation") + action_logs: List[ActionLog] = Field(default_factory=list, description="List of action logs") + is_reset: bool = Field(default=True, description="Whether the episode has been reset") class WebInterfaceManager: @@ -211,9 +201,7 @@ async def reset_environment(self) -> Dict[str, Any]: async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: """Execute a step in the environment and update state.""" # Deserialize action with preprocessing for web interface special cases - action: Action = deserialize_action_with_preprocessing( - action_data, self.action_cls - ) + action: Action = deserialize_action_with_preprocessing(action_data, self.action_cls) # Execute step observation: Observation = self.env.step(action) @@ -277,10 +265,7 @@ def create_web_interface_app( from .http_server import create_fastapi_app # Create the base environment app with concurrency settings - app = create_fastapi_app( - env, action_cls, observation_cls, - max_concurrent_envs, concurrency_config - ) + app = create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs, concurrency_config) # If env is a class/factory, instantiate it for the web interface # (the HTTPEnvServer in create_fastapi_app handles this separately) @@ -289,9 +274,7 @@ def create_web_interface_app( elif callable(env): env_instance = env() else: - raise TypeError( - f"env must be an Environment instance or callable, got {type(env)}" - ) + raise TypeError(f"env must be an Environment instance or callable, got {type(env)}") # Load environment metadata metadata = load_environment_metadata(env_instance, env_name) @@ -348,9 +331,7 @@ async def web_state(): return app -def get_web_interface_html( - action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None -) -> str: +def get_web_interface_html(action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None) -> str: """Generate the HTML for the web interface.""" # Check if this is a chat environment by looking for tokens field @@ -1332,9 +1313,7 @@ def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: return action_fields -def _determine_input_type_from_schema( - field_info: Dict[str, Any], field_name: str -) -> str: +def _determine_input_type_from_schema(field_info: Dict[str, Any], field_name: str) -> str: """Determine the appropriate HTML input type from JSON schema info.""" schema_type = field_info.get("type") @@ -1406,15 +1385,9 @@ def _markdown_to_html(markdown: str) -> str: html_content = html.escape(markdown) # Convert headers - html_content = re.sub( - r"^# (.*?)$", r"

\1

", html_content, flags=re.MULTILINE - ) - html_content = re.sub( - r"^## (.*?)$", r"

\1

", html_content, flags=re.MULTILINE - ) - html_content = re.sub( - r"^### (.*?)$", r"

\1

", html_content, flags=re.MULTILINE - ) + html_content = re.sub(r"^# (.*?)$", r"

\1

", html_content, flags=re.MULTILINE) + html_content = re.sub(r"^## (.*?)$", r"

\1

", html_content, flags=re.MULTILINE) + html_content = re.sub(r"^### (.*?)$", r"

\1

", html_content, flags=re.MULTILINE) # Convert code blocks html_content = re.sub( @@ -1430,12 +1403,8 @@ def _markdown_to_html(markdown: str) -> str: html_content = re.sub(r"\*(.*?)\*", r"\1", html_content) # Convert lists - html_content = re.sub( - r"^- (.*?)$", r"
  • \1
  • ", html_content, flags=re.MULTILINE - ) - html_content = re.sub( - r"(
  • .*
  • )", r"
      \1
    ", html_content, flags=re.DOTALL - ) + html_content = re.sub(r"^- (.*?)$", r"
  • \1
  • ", html_content, flags=re.MULTILINE) + html_content = re.sub(r"(
  • .*
  • )", r"
      \1
    ", html_content, flags=re.DOTALL) # Convert line breaks html_content = html_content.replace("\n", "
    ") @@ -1443,9 +1412,7 @@ def _markdown_to_html(markdown: str) -> str: return html_content -def _generate_action_interface( - action_fields: List[Dict[str, Any]], is_chat_env: bool -) -> str: +def _generate_action_interface(action_fields: List[Dict[str, Any]], is_chat_env: bool) -> str: """Generate either a chat interface or action form based on environment type.""" if is_chat_env: return _generate_chat_interface() @@ -1569,9 +1536,7 @@ def _generate_single_field(field: Dict[str, Any]) -> str: for choice in choices: selected = "selected" if str(choice) == str(default_value) else "" - options_html.append( - f'' - ) + options_html.append(f'') return f'''
    From 3e116f8c0526a361ea22db977ca5cb1be0b9c5b5 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 21:22:38 +0100 Subject: [PATCH 18/27] relative imports in template --- src/openenv/cli/templates/openenv_env/server/app.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/openenv/cli/templates/openenv_env/server/app.py b/src/openenv/cli/templates/openenv_env/server/app.py index 87e3db6dc..5100b1050 100644 --- a/src/openenv/cli/templates/openenv_env/server/app.py +++ b/src/openenv/cli/templates/openenv_env/server/app.py @@ -35,7 +35,8 @@ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'" ) from e -from __ENV_NAME__.models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation +# Import from local models.py (PYTHONPATH includes /app/env in Docker) +from models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation from .__ENV_NAME___environment import __ENV_CLASS_NAME__Environment From 25b7cfaf26e62a6495121482983adf285c00f21a Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 21:22:55 +0100 Subject: [PATCH 19/27] use pydantic in template --- src/openenv/cli/templates/openenv_env/models.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/models.py b/src/openenv/cli/templates/openenv_env/models.py index 64010449b..4540d5a29 100644 --- a/src/openenv/cli/templates/openenv_env/models.py +++ b/src/openenv/cli/templates/openenv_env/models.py @@ -10,22 +10,20 @@ The __ENV_NAME__ environment is a simple test environment that echoes back messages. """ -from dataclasses import dataclass +from pydantic import Field from openenv.core.env_server.types import Action, Observation -@dataclass(kw_only=True) class __ENV_CLASS_NAME__Action(Action): """Action for the __ENV_TITLE_NAME__ environment - just a message to echo.""" - message: str + message: str = Field(..., description="Message to echo back") -@dataclass(kw_only=True) class __ENV_CLASS_NAME__Observation(Observation): """Observation from the __ENV_TITLE_NAME__ environment - the echoed message.""" - echoed_message: str - message_length: int = 0 + echoed_message: str = Field(default="", description="The echoed message") + message_length: int = Field(default=0, description="Length of the echoed message") From 8f23dc42f175bcf6d7e9c774c91590d03b7be87b Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:29:07 +0530 Subject: [PATCH 20/27] rename to session_timeout --- src/openenv/core/env_server/http_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index fd25739b2..bc2a09040 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -135,7 +135,7 @@ def __init__( # Use legacy parameters self._concurrency_config = ConcurrencyConfig( max_concurrent_envs=max_concurrent_envs, - session_timeout_seconds=None, + session_timeout=None, reject_on_capacity=True, ) self._max_concurrent_envs = max_concurrent_envs @@ -549,7 +549,7 @@ async def step(request: StepRequest) -> StepResponse: Returns information about: - **max_concurrent_envs**: Maximum number of concurrent WebSocket sessions -- **session_timeout_seconds**: Timeout for inactive sessions (None if no timeout) +- **session_timeout**: Timeout in seconds for inactive sessions (None if no timeout) - **reject_on_capacity**: Whether to reject or queue connections at capacity """, ) From 0d56b834c142295795e1e410018892b16adb69ba Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:30:13 +0530 Subject: [PATCH 21/27] ConcurrencyConfig, ServerCapacityStatus, and SessionInfo inherit from BaseMessage --- src/openenv/core/env_server/types.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 3c7d18b05..0821437fc 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -288,21 +288,16 @@ class WSErrorResponse(BaseModel): data: Dict[str, Any] = Field(description="Error details including message and code") -class ConcurrencyConfig(BaseModel): +class ConcurrencyConfig(BaseMessage): """Configuration for concurrent environment sessions.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - max_concurrent_envs: int = Field( default=1, ge=1, le=1000, description="Maximum number of concurrent WebSocket sessions allowed", ) - session_timeout_seconds: Optional[float] = Field( + session_timeout: Optional[float] = Field( default=None, gt=0, description="Timeout in seconds for inactive sessions. None means no timeout.", @@ -313,14 +308,9 @@ class ConcurrencyConfig(BaseModel): ) -class ServerCapacityStatus(BaseModel): +class ServerCapacityStatus(BaseMessage): """Status of server capacity for concurrent sessions.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - active_sessions: int = Field( ge=0, description="Number of currently active sessions", @@ -349,14 +339,9 @@ def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": ) -class SessionInfo(BaseModel): +class SessionInfo(BaseMessage): """Information about an active session.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - session_id: str = Field(description="Unique identifier for the session") created_at: float = Field(description="Unix timestamp when the session was created") last_activity_at: float = Field( From 9cd2aacbba661f971152fcb17ea892fc1040a0a1 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:47:59 +0530 Subject: [PATCH 22/27] message classes to inherit from BaseMessage for shared config --- src/openenv/core/env_server/types.py | 48 ++++++++-------------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 0821437fc..4d0cacb70 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -127,6 +127,15 @@ class StepResponse(BaseModel): done: bool = Field(default=False, description="Whether the episode has terminated") +class BaseMessage(BaseModel): + """Base class for WebSocket messages with shared configuration.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + class State(BaseModel): """Base class for environment state. @@ -149,27 +158,17 @@ class State(BaseModel): ) -class CodeExecResult(BaseModel): +class CodeExecResult(BaseMessage): """Result of code execution containing stdout, stderr, and exit code.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - stdout: str = Field(description="Standard output from code execution") stderr: str = Field(description="Standard error from code execution") exit_code: int = Field(description="Exit code from code execution") -class EnvironmentMetadata(BaseModel): +class EnvironmentMetadata(BaseMessage): """Metadata about an environment for documentation and UI purposes.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - name: str = Field(description="Name of the environment") description: str = Field(description="Description of what the environment does") readme_content: Optional[str] = Field( @@ -184,14 +183,9 @@ class EnvironmentMetadata(BaseModel): ) -class SchemaResponse(BaseModel): +class SchemaResponse(BaseMessage): """Response model for the combined schema endpoint.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - action: Dict[str, Any] = Field( description="JSON schema for actions accepted by this environment" ) @@ -203,26 +197,12 @@ class SchemaResponse(BaseModel): ) -class HealthResponse(BaseModel): +class HealthResponse(BaseMessage): """Response model for health check endpoint.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - status: str = Field(description="Health status of the environment server") -class BaseMessage(BaseModel): - """Base class for WebSocket messages with shared configuration.""" - - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - - class WSResetMessage(BaseMessage): """WebSocket message to reset the environment.""" @@ -257,7 +237,7 @@ class WSCloseMessage(BaseMessage): # Discriminated union for incoming WebSocket messages WSIncomingMessage = Annotated[ WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage, - Field(discriminator="type") + Field(discriminator="type"), ] From 77a8c832bbe68a3a2e9d2f7528bc97219c4725f0 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 11 Dec 2025 22:19:45 +0530 Subject: [PATCH 23/27] refactor: rename CONCURRENCY_SAFE to SUPPORTS_CONCURRENT_SESSIONS --- .../server/__ENV_NAME___environment.py | 2 +- src/openenv/core/env_server/exceptions.py | 6 +- src/openenv/core/env_server/http_server.py | 283 +++++++++++------- src/openenv/core/env_server/interfaces.py | 2 +- src/openenv/core/env_server/types.py | 33 +- 5 files changed, 197 insertions(+), 129 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py index 72db6472f..454ea6808 100644 --- a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py +++ b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py @@ -40,7 +40,7 @@ class __ENV_CLASS_NAME__Environment(Environment): # Set to True if your environment isolates state between instances. # When True, multiple WebSocket clients can connect simultaneously, each # getting their own environment instance (when using factory mode in app.py). - CONCURRENCY_SAFE: bool = True + SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): """Initialize the __ENV_NAME__ environment.""" diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py index 41a8235bb..a16715721 100644 --- a/src/openenv/core/env_server/exceptions.py +++ b/src/openenv/core/env_server/exceptions.py @@ -20,7 +20,7 @@ class ConcurrencyConfigurationError(OpenEnvError): Raised when an environment is misconfigured for concurrent sessions. This error is raised during server startup when max_concurrent_envs > 1 - is specified for an environment that is not marked as CONCURRENCY_SAFE. + is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. """ def __init__( @@ -34,10 +34,10 @@ def __init__( if message is None: message = ( - f"Environment '{environment_name}' is not marked as CONCURRENCY_SAFE. " + f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. " f"Cannot run with max_concurrent_envs={max_concurrent_envs}. " f"Either set max_concurrent_envs=1 or ensure the environment " - f"properly isolates session state and set CONCURRENCY_SAFE=True." + f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True." ) super().__init__(message) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index bc2a09040..3752bb50a 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -20,7 +20,7 @@ import os import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union, cast from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect, status from pydantic import ValidationError @@ -113,10 +113,10 @@ def __init__( concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. If provided, overrides max_concurrent_envs and allows configuration of session timeout and capacity behavior. - + Raises: ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an - environment that is not marked as CONCURRENCY_SAFE. + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. """ # Validate that env is callable if not callable(env): @@ -124,9 +124,9 @@ def __init__( f"env must be a callable (class or factory function), got {type(env)}. " f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())." ) - + self._env_factory: Callable[[], Environment] = env - + # Handle concurrency configuration if concurrency_config is not None: self._concurrency_config = concurrency_config @@ -139,51 +139,63 @@ def __init__( reject_on_capacity=True, ) self._max_concurrent_envs = max_concurrent_envs - + self._skip_concurrency_check = skip_concurrency_check or os.getenv( "OPENENV_SKIP_CONCURRENCY_CHECK", "" ).lower() in ("1", "true", "yes") - + self.env = env() - + # Validate concurrency configuration self._validate_concurrency_safety() - + self.action_cls = action_cls self.observation_cls = observation_cls - + # Session management for WebSocket connections self._sessions: Dict[str, Environment] = {} self._session_executors: Dict[str, ThreadPoolExecutor] = {} self._session_info: Dict[str, SessionInfo] = {} self._session_lock = asyncio.Lock() - + # Create thread pool for running sync code in async context # This is needed for environments using sync libraries (e.g., Playwright) # Configurable via OPENENV_THREAD_POOL_SIZE (default: 32) pool_size = int(os.getenv("OPENENV_THREAD_POOL_SIZE", "32")) self._executor = ThreadPoolExecutor(max_workers=pool_size) - # Check if environment has async methods for better concurrency - self._has_step_async = hasattr(env, "step_async") and asyncio.iscoroutinefunction(env.step_async) - self._has_reset_async = hasattr(env, "reset_async") and asyncio.iscoroutinefunction(env.reset_async) + self._reset_async: Optional[Callable[..., Awaitable[Observation]]] = None + if hasattr(self.env, "reset_async"): + reset_method = getattr(self.env, "reset_async") + if asyncio.iscoroutinefunction(reset_method): + self._reset_async = cast( + Callable[..., Awaitable[Observation]], reset_method + ) + + self._step_async: Optional[Callable[..., Awaitable[Observation]]] = None + if hasattr(self.env, "step_async"): + step_method = getattr(self.env, "step_async") + if asyncio.iscoroutinefunction(step_method): + self._step_async = cast( + Callable[..., Awaitable[Observation]], step_method + ) def _validate_concurrency_safety(self) -> None: """ Validate that the environment supports the configured concurrency level. - + Raises: ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an - environment that is not marked as CONCURRENCY_SAFE. + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. """ if self._max_concurrent_envs <= 1: return - + if self._skip_concurrency_check: return - - is_concurrency_safe = getattr(self.env, "CONCURRENCY_SAFE", False) - + + is_concurrency_safe = getattr(self.env, "SUPPORTS_CONCURRENT_SESSIONS", False) + if not is_concurrency_safe: env_name = type(self.env).__name__ raise ConcurrencyConfigurationError( @@ -194,7 +206,7 @@ def _validate_concurrency_safety(self) -> None: def get_capacity_status(self) -> ServerCapacityStatus: """ Get the current capacity status of the server. - + Returns: ServerCapacityStatus with current session counts and availability. """ @@ -203,19 +215,28 @@ def get_capacity_status(self) -> ServerCapacityStatus: max_sessions=self._max_concurrent_envs, ) - async def _run_sync_in_thread_pool(self, func, *args, **kwargs): + async def _run_sync_in_thread_pool( + self, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: """Run a synchronous function in the thread pool executor.""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs)) - def _get_valid_kwargs(self, sig, kwargs, skip_params=None): + def _get_valid_kwargs( + self, + sig: inspect.Signature, + kwargs: Dict[str, Any], + skip_params: Optional[set[str]] = None, + ) -> Dict[str, Any]: """Filter kwargs to only include parameters accepted by the function signature.""" if skip_params is None: skip_params = set() valid_kwargs = {} - has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + has_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) for k, v in kwargs.items(): if k in sig.parameters or has_kwargs: @@ -227,16 +248,16 @@ def _get_valid_kwargs(self, sig, kwargs, skip_params=None): async def _create_session(self) -> tuple[str, Environment]: """ Create a new WebSocket session with its own environment instance. - + Returns: Tuple of (session_id, environment) - + Raises: SessionCapacityError: If max concurrent sessions reached EnvironmentFactoryError: If the factory fails to create an environment """ import time - + async with self._session_lock: if len(self._sessions) >= self._max_concurrent_envs: if self._concurrency_config.reject_on_capacity: @@ -251,17 +272,16 @@ async def _create_session(self) -> tuple[str, Environment]: max_sessions=self._max_concurrent_envs, message="Session queuing not yet implemented", ) - + session_id = str(uuid.uuid4()) current_time = time.time() - + env = self._env_factory() - + self._sessions[session_id] = env - - # Create dedicated executor for this session + self._session_executors[session_id] = ThreadPoolExecutor(max_workers=1) - + # Track session metadata self._session_info[session_id] = SessionInfo( session_id=session_id, @@ -270,73 +290,74 @@ async def _create_session(self) -> tuple[str, Environment]: step_count=0, environment_type=type(env).__name__, ) - + return session_id, env - + async def _destroy_session(self, session_id: str) -> None: """ Destroy a WebSocket session and cleanup resources. - + Args: session_id: The session ID to destroy """ async with self._session_lock: if session_id in self._sessions: env = self._sessions.pop(session_id) - # Call close() if environment has it - if hasattr(env, 'close') and callable(env.close): + if hasattr(env, "close") and callable(getattr(env, "close")): try: - env.close() + getattr(env, "close")() except Exception: - pass # Best effort cleanup - + pass + if session_id in self._session_executors: executor = self._session_executors.pop(session_id) executor.shutdown(wait=False) - + # Remove session metadata self._session_info.pop(session_id, None) - - def _update_session_activity(self, session_id: str, increment_step: bool = False) -> None: + + def _update_session_activity( + self, session_id: str, increment_step: bool = False + ) -> None: """ Update session activity timestamp and optionally increment step count. - + Args: session_id: The session ID to update increment_step: If True, increment the step count """ import time - + if session_id in self._session_info: self._session_info[session_id].last_activity_at = time.time() if increment_step: self._session_info[session_id].step_count += 1 - + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: """ Get information about a specific session. - + Args: session_id: The session ID to query - + Returns: SessionInfo if the session exists, None otherwise """ return self._session_info.get(session_id) async def _run_in_session_executor( - self, session_id: str, func: Callable, *args, **kwargs - ) -> Any: + self, session_id: str, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: """Run a synchronous function in the session's thread pool executor.""" executor = self._session_executors.get(session_id, self._executor) loop = asyncio.get_event_loop() return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) - + @property def active_sessions(self) -> int: """Return the number of active WebSocket sessions.""" return len(self._sessions) - + @property def max_concurrent_envs(self) -> int: """Return the maximum number of concurrent environments.""" @@ -345,7 +366,7 @@ def max_concurrent_envs(self) -> int: @property def is_concurrency_safe(self) -> bool: """Return whether the environment is marked as concurrency safe.""" - return getattr(self.env, "CONCURRENCY_SAFE", False) + return getattr(self.env, "SUPPORTS_CONCURRENT_SESSIONS", False) @property def concurrency_config(self) -> ConcurrencyConfig: @@ -369,18 +390,18 @@ async def reset_handler( # Start with all fields from the request, including extra ones kwargs = request.model_dump(exclude_unset=True) - # Pass arguments only if environment accepts them - if self._has_reset_async: - sig = inspect.signature(self.env.reset_async) + if self._reset_async: + sig = inspect.signature(self._reset_async) else: sig = inspect.signature(self.env.reset) valid_kwargs = self._get_valid_kwargs(sig, kwargs) - # Use async method if available for better concurrency - if self._has_reset_async: - observation = await self.env.reset_async(**valid_kwargs) + if self._reset_async: + observation = await self._reset_async(**valid_kwargs) else: - observation = await self._run_sync_in_thread_pool(self.env.reset, **valid_kwargs) + observation = await self._run_sync_in_thread_pool( + self.env.reset, **valid_kwargs + ) return ResetResponse(**serialize_observation(observation)) # Helper function to handle step endpoint @@ -393,24 +414,26 @@ async def step_handler(request: StepRequest) -> StepResponse: action = deserialize_action(action_data, self.action_cls) except ValidationError as e: # Return HTTP 422 with detailed validation errors - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors()) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() + ) # Handle optional parameters # Start with all fields from the request, including extra ones, but exclude 'action' kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) - # Pass arguments only if environment accepts them - if self._has_step_async: - sig = inspect.signature(self.env.step_async) + if self._step_async: + sig = inspect.signature(self._step_async) else: sig = inspect.signature(self.env.step) valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) - # Use async method if available for better concurrency - if self._has_step_async: - observation = await self.env.step_async(action, **valid_kwargs) + if self._step_async: + observation = await self._step_async(action, **valid_kwargs) else: - observation = await self._run_sync_in_thread_pool(self.env.step, action, **valid_kwargs) + observation = await self._run_sync_in_thread_pool( + self.env.step, action, **valid_kwargs + ) # Return serialized observation return StepResponse(**serialize_observation(observation)) @@ -611,38 +634,41 @@ async def get_schemas() -> SchemaResponse: async def websocket_endpoint(websocket: WebSocket): """ WebSocket endpoint for persistent environment sessions. - + Each WebSocket connection gets its own environment instance (when using factory mode) or shares the single instance (backward compatible mode). - + Message Protocol: - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse """ await websocket.accept() - + session_id = None session_env = None - + try: # Create session with dedicated environment session_id, session_env = await self._create_session() - + while True: # Receive message from client raw_message = await websocket.receive_text() - + try: message_dict = json.loads(raw_message) except json.JSONDecodeError as e: error_resp = WSErrorResponse( - data={"message": f"Invalid JSON: {e}", "code": "INVALID_JSON"} + data={ + "message": f"Invalid JSON: {e}", + "code": "INVALID_JSON", + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + msg_type = message_dict.get("type", "") - + try: if msg_type == "reset": # Parse and validate reset message @@ -650,105 +676,130 @@ async def websocket_endpoint(websocket: WebSocket): msg = WSResetMessage(**message_dict) except ValidationError as e: error_resp = WSErrorResponse( - data={"message": "Invalid reset message", "code": "VALIDATION_ERROR", "errors": e.errors()} + data={ + "message": "Invalid reset message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + # Handle reset sig = inspect.signature(session_env.reset) valid_kwargs = self._get_valid_kwargs(sig, msg.data) - + observation = await self._run_in_session_executor( session_id, session_env.reset, **valid_kwargs ) - + self._update_session_activity(session_id) - + response = WSObservationResponse( data=serialize_observation(observation) ) await websocket.send_text(response.model_dump_json()) - + elif msg_type == "step": # Parse and validate step message try: msg = WSStepMessage(**message_dict) except ValidationError as e: error_resp = WSErrorResponse( - data={"message": "Invalid step message", "code": "VALIDATION_ERROR", "errors": e.errors()} + data={ + "message": "Invalid step message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + # Deserialize action with Pydantic validation try: action = deserialize_action(msg.data, self.action_cls) except ValidationError as e: error_resp = WSErrorResponse( - data={"message": str(e), "code": "VALIDATION_ERROR", "errors": e.errors()} + data={ + "message": str(e), + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + observation = await self._run_in_session_executor( session_id, session_env.step, action ) - - self._update_session_activity(session_id, increment_step=True) - + + self._update_session_activity( + session_id, increment_step=True + ) + response = WSObservationResponse( data=serialize_observation(observation) ) await websocket.send_text(response.model_dump_json()) - + elif msg_type == "state": # Parse and validate state message try: msg = WSStateMessage(**message_dict) except ValidationError as e: error_resp = WSErrorResponse( - data={"message": "Invalid state message", "code": "VALIDATION_ERROR", "errors": e.errors()} + data={ + "message": "Invalid state message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + # Handle state request state = session_env.state - if hasattr(state, 'model_dump'): + if hasattr(state, "model_dump"): state_data = state.model_dump() else: state_data = dict(state) if state else {} - + response = WSStateResponse(data=state_data) await websocket.send_text(response.model_dump_json()) - + elif msg_type == "close": # Parse and validate close message try: msg = WSCloseMessage(**message_dict) except ValidationError as e: error_resp = WSErrorResponse( - data={"message": "Invalid close message", "code": "VALIDATION_ERROR", "errors": e.errors()} + data={ + "message": "Invalid close message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + # Client requested close break - + else: error_resp = WSErrorResponse( - data={"message": f"Unknown message type: {msg_type}", "code": "UNKNOWN_TYPE"} + data={ + "message": f"Unknown message type: {msg_type}", + "code": "UNKNOWN_TYPE", + } ) await websocket.send_text(error_resp.model_dump_json()) - + except Exception as e: error_resp = WSErrorResponse( data={"message": str(e), "code": "EXECUTION_ERROR"} ) await websocket.send_text(error_resp.model_dump_json()) - + except WebSocketDisconnect: pass except SessionCapacityError as e: @@ -834,14 +885,17 @@ def create_app( from .web_interface import create_web_interface_app return create_web_interface_app( - env, action_cls, observation_cls, env_name, - max_concurrent_envs, concurrency_config + env, + action_cls, + observation_cls, + env_name, + max_concurrent_envs, + concurrency_config, ) else: # Use standard FastAPI app without web interface return create_fastapi_app( - env, action_cls, observation_cls, - max_concurrent_envs, concurrency_config + env, action_cls, observation_cls, max_concurrent_envs, concurrency_config ) @@ -854,7 +908,7 @@ def create_fastapi_app( ) -> FastAPI: """ Create a FastAPI application with comprehensive documentation. - + Args: env: Environment factory (callable or class) that creates new instances action_cls: The Action subclass this environment expects @@ -863,14 +917,16 @@ def create_fastapi_app( Ignored if concurrency_config is provided. concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. If provided, overrides max_concurrent_envs. - + Returns: FastAPI application instance """ try: from fastapi import FastAPI except ImportError: - raise ImportError("FastAPI is required. Install with: pip install fastapi uvicorn") + raise ImportError( + "FastAPI is required. Install with: pip install fastapi uvicorn" + ) app = FastAPI( title="OpenEnv Environment HTTP API", @@ -933,8 +989,11 @@ def create_fastapi_app( ) server = HTTPEnvServer( - env, action_cls, observation_cls, - max_concurrent_envs, concurrency_config=concurrency_config + env, + action_cls, + observation_cls, + max_concurrent_envs, + concurrency_config=concurrency_config, ) server.register_routes(app) return app diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index 196e7ac82..f147589d3 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -104,7 +104,7 @@ class Environment(ABC): """ # Class-level flag indicating whether this environment supports concurrent sessions - CONCURRENCY_SAFE: bool = False + SUPPORTS_CONCURRENT_SESSIONS: bool = False def __init__(self, transform: Transform | None = None): self.transform = transform diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 4d0cacb70..8993d280c 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Any, Dict, Optional, Union, Literal, Annotated -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field, ConfigDict, model_validator # Type aliases @@ -299,23 +299,32 @@ class ServerCapacityStatus(BaseMessage): ge=1, description="Maximum number of allowed sessions", ) - available_slots: int = Field( - ge=0, - description="Number of available session slots", - ) - is_at_capacity: bool = Field( - description="Whether the server has reached maximum capacity", - ) + + @model_validator(mode="after") + def check_capacity_bounds(self) -> "ServerCapacityStatus": + if self.active_sessions > self.max_sessions: + raise ValueError( + f"active_sessions ({self.active_sessions}) cannot exceed " + f"max_sessions ({self.max_sessions})" + ) + return self + + @property + def available_slots(self) -> int: + """Number of available session slots.""" + return self.max_sessions - self.active_sessions + + @property + def is_at_capacity(self) -> bool: + """Whether the server has reached maximum capacity.""" + return self.available_slots == 0 @classmethod def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": """Create status from active and max session counts.""" - available = max(0, max_sessions - active) return cls( active_sessions=active, max_sessions=max_sessions, - available_slots=available, - is_at_capacity=active >= max_sessions, ) @@ -333,5 +342,5 @@ class SessionInfo(BaseMessage): description="Number of steps executed in this session", ) environment_type: str = Field( - description="Type name of the environment class for this session" + description="Environment type for this session (e.g. `CodingEnv`)" ) From 86a222da891403f4088be524952b803c9be64c7b Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Fri, 12 Dec 2025 21:37:08 +0530 Subject: [PATCH 24/27] refactor: core types for better inference and fix async detection - Genericize `Environment`, `HTTPEnvClient`, and `WebSocketEnvClient` with `ActT`, `ObsT`, and `StateT` to improve type inference in IDEs. - Update client methods to use `Dict[str, Any]` for stricter typing of JSON payloads. - Remove conditional `websockets` import in `ws_env_client.py` and simplify connection logic. - Fix async method detection in `HTTPEnvServer` to correctly handle factory functions and avoid unnecessary instantiation duringRefactor core types for better inference and fix async detection - Genericize `Environment`, `HTTPEnvClient`, and `WebSocketEnvClient` with `ActT`, `ObsT`, and `StateT` to improve type inference in IDEs. - Update client methods to use `Dict[str, Any]` for stricter typing of JSON payloads. - Remove conditional `websockets` import in `ws_env_client.py` and simplify connection logic. - Fix async method detection in `HTTPEnvServer` to correctly handle factory functions and avoid unnecessary instantiation during --- src/openenv/core/client_types.py | 5 +- src/openenv/core/env_server/http_server.py | 236 +++++++++++---------- src/openenv/core/env_server/interfaces.py | 56 ++++- src/openenv/core/http_env_client.py | 12 +- src/openenv/core/ws_env_client.py | 25 +-- 5 files changed, 186 insertions(+), 148 deletions(-) diff --git a/src/openenv/core/client_types.py b/src/openenv/core/client_types.py index 8808e96bf..c7501c656 100644 --- a/src/openenv/core/client_types.py +++ b/src/openenv/core/client_types.py @@ -1,9 +1,10 @@ # Type definitions for EnvTorch from dataclasses import dataclass -from typing import Any, Generic, Optional, TypeVar +from typing import Generic, Optional, TypeVar # Generic type for observations -ObsT = TypeVar("ObsT") # TypeVar for typehinting in IDEs +ObsT = TypeVar("ObsT") +StateT = TypeVar("StateT") @dataclass diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 3752bb50a..56b73b3fa 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -20,7 +20,7 @@ import os import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union, cast +from typing import Any, Callable, Dict, Optional, Type, Union from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect, status from pydantic import ValidationError @@ -75,10 +75,13 @@ class HTTPEnvServer: Example: >>> from core.env_server import HTTPEnvServer >>> from envs.coding_env.server import CodeExecutionEnvironment + >>> from envs.coding_env.models import CodeAction, CodeObservation >>> >>> # Pass environment class (factory pattern) >>> server = HTTPEnvServer( ... env=CodeExecutionEnvironment, + ... action_cls=CodeAction, + ... observation_cls=CodeObservation, ... max_concurrent_envs=4, ... ) >>> @@ -144,8 +147,6 @@ def __init__( "OPENENV_SKIP_CONCURRENCY_CHECK", "" ).lower() in ("1", "true", "yes") - self.env = env() - # Validate concurrency configuration self._validate_concurrency_safety() @@ -164,22 +165,6 @@ def __init__( pool_size = int(os.getenv("OPENENV_THREAD_POOL_SIZE", "32")) self._executor = ThreadPoolExecutor(max_workers=pool_size) - self._reset_async: Optional[Callable[..., Awaitable[Observation]]] = None - if hasattr(self.env, "reset_async"): - reset_method = getattr(self.env, "reset_async") - if asyncio.iscoroutinefunction(reset_method): - self._reset_async = cast( - Callable[..., Awaitable[Observation]], reset_method - ) - - self._step_async: Optional[Callable[..., Awaitable[Observation]]] = None - if hasattr(self.env, "step_async"): - step_method = getattr(self.env, "step_async") - if asyncio.iscoroutinefunction(step_method): - self._step_async = cast( - Callable[..., Awaitable[Observation]], step_method - ) - def _validate_concurrency_safety(self) -> None: """ Validate that the environment supports the configured concurrency level. @@ -194,10 +179,17 @@ def _validate_concurrency_safety(self) -> None: if self._skip_concurrency_check: return - is_concurrency_safe = getattr(self.env, "SUPPORTS_CONCURRENT_SESSIONS", False) + if inspect.isclass(self._env_factory): + is_concurrency_safe = getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) + env_name = self._env_factory.__name__ + else: + _temp_env = self._env_factory() + is_concurrency_safe = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) + env_name = type(_temp_env).__name__ + _temp_env.close() + del _temp_env if not is_concurrency_safe: - env_name = type(self.env).__name__ raise ConcurrencyConfigurationError( environment_name=env_name, max_concurrent_envs=self._max_concurrent_envs, @@ -303,17 +295,12 @@ async def _destroy_session(self, session_id: str) -> None: async with self._session_lock: if session_id in self._sessions: env = self._sessions.pop(session_id) - if hasattr(env, "close") and callable(getattr(env, "close")): - try: - getattr(env, "close")() - except Exception: - pass + env.close() if session_id in self._session_executors: executor = self._session_executors.pop(session_id) executor.shutdown(wait=False) - # Remove session metadata self._session_info.pop(session_id, None) def _update_session_activity( @@ -366,7 +353,15 @@ def max_concurrent_envs(self) -> int: @property def is_concurrency_safe(self) -> bool: """Return whether the environment is marked as concurrency safe.""" - return getattr(self.env, "SUPPORTS_CONCURRENT_SESSIONS", False) + import inspect + if inspect.isclass(self._env_factory): + return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) + else: + _temp_env = self._env_factory() + result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) + _temp_env.close() + del _temp_env + return result @property def concurrency_config(self) -> ConcurrencyConfig: @@ -386,57 +381,64 @@ async def reset_handler( request: ResetRequest = Body(default_factory=ResetRequest), ) -> ResetResponse: """Reset endpoint - returns initial observation.""" - # Handle optional parameters - # Start with all fields from the request, including extra ones - kwargs = request.model_dump(exclude_unset=True) - - if self._reset_async: - sig = inspect.signature(self._reset_async) - else: - sig = inspect.signature(self.env.reset) - valid_kwargs = self._get_valid_kwargs(sig, kwargs) - - if self._reset_async: - observation = await self._reset_async(**valid_kwargs) - else: - observation = await self._run_sync_in_thread_pool( - self.env.reset, **valid_kwargs - ) - return ResetResponse(**serialize_observation(observation)) + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True) + + is_async = _env.reset_async.__func__ is not Environment.reset_async + + if is_async: + sig = inspect.signature(_env.reset_async) + else: + sig = inspect.signature(_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, kwargs) + + if is_async: + observation = await _env.reset_async(**valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.reset, **valid_kwargs + ) + return ResetResponse(**serialize_observation(observation)) + finally: + _env.close() # Helper function to handle step endpoint async def step_handler(request: StepRequest) -> StepResponse: """Step endpoint - executes action and returns observation.""" action_data = request.action - # Deserialize action with Pydantic validation try: action = deserialize_action(action_data, self.action_cls) except ValidationError as e: - # Return HTTP 422 with detailed validation errors raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() ) - # Handle optional parameters - # Start with all fields from the request, including extra ones, but exclude 'action' - kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) - - if self._step_async: - sig = inspect.signature(self._step_async) - else: - sig = inspect.signature(self.env.step) - valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) - - if self._step_async: - observation = await self._step_async(action, **valid_kwargs) - else: - observation = await self._run_sync_in_thread_pool( - self.env.step, action, **valid_kwargs - ) + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) + + is_async = _env.step_async.__func__ is not Environment.step_async + + if is_async: + sig = inspect.signature(_env.step_async) + else: + sig = inspect.signature(_env.step) + valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) + + if is_async: + observation = await _env.step_async(action, **valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.step, action, **valid_kwargs + ) - # Return serialized observation - return StepResponse(**serialize_observation(observation)) + return StepResponse(**serialize_observation(observation)) + finally: + _env.close() # Register routes using the helpers @app.post( @@ -522,24 +524,36 @@ async def reset( async def step(request: StepRequest) -> StepResponse: return await step_handler(request) - # Configure and register GET endpoints declaratively + def get_state_handler() -> State: + _env = self._env_factory() + try: + return _env.state + finally: + _env.close() + + def get_metadata_handler() -> EnvironmentMetadata: + _env = self._env_factory() + try: + return _env.get_metadata() + finally: + _env.close() + get_endpoints = [ GetEndpointConfig( path="/state", - handler=lambda: self.env.state, + handler=get_state_handler, response_model=State, tag="State Management", summary="Get current environment state", description=""" Retrieve the current internal state of the environment. -This endpoint allows inspection of the environment state without modifying it. The structure of the state object is defined by the environment's State model. """, ), GetEndpointConfig( path="/metadata", - handler=self.env.get_metadata, + handler=get_metadata_handler, response_model=EnvironmentMetadata, tag="Environment Info", summary="Get environment metadata", @@ -686,12 +700,18 @@ async def websocket_endpoint(websocket: WebSocket): continue # Handle reset - sig = inspect.signature(session_env.reset) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) + is_async = session_env.reset_async.__func__ is not Environment.reset_async - observation = await self._run_in_session_executor( - session_id, session_env.reset, **valid_kwargs - ) + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await session_env.reset_async(**valid_kwargs) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) self._update_session_activity(session_id) @@ -729,9 +749,14 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(error_resp.model_dump_json()) continue - observation = await self._run_in_session_executor( - session_id, session_env.step, action - ) + is_async = session_env.step_async.__func__ is not Environment.step_async + + if is_async: + observation = await session_env.step_async(action) + else: + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) self._update_session_activity( session_id, increment_step=True @@ -803,46 +828,33 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass except SessionCapacityError as e: - try: - error_resp = WSErrorResponse( - data={ - "message": str(e), - "code": "CAPACITY_REACHED", - "active_sessions": e.active_sessions, - "max_sessions": e.max_sessions, - } - ) - await websocket.send_text(error_resp.model_dump_json()) - except Exception: - pass + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "CAPACITY_REACHED", + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + } + ) + await websocket.send_text(error_resp.model_dump_json()) except EnvironmentFactoryError as e: - try: - error_resp = WSErrorResponse( - data={ - "message": str(e), - "code": "FACTORY_ERROR", - "factory_name": e.factory_name, - } - ) - await websocket.send_text(error_resp.model_dump_json()) - except Exception: - pass + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "FACTORY_ERROR", + "factory_name": e.factory_name, + } + ) + await websocket.send_text(error_resp.model_dump_json()) except Exception as e: - try: - error_resp = WSErrorResponse( - data={"message": str(e), "code": "SESSION_ERROR"} - ) - await websocket.send_text(error_resp.model_dump_json()) - except Exception: - pass + error_resp = WSErrorResponse( + data={"message": str(e), "code": "SESSION_ERROR"} + ) + await websocket.send_text(error_resp.model_dump_json()) finally: - # Cleanup session if session_id: await self._destroy_session(session_id) - try: - await websocket.close() - except Exception: - pass + await websocket.close() def create_app( diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index f147589d3..03f1ddb21 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -5,10 +5,14 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Optional, Protocol, TypedDict +from typing import Any, Generic, Optional, Protocol, TypedDict, TypeVar from .types import Action, Observation, State, EnvironmentMetadata +ActT = TypeVar("ActT", bound=Action) +ObsT = TypeVar("ObsT", bound=Observation) +StateT = TypeVar("StateT", bound=State) + class Message(TypedDict): """A message in a conversation. @@ -64,7 +68,7 @@ def decode( ... -class Transform(ABC): +class Transform(ABC, Generic[ObsT]): """Transform observations to add rewards, metrics, or other modifications. Transforms follow the TorchRL pattern where they take an observation @@ -73,7 +77,7 @@ class Transform(ABC): """ @abstractmethod - def __call__(self, observation: Observation) -> Observation: + def __call__(self, observation: ObsT) -> ObsT: """Transform an observation. Args: @@ -85,7 +89,7 @@ def __call__(self, observation: Observation) -> Observation: pass -class Environment(ABC): +class Environment(ABC, Generic[ActT, ObsT, StateT]): """Base class for all environment servers following Gym/Gymnasium API. Args: @@ -106,7 +110,7 @@ class Environment(ABC): # Class-level flag indicating whether this environment supports concurrent sessions SUPPORTS_CONCURRENT_SESSIONS: bool = False - def __init__(self, transform: Transform | None = None): + def __init__(self, transform: Optional[Transform[ObsT]] = None): self.transform = transform @abstractmethod @@ -115,23 +119,47 @@ def reset( seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, - ) -> Observation: + ) -> ObsT: """Reset the environment and return initial observation.""" pass + async def reset_async( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of reset. Default implementation calls sync reset. + + Override to provide true async implementation. + """ + return self.reset(seed=seed, episode_id=episode_id, **kwargs) + @abstractmethod def step( self, - action: Action, + action: ActT, timeout_s: Optional[float] = None, **kwargs: Any, - ) -> Observation: + ) -> ObsT: """Take a step in the environment.""" pass + async def step_async( + self, + action: ActT, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of step. Default implementation calls sync step. + + Override to provide true async implementation. + """ + return self.step(action, timeout_s=timeout_s, **kwargs) + @property @abstractmethod - def state(self) -> State: + def state(self) -> StateT: """Get the current environment state.""" pass @@ -151,8 +179,16 @@ def get_metadata(self) -> EnvironmentMetadata: version="1.0.0", ) - def _apply_transform(self, observation: Observation) -> Observation: + def _apply_transform(self, observation: ObsT) -> ObsT: """Apply transform if one is provided.""" if self.transform is not None: return self.transform(observation) return observation + + def close(self) -> None: + """Clean up resources used by the environment. + + Override this method to implement custom cleanup logic. + Called when the environment is being destroyed or reset. + """ + pass diff --git a/src/openenv/core/http_env_client.py b/src/openenv/core/http_env_client.py index 007ef6a5f..0f25363d4 100644 --- a/src/openenv/core/http_env_client.py +++ b/src/openenv/core/http_env_client.py @@ -16,7 +16,7 @@ import requests -from .client_types import StepResult +from .client_types import StepResult, StateT from .containers.runtime import LocalDockerProvider if TYPE_CHECKING: @@ -27,7 +27,7 @@ EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient") -class HTTPEnvClient(ABC, Generic[ActT, ObsT]): +class HTTPEnvClient(ABC, Generic[ActT, ObsT, StateT]): def __init__( self, base_url: str, @@ -129,17 +129,17 @@ def from_hub( return cls.from_docker_image(image=base_url, provider=provider) @abstractmethod - def _step_payload(self, action: ActT) -> dict: + def _step_payload(self, action: ActT) -> Dict[str, Any]: """Convert an Action object to the JSON body expected by the env server.""" raise NotImplementedError @abstractmethod - def _parse_result(self, payload: dict) -> StepResult[ObsT]: + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: """Convert a JSON response from the env server to StepResult[ObsT].""" raise NotImplementedError @abstractmethod - def _parse_state(self, payload: dict) -> Any: + def _parse_state(self, payload: Dict[str, Any]) -> StateT: """Convert a JSON response from the state endpoint to a State object.""" raise NotImplementedError @@ -203,7 +203,7 @@ def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: r.raise_for_status() return self._parse_result(r.json()) - def state(self) -> Any: + def state(self) -> StateT: """ Get the current environment state from the server. diff --git a/src/openenv/core/ws_env_client.py b/src/openenv/core/ws_env_client.py index c6f054e85..6c1d6a4ab 100644 --- a/src/openenv/core/ws_env_client.py +++ b/src/openenv/core/ws_env_client.py @@ -18,26 +18,21 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar -from .client_types import StepResult +from .client_types import StepResult, StateT from .containers.runtime import LocalDockerProvider if TYPE_CHECKING: from .containers.runtime import ContainerProvider from websockets.sync.client import ClientConnection -try: - import websockets - from websockets.sync.client import connect as ws_connect -except ImportError: - websockets = None # type: ignore - ws_connect = None # type: ignore +from websockets.sync.client import connect as ws_connect ActT = TypeVar("ActT") ObsT = TypeVar("ObsT") WSEnvClientT = TypeVar("WSEnvClientT", bound="WebSocketEnvClient") -class WebSocketEnvClient(ABC, Generic[ActT, ObsT]): +class WebSocketEnvClient(ABC, Generic[ActT, ObsT, StateT]): """ WebSocket-based environment client for persistent sessions. @@ -78,12 +73,6 @@ def __init__( message_timeout_s: Timeout for receiving responses to messages provider: Optional container provider for lifecycle management """ - if websockets is None: - raise ImportError( - "websockets library is required for WebSocketEnvClient. " - "Install with: pip install websockets" - ) - # Convert HTTP URL to WebSocket URL ws_url = base_url.rstrip("/") if ws_url.startswith("http://"): @@ -220,17 +209,17 @@ def from_hub( return cls.from_docker_image(image=base_url, provider=provider, **kwargs) @abstractmethod - def _step_payload(self, action: ActT) -> dict: + def _step_payload(self, action: ActT) -> Dict[str, Any]: """Convert an Action object to the JSON data expected by the env server.""" raise NotImplementedError @abstractmethod - def _parse_result(self, payload: dict) -> StepResult[ObsT]: + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: """Convert a JSON response from the env server to StepResult[ObsT].""" raise NotImplementedError @abstractmethod - def _parse_state(self, payload: dict) -> Any: + def _parse_state(self, payload: Dict[str, Any]) -> StateT: """Convert a JSON response from the state endpoint to a State object.""" raise NotImplementedError @@ -272,7 +261,7 @@ def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: response = self._send_and_receive(message) return self._parse_result(response.get("data", {})) - def state(self) -> Any: + def state(self) -> StateT: """ Get the current environment state from the server. From e95f8b14b9e61100cba7722cd9a984dd7bb72e80 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Fri, 12 Dec 2025 22:20:39 +0530 Subject: [PATCH 25/27] fix: concurrency handling and improve exception messages --- src/openenv/core/__init__.py | 7 ++- src/openenv/core/env_server/exceptions.py | 5 +- src/openenv/core/env_server/http_server.py | 61 +++++++++++--------- src/openenv/core/env_server/interfaces.py | 2 +- src/openenv/core/env_server/types.py | 1 - src/openenv/core/env_server/web_interface.py | 8 ++- 6 files changed, 45 insertions(+), 39 deletions(-) diff --git a/src/openenv/core/__init__.py b/src/openenv/core/__init__.py index 93ae09786..e9bbf2365 100644 --- a/src/openenv/core/__init__.py +++ b/src/openenv/core/__init__.py @@ -8,9 +8,10 @@ # Re-export main components from submodules for convenience from .env_server import * # noqa: F403 -from .env_server import __all__ as _env_server_all - +from . import env_server +from .ws_env_client import WebSocketEnvClient +from .http_env_client import HTTPEnvClient # Note: MCP module doesn't export anything yet -__all__ = list(_env_server_all) \ No newline at end of file +__all__ = ["WebSocketEnvClient", "HTTPEnvClient"] + env_server.__all__ # type: ignore \ No newline at end of file diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py index a16715721..23fed6567 100644 --- a/src/openenv/core/env_server/exceptions.py +++ b/src/openenv/core/env_server/exceptions.py @@ -96,10 +96,9 @@ def __init__(self, reason: str, message: Optional[str] = None): class EnvironmentFactoryError(OpenEnvError): """Raised when the environment factory fails to create an instance.""" - def __init__(self, factory_name: str, cause: Exception): + def __init__(self, factory_name: str): self.factory_name = factory_name - self.cause = cause - message = f"Environment factory '{factory_name}' failed to create instance: {cause}" + message = f"Environment factory '{factory_name}' failed to create instance." super().__init__(message) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 56b73b3fa..604600f79 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -96,8 +96,7 @@ def __init__( env: Union[Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], - max_concurrent_envs: int = 1, - skip_concurrency_check: bool = False, + max_concurrent_envs: Optional[int] = None, concurrency_config: Optional[ConcurrencyConfig] = None, ): """ @@ -108,16 +107,13 @@ def __init__( Will be called to create a new environment for each WebSocket session. action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns - max_concurrent_envs: Maximum number of concurrent WebSocket sessions (default: 1). - If concurrency_config is provided, this parameter is ignored. - skip_concurrency_check: If True, skip concurrency safety validation. - Use with caution for advanced users who understand - the isolation requirements. + max_concurrent_envs: Maximum number of concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. - If provided, overrides max_concurrent_envs and allows - configuration of session timeout and capacity behavior. + Mutually exclusive with max_concurrent_envs. Raises: + ValueError: If both max_concurrent_envs and concurrency_config are provided. ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. """ @@ -131,21 +127,29 @@ def __init__( self._env_factory: Callable[[], Environment] = env # Handle concurrency configuration + if max_concurrent_envs is not None and concurrency_config is not None: + raise ValueError( + "Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. " + "Please use only one method to configure concurrency." + ) + if concurrency_config is not None: self._concurrency_config = concurrency_config - self._max_concurrent_envs = concurrency_config.max_concurrent_envs - else: - # Use legacy parameters + elif max_concurrent_envs is not None: self._concurrency_config = ConcurrencyConfig( max_concurrent_envs=max_concurrent_envs, session_timeout=None, reject_on_capacity=True, ) - self._max_concurrent_envs = max_concurrent_envs + else: + # Default configuration + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=1, + session_timeout=None, + reject_on_capacity=True, + ) - self._skip_concurrency_check = skip_concurrency_check or os.getenv( - "OPENENV_SKIP_CONCURRENCY_CHECK", "" - ).lower() in ("1", "true", "yes") + self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs # Validate concurrency configuration self._validate_concurrency_safety() @@ -176,9 +180,6 @@ def _validate_concurrency_safety(self) -> None: if self._max_concurrent_envs <= 1: return - if self._skip_concurrency_check: - return - if inspect.isclass(self._env_factory): is_concurrency_safe = getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) env_name = self._env_factory.__name__ @@ -268,7 +269,11 @@ async def _create_session(self) -> tuple[str, Environment]: session_id = str(uuid.uuid4()) current_time = time.time() - env = self._env_factory() + try: + env = self._env_factory() + except Exception as e: + factory_name = getattr(self._env_factory, "__name__", str(self._env_factory)) + raise EnvironmentFactoryError(factory_name) from e self._sessions[session_id] = env @@ -862,7 +867,7 @@ def create_app( action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, - max_concurrent_envs: int = 1, + max_concurrent_envs: Optional[int] = None, concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ @@ -876,10 +881,10 @@ def create_app( action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading - max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). - Ignored if concurrency_config is provided. + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. - If provided, overrides max_concurrent_envs. + Mutually exclusive with max_concurrent_envs. Returns: FastAPI application instance with or without web interface and README integration @@ -915,7 +920,7 @@ def create_fastapi_app( env: Union[Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], - max_concurrent_envs: int = 1, + max_concurrent_envs: Optional[int] = None, concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ @@ -925,10 +930,10 @@ def create_fastapi_app( env: Environment factory (callable or class) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns - max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). - Ignored if concurrency_config is provided. + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. - If provided, overrides max_concurrent_envs. + Mutually exclusive with max_concurrent_envs. Returns: FastAPI application instance diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index 03f1ddb21..c02ba4a05 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -96,7 +96,7 @@ class Environment(ABC, Generic[ActT, ObsT, StateT]): transform: Optional transform to apply to observations Class Attributes: - CONCURRENCY_SAFE: Whether this environment supports concurrent sessions. + SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions. When True, multiple WebSocket connections can each have their own environment instance (up to max_concurrent_envs). When False (default), the environment should only be used with a single session at a time. diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 8993d280c..273994479 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -274,7 +274,6 @@ class ConcurrencyConfig(BaseMessage): max_concurrent_envs: int = Field( default=1, ge=1, - le=1000, description="Maximum number of concurrent WebSocket sessions allowed", ) session_timeout: Optional[float] = Field( diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index be55b9146..5711d0ef0 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -239,7 +239,7 @@ def create_web_interface_app( action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, - max_concurrent_envs: int = 1, + max_concurrent_envs: Optional[int] = None, concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ @@ -250,8 +250,10 @@ def create_web_interface_app( action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading - max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) - concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. Returns: FastAPI application instance with web interface From 05e6da08dc6276a603db925652ae9c78d718fe91 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Sat, 13 Dec 2025 23:10:52 +0530 Subject: [PATCH 26/27] chore: clean up exception handling and remove unused concurrency config field --- src/openenv/core/env_server/exceptions.py | 13 +- src/openenv/core/env_server/http_server.py | 237 +++++++-------------- src/openenv/core/env_server/types.py | 4 - 3 files changed, 80 insertions(+), 174 deletions(-) diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py index 23fed6567..4fb4a6ec8 100644 --- a/src/openenv/core/env_server/exceptions.py +++ b/src/openenv/core/env_server/exceptions.py @@ -31,7 +31,7 @@ def __init__( ): self.environment_name = environment_name self.max_concurrent_envs = max_concurrent_envs - + if message is None: message = ( f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. " @@ -39,7 +39,7 @@ def __init__( f"Either set max_concurrent_envs=1 or ensure the environment " f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True." ) - + super().__init__(message) @@ -96,9 +96,10 @@ def __init__(self, reason: str, message: Optional[str] = None): class EnvironmentFactoryError(OpenEnvError): """Raised when the environment factory fails to create an instance.""" - def __init__(self, factory_name: str): + def __init__(self, factory_name: str, message: Optional[str] = None): self.factory_name = factory_name - - message = f"Environment factory '{factory_name}' failed to create instance." - + + if message is None: + message = f"Environment factory '{factory_name}' failed to create instance." + super().__init__(message) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 604600f79..1b1797cc7 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -18,6 +18,7 @@ import inspect import json import os +import time import uuid from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, Optional, Type, Union @@ -139,14 +140,12 @@ def __init__( self._concurrency_config = ConcurrencyConfig( max_concurrent_envs=max_concurrent_envs, session_timeout=None, - reject_on_capacity=True, ) else: # Default configuration self._concurrency_config = ConcurrencyConfig( max_concurrent_envs=1, session_timeout=None, - reject_on_capacity=True, ) self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs @@ -165,9 +164,7 @@ def __init__( # Create thread pool for running sync code in async context # This is needed for environments using sync libraries (e.g., Playwright) - # Configurable via OPENENV_THREAD_POOL_SIZE (default: 32) - pool_size = int(os.getenv("OPENENV_THREAD_POOL_SIZE", "32")) - self._executor = ThreadPoolExecutor(max_workers=pool_size) + self._executor = ThreadPoolExecutor(max_workers=32) def _validate_concurrency_safety(self) -> None: """ @@ -181,18 +178,16 @@ def _validate_concurrency_safety(self) -> None: return if inspect.isclass(self._env_factory): - is_concurrency_safe = getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) - env_name = self._env_factory.__name__ + env_cls = self._env_factory else: _temp_env = self._env_factory() - is_concurrency_safe = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) - env_name = type(_temp_env).__name__ + env_cls = type(_temp_env) _temp_env.close() del _temp_env - if not is_concurrency_safe: + if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False): raise ConcurrencyConfigurationError( - environment_name=env_name, + environment_name=env_cls.__name__, max_concurrent_envs=self._max_concurrent_envs, ) @@ -249,22 +244,12 @@ async def _create_session(self) -> tuple[str, Environment]: SessionCapacityError: If max concurrent sessions reached EnvironmentFactoryError: If the factory fails to create an environment """ - import time - async with self._session_lock: if len(self._sessions) >= self._max_concurrent_envs: - if self._concurrency_config.reject_on_capacity: - raise SessionCapacityError( - active_sessions=len(self._sessions), - max_sessions=self._max_concurrent_envs, - ) - else: - # TODO: Implement queuing mechanism when reject_on_capacity=False - raise SessionCapacityError( - active_sessions=len(self._sessions), - max_sessions=self._max_concurrent_envs, - message="Session queuing not yet implemented", - ) + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) session_id = str(uuid.uuid4()) current_time = time.time() @@ -318,8 +303,6 @@ def _update_session_activity( session_id: The session ID to update increment_step: If True, increment the step count """ - import time - if session_id in self._session_info: self._session_info[session_id].last_activity_at = time.time() if increment_step: @@ -580,24 +563,6 @@ def get_metadata_handler() -> EnvironmentMetadata: ] register_get_endpoints(app, get_endpoints) - # Register concurrency config endpoint - @app.get( - "/concurrency", - response_model=ConcurrencyConfig, - tags=["Environment Info"], - summary="Get concurrency configuration", - description=""" -Get the current concurrency configuration for this server. - -Returns information about: -- **max_concurrent_envs**: Maximum number of concurrent WebSocket sessions -- **session_timeout**: Timeout in seconds for inactive sessions (None if no timeout) -- **reject_on_capacity**: Whether to reject or queue connections at capacity - """, - ) - async def get_concurrency_config() -> ConcurrencyConfig: - """Return concurrency configuration.""" - return self._concurrency_config # Register combined schema endpoint @app.get( @@ -654,8 +619,7 @@ async def websocket_endpoint(websocket: WebSocket): """ WebSocket endpoint for persistent environment sessions. - Each WebSocket connection gets its own environment instance (when using - factory mode) or shares the single instance (backward compatible mode). + Each WebSocket connection gets its own environment instance. Message Protocol: - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage @@ -689,141 +653,83 @@ async def websocket_endpoint(websocket: WebSocket): msg_type = message_dict.get("type", "") try: - if msg_type == "reset": - # Parse and validate reset message - try: + match msg_type: + case "reset": msg = WSResetMessage(**message_dict) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": "Invalid reset message", - "code": "VALIDATION_ERROR", - "errors": e.errors(), - } - ) - await websocket.send_text(error_resp.model_dump_json()) - continue - - # Handle reset - is_async = session_env.reset_async.__func__ is not Environment.reset_async - - if is_async: - sig = inspect.signature(session_env.reset_async) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) - observation = await session_env.reset_async(**valid_kwargs) - else: - sig = inspect.signature(session_env.reset) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) - observation = await self._run_in_session_executor( - session_id, session_env.reset, **valid_kwargs - ) - self._update_session_activity(session_id) + is_async = session_env.reset_async.__func__ is not Environment.reset_async - response = WSObservationResponse( - data=serialize_observation(observation) - ) - await websocket.send_text(response.model_dump_json()) + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await session_env.reset_async(**valid_kwargs) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) - elif msg_type == "step": - # Parse and validate step message - try: - msg = WSStepMessage(**message_dict) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": "Invalid step message", - "code": "VALIDATION_ERROR", - "errors": e.errors(), - } + self._update_session_activity(session_id) + + response = WSObservationResponse( + data=serialize_observation(observation) ) - await websocket.send_text(error_resp.model_dump_json()) - continue - # Deserialize action with Pydantic validation - try: + case "step": + msg = WSStepMessage(**message_dict) action = deserialize_action(msg.data, self.action_cls) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": str(e), - "code": "VALIDATION_ERROR", - "errors": e.errors(), - } - ) - await websocket.send_text(error_resp.model_dump_json()) - continue - is_async = session_env.step_async.__func__ is not Environment.step_async + is_async = session_env.step_async.__func__ is not Environment.step_async - if is_async: - observation = await session_env.step_async(action) - else: - observation = await self._run_in_session_executor( - session_id, session_env.step, action - ) + if is_async: + observation = await session_env.step_async(action) + else: + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) - self._update_session_activity( - session_id, increment_step=True - ) + self._update_session_activity( + session_id, increment_step=True + ) - response = WSObservationResponse( - data=serialize_observation(observation) - ) - await websocket.send_text(response.model_dump_json()) + response = WSObservationResponse( + data=serialize_observation(observation) + ) - elif msg_type == "state": - # Parse and validate state message - try: + case "state": msg = WSStateMessage(**message_dict) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": "Invalid state message", - "code": "VALIDATION_ERROR", - "errors": e.errors(), - } - ) - await websocket.send_text(error_resp.model_dump_json()) - continue - - # Handle state request - state = session_env.state - if hasattr(state, "model_dump"): - state_data = state.model_dump() - else: - state_data = dict(state) if state else {} - - response = WSStateResponse(data=state_data) - await websocket.send_text(response.model_dump_json()) - - elif msg_type == "close": - # Parse and validate close message - try: + state = session_env.state + if hasattr(state, "model_dump"): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + + case "close": msg = WSCloseMessage(**message_dict) - except ValidationError as e: - error_resp = WSErrorResponse( + break + + case _: + response = WSErrorResponse( data={ - "message": "Invalid close message", - "code": "VALIDATION_ERROR", - "errors": e.errors(), + "message": f"Unknown message type: {msg_type}", + "code": "UNKNOWN_TYPE", } ) - await websocket.send_text(error_resp.model_dump_json()) - continue - # Client requested close - break - - else: - error_resp = WSErrorResponse( - data={ - "message": f"Unknown message type: {msg_type}", - "code": "UNKNOWN_TYPE", - } - ) - await websocket.send_text(error_resp.model_dump_json()) + await websocket.send_text(response.model_dump_json()) + except ValidationError as e: + error_resp = WSErrorResponse( + data={ + "message": "Invalid message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } + ) + await websocket.send_text(error_resp.model_dump_json()) except Exception as e: error_resp = WSErrorResponse( data={"message": str(e), "code": "EXECUTION_ERROR"} @@ -859,7 +765,10 @@ async def websocket_endpoint(websocket: WebSocket): finally: if session_id: await self._destroy_session(session_id) - await websocket.close() + try: + await websocket.close() + except RuntimeError: + pass def create_app( diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 273994479..a22914b73 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -281,10 +281,6 @@ class ConcurrencyConfig(BaseMessage): gt=0, description="Timeout in seconds for inactive sessions. None means no timeout.", ) - reject_on_capacity: bool = Field( - default=True, - description="If True, reject new connections when at capacity. If False, queue them.", - ) class ServerCapacityStatus(BaseMessage): From d52850f646f97292ea9435bc1748c6f3ce2ad91b Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Sun, 14 Dec 2025 23:16:09 +0530 Subject: [PATCH 27/27] refactor: simplify environment factory type annotations and add utility function for URL conversion --- src/openenv/core/env_server/http_server.py | 12 ++++----- src/openenv/core/env_server/web_interface.py | 4 +-- src/openenv/core/utils.py | 26 ++++++++++++++++++++ src/openenv/core/ws_env_client.py | 9 ++----- 4 files changed, 36 insertions(+), 15 deletions(-) create mode 100644 src/openenv/core/utils.py diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 1b1797cc7..b816b3d62 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -94,7 +94,7 @@ class HTTPEnvServer: def __init__( self, - env: Union[Callable[[], Environment], Type[Environment]], + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], max_concurrent_envs: Optional[int] = None, @@ -104,7 +104,7 @@ def __init__( Initialize HTTP server wrapper. Args: - env: Environment factory (callable or class) that creates new instances. + env: Environment factory (callable) that creates new instances. Will be called to create a new environment for each WebSocket session. action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns @@ -772,7 +772,7 @@ async def websocket_endpoint(websocket: WebSocket): def create_app( - env: Union[Callable[[], Environment], Type[Environment]], + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, @@ -786,7 +786,7 @@ def create_app( including README integration for better user experience. Args: - env: Environment factory (callable or class) that creates new instances + env: Environment factory (callable) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading @@ -826,7 +826,7 @@ def create_app( def create_fastapi_app( - env: Union[Callable[[], Environment], Type[Environment]], + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], max_concurrent_envs: Optional[int] = None, @@ -836,7 +836,7 @@ def create_fastapi_app( Create a FastAPI application with comprehensive documentation. Args: - env: Environment factory (callable or class) that creates new instances + env: Environment factory (callable) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns max_concurrent_envs: Maximum concurrent WebSocket sessions. diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index 5711d0ef0..703025375 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -235,7 +235,7 @@ def get_state(self) -> Dict[str, Any]: def create_web_interface_app( - env: Union[Callable[[], Environment], Type[Environment]], + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, @@ -246,7 +246,7 @@ def create_web_interface_app( Create a FastAPI application with web interface for the given environment. Args: - env: Environment factory (callable or class) that creates new instances + env: Environment factory (callable) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading diff --git a/src/openenv/core/utils.py b/src/openenv/core/utils.py new file mode 100644 index 000000000..42e9cee82 --- /dev/null +++ b/src/openenv/core/utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Utility functions for OpenEnv core.""" + +def convert_to_ws_url(url: str) -> str: + """ + Convert an HTTP/HTTPS URL to a WS/WSS URL. + + Args: + url: The URL to convert. + + Returns: + The converted WebSocket URL. + """ + ws_url = url.rstrip("/") + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[7:] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[8:] + elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): + ws_url = "ws://" + ws_url + return ws_url diff --git a/src/openenv/core/ws_env_client.py b/src/openenv/core/ws_env_client.py index 6c1d6a4ab..efa829f64 100644 --- a/src/openenv/core/ws_env_client.py +++ b/src/openenv/core/ws_env_client.py @@ -20,6 +20,7 @@ from .client_types import StepResult, StateT from .containers.runtime import LocalDockerProvider +from .utils import convert_to_ws_url if TYPE_CHECKING: from .containers.runtime import ContainerProvider @@ -74,13 +75,7 @@ def __init__( provider: Optional container provider for lifecycle management """ # Convert HTTP URL to WebSocket URL - ws_url = base_url.rstrip("/") - if ws_url.startswith("http://"): - ws_url = "ws://" + ws_url[7:] - elif ws_url.startswith("https://"): - ws_url = "wss://" + ws_url[8:] - elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): - ws_url = "ws://" + ws_url + ws_url = convert_to_ws_url(base_url) self._ws_url = f"{ws_url}/ws" self._connect_timeout = connect_timeout_s