diff --git a/pyproject.toml b/pyproject.toml index 811c068c9..edb6c1f17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ dependencies = [ "huggingface_hub>=0.20.0", "openai>=2.7.2", "tomli>=2.3.0", - "tomli-w>=1.2.0" + "tomli-w>=1.2.0", + "websockets>=15.0.1", ] [project.optional-dependencies] diff --git a/src/openenv/cli/templates/openenv_env/README.md b/src/openenv/cli/templates/openenv_env/README.md index ef238dfb7..f6a5c0292 100644 --- a/src/openenv/cli/templates/openenv_env/README.md +++ b/src/openenv/cli/templates/openenv_env/README.md @@ -114,6 +114,7 @@ The deployed space includes: - **Web Interface** at `/web` - Interactive UI for exploring the environment - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface - **Health Check** at `/health` - Container health monitoring +- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions ## Environment Details @@ -154,6 +155,61 @@ result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message="Hello!")) Note: When connecting to an existing server, `__ENV_NAME__env.close()` will NOT stop the server. +### WebSocket Client for Persistent Sessions + +For long-running episodes or when you need lower latency, use the WebSocket client: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__EnvWS + +# Connect via WebSocket (maintains persistent connection) +with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as env: + result = env.reset() + print(f"Reset: {result.observation.echoed_message}") + # Multiple steps with low latency + for msg in ["Hello", "World", "!"]: + result = env.step(__ENV_CLASS_NAME__Action(message=msg)) + print(f"Echoed: {result.observation.echoed_message}") +``` + +WebSocket advantages: +- **Lower latency**: No HTTP connection overhead per request +- **Persistent session**: Server maintains your environment state +- **Efficient for episodes**: Better for many sequential steps + +### Concurrent WebSocket Sessions + +The server supports multiple concurrent WebSocket connections. To enable this, +modify `server/app.py` to use factory mode: + +```python +# In server/app.py - use factory mode for concurrent sessions +app = create_app( + __ENV_CLASS_NAME__Environment, # Pass class, not instance + __ENV_CLASS_NAME__Action, + __ENV_CLASS_NAME__Observation, + max_concurrent_envs=4, # Allow 4 concurrent sessions +) +``` + +Then multiple clients can connect simultaneously: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__EnvWS +from concurrent.futures import ThreadPoolExecutor + +def run_episode(client_id: int): + with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as env: + result = env.reset() + for i in range(10): + result = env.step(__ENV_CLASS_NAME__Action(message=f"Client {client_id}, step {i}")) + return client_id, result.observation.message_length + +# Run 4 episodes concurrently +with ThreadPoolExecutor(max_workers=4) as executor: + results = list(executor.map(run_episode, range(4))) +``` + ## Development & Testing ### Direct Environment Testing @@ -189,11 +245,11 @@ __ENV_NAME__/ ├── openenv.yaml # OpenEnv manifest ├── pyproject.toml # Project metadata and dependencies ├── uv.lock # Locked dependencies (generated) -├── client.py # __ENV_CLASS_NAME__Env client implementation +├── client.py # __ENV_CLASS_NAME__Env (HTTP) and __ENV_CLASS_NAME__EnvWS (WebSocket) clients ├── models.py # Action and Observation models └── server/ ├── __init__.py # Server module exports ├── __ENV_NAME___environment.py # Core environment logic - ├── app.py # FastAPI application + ├── app.py # FastAPI application (HTTP + WebSocket endpoints) └── Dockerfile # Container image definition ``` diff --git a/src/openenv/cli/templates/openenv_env/__init__.py b/src/openenv/cli/templates/openenv_env/__init__.py index 656800a55..aed293ba8 100644 --- a/src/openenv/cli/templates/openenv_env/__init__.py +++ b/src/openenv/cli/templates/openenv_env/__init__.py @@ -6,8 +6,12 @@ """__ENV_TITLE_NAME__ Environment - A simple test environment for HTTP server.""" -from .client import __ENV_CLASS_NAME__Env +from .client import __ENV_CLASS_NAME__Env, __ENV_CLASS_NAME__EnvWS from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation -__all__ = ["__ENV_CLASS_NAME__Action", "__ENV_CLASS_NAME__Observation", "__ENV_CLASS_NAME__Env"] - +__all__ = [ + "__ENV_CLASS_NAME__Action", + "__ENV_CLASS_NAME__Observation", + "__ENV_CLASS_NAME__Env", + "__ENV_CLASS_NAME__EnvWS", +] diff --git a/src/openenv/cli/templates/openenv_env/client.py b/src/openenv/cli/templates/openenv_env/client.py index 703b28a85..0775f2536 100644 --- a/src/openenv/cli/templates/openenv_env/client.py +++ b/src/openenv/cli/templates/openenv_env/client.py @@ -5,10 +5,11 @@ # LICENSE file in the root directory of this source tree. """ -__ENV_TITLE_NAME__ Environment HTTP Client. +__ENV_TITLE_NAME__ Environment Clients. -This module provides the client for connecting to a __ENV_TITLE_NAME__ Environment server -over HTTP. +This module provides clients for connecting to a __ENV_TITLE_NAME__ Environment server: +- __ENV_CLASS_NAME__Env: HTTP client for request/response interactions +- __ENV_CLASS_NAME__EnvWS: WebSocket client for persistent sessions """ from typing import Any, Dict @@ -16,6 +17,7 @@ from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.ws_env_client import WebSocketEnvClient from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation @@ -98,3 +100,90 @@ def _parse_state(self, payload: Dict) -> State: episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), ) + + +class __ENV_CLASS_NAME__EnvWS(WebSocketEnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation]): + """ + WebSocket client for the __ENV_TITLE_NAME__ Environment. + + This client maintains a persistent WebSocket connection to the environment server, + enabling efficient multi-step interactions with lower latency than HTTP. + Each client instance has its own dedicated environment session on the server. + + Advantages over HTTP client: + - Lower latency for sequential interactions (no connection overhead per request) + - Session state is maintained server-side + - Better suited for long-running episodes + + Example: + >>> # Connect to a running server via WebSocket + >>> with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.echoed_message) + ... + ... result = client.step(__ENV_CLASS_NAME__Action(message="Hello!")) + ... print(result.observation.echoed_message) + + Example with Docker: + >>> # Automatically start container and connect via WebSocket + >>> client = __ENV_CLASS_NAME__EnvWS.from_docker_image("__ENV_NAME__-env:latest") + >>> try: + ... result = client.reset() + ... result = client.step(__ENV_CLASS_NAME__Action(message="Test")) + ... finally: + ... client.close() + """ + + def _step_payload(self, action: __ENV_CLASS_NAME__Action) -> Dict: + """ + Convert __ENV_CLASS_NAME__Action to JSON payload for step message. + + Args: + action: __ENV_CLASS_NAME__Action instance + + Returns: + Dictionary representation suitable for JSON encoding + """ + return { + "message": action.message, + } + + def _parse_result(self, payload: Dict) -> StepResult[__ENV_CLASS_NAME__Observation]: + """ + Parse WebSocket response into StepResult[__ENV_CLASS_NAME__Observation]. + + Args: + payload: JSON response data from server + + Returns: + StepResult with __ENV_CLASS_NAME__Observation + """ + obs_data = payload.get("observation", {}) + observation = __ENV_CLASS_NAME__Observation( + echoed_message=obs_data.get("echoed_message", ""), + message_length=obs_data.get("message_length", 0), + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict) -> State: + """ + Parse WebSocket state response into State object. + + Args: + payload: JSON response from state request + + Returns: + State object with episode_id and step_count + """ + return State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + ) diff --git a/src/openenv/cli/templates/openenv_env/models.py b/src/openenv/cli/templates/openenv_env/models.py index 64010449b..4540d5a29 100644 --- a/src/openenv/cli/templates/openenv_env/models.py +++ b/src/openenv/cli/templates/openenv_env/models.py @@ -10,22 +10,20 @@ The __ENV_NAME__ environment is a simple test environment that echoes back messages. """ -from dataclasses import dataclass +from pydantic import Field from openenv.core.env_server.types import Action, Observation -@dataclass(kw_only=True) class __ENV_CLASS_NAME__Action(Action): """Action for the __ENV_TITLE_NAME__ environment - just a message to echo.""" - message: str + message: str = Field(..., description="Message to echo back") -@dataclass(kw_only=True) class __ENV_CLASS_NAME__Observation(Observation): """Observation from the __ENV_TITLE_NAME__ environment - the echoed message.""" - echoed_message: str - message_length: int = 0 + echoed_message: str = Field(default="", description="The echoed message") + message_length: int = Field(default=0, description="Length of the echoed message") diff --git a/src/openenv/cli/templates/openenv_env/pyproject.toml b/src/openenv/cli/templates/openenv_env/pyproject.toml index 55b90113f..4c6b948ff 100644 --- a/src/openenv/cli/templates/openenv_env/pyproject.toml +++ b/src/openenv/cli/templates/openenv_env/pyproject.toml @@ -15,6 +15,8 @@ description = "__ENV_TITLE_NAME__ environment for OpenEnv" requires-python = ">=3.10" dependencies = [ # Core OpenEnv runtime (provides FastAPI server + HTTP client types) + # install from github + # "openenv[core] @ git+https://github.com/meta-pytorch/OpenEnv.git", "openenv[core]>=0.2.0", # Environment-specific dependencies # Add all dependencies needed for your environment here diff --git a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py index e2a9ce0b7..454ea6808 100644 --- a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py +++ b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py @@ -36,6 +36,12 @@ class __ENV_CLASS_NAME__Environment(Environment): >>> print(obs.message_length) # 5 """ + # Enable concurrent WebSocket sessions. + # Set to True if your environment isolates state between instances. + # When True, multiple WebSocket clients can connect simultaneously, each + # getting their own environment instance (when using factory mode in app.py). + SUPPORTS_CONCURRENT_SESSIONS: bool = True + def __init__(self): """Initialize the __ENV_NAME__ environment.""" self._state = State(episode_id=str(uuid4()), step_count=0) diff --git a/src/openenv/cli/templates/openenv_env/server/app.py b/src/openenv/cli/templates/openenv_env/server/app.py index db216fb06..5100b1050 100644 --- a/src/openenv/cli/templates/openenv_env/server/app.py +++ b/src/openenv/cli/templates/openenv_env/server/app.py @@ -8,7 +8,14 @@ FastAPI application for the __ENV_TITLE_NAME__ Environment. This module creates an HTTP server that exposes the __ENV_CLASS_NAME__Environment -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with HTTPEnvClient and WebSocketEnvClient. + +Endpoints: + - POST /reset: Reset the environment + - POST /step: Execute an action + - GET /state: Get current environment state + - GET /schema: Get action/observation schemas + - WS /ws: WebSocket endpoint for persistent sessions Usage: # Development (with auto-reload): @@ -28,18 +35,18 @@ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'" ) from e -from __ENV_NAME__.models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation +# Import from local models.py (PYTHONPATH includes /app/env in Docker) +from models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation from .__ENV_NAME___environment import __ENV_CLASS_NAME__Environment -# Create the environment instance -env = __ENV_CLASS_NAME__Environment() # Create the app with web interface and README integration app = create_app( - env, + __ENV_CLASS_NAME__Environment, __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation, env_name="__ENV_NAME__", + max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions ) diff --git a/src/openenv/core/__init__.py b/src/openenv/core/__init__.py index 99507ab55..e9bbf2365 100644 --- a/src/openenv/core/__init__.py +++ b/src/openenv/core/__init__.py @@ -7,13 +7,11 @@ """Core components for agentic environments.""" # Re-export main components from submodules for convenience -from .env_server import * -from .client_types import StepResult +from .env_server import * # noqa: F403 +from . import env_server +from .ws_env_client import WebSocketEnvClient from .http_env_client import HTTPEnvClient # Note: MCP module doesn't export anything yet -__all__ = [ - "HTTPEnvClient", - "StepResult", -] +__all__ = ["WebSocketEnvClient", "HTTPEnvClient"] + env_server.__all__ # type: ignore \ No newline at end of file diff --git a/src/openenv/core/client_types.py b/src/openenv/core/client_types.py index 8808e96bf..c7501c656 100644 --- a/src/openenv/core/client_types.py +++ b/src/openenv/core/client_types.py @@ -1,9 +1,10 @@ # Type definitions for EnvTorch from dataclasses import dataclass -from typing import Any, Generic, Optional, TypeVar +from typing import Generic, Optional, TypeVar # Generic type for observations -ObsT = TypeVar("ObsT") # TypeVar for typehinting in IDEs +ObsT = TypeVar("ObsT") +StateT = TypeVar("StateT") @dataclass diff --git a/src/openenv/core/env_server/__init__.py b/src/openenv/core/env_server/__init__.py index 4e1c2d7ac..ed0d41278 100644 --- a/src/openenv/core/env_server/__init__.py +++ b/src/openenv/core/env_server/__init__.py @@ -15,7 +15,33 @@ deserialize_action_with_preprocessing, serialize_observation, ) -from .types import Action, Observation, State, SchemaResponse, HealthResponse +from .types import ( + Action, + Observation, + State, + SchemaResponse, + HealthResponse, + BaseMessage, + WSIncomingMessage, + WSResetMessage, + WSStepMessage, + WSStateMessage, + WSCloseMessage, + WSObservationResponse, + WSStateResponse, + WSErrorResponse, + ConcurrencyConfig, + ServerCapacityStatus, + SessionInfo, +) +from .exceptions import ( + OpenEnvError, + ConcurrencyConfigurationError, + SessionCapacityError, + SessionNotFoundError, + SessionCreationError, + EnvironmentFactoryError, +) from .web_interface import create_web_interface_app, WebInterfaceManager __all__ = [ @@ -30,6 +56,27 @@ "State", "SchemaResponse", "HealthResponse", + # WebSocket message types + "BaseMessage", + "WSIncomingMessage", + "WSResetMessage", + "WSStepMessage", + "WSStateMessage", + "WSCloseMessage", + "WSObservationResponse", + "WSStateResponse", + "WSErrorResponse", + # Concurrency types + "ConcurrencyConfig", + "ServerCapacityStatus", + "SessionInfo", + # Exceptions + "OpenEnvError", + "ConcurrencyConfigurationError", + "SessionCapacityError", + "SessionNotFoundError", + "SessionCreationError", + "EnvironmentFactoryError", # Base transforms "CompositeTransform", "NullTransform", diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py new file mode 100644 index 000000000..4fb4a6ec8 --- /dev/null +++ b/src/openenv/core/env_server/exceptions.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Custom exceptions for environment server operations.""" + +from typing import Optional + + +class OpenEnvError(Exception): + """Base exception for all OpenEnv errors.""" + + pass + + +class ConcurrencyConfigurationError(OpenEnvError): + """ + Raised when an environment is misconfigured for concurrent sessions. + + This error is raised during server startup when max_concurrent_envs > 1 + is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + + def __init__( + self, + environment_name: str, + max_concurrent_envs: int, + message: Optional[str] = None, + ): + self.environment_name = environment_name + self.max_concurrent_envs = max_concurrent_envs + + if message is None: + message = ( + f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. " + f"Cannot run with max_concurrent_envs={max_concurrent_envs}. " + f"Either set max_concurrent_envs=1 or ensure the environment " + f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True." + ) + + super().__init__(message) + + +class SessionCapacityError(OpenEnvError): + """ + Raised when the server cannot accept new sessions due to capacity limits. + + This error is raised when a new WebSocket connection is attempted but + the server has already reached max_concurrent_envs active sessions. + """ + + def __init__( + self, + active_sessions: int, + max_sessions: int, + message: Optional[str] = None, + ): + self.active_sessions = active_sessions + self.max_sessions = max_sessions + + if message is None: + message = ( + f"Server at capacity: {active_sessions}/{max_sessions} sessions active. " + f"Cannot accept new connections." + ) + + super().__init__(message) + + +class SessionNotFoundError(OpenEnvError): + """Raised when attempting to access a session that does not exist.""" + + def __init__(self, session_id: str, message: Optional[str] = None): + self.session_id = session_id + + if message is None: + message = f"Session '{session_id}' not found." + + super().__init__(message) + + +class SessionCreationError(OpenEnvError): + """Raised when a session cannot be created.""" + + def __init__(self, reason: str, message: Optional[str] = None): + self.reason = reason + + if message is None: + message = f"Failed to create session: {reason}" + + super().__init__(message) + + +class EnvironmentFactoryError(OpenEnvError): + """Raised when the environment factory fails to create an instance.""" + + def __init__(self, factory_name: str, message: Optional[str] = None): + self.factory_name = factory_name + + if message is None: + message = f"Environment factory '{factory_name}' failed to create instance." + + super().__init__(message) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 7fa7c0f32..b816b3d62 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -8,18 +8,22 @@ HTTP server wrapper for Environment instances. This module provides utilities to wrap any Environment subclass and expose it -over HTTP endpoints that HTTPEnvClient can consume. +over HTTP endpoints that HTTPEnvClient can consume. Also supports WebSocket +connections for persistent sessions with multi-environment concurrency. """ from __future__ import annotations import asyncio import inspect +import json import os +import time +import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Type +from typing import Any, Callable, Dict, Optional, Type, Union -from fastapi import Body, FastAPI, HTTPException, status +from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect, status from pydantic import ValidationError from .interfaces import Environment @@ -39,6 +43,21 @@ EnvironmentMetadata, SchemaResponse, HealthResponse, + WSResetMessage, + WSStepMessage, + WSStateMessage, + WSCloseMessage, + WSObservationResponse, + WSStateResponse, + WSErrorResponse, + ConcurrencyConfig, + ServerCapacityStatus, + SessionInfo, +) +from .exceptions import ( + ConcurrencyConfigurationError, + SessionCapacityError, + EnvironmentFactoryError, ) @@ -47,7 +66,8 @@ class HTTPEnvServer: HTTP server wrapper for Environment instances. This class wraps an Environment and exposes its reset(), step(), and state - methods as HTTP endpoints compatible with HTTPEnvClient. + methods as HTTP endpoints compatible with HTTPEnvClient. Also supports + WebSocket connections for persistent sessions with multi-environment concurrency. The server expects: - Action deserialization: Converts JSON dict to Action subclass @@ -56,9 +76,15 @@ class HTTPEnvServer: Example: >>> from core.env_server import HTTPEnvServer >>> from envs.coding_env.server import CodeExecutionEnvironment + >>> from envs.coding_env.models import CodeAction, CodeObservation >>> - >>> env = CodeExecutionEnvironment() - >>> server = HTTPEnvServer(env) + >>> # Pass environment class (factory pattern) + >>> server = HTTPEnvServer( + ... env=CodeExecutionEnvironment, + ... action_cls=CodeAction, + ... observation_cls=CodeObservation, + ... max_concurrent_envs=4, + ... ) >>> >>> # Register routes with FastAPI >>> from fastapi import FastAPI @@ -68,31 +94,128 @@ class HTTPEnvServer: def __init__( self, - env: Environment, + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, ): """ Initialize HTTP server wrapper. Args: - env: The Environment instance to wrap + env: Environment factory (callable) that creates new instances. + Will be called to create a new environment for each WebSocket session. action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum number of concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + + Raises: + ValueError: If both max_concurrent_envs and concurrency_config are provided. + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. """ - self.env = env + # Validate that env is callable + if not callable(env): + raise TypeError( + f"env must be a callable (class or factory function), got {type(env)}. " + f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())." + ) + + self._env_factory: Callable[[], Environment] = env + + # Handle concurrency configuration + if max_concurrent_envs is not None and concurrency_config is not None: + raise ValueError( + "Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. " + "Please use only one method to configure concurrency." + ) + + if concurrency_config is not None: + self._concurrency_config = concurrency_config + elif max_concurrent_envs is not None: + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=max_concurrent_envs, + session_timeout=None, + ) + else: + # Default configuration + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=1, + session_timeout=None, + ) + + self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs + + # Validate concurrency configuration + self._validate_concurrency_safety() + self.action_cls = action_cls self.observation_cls = observation_cls + + # Session management for WebSocket connections + self._sessions: Dict[str, Environment] = {} + self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_info: Dict[str, SessionInfo] = {} + self._session_lock = asyncio.Lock() + # Create thread pool for running sync code in async context - # This is needed for environments using sync libraries (e.g., Playwright sync API) - self._executor = ThreadPoolExecutor(max_workers=1) + # This is needed for environments using sync libraries (e.g., Playwright) + self._executor = ThreadPoolExecutor(max_workers=32) + + def _validate_concurrency_safety(self) -> None: + """ + Validate that the environment supports the configured concurrency level. + + Raises: + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + if self._max_concurrent_envs <= 1: + return + + if inspect.isclass(self._env_factory): + env_cls = self._env_factory + else: + _temp_env = self._env_factory() + env_cls = type(_temp_env) + _temp_env.close() + del _temp_env + + if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False): + raise ConcurrencyConfigurationError( + environment_name=env_cls.__name__, + max_concurrent_envs=self._max_concurrent_envs, + ) + + def get_capacity_status(self) -> ServerCapacityStatus: + """ + Get the current capacity status of the server. + + Returns: + ServerCapacityStatus with current session counts and availability. + """ + return ServerCapacityStatus.from_counts( + active=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) - async def _run_sync_in_thread_pool(self, func, *args, **kwargs): + async def _run_sync_in_thread_pool( + self, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: """Run a synchronous function in the thread pool executor.""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs)) - def _get_valid_kwargs(self, sig, kwargs, skip_params=None): + def _get_valid_kwargs( + self, + sig: inspect.Signature, + kwargs: Dict[str, Any], + skip_params: Optional[set[str]] = None, + ) -> Dict[str, Any]: """Filter kwargs to only include parameters accepted by the function signature.""" if skip_params is None: skip_params = set() @@ -110,6 +233,129 @@ def _get_valid_kwargs(self, sig, kwargs, skip_params=None): return valid_kwargs + async def _create_session(self) -> tuple[str, Environment]: + """ + Create a new WebSocket session with its own environment instance. + + Returns: + Tuple of (session_id, environment) + + Raises: + SessionCapacityError: If max concurrent sessions reached + EnvironmentFactoryError: If the factory fails to create an environment + """ + async with self._session_lock: + if len(self._sessions) >= self._max_concurrent_envs: + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + + session_id = str(uuid.uuid4()) + current_time = time.time() + + try: + env = self._env_factory() + except Exception as e: + factory_name = getattr(self._env_factory, "__name__", str(self._env_factory)) + raise EnvironmentFactoryError(factory_name) from e + + self._sessions[session_id] = env + + self._session_executors[session_id] = ThreadPoolExecutor(max_workers=1) + + # Track session metadata + self._session_info[session_id] = SessionInfo( + session_id=session_id, + created_at=current_time, + last_activity_at=current_time, + step_count=0, + environment_type=type(env).__name__, + ) + + return session_id, env + + async def _destroy_session(self, session_id: str) -> None: + """ + Destroy a WebSocket session and cleanup resources. + + Args: + session_id: The session ID to destroy + """ + async with self._session_lock: + if session_id in self._sessions: + env = self._sessions.pop(session_id) + env.close() + + if session_id in self._session_executors: + executor = self._session_executors.pop(session_id) + executor.shutdown(wait=False) + + self._session_info.pop(session_id, None) + + def _update_session_activity( + self, session_id: str, increment_step: bool = False + ) -> None: + """ + Update session activity timestamp and optionally increment step count. + + Args: + session_id: The session ID to update + increment_step: If True, increment the step count + """ + if session_id in self._session_info: + self._session_info[session_id].last_activity_at = time.time() + if increment_step: + self._session_info[session_id].step_count += 1 + + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: + """ + Get information about a specific session. + + Args: + session_id: The session ID to query + + Returns: + SessionInfo if the session exists, None otherwise + """ + return self._session_info.get(session_id) + + async def _run_in_session_executor( + self, session_id: str, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: + """Run a synchronous function in the session's thread pool executor.""" + executor = self._session_executors.get(session_id, self._executor) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) + + @property + def active_sessions(self) -> int: + """Return the number of active WebSocket sessions.""" + return len(self._sessions) + + @property + def max_concurrent_envs(self) -> int: + """Return the maximum number of concurrent environments.""" + return self._max_concurrent_envs + + @property + def is_concurrency_safe(self) -> bool: + """Return whether the environment is marked as concurrency safe.""" + import inspect + if inspect.isclass(self._env_factory): + return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) + else: + _temp_env = self._env_factory() + result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) + _temp_env.close() + del _temp_env + return result + + @property + def concurrency_config(self) -> ConcurrencyConfig: + """Return the concurrency configuration.""" + return self._concurrency_config + def register_routes(self, app: FastAPI) -> None: """ Register HTTP routes on a FastAPI application. @@ -123,49 +369,64 @@ async def reset_handler( request: ResetRequest = Body(default_factory=ResetRequest), ) -> ResetResponse: """Reset endpoint - returns initial observation.""" - # Handle optional parameters - # Start with all fields from the request, including extra ones - kwargs = request.model_dump(exclude_unset=True) - - # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.reset) - valid_kwargs = self._get_valid_kwargs(sig, kwargs) - - # Run synchronous reset in thread pool to avoid blocking event loop - observation = await self._run_sync_in_thread_pool( - self.env.reset, **valid_kwargs - ) - return ResetResponse(**serialize_observation(observation)) + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True) + + is_async = _env.reset_async.__func__ is not Environment.reset_async + + if is_async: + sig = inspect.signature(_env.reset_async) + else: + sig = inspect.signature(_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, kwargs) + + if is_async: + observation = await _env.reset_async(**valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.reset, **valid_kwargs + ) + return ResetResponse(**serialize_observation(observation)) + finally: + _env.close() # Helper function to handle step endpoint async def step_handler(request: StepRequest) -> StepResponse: """Step endpoint - executes action and returns observation.""" action_data = request.action - # Deserialize action with Pydantic validation try: action = deserialize_action(action_data, self.action_cls) except ValidationError as e: - # Return HTTP 422 with detailed validation errors raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() ) - # Handle optional parameters - # Start with all fields from the request, including extra ones, but exclude 'action' - kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) - # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.step) - valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) + is_async = _env.step_async.__func__ is not Environment.step_async - # Run synchronous step in thread pool to avoid blocking event loop - observation = await self._run_sync_in_thread_pool( - self.env.step, action, **valid_kwargs - ) + if is_async: + sig = inspect.signature(_env.step_async) + else: + sig = inspect.signature(_env.step) + valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) - # Return serialized observation - return StepResponse(**serialize_observation(observation)) + if is_async: + observation = await _env.step_async(action, **valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.step, action, **valid_kwargs + ) + + return StepResponse(**serialize_observation(observation)) + finally: + _env.close() # Register routes using the helpers @app.post( @@ -251,24 +512,36 @@ async def reset( async def step(request: StepRequest) -> StepResponse: return await step_handler(request) - # Configure and register GET endpoints declaratively + def get_state_handler() -> State: + _env = self._env_factory() + try: + return _env.state + finally: + _env.close() + + def get_metadata_handler() -> EnvironmentMetadata: + _env = self._env_factory() + try: + return _env.get_metadata() + finally: + _env.close() + get_endpoints = [ GetEndpointConfig( path="/state", - handler=lambda: self.env.state, + handler=get_state_handler, response_model=State, tag="State Management", summary="Get current environment state", description=""" Retrieve the current internal state of the environment. -This endpoint allows inspection of the environment state without modifying it. The structure of the state object is defined by the environment's State model. """, ), GetEndpointConfig( path="/metadata", - handler=self.env.get_metadata, + handler=get_metadata_handler, response_model=EnvironmentMetadata, tag="Environment Info", summary="Get environment metadata", @@ -290,6 +563,7 @@ async def step(request: StepRequest) -> StepResponse: ] register_get_endpoints(app, get_endpoints) + # Register combined schema endpoint @app.get( "/schema", @@ -339,12 +613,171 @@ async def get_schemas() -> SchemaResponse: state=State.model_json_schema(), ) + # Register WebSocket endpoint for persistent sessions + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for persistent environment sessions. + + Each WebSocket connection gets its own environment instance. + + Message Protocol: + - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage + - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse + """ + await websocket.accept() + + session_id = None + session_env = None + + try: + # Create session with dedicated environment + session_id, session_env = await self._create_session() + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + message_dict = json.loads(raw_message) + except json.JSONDecodeError as e: + error_resp = WSErrorResponse( + data={ + "message": f"Invalid JSON: {e}", + "code": "INVALID_JSON", + } + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + msg_type = message_dict.get("type", "") + + try: + match msg_type: + case "reset": + msg = WSResetMessage(**message_dict) + + is_async = session_env.reset_async.__func__ is not Environment.reset_async + + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await session_env.reset_async(**valid_kwargs) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) + + self._update_session_activity(session_id) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + + case "step": + msg = WSStepMessage(**message_dict) + action = deserialize_action(msg.data, self.action_cls) + + is_async = session_env.step_async.__func__ is not Environment.step_async + + if is_async: + observation = await session_env.step_async(action) + else: + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) + + self._update_session_activity( + session_id, increment_step=True + ) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + + case "state": + msg = WSStateMessage(**message_dict) + state = session_env.state + if hasattr(state, "model_dump"): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + + case "close": + msg = WSCloseMessage(**message_dict) + break + + case _: + response = WSErrorResponse( + data={ + "message": f"Unknown message type: {msg_type}", + "code": "UNKNOWN_TYPE", + } + ) + + await websocket.send_text(response.model_dump_json()) + + except ValidationError as e: + error_resp = WSErrorResponse( + data={ + "message": "Invalid message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": "EXECUTION_ERROR"} + ) + await websocket.send_text(error_resp.model_dump_json()) + + except WebSocketDisconnect: + pass + except SessionCapacityError as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "CAPACITY_REACHED", + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except EnvironmentFactoryError as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "FACTORY_ERROR", + "factory_name": e.factory_name, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": "SESSION_ERROR"} + ) + await websocket.send_text(error_resp.model_dump_json()) + finally: + if session_id: + await self._destroy_session(session_id) + try: + await websocket.close() + except RuntimeError: + pass + def create_app( - env: Environment, + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ Create a FastAPI application with or without web interface. @@ -353,10 +786,14 @@ def create_app( including README integration for better user experience. Args: - env: The Environment instance to serve + env: Environment factory (callable) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. Returns: FastAPI application instance with or without web interface and README integration @@ -373,18 +810,43 @@ def create_app( # Import web interface only when needed from .web_interface import create_web_interface_app - return create_web_interface_app(env, action_cls, observation_cls, env_name) + return create_web_interface_app( + env, + action_cls, + observation_cls, + env_name, + max_concurrent_envs, + concurrency_config, + ) else: # Use standard FastAPI app without web interface - return create_fastapi_app(env, action_cls, observation_cls) + return create_fastapi_app( + env, action_cls, observation_cls, max_concurrent_envs, concurrency_config + ) def create_fastapi_app( - env: Environment, + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: - """Create a FastAPI application with comprehensive documentation.""" + """ + Create a FastAPI application with comprehensive documentation. + + Args: + env: Environment factory (callable) that creates new instances + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + + Returns: + FastAPI application instance + """ try: from fastapi import FastAPI except ImportError: @@ -452,6 +914,12 @@ def create_fastapi_app( }, ) - server = HTTPEnvServer(env, action_cls, observation_cls) + server = HTTPEnvServer( + env, + action_cls, + observation_cls, + max_concurrent_envs, + concurrency_config=concurrency_config, + ) server.register_routes(app) return app diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index b438cd667..c02ba4a05 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -5,10 +5,14 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Optional, Protocol, TypedDict +from typing import Any, Generic, Optional, Protocol, TypedDict, TypeVar from .types import Action, Observation, State, EnvironmentMetadata +ActT = TypeVar("ActT", bound=Action) +ObsT = TypeVar("ObsT", bound=Observation) +StateT = TypeVar("StateT", bound=State) + class Message(TypedDict): """A message in a conversation. @@ -64,7 +68,7 @@ def decode( ... -class Transform(ABC): +class Transform(ABC, Generic[ObsT]): """Transform observations to add rewards, metrics, or other modifications. Transforms follow the TorchRL pattern where they take an observation @@ -73,7 +77,7 @@ class Transform(ABC): """ @abstractmethod - def __call__(self, observation: Observation) -> Observation: + def __call__(self, observation: ObsT) -> ObsT: """Transform an observation. Args: @@ -85,14 +89,28 @@ def __call__(self, observation: Observation) -> Observation: pass -class Environment(ABC): +class Environment(ABC, Generic[ActT, ObsT, StateT]): """Base class for all environment servers following Gym/Gymnasium API. Args: transform: Optional transform to apply to observations + + Class Attributes: + SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions. + When True, multiple WebSocket connections can each have their own + environment instance (up to max_concurrent_envs). When False (default), + the environment should only be used with a single session at a time. + + Set this to True in your Environment subclass if: + - The environment uses proper session isolation (e.g., unique working dirs) + - No shared mutable state exists between instances + - External resources (databases, APIs) can handle concurrent access """ + + # Class-level flag indicating whether this environment supports concurrent sessions + SUPPORTS_CONCURRENT_SESSIONS: bool = False - def __init__(self, transform: Transform | None = None): + def __init__(self, transform: Optional[Transform[ObsT]] = None): self.transform = transform @abstractmethod @@ -101,23 +119,47 @@ def reset( seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, - ) -> Observation: + ) -> ObsT: """Reset the environment and return initial observation.""" pass + async def reset_async( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of reset. Default implementation calls sync reset. + + Override to provide true async implementation. + """ + return self.reset(seed=seed, episode_id=episode_id, **kwargs) + @abstractmethod def step( self, - action: Action, + action: ActT, timeout_s: Optional[float] = None, **kwargs: Any, - ) -> Observation: + ) -> ObsT: """Take a step in the environment.""" pass + async def step_async( + self, + action: ActT, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of step. Default implementation calls sync step. + + Override to provide true async implementation. + """ + return self.step(action, timeout_s=timeout_s, **kwargs) + @property @abstractmethod - def state(self) -> State: + def state(self) -> StateT: """Get the current environment state.""" pass @@ -137,8 +179,16 @@ def get_metadata(self) -> EnvironmentMetadata: version="1.0.0", ) - def _apply_transform(self, observation: Observation) -> Observation: + def _apply_transform(self, observation: ObsT) -> ObsT: """Apply transform if one is provided.""" if self.transform is not None: return self.transform(observation) return observation + + def close(self) -> None: + """Clean up resources used by the environment. + + Override this method to implement custom cleanup logic. + Called when the environment is being destroyed or reset. + """ + pass diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index a97a05283..df06592f5 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -80,7 +80,7 @@ def deserialize_action_with_preprocessing( value = [] if isinstance(value, list): try: - import torch + import torch # type: ignore processed_data[key] = torch.tensor(value, dtype=torch.long) except ImportError: diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index c3ee689c0..a22914b73 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Optional, Union -from pydantic import BaseModel, Field, ConfigDict +from typing import Any, Dict, Optional, Union, Literal, Annotated +from pydantic import BaseModel, Field, ConfigDict, model_validator # Type aliases @@ -127,6 +127,15 @@ class StepResponse(BaseModel): done: bool = Field(default=False, description="Whether the episode has terminated") +class BaseMessage(BaseModel): + """Base class for WebSocket messages with shared configuration.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + class State(BaseModel): """Base class for environment state. @@ -149,27 +158,17 @@ class State(BaseModel): ) -class CodeExecResult(BaseModel): +class CodeExecResult(BaseMessage): """Result of code execution containing stdout, stderr, and exit code.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - stdout: str = Field(description="Standard output from code execution") stderr: str = Field(description="Standard error from code execution") exit_code: int = Field(description="Exit code from code execution") -class EnvironmentMetadata(BaseModel): +class EnvironmentMetadata(BaseMessage): """Metadata about an environment for documentation and UI purposes.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - name: str = Field(description="Name of the environment") description: str = Field(description="Description of what the environment does") readme_content: Optional[str] = Field( @@ -184,14 +183,9 @@ class EnvironmentMetadata(BaseModel): ) -class SchemaResponse(BaseModel): +class SchemaResponse(BaseMessage): """Response model for the combined schema endpoint.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - action: Dict[str, Any] = Field( description="JSON schema for actions accepted by this environment" ) @@ -203,12 +197,145 @@ class SchemaResponse(BaseModel): ) -class HealthResponse(BaseModel): +class HealthResponse(BaseMessage): """Response model for health check endpoint.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, + status: str = Field(description="Health status of the environment server") + + +class WSResetMessage(BaseMessage): + """WebSocket message to reset the environment.""" + + type: Literal["reset"] = Field(default="reset", description="Message type") + data: Dict[str, Any] = Field( + default_factory=dict, + description="Optional reset parameters (seed, episode_id, etc.)", ) - status: str = Field(description="Health status of the environment server") + +class WSStepMessage(BaseMessage): + """WebSocket message to execute a step.""" + + type: Literal["step"] = Field(default="step", description="Message type") + data: Dict[str, Any] = Field( + ..., description="Action data conforming to environment's action schema" + ) + + +class WSStateMessage(BaseMessage): + """WebSocket message to request current state.""" + + type: Literal["state"] = Field(default="state", description="Message type") + + +class WSCloseMessage(BaseMessage): + """WebSocket message to close the session.""" + + type: Literal["close"] = Field(default="close", description="Message type") + + +# Discriminated union for incoming WebSocket messages +WSIncomingMessage = Annotated[ + WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage, + Field(discriminator="type"), +] + + +class WSObservationResponse(BaseModel): + """WebSocket response containing an observation.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="observation", description="Response type") + data: Dict[str, Any] = Field(description="Observation data") + + +class WSStateResponse(BaseModel): + """WebSocket response containing environment state.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="state", description="Response type") + data: Dict[str, Any] = Field(description="State data") + + +class WSErrorResponse(BaseModel): + """WebSocket response for errors.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="error", description="Response type") + data: Dict[str, Any] = Field(description="Error details including message and code") + + +class ConcurrencyConfig(BaseMessage): + """Configuration for concurrent environment sessions.""" + + max_concurrent_envs: int = Field( + default=1, + ge=1, + description="Maximum number of concurrent WebSocket sessions allowed", + ) + session_timeout: Optional[float] = Field( + default=None, + gt=0, + description="Timeout in seconds for inactive sessions. None means no timeout.", + ) + + +class ServerCapacityStatus(BaseMessage): + """Status of server capacity for concurrent sessions.""" + + active_sessions: int = Field( + ge=0, + description="Number of currently active sessions", + ) + max_sessions: int = Field( + ge=1, + description="Maximum number of allowed sessions", + ) + + @model_validator(mode="after") + def check_capacity_bounds(self) -> "ServerCapacityStatus": + if self.active_sessions > self.max_sessions: + raise ValueError( + f"active_sessions ({self.active_sessions}) cannot exceed " + f"max_sessions ({self.max_sessions})" + ) + return self + + @property + def available_slots(self) -> int: + """Number of available session slots.""" + return self.max_sessions - self.active_sessions + + @property + def is_at_capacity(self) -> bool: + """Whether the server has reached maximum capacity.""" + return self.available_slots == 0 + + @classmethod + def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": + """Create status from active and max session counts.""" + return cls( + active_sessions=active, + max_sessions=max_sessions, + ) + + +class SessionInfo(BaseMessage): + """Information about an active session.""" + + session_id: str = Field(description="Unique identifier for the session") + created_at: float = Field(description="Unix timestamp when the session was created") + last_activity_at: float = Field( + description="Unix timestamp of the last activity in the session" + ) + step_count: int = Field( + default=0, + ge=0, + description="Number of steps executed in this session", + ) + environment_type: str = Field( + description="Environment type for this session (e.g. `CodingEnv`)" + ) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index b370cfa53..703025375 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -14,21 +14,19 @@ from __future__ import annotations import json -from typing import Any, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, Union from datetime import datetime from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse -from pydantic import BaseModel, Field, ConfigDict +from pydantic import Field from .interfaces import Environment from .serialization import deserialize_action_with_preprocessing, serialize_observation -from .types import Action, Observation, State, EnvironmentMetadata +from .types import Action, Observation, State, EnvironmentMetadata, ConcurrencyConfig, BaseMessage -def load_environment_metadata( - env: Environment, env_name: Optional[str] = None -) -> EnvironmentMetadata: +def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata: """ Load environment metadata including README content. @@ -98,37 +96,25 @@ def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]: return None -class ActionLog(BaseModel): +class ActionLog(BaseMessage): """Log entry for an action taken.""" - model_config = ConfigDict(extra="forbid", validate_assignment=True) - timestamp: str = Field(description="Timestamp when action was taken") action: Dict[str, Any] = Field(description="Action that was taken") observation: Dict[str, Any] = Field(description="Observation returned from action") - reward: Optional[float] = Field( - default=None, description="Reward received from action" - ) + reward: Optional[float] = Field(default=None, description="Reward received from action") done: bool = Field(description="Whether the episode is done after this action") step_count: int = Field(description="Step count when this action was taken") -class EpisodeState(BaseModel): +class EpisodeState(BaseMessage): """Current episode state for the web interface.""" - model_config = ConfigDict(extra="forbid", validate_assignment=True) - episode_id: Optional[str] = Field(default=None, description="Current episode ID") step_count: int = Field(description="Current step count in episode") - current_observation: Optional[Dict[str, Any]] = Field( - default=None, description="Current observation" - ) - action_logs: List[ActionLog] = Field( - default_factory=list, description="List of action logs" - ) - is_reset: bool = Field( - default=True, description="Whether the episode has been reset" - ) + current_observation: Optional[Dict[str, Any]] = Field(default=None, description="Current observation") + action_logs: List[ActionLog] = Field(default_factory=list, description="List of action logs") + is_reset: bool = Field(default=True, description="Whether the episode has been reset") class WebInterfaceManager: @@ -211,9 +197,7 @@ async def reset_environment(self) -> Dict[str, Any]: async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: """Execute a step in the environment and update state.""" # Deserialize action with preprocessing for web interface special cases - action: Action = deserialize_action_with_preprocessing( - action_data, self.action_cls - ) + action: Action = deserialize_action_with_preprocessing(action_data, self.action_cls) # Execute step observation: Observation = self.env.step(action) @@ -251,19 +235,25 @@ def get_state(self) -> Dict[str, Any]: def create_web_interface_app( - env: Environment, + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ Create a FastAPI application with web interface for the given environment. Args: - env: The Environment instance to serve + env: Environment factory (callable) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. Returns: FastAPI application instance with web interface @@ -271,13 +261,16 @@ def create_web_interface_app( from .http_server import create_fastapi_app # Create the base environment app - app = create_fastapi_app(env, action_cls, observation_cls) + app = create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs, concurrency_config) + + # Create a test instance for metadata + env_instance = env() # Load environment metadata - metadata = load_environment_metadata(env, env_name) + metadata = load_environment_metadata(env_instance, env_name) # Create web interface manager - web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + web_manager = WebInterfaceManager(env_instance, action_cls, observation_cls, metadata) # Add web interface routes @app.get("/web", response_class=HTMLResponse) @@ -309,12 +302,13 @@ async def web_reset(): @app.post("/web/step") async def web_step(request: Dict[str, Any]): """Step endpoint for web interface.""" - # Check if this is a message-based request (chat environment) if "message" in request: message = request["message"] - # Convert message to action using the environment's message_to_action method - action = web_manager.env.message_to_action(message) - action_data = {"tokens": action.tokens.tolist()} + if hasattr(web_manager.env, "message_to_action"): + action = getattr(web_manager.env, "message_to_action")(message) + action_data = {"tokens": action.tokens.tolist()} + else: + action_data = request.get("action", {}) else: action_data = request.get("action", {}) @@ -328,9 +322,7 @@ async def web_state(): return app -def get_web_interface_html( - action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None -) -> str: +def get_web_interface_html(action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None) -> str: """Generate the HTML for the web interface.""" # Check if this is a chat environment by looking for tokens field @@ -339,6 +331,7 @@ def get_web_interface_html( for field_name, field_info in action_cls.model_fields.items(): if ( field_name == "tokens" + and field_info.annotation is not None and hasattr(field_info.annotation, "__name__") and "Tensor" in field_info.annotation.__name__ ): @@ -1312,9 +1305,7 @@ def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: return action_fields -def _determine_input_type_from_schema( - field_info: Dict[str, Any], field_name: str -) -> str: +def _determine_input_type_from_schema(field_info: Dict[str, Any], field_name: str) -> str: """Determine the appropriate HTML input type from JSON schema info.""" schema_type = field_info.get("type") @@ -1386,15 +1377,9 @@ def _markdown_to_html(markdown: str) -> str: html_content = html.escape(markdown) # Convert headers - html_content = re.sub( - r"^# (.*?)$", r"

\1

", html_content, flags=re.MULTILINE - ) - html_content = re.sub( - r"^## (.*?)$", r"

\1

", html_content, flags=re.MULTILINE - ) - html_content = re.sub( - r"^### (.*?)$", r"

\1

", html_content, flags=re.MULTILINE - ) + html_content = re.sub(r"^# (.*?)$", r"

\1

", html_content, flags=re.MULTILINE) + html_content = re.sub(r"^## (.*?)$", r"

\1

", html_content, flags=re.MULTILINE) + html_content = re.sub(r"^### (.*?)$", r"

\1

", html_content, flags=re.MULTILINE) # Convert code blocks html_content = re.sub( @@ -1410,12 +1395,8 @@ def _markdown_to_html(markdown: str) -> str: html_content = re.sub(r"\*(.*?)\*", r"\1", html_content) # Convert lists - html_content = re.sub( - r"^- (.*?)$", r"
  • \1
  • ", html_content, flags=re.MULTILINE - ) - html_content = re.sub( - r"(
  • .*
  • )", r"", html_content, flags=re.DOTALL - ) + html_content = re.sub(r"^- (.*?)$", r"
  • \1
  • ", html_content, flags=re.MULTILINE) + html_content = re.sub(r"(
  • .*
  • )", r"", html_content, flags=re.DOTALL) # Convert line breaks html_content = html_content.replace("\n", "
    ") @@ -1423,9 +1404,7 @@ def _markdown_to_html(markdown: str) -> str: return html_content -def _generate_action_interface( - action_fields: List[Dict[str, Any]], is_chat_env: bool -) -> str: +def _generate_action_interface(action_fields: List[Dict[str, Any]], is_chat_env: bool) -> str: """Generate either a chat interface or action form based on environment type.""" if is_chat_env: return _generate_chat_interface() @@ -1549,9 +1528,7 @@ def _generate_single_field(field: Dict[str, Any]) -> str: for choice in choices: selected = "selected" if str(choice) == str(default_value) else "" - options_html.append( - f'' - ) + options_html.append(f'') return f'''
    diff --git a/src/openenv/core/http_env_client.py b/src/openenv/core/http_env_client.py index 007ef6a5f..0f25363d4 100644 --- a/src/openenv/core/http_env_client.py +++ b/src/openenv/core/http_env_client.py @@ -16,7 +16,7 @@ import requests -from .client_types import StepResult +from .client_types import StepResult, StateT from .containers.runtime import LocalDockerProvider if TYPE_CHECKING: @@ -27,7 +27,7 @@ EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient") -class HTTPEnvClient(ABC, Generic[ActT, ObsT]): +class HTTPEnvClient(ABC, Generic[ActT, ObsT, StateT]): def __init__( self, base_url: str, @@ -129,17 +129,17 @@ def from_hub( return cls.from_docker_image(image=base_url, provider=provider) @abstractmethod - def _step_payload(self, action: ActT) -> dict: + def _step_payload(self, action: ActT) -> Dict[str, Any]: """Convert an Action object to the JSON body expected by the env server.""" raise NotImplementedError @abstractmethod - def _parse_result(self, payload: dict) -> StepResult[ObsT]: + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: """Convert a JSON response from the env server to StepResult[ObsT].""" raise NotImplementedError @abstractmethod - def _parse_state(self, payload: dict) -> Any: + def _parse_state(self, payload: Dict[str, Any]) -> StateT: """Convert a JSON response from the state endpoint to a State object.""" raise NotImplementedError @@ -203,7 +203,7 @@ def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: r.raise_for_status() return self._parse_result(r.json()) - def state(self) -> Any: + def state(self) -> StateT: """ Get the current environment state from the server. diff --git a/src/openenv/core/utils.py b/src/openenv/core/utils.py new file mode 100644 index 000000000..42e9cee82 --- /dev/null +++ b/src/openenv/core/utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Utility functions for OpenEnv core.""" + +def convert_to_ws_url(url: str) -> str: + """ + Convert an HTTP/HTTPS URL to a WS/WSS URL. + + Args: + url: The URL to convert. + + Returns: + The converted WebSocket URL. + """ + ws_url = url.rstrip("/") + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[7:] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[8:] + elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): + ws_url = "ws://" + ws_url + return ws_url diff --git a/src/openenv/core/ws_env_client.py b/src/openenv/core/ws_env_client.py new file mode 100644 index 000000000..efa829f64 --- /dev/null +++ b/src/openenv/core/ws_env_client.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +WebSocket-based environment client for persistent sessions. + +This module provides a WebSocket client that maintains a persistent connection +to an environment server, enabling efficient multi-step interactions without +the overhead of HTTP request/response cycles. +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar + +from .client_types import StepResult, StateT +from .containers.runtime import LocalDockerProvider +from .utils import convert_to_ws_url + +if TYPE_CHECKING: + from .containers.runtime import ContainerProvider + from websockets.sync.client import ClientConnection + +from websockets.sync.client import connect as ws_connect + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") +WSEnvClientT = TypeVar("WSEnvClientT", bound="WebSocketEnvClient") + + +class WebSocketEnvClient(ABC, Generic[ActT, ObsT, StateT]): + """ + WebSocket-based environment client for persistent sessions. + + This client maintains a persistent WebSocket connection to an environment + server, enabling efficient multi-step interactions. Each client instance + corresponds to a dedicated environment session on the server. + + Compared to HTTPEnvClient: + - Lower latency for sequential interactions + - Session state is maintained server-side + - Better suited for long-running episodes + + Example: + >>> from envs.coding_env.client import CodingEnvWS + >>> + >>> # Connect to a server via WebSocket + >>> with CodingEnvWS(base_url="ws://localhost:8000") as env: + ... result = env.reset(seed=42) + ... while not result.done: + ... action = agent.predict(result.observation) + ... result = env.step(action) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + provider: Optional["ContainerProvider"] = None, + ): + """ + Initialize WebSocket client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + Will be converted to ws:// if http:// is provided. + connect_timeout_s: Timeout for establishing WebSocket connection + message_timeout_s: Timeout for receiving responses to messages + provider: Optional container provider for lifecycle management + """ + # Convert HTTP URL to WebSocket URL + ws_url = convert_to_ws_url(base_url) + + self._ws_url = f"{ws_url}/ws" + self._connect_timeout = connect_timeout_s + self._message_timeout = message_timeout_s + self._provider = provider + self._ws: Optional[ClientConnection] = None + + def connect(self) -> "WebSocketEnvClient": + """ + Establish WebSocket connection to the server. + + Returns: + self for method chaining + + Raises: + ConnectionError: If connection cannot be established + """ + if self._ws is not None: + return self + + try: + self._ws = ws_connect( + self._ws_url, + open_timeout=self._connect_timeout, + ) + except Exception as e: + raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e + + return self + + def disconnect(self) -> None: + """Close the WebSocket connection.""" + if self._ws is not None: + try: + # Send close message + self._send({"type": "close"}) + except Exception: + pass # Best effort + try: + self._ws.close() + except Exception: + pass + self._ws = None + + def _ensure_connected(self) -> None: + """Ensure WebSocket connection is established.""" + if self._ws is None: + self.connect() + + def _send(self, message: Dict[str, Any]) -> None: + """Send a message over the WebSocket.""" + self._ensure_connected() + assert self._ws is not None + self._ws.send(json.dumps(message)) + + def _receive(self) -> Dict[str, Any]: + """Receive and parse a message from the WebSocket.""" + assert self._ws is not None + raw = self._ws.recv(timeout=self._message_timeout) + return json.loads(raw) + + def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Send a message and wait for response.""" + self._send(message) + response = self._receive() + + # Check for error response + if response.get("type") == "error": + error_data = response.get("data", {}) + raise RuntimeError( + f"Server error: {error_data.get('message', 'Unknown error')} " + f"(code: {error_data.get('code', 'UNKNOWN')})" + ) + + return response + + @classmethod + def from_docker_image( + cls: Type[WSEnvClientT], + image: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> WSEnvClientT: + """ + Create a WebSocket environment client by spinning up a Docker container. + + Args: + image: Docker image name to run (e.g., "coding-env:latest") + provider: Container provider to use (defaults to LocalDockerProvider) + **kwargs: Additional arguments to pass to provider.start_container() + + Returns: + Connected WebSocket client instance + """ + if provider is None: + provider = LocalDockerProvider() + + # Start container + base_url = provider.start_container(image, **kwargs) + + # Wait for server to be ready + provider.wait_for_ready(base_url) + + # Create and connect client + client = cls(base_url=base_url, provider=provider) + client.connect() + + return client + + @classmethod + def from_hub( + cls: Type[WSEnvClientT], + repo_id: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> WSEnvClientT: + """ + Create a WebSocket client by pulling from a Hugging Face model hub. + """ + if provider is None: + provider = LocalDockerProvider() + + tag = kwargs.pop("tag", "latest") + base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + + return cls.from_docker_image(image=base_url, provider=provider, **kwargs) + + @abstractmethod + def _step_payload(self, action: ActT) -> Dict[str, Any]: + """Convert an Action object to the JSON data expected by the env server.""" + raise NotImplementedError + + @abstractmethod + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: + """Convert a JSON response from the env server to StepResult[ObsT].""" + raise NotImplementedError + + @abstractmethod + def _parse_state(self, payload: Dict[str, Any]) -> StateT: + """Convert a JSON response from the state endpoint to a State object.""" + raise NotImplementedError + + def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment with optional parameters. + + Args: + **kwargs: Optional parameters passed to the environment's reset method. + Common parameters include: + - seed: Random seed for reproducibility + - episode_id: Custom episode identifier + + Returns: + StepResult containing initial observation + """ + message = { + "type": "reset", + "data": kwargs, + } + response = self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters (currently ignored for WebSocket) + + Returns: + StepResult containing observation, reward, and done status + """ + message = { + "type": "step", + "data": self._step_payload(action), + } + response = self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + def state(self) -> StateT: + """ + Get the current environment state from the server. + + Returns: + State object with environment state information + """ + message = {"type": "state"} + response = self._send_and_receive(message) + return self._parse_state(response.get("data", {})) + + def close(self) -> None: + """ + Close the WebSocket connection and clean up resources. + + If this client was created via from_docker_image(), this will also + stop and remove the associated container. + """ + self.disconnect() + + if self._provider is not None: + self._provider.stop_container() + + def __enter__(self) -> "WebSocketEnvClient": + """Enter context manager, ensuring connection is established.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager, closing connection.""" + self.close()