Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into 78-rename-base-url-to…
Browse files Browse the repository at this point in the history
…-base-url
  • Loading branch information
harishmohanraj committed Aug 26, 2024
2 parents 81a5f01 + 24932e4 commit 07b56dc
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 7 deletions.
15 changes: 13 additions & 2 deletions fastagency/studio/models/llms/azure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Annotated, Any, Dict, Literal
from typing import Annotated, Any, Dict, Literal, Type
from uuid import UUID

from pydantic import AfterValidator, BaseModel, Field, HttpUrl
from pydantic import AfterValidator, BaseModel, Field, HttpUrl, field_validator
from pydantic_core import PydanticCustomError
from typing_extensions import TypeAlias

from ..base import Model
Expand Down Expand Up @@ -48,6 +49,9 @@ class UrlModel(BaseModel):
url: URL


BASE_URL_ERROR_MESSAGE = "The Base URL contains curly braces, indicating a placeholder. Please replace the entire placeholder, including the curly braces, with your actual Azure resource name."


@register("llm")
class AzureOAI(Model):
model: Annotated[
Expand Down Expand Up @@ -88,6 +92,13 @@ class AzureOAI(Model):
),
] = 0.8

@field_validator("base_url")
@classmethod
def validate_base_url(cls: Type["AzureOAI"], value: Any) -> Any:
if "{" in value or "}" in value:
raise PydanticCustomError("invalid_base_url", BASE_URL_ERROR_MESSAGE)
return value

@classmethod
async def create_autogen(
cls, model_id: UUID, user_id: UUID, **kwargs: Any
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ server = [

# dev dependencies
devdocs = [
"mkdocs-material==9.5.32",
"mkdocs-material==9.5.33",
"mkdocs-static-i18n==1.2.3",
"mdx-include==1.4.2",
"mkdocstrings[python]==0.25.2",
Expand All @@ -96,7 +96,7 @@ lint = [
"types-Pygments",
"types-docutils",
"mypy==1.11.1",
"ruff==0.6.1",
"ruff==0.6.2",
"pyupgrade-directories==0.3.0",
"bandit==1.7.9",
"semgrep==1.85.0",
Expand All @@ -106,7 +106,7 @@ lint = [
test-core = [
"coverage[toml]==7.6.1",
"pytest==8.3.2",
"pytest-asyncio==0.23.8",
"pytest-asyncio==0.24.0",
"dirty-equals==0.7.1.post0",
"pytest-rerunfailures==14.0",
]
Expand All @@ -116,7 +116,7 @@ testing = [
"fastagency[test-core]",
"fastagency[server]", # Uvicorn is needed for testing
"pydantic-settings==2.4.0",
"PyYAML==6.0.1",
"PyYAML==6.0.2",
"watchfiles==0.23.0",
"email-validator==2.2.0",
]
Expand Down
37 changes: 36 additions & 1 deletion tests/studio/models/llms/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import Any, Dict

import pytest
from pydantic import ValidationError

from fastagency.studio.helpers import create_autogen, get_model_by_ref
from fastagency.studio.models.base import ObjectReference
from fastagency.studio.models.llms.azure import AzureOAI, AzureOAIAPIKey
from fastagency.studio.models.llms.azure import (
BASE_URL_ERROR_MESSAGE,
AzureOAI,
AzureOAIAPIKey,
UrlModel,
)


def test_import(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down Expand Up @@ -45,6 +51,35 @@ async def test_azure_constructor(
}
assert model.model_dump() == expected

@pytest.mark.parametrize(
"base_url",
[
"https://{your-resource-name.openai.azure.com",
"https://your-resource-name}.openai.azure.com",
"https://{your-resource-name}.openai.azure.com",
],
)
@pytest.mark.db
@pytest.mark.asyncio
async def test_azure_constructor_with_invalid_base_url(
self, azure_oai_gpt35_ref: ObjectReference, base_url: str
) -> None:
# create data
model = await get_model_by_ref(azure_oai_gpt35_ref)
assert isinstance(model, AzureOAI)

# Construct a new AzureOAI model with the invalid base_url
with pytest.raises(ValidationError, match=BASE_URL_ERROR_MESSAGE):
AzureOAI(
name=model.name,
model=model.model,
api_key=model.api_key,
base_url=UrlModel(url=base_url).url,
api_type=model.api_type,
api_version=model.api_version,
temperature=model.temperature,
)

def test_azure_model_schema(self) -> None:
schema = AzureOAI.model_json_schema()
expected = {
Expand Down

0 comments on commit 07b56dc

Please sign in to comment.