diff --git a/wren-ai-service/src/config.py b/wren-ai-service/src/config.py index 9d5880e78c..51a3709278 100644 --- a/wren-ai-service/src/config.py +++ b/wren-ai-service/src/config.py @@ -55,6 +55,7 @@ class Settings(BaseSettings): so we set it to 1_000_000, which is a large number """, ) + max_ask_timeout: int = Field(default=300, description="Maximum timeout for the ask query in seconds.") # user guide config is_oss: bool = Field(default=True) diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index cae34c40a2..97cdea1e76 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -149,6 +149,7 @@ def create_service_container( max_histories=settings.max_histories, enable_column_pruning=settings.enable_column_pruning, max_sql_correction_retries=settings.max_sql_correction_retries, + max_ask_timeout=settings.max_ask_timeout, **query_cache, ), chart_service=services.ChartService( diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 65fd0d6d5b..9a71f1ef6b 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -29,6 +29,7 @@ class AskRequest(BaseRequest): enable_column_pruning: bool = False use_dry_plan: bool = False allow_dry_plan_fallback: bool = True + timeout: float = Field(default=30.0, description="Timeout for the ask query in seconds.") class AskResponse(BaseModel): @@ -144,6 +145,7 @@ def __init__( enable_column_pruning: bool = False, max_sql_correction_retries: int = 3, max_histories: int = 5, + max_ask_timeout: int = 300, maxsize: int = 1_000_000, ttl: int = 120, ): @@ -160,6 +162,7 @@ def __init__( self._enable_column_pruning = enable_column_pruning self._max_histories = max_histories self._max_sql_correction_retries = max_sql_correction_retries + self._max_ask_timeout = max_ask_timeout def _is_stopped(self, query_id: str, container: dict): if ( @@ -175,6 +178,25 @@ async def ask( self, ask_request: AskRequest, **kwargs, + ): + timeout = min(ask_request.timeout, self._max_ask_timeout) + try: + await asyncio.wait_for(self._ask(ask_request, **kwargs), timeout=timeout) + except asyncio.TimeoutError: + logger.warning(f"ask pipeline - TIMEOUT: {ask_request.query_id}") + self._ask_results[ask_request.query_id] = AskResultResponse( + status="failed", + error=AskError( + code="OTHERS", + message="Query timed out", + ), + trace_id=kwargs.get("trace_id"), + ) + + async def _ask( + self, + ask_request: AskRequest, + **kwargs, ): trace_id = kwargs.get("trace_id") results = {