Skip to content

Commit

Permalink
Feat: Supports returning expected error responses to wrong requests
Browse files Browse the repository at this point in the history
Signed-off-by: Kebe <mail@kebe7jun.com>
  • Loading branch information
kebe7jun committed Feb 26, 2025
1 parent 3dc9ff3 commit 642fc5e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
25 changes: 25 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
import threading
import time
from http import HTTPStatus
from json import JSONDecodeError
from typing import AsyncIterator, Dict, Optional

from pydantic import ValidationError

# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Expand Down Expand Up @@ -57,6 +60,7 @@
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.adapter import (
create_error_response,
v1_batches,
v1_cancel_batch,
v1_chat_completions,
Expand Down Expand Up @@ -94,6 +98,27 @@
)


@app.exception_handler(ValidationError)
@app.exception_handler(JSONDecodeError)
async def validation_exception_handler(
request: Request, exc: ValidationError
) -> Response:
return create_error_response(
message=str(exc),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST,
)


@app.exception_handler(Exception)
async def exception_handler(request: Request, exc: Exception) -> Response:
return create_error_response(
message=str(exc),
err_type="InternalServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)


# Store global states
@dataclasses.dataclass
class _GlobalState:
Expand Down
11 changes: 10 additions & 1 deletion python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import time
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Literal


Expand Down Expand Up @@ -325,6 +325,15 @@ class ChatCompletionRequest(BaseModel):
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None

@model_validator(mode="after")
def validate(self):
# messages
if not self.messages:
raise ValueError("messages cannot be empty")
if self.messages[-1].role == "assistant":
# for https://github.com/sgl-project/sglang/issues/3579
raise ValueError("last message role must not be assistant")


class FunctionResponse(BaseModel):
"""Function response."""
Expand Down

0 comments on commit 642fc5e

Please sign in to comment.