From 4c4a858bd5a3b01d66e306941ecbfb8fe4d20c3d Mon Sep 17 00:00:00 2001 From: Kebe Date: Wed, 26 Feb 2025 14:50:28 +0800 Subject: [PATCH] Feat: Supports returning expected error responses to wrong requests Signed-off-by: Kebe --- python/sglang/srt/entrypoints/http_server.py | 25 ++++++++++++++++++++ python/sglang/srt/openai_api/protocol.py | 11 ++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 2b2421a376c..f19a17c0a92 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -25,8 +25,11 @@ import threading import time from http import HTTPStatus +from json import JSONDecodeError from typing import AsyncIterator, Dict, Optional +from pydantic import ValidationError + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -57,6 +60,7 @@ from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.openai_api.adapter import ( + create_error_response, v1_batches, v1_cancel_batch, v1_chat_completions, @@ -94,6 +98,27 @@ ) +@app.exception_handler(ValidationError) +@app.exception_handler(JSONDecodeError) +async def validation_exception_handler( + request: Request, exc: ValidationError +) -> Response: + return create_error_response( + message=str(exc), + err_type="BadRequestError", + status_code=HTTPStatus.BAD_REQUEST, + ) + + +@app.exception_handler(Exception) +async def exception_handler(request: Request, exc: Exception) -> Response: + return create_error_response( + message=str(exc), + err_type="InternalServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + # Store global states @dataclasses.dataclass class _GlobalState: diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 95b34527edb..ca3df3256c9 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -16,7 +16,7 @@ import time from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from typing_extensions import Literal @@ -325,6 +325,15 @@ class ChatCompletionRequest(BaseModel): lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None session_params: Optional[Dict] = None + @model_validator(mode="after") + def validate(self): + # messages + if not self.messages: + raise ValueError("messages cannot be empty") + if self.messages[-1].role == "assistant" and len(self.messages) == 1: + # for https://github.com/sgl-project/sglang/issues/3579 + raise ValueError("cannot just include messages that assistant roles") + class FunctionResponse(BaseModel): """Function response."""