Skip to content

Commit

Permalink
fix(prompts): output schema
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Jan 29, 2025
1 parent ad42f06 commit f2ce1a9
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 31 deletions.
10 changes: 7 additions & 3 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from phoenix.server.api.helpers.prompts.models import (
PromptChatTemplateV1,
PromptOutputSchema,
PromptOutputSchemaWrapper,
PromptStringTemplateV1,
PromptTemplate,
PromptTemplateWrapper,
Expand Down Expand Up @@ -165,19 +166,22 @@ def process_result_value(


class _PromptOutputSchema(TypeDecorator[PromptOutputSchema]):
# See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
cache_ok = True
impl = JSON_

def process_bind_param(
self, value: Optional[PromptOutputSchema], _: Dialect
) -> Optional[dict[str, Any]]:
return value.dict() if value is not None else None
return value.model_dump(exclude_unset=True, by_alias=True) if value is not None else None

def process_result_value(
self, value: Optional[dict[str, Any]], _: Dialect
) -> Optional[PromptOutputSchema]:
return PromptOutputSchema.model_validate(value) if value is not None else None
if value is None:
return None
wrapped_schema = PromptOutputSchemaWrapper.model_validate({"schema": value})
return wrapped_schema.schema_


class ExperimentRunOutput(TypedDict, total=False):
Expand Down
80 changes: 55 additions & 25 deletions src/phoenix/server/api/helpers/prompts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,6 @@
JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]]


class Undefined:
"""
A singleton class that represents an unset or undefined value. Needed since Pydantic
can't natively distinguish between an undefined value and a value that is set to
None.
"""

def __new__(cls) -> Any:
if not hasattr(cls, "_instance"):
cls._instance = super().__new__(cls)
return cls._instance


UNDEFINED: Any = Undefined()


class PromptTemplateType(str, Enum):
STRING = "STR"
CHAT = "CHAT"
Expand Down Expand Up @@ -128,11 +112,12 @@ class PromptStringTemplateV1(PromptModel):


class PromptTemplateWrapper(PromptModel):
template: PromptTemplate

"""
Discriminated union types don't have pydantic methods such as
`model_validate`, so a wrapper around the union type is needed.
"""

class PromptOutputSchema(PromptModel):
definition: JSONSchemaObjectDefinition
template: PromptTemplate


class PromptToolDefinition(PromptModel):
Expand All @@ -144,6 +129,51 @@ class PromptToolsV1(PromptModel):
tool_definitions: list[PromptToolDefinition]


class PromptOpenAIJSONSchema(PromptModel):
"""
Based on https://github.com/openai/openai-python/blob/d16e6edde5a155626910b5758a0b939bfedb9ced/src/openai/types/shared/response_format_json_schema.py#L13
"""

name: str
description: str = Field(default=None) # type: ignore[assignment]
schema_: JSONSchemaObjectDefinition = Field(
...,
alias="schema", # an alias is used to avoid conflict with the pydantic schema class method
)
strict: Optional[bool] = Field(default=None)


class PromptOpenAIResponseFormatJSONSchema(PromptModel):
"""
Based on https://github.com/openai/openai-python/blob/d16e6edde5a155626910b5758a0b939bfedb9ced/src/openai/types/shared/response_format_json_schema.py#L40
"""

json_schema: PromptOpenAIJSONSchema
type: Literal["json_schema"]


class PromptOpenAIOutputSchema(PromptModel):
version: Literal["openai-output-schema-v1"]
definition: PromptOpenAIResponseFormatJSONSchema


PromptOutputSchema: TypeAlias = Annotated[
Union[PromptOpenAIOutputSchema], Field(..., discriminator="version")
]


class PromptOutputSchemaWrapper(PromptModel):
"""
Discriminated union types don't have pydantic methods such as
`model_validate`, so a wrapper around the union type is needed.
"""

schema_: PromptOutputSchema = Field(
...,
alias="schema", # an alias is used to avoid conflict with the pydantic schema class method
)


class PromptVersion(PromptModel):
user_id: Optional[int]
description: Optional[str]
Expand Down Expand Up @@ -188,9 +218,9 @@ class OpenAIFunctionDefinition(PromptModel):
"""

name: str
description: str = UNDEFINED
parameters: JSONSchemaObjectDefinition = UNDEFINED
strict: Optional[bool] = UNDEFINED
description: str = Field(default=None) # type: ignore[assignment]
parameters: JSONSchemaObjectDefinition = Field(default=None) # type: ignore[arg-type]
strict: Optional[bool] = Field(default=None)


class OpenAIToolDefinition(PromptModel):
Expand Down Expand Up @@ -218,5 +248,5 @@ class AnthropicToolDefinition(PromptModel):

input_schema: JSONSchemaObjectDefinition
name: str
cache_control: Optional[AnthropicCacheControlEphemeralParam] = UNDEFINED
description: str = UNDEFINED
cache_control: Optional[AnthropicCacheControlEphemeralParam] = Field(default=None)
description: str = Field(default=None) # type: ignore[assignment]
1 change: 1 addition & 0 deletions src/phoenix/server/api/input_types/PromptVersionInput.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ToolDefinitionInput:

@strawberry.input
class OutputSchemaInput:
version: strawberry.Private[str] = "openai-output-schema-v1"
definition: JSON


Expand Down
3 changes: 2 additions & 1 deletion src/phoenix/server/api/types/OutputSchema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ class OutputSchema:


def to_gql_output_schema_from_pydantic(pydantic_output_schema: PromptOutputSchema) -> OutputSchema:
return OutputSchema(**pydantic_output_schema.dict())
definition = pydantic_output_schema.definition
return OutputSchema(definition=definition.model_dump(exclude_unset=True, by_alias=True))
34 changes: 32 additions & 2 deletions tests/unit/server/api/mutations/test_prompt_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,23 @@ class TestPromptMutations:
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
"outputSchema": {"definition": {"type": "object"}},
"outputSchema": {
"definition": {
"type": "json_schema",
"json_schema": {
"name": "response",
"schema": {
"type": "object",
"properties": {
"foo": {"type": "string"},
},
"required": ["foo"],
"additionalProperties": False,
},
"strict": True,
},
}
},
},
}
},
Expand Down Expand Up @@ -694,7 +710,21 @@ async def test_create_chat_prompt_fails_with_invalid_input(
"modelProvider": "openai",
"modelName": "o1-mini",
"outputSchema": {
"definition": {"type": "object"},
"definition": {
"type": "json_schema",
"json_schema": {
"name": "response",
"schema": {
"type": "object",
"properties": {
"foo": {"type": "string"},
},
"required": ["foo"],
"additionalProperties": False,
},
"strict": True,
},
}
},
},
}
Expand Down

0 comments on commit f2ce1a9

Please sign in to comment.