diff --git a/fastagency/studio/models/llms/openai.py b/fastagency/studio/models/llms/openai.py index 83462078a..adfb1912b 100644 --- a/fastagency/studio/models/llms/openai.py +++ b/fastagency/studio/models/llms/openai.py @@ -55,7 +55,10 @@ async def create_autogen(cls, model_id: UUID, user_id: UUID, **kwargs: Any) -> s @field_validator("api_key") @classmethod def validate_api_key(cls: Type["OpenAIAPIKey"], value: Any) -> Any: - if not re.match(r"^sk-[a-zA-Z0-9]{20}T3BlbkFJ[a-zA-Z0-9]{20}$", value): + if not re.match( + r"^(sk-(proj-|None-|svcacct-)[A-Za-z0-9_-]+|sk-[a-zA-Z0-9]{20}T3BlbkFJ[a-zA-Z0-9]{20})$", + value, + ): raise ValueError("Invalid OpenAI API Key") return value diff --git a/tests/studio/models/llms/test_openai.py b/tests/studio/models/llms/test_openai.py index 007d75fb2..c018bb6d2 100644 --- a/tests/studio/models/llms/test_openai.py +++ b/tests/studio/models/llms/test_openai.py @@ -19,15 +19,25 @@ def test_import(monkeypatch: pytest.MonkeyPatch) -> None: class TestOpenAIAPIKey: - def test_constructor_success(self) -> None: + @pytest.mark.parametrize( + "openai_api_key", + [ + "sk-sUeBP9asw6GiYHXqtg70T3BlbkFJJuLwJFco90bOpU0Ntest", # pragma: allowlist secret + # OpenAI currently supports three prefixes for API keys: + # project-based API key format + "sk-proj-SomeLengthStringWhichCanHave-and_inItAndTheLengthCanBeChangedAtAnyTime", # pragma: allowlist secret + # user-level API key format + "sk-None-SomeLengthStringWhichCanHave-and_inItAndTheLengthCanBeChangedAtAnyTime", # pragma: allowlist secret + # service account APi key format + "sk-svcacct-SomeLengthStringWhichCanHave-and_inItAndTheLengthCanBeChangedAtAnyTime", # pragma: allowlist secret + ], + ) + def test_constructor_success(self, openai_api_key: str) -> None: api_key = OpenAIAPIKey( - api_key="sk-sUeBP9asw6GiYHXqtg70T3BlbkFJJuLwJFco90bOpU0Ntest", # pragma: allowlist secret + api_key=openai_api_key, name="Hello World!", ) # pragma: allowlist secret - assert ( - api_key.api_key - == "sk-sUeBP9asw6GiYHXqtg70T3BlbkFJJuLwJFco90bOpU0Ntest" # pragma: allowlist secret - ) # pragma: allowlist secret + assert api_key.api_key == openai_api_key # pragma: allowlist secret def test_constructor_failure(self) -> None: with pytest.raises(ValueError, match="Invalid OpenAI API Key"):