diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index 98ca740fc..c784708ee 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -204,6 +204,7 @@ def initialize_all(app: FastAPI, args): prefill_model_labels=args.prefill_model_labels, decode_model_labels=args.decode_model_labels, kv_aware_threshold=args.kv_aware_threshold, + max_instance_failover_reroute_attempts=args.max_instance_failover_reroute_attempts, ) # Initialize feature gates diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 2786705fc..4d07ef789 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -367,6 +367,13 @@ def parse_args(): help="The threshold for kv-aware routing.", ) + parser.add_argument( + "--max-instance-failover-reroute-attempts", + type=int, + default=0, + help="Number of reroute attempts per failed request", + ) + args = parser.parse_args() args = load_initial_config_from_config_file_if_required(parser, args) diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 093872b3f..b933ca2a4 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -80,6 +80,11 @@ def _qps_routing( ret = url return ret + def set_request_migration(self, max_instance_failover_reroute_attempts): + self.max_instance_failover_reroute_attempts = ( + max_instance_failover_reroute_attempts + ) + def _update_hash_ring(self, endpoints: List["EndpointInfo"]): """ Update the hash ring with the current list of endpoints. @@ -466,10 +471,10 @@ def initialize_routing_logic( ) -> RoutingInterface: if routing_logic == RoutingLogic.ROUND_ROBIN: logger.info("Initializing round-robin routing logic") - return RoundRobinRouter() + router = RoundRobinRouter() elif routing_logic == RoutingLogic.SESSION_BASED: logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}") - return SessionRouter(kwargs.get("session_key")) + router = SessionRouter(kwargs.get("session_key")) elif routing_logic == RoutingLogic.KVAWARE: logger.info("Initializing kvaware routing logic") router = KvawareRouter( @@ -478,17 +483,22 @@ def initialize_routing_logic( kwargs.get("kv_aware_threshold"), ) router.start_kv_manager() - return router elif routing_logic == RoutingLogic.PREFIXAWARE: logger.info("Initializing prefix-aware routing logic") - return PrefixAwareRouter() + router = PrefixAwareRouter() elif routing_logic == RoutingLogic.DISAGGREGATED_PREFILL: logger.info("Initializing disaggregated prefill routing logic") - return DisaggregatedPrefillRouter( + router = DisaggregatedPrefillRouter( kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels") ) else: raise ValueError(f"Invalid routing logic {routing_logic}") + router.set_request_migration( + max_instance_failover_reroute_attempts=kwargs.get( + "max_instance_failover_reroute_attempts" + ) + ) + return router def reconfigure_routing_logic( diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 83e647927..00c98cd3e 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -13,6 +13,7 @@ # limitations under the License. # --- Request Processing & Routing --- +import asyncio import json import os import time @@ -136,8 +137,57 @@ async def process_request( ) +def perform_service_discovery( + request, request_json, request_endpoint, requested_model, error_urls +): + service_discovery = get_service_discovery() + endpoints = service_discovery.get_endpoint_info() + + aliases = getattr(service_discovery, "aliases", None) + if aliases and requested_model in aliases.keys(): + requested_model = aliases[requested_model] + request_body = replace_model_in_request_body(request_json, requested_model) + update_content_length(request, request_body) + + if not request_endpoint: + endpoints = list( + filter( + lambda x: requested_model in x.model_names + and not x.sleep + and x.url not in error_urls, + endpoints, + ) + ) + engine_stats = request.app.state.engine_stats_scraper.get_engine_stats() + request_stats = request.app.state.request_stats_monitor.get_request_stats( + time.time() + ) + else: + endpoints = list( + filter( + lambda x: requested_model in x.model_names + and x.Id == request_endpoint + and not x.sleep + and x.url not in error_urls, + endpoints, + ) + ) + engine_stats, request_stats = None, None + + if not endpoints: + return JSONResponse( + status_code=400, + content={ + "error": f"Model {requested_model} not found or vLLM engine is sleeping." + }, + ) + return endpoints, engine_stats, request_stats + + async def route_general_request( - request: Request, endpoint: str, background_tasks: BackgroundTasks + request: Request, + endpoint: str, + background_tasks: BackgroundTasks, ): """ Route the incoming request to the backend server and stream the response back to the client. @@ -203,96 +253,82 @@ async def route_general_request( status_code=400, detail="Request body is not JSON parsable." ) - service_discovery = get_service_discovery() - endpoints = service_discovery.get_endpoint_info() + # Perform service discovery to request path a number of times equal to reroutes + 1 + error_urls = set() + for _ in range(request.app.state.router.max_instance_failover_reroute_attempts + 1): + endpoints, engine_stats, request_stats = await asyncio.to_thread( + perform_service_discovery, + request, + request_json, + request_endpoint, + requested_model, + error_urls, + ) - aliases = getattr(service_discovery, "aliases", None) - if aliases and requested_model in aliases.keys(): - requested_model = aliases[requested_model] - request_body = replace_model_in_request_body(request_json, requested_model) - update_content_length(request, request_body) + logger.debug(f"Routing request {request_id} for model: {requested_model}") + if request_endpoint: + server_url = endpoints[0].url + logger.debug( + f"Routing request {request_id} to engine with Id: {endpoints[0].Id}" + ) - if not request_endpoint: - endpoints = list( - filter( - lambda x: requested_model in x.model_names and not x.sleep, - endpoints, + elif isinstance(request.app.state.router, KvawareRouter) or isinstance( + request.app.state.router, PrefixAwareRouter + ): + server_url = await request.app.state.router.route_request( + endpoints, engine_stats, request_stats, request, request_json ) - ) - engine_stats = request.app.state.engine_stats_scraper.get_engine_stats() - request_stats = request.app.state.request_stats_monitor.get_request_stats( - time.time() - ) - else: - endpoints = list( - filter( - lambda x: requested_model in x.model_names - and x.Id == request_endpoint - and not x.sleep, - endpoints, + else: + server_url = request.app.state.router.route_request( + endpoints, engine_stats, request_stats, request ) - ) - if not endpoints: - return JSONResponse( - status_code=400, - content={ - "error": f"Model {requested_model} not found or vLLM engine is sleeping." - }, + curr_time = time.time() + # Extract actual session ID from request headers for logging + session_key = ( + getattr(request.app.state.router, "session_key", None) + if hasattr(request.app.state.router, "session_key") + else None ) - - logger.debug(f"Routing request {request_id} for model: {requested_model}") - if request_endpoint: - server_url = endpoints[0].url - logger.debug( - f"Routing request {request_id} to engine with Id: {endpoints[0].Id}" + session_id = ( + request.headers.get(session_key, None) if session_key is not None else None ) + session_id_display = session_id if session_id is not None else "None" - elif isinstance(request.app.state.router, KvawareRouter) or isinstance( - request.app.state.router, PrefixAwareRouter - ): - server_url = await request.app.state.router.route_request( - endpoints, engine_stats, request_stats, request, request_json + # Debug logging to help troubleshoot session ID extraction + logger.debug( + f"Debug session extraction - Router type: {type(request.app.state.router).__name__}" ) - else: - server_url = request.app.state.router.route_request( - endpoints, engine_stats, request_stats, request + logger.debug(f"Debug session extraction - Session key config: {session_key}") + logger.debug( + f"Debug session extraction - Request headers: {dict(request.headers)}" ) + logger.debug(f"Debug session extraction - Extracted session ID: {session_id}") - curr_time = time.time() - # Extract actual session ID from request headers for logging - session_key = ( - getattr(request.app.state.router, "session_key", None) - if hasattr(request.app.state.router, "session_key") - else None - ) - session_id = ( - request.headers.get(session_key, None) if session_key is not None else None - ) - session_id_display = session_id if session_id is not None else "None" - - # Debug logging to help troubleshoot session ID extraction - logger.debug( - f"Debug session extraction - Router type: {type(request.app.state.router).__name__}" - ) - logger.debug(f"Debug session extraction - Session key config: {session_key}") - logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}") - logger.debug(f"Debug session extraction - Extracted session ID: {session_id}") + logger.info( + f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}" + ) + error = None + try: + stream_generator = process_request( + request, + request_body, + server_url, + request_id, + endpoint, + background_tasks, + ) + headers, status = await anext(stream_generator) + headers_dict = {key: value for key, value in headers.items()} + headers_dict["X-Request-Id"] = request_id + # Break out of the loop when the request's stream is fully generated + break + except Exception as e: + error_urls.add(server_url) + error = e - logger.info( - f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}" - ) - stream_generator = process_request( - request, - request_body, - server_url, - request_id, - endpoint, - background_tasks, - ) - headers, status = await anext(stream_generator) - headers_dict = {key: value for key, value in headers.items()} - headers_dict["X-Request-Id"] = request_id + if error: + raise error return StreamingResponse( stream_generator, status_code=status,