Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(prompts): output schema #6194

Merged
merged 15 commits into from
Feb 3, 2025
26 changes: 23 additions & 3 deletions schemas/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -2286,15 +2286,34 @@
},
"PromptOutputSchema": {
"properties": {
"definition": {
"type": {
"type": "string",
"const": "output-schema-v1",
"title": "Type"
},
"name": {
"type": "string",
"title": "Name"
},
"description": {
"type": "string",
"title": "Description"
},
"schema": {
"type": "object",
"title": "Schema"
},
"extra_parameters": {
"type": "object",
"title": "Definition"
"title": "Extra Parameters"
}
},
"additionalProperties": false,
"type": "object",
"required": [
"definition"
"type",
"name",
"extra_parameters"
],
"title": "PromptOutputSchema"
},
Expand Down Expand Up @@ -2357,6 +2376,7 @@
}
},
"type": "array",
"minItems": 1,
"title": "Tools"
}
},
Expand Down
8 changes: 6 additions & 2 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from phoenix.server.api.helpers.prompts.models import (
PromptChatTemplateV1,
PromptOutputSchema,
PromptOutputSchemaWrapper,
PromptStringTemplateV1,
PromptTemplate,
PromptTemplateWrapper,
Expand Down Expand Up @@ -165,7 +166,7 @@ def process_result_value(


class _PromptOutputSchema(TypeDecorator[PromptOutputSchema]):
# See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
cache_ok = True
impl = JSON_

Expand All @@ -177,7 +178,10 @@ def process_bind_param(
def process_result_value(
self, value: Optional[dict[str, Any]], _: Dialect
) -> Optional[PromptOutputSchema]:
return PromptOutputSchema.model_validate(value) if value is not None else None
if value is None:
return None
wrapped_schema = PromptOutputSchemaWrapper.model_validate({"schema": value})
return wrapped_schema.schema_


class ExperimentRunOutput(TypedDict, total=False):
Expand Down
122 changes: 110 additions & 12 deletions src/phoenix/server/api/helpers/prompts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,12 @@ class PromptStringTemplateV1(PromptModel):


class PromptTemplateWrapper(PromptModel):
template: PromptTemplate

"""
Discriminated union types don't have pydantic methods such as
`model_validate`, so a wrapper around the union type is needed.
"""

class PromptOutputSchema(PromptModel):
definition: JSONSchemaObjectDefinition
template: PromptTemplate


class PromptFunctionToolV1(PromptModel):
Expand All @@ -154,19 +155,116 @@ class PromptFunctionToolV1(PromptModel):
extra_parameters: dict[str, Any]


PromptTool: TypeAlias = Annotated[Union[PromptFunctionToolV1], Field(..., discriminator="type")]


class PromptToolsV1(PromptModel):
type: Literal["tools-v1"]
tools: list[Annotated[Union[PromptFunctionToolV1], Field(..., discriminator="type")]]
tools: Annotated[list[PromptTool], Field(..., min_length=1)]


class PromptOpenAIJSONSchema(PromptModel):
"""
Based on https://github.com/openai/openai-python/blob/d16e6edde5a155626910b5758a0b939bfedb9ced/src/openai/types/shared/response_format_json_schema.py#L13
"""

name: str
description: str = UNDEFINED
schema_: JSONSchemaObjectDefinition = Field(
...,
alias="schema", # an alias is used to avoid conflict with the pydantic schema class method
)
strict: Optional[bool] = UNDEFINED


class PromptOpenAIOutputSchema(PromptModel):
"""
Based on https://github.com/openai/openai-python/blob/d16e6edde5a155626910b5758a0b939bfedb9ced/src/openai/types/shared/response_format_json_schema.py#L40
"""

json_schema: PromptOpenAIJSONSchema
type: Literal["json_schema"]


class PromptOutputSchema(PromptModel):
type: Literal["output-schema-v1"]
name: str
description: str = UNDEFINED
schema_: JSONSchemaObjectDefinition = Field(
default=UNDEFINED,
alias="schema", # an alias is used to avoid conflict with the pydantic schema class method
)
extra_parameters: dict[str, Any]


class PromptOutputSchemaWrapper(PromptModel):
"""
Discriminated union types don't have pydantic methods such as
`model_validate`, so a wrapper around the union type is needed.
"""

schema_: Annotated[
Union[PromptOutputSchema],
Field(
...,
discriminator="type",
alias="schema", # avoid conflict with the pydantic schema class method
),
]


def _openai_to_prompt_output_schema(
schema: PromptOpenAIOutputSchema,
) -> PromptOutputSchema:
jsonschema = schema.json_schema
extra_parameters = {}
if (strict := jsonschema.strict) is not UNDEFINED:
extra_parameters["strict"] = strict
return PromptOutputSchema(
type="output-schema-v1",
name=jsonschema.name,
description=jsonschema.description,
schema=jsonschema.schema_,
extra_parameters=extra_parameters,
)


def _prompt_to_openai_output_schema(
output_schema: PromptOutputSchema,
) -> PromptOpenAIOutputSchema:
assert output_schema.type == "output-schema-v1"
name = output_schema.name
description = output_schema.description
schema = output_schema.schema_
extra_parameters = output_schema.extra_parameters
strict = extra_parameters.get("strict", UNDEFINED)
return PromptOpenAIOutputSchema(
type="json_schema",
json_schema=PromptOpenAIJSONSchema(
name=name,
description=description,
schema=schema,
strict=strict,
),
)


def normalize_output_schema(
output_schema: dict[str, Any], model_provider: str
) -> PromptOutputSchema:
if model_provider.lower() == "openai":
openai_output_schema = PromptOpenAIOutputSchema.model_validate(output_schema)
return _openai_to_prompt_output_schema(openai_output_schema)
raise ValueError(f"Unsupported model provider: {model_provider}")


def _get_tool_definition_model(
model_provider: str,
) -> Optional[Union[type["OpenAIToolDefinition"], type["AnthropicToolDefinition"]]]:
def denormalize_output_schema(
output_schema: PromptOutputSchema, model_provider: str
) -> dict[str, Any]:
if model_provider.lower() == "openai":
return OpenAIToolDefinition
if model_provider.lower() == "anthropic":
return AnthropicToolDefinition
return None
openai_output_schema = _prompt_to_openai_output_schema(output_schema)
return openai_output_schema.model_dump()
raise ValueError(f"Unsupported model provider: {model_provider}")


# OpenAI tool definitions
Expand Down
1 change: 1 addition & 0 deletions src/phoenix/server/api/input_types/PromptVersionInput.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ToolDefinitionInput:

@strawberry.input
class OutputSchemaInput:
version: strawberry.Private[str] = "openai-output-schema-v1"
definition: JSON


Expand Down
27 changes: 20 additions & 7 deletions src/phoenix/server/api/mutations/prompt_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from phoenix.db.types.identifier import Identifier as IdentifierModel
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
from phoenix.server.api.helpers.prompts.models import PromptOutputSchema, normalize_tools
from phoenix.server.api.helpers.prompts.models import (
normalize_output_schema,
normalize_tools,
)
from phoenix.server.api.input_types.PromptVersionInput import (
ChatPromptVersionInput,
to_pydantic_prompt_chat_template_v1,
Expand Down Expand Up @@ -82,11 +85,16 @@ async def create_chat_prompt(
input_prompt_version = input.prompt_version
tool_definitions = [tool.definition for tool in input_prompt_version.tools]
try:
tools = normalize_tools(tool_definitions, input_prompt_version.model_provider)
tools = (
normalize_tools(tool_definitions, input_prompt_version.model_provider)
if tool_definitions
else None
)
template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
output_schema = (
PromptOutputSchema.model_validate(
strawberry.asdict(input_prompt_version.output_schema)
normalize_output_schema(
input_prompt_version.output_schema.definition,
input_prompt_version.model_provider,
)
if input_prompt_version.output_schema
else None
Expand Down Expand Up @@ -135,11 +143,16 @@ async def create_chat_prompt_version(
input_prompt_version = input.prompt_version
tool_definitions = [tool.definition for tool in input.prompt_version.tools]
try:
tools = normalize_tools(tool_definitions, input.prompt_version.model_provider)
tools = (
normalize_tools(tool_definitions, input_prompt_version.model_provider)
if tool_definitions
else None
)
template = to_pydantic_prompt_chat_template_v1(input.prompt_version.template)
output_schema = (
PromptOutputSchema.model_validate(
strawberry.asdict(input_prompt_version.output_schema)
normalize_output_schema(
input_prompt_version.output_schema.definition,
input_prompt_version.model_provider,
)
if input_prompt_version.output_schema
else None
Expand Down
6 changes: 0 additions & 6 deletions src/phoenix/server/api/types/OutputSchema.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import strawberry
from strawberry.scalars import JSON

from phoenix.server.api.helpers.prompts.models import PromptOutputSchema


@strawberry.type
class OutputSchema:
"""A JSON schema definition used to guide an LLM's output"""

definition: JSON


def to_gql_output_schema_from_pydantic(pydantic_output_schema: PromptOutputSchema) -> OutputSchema:
return OutputSchema(**pydantic_output_schema.model_dump())
10 changes: 8 additions & 2 deletions src/phoenix/server/api/types/PromptVersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from phoenix.server.api.helpers.prompts.models import (
PromptTemplateFormat,
PromptTemplateType,
denormalize_output_schema,
denormalize_tools,
)
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
Expand All @@ -21,7 +22,7 @@
to_gql_template_from_orm,
)

from .OutputSchema import OutputSchema, to_gql_output_schema_from_pydantic
from .OutputSchema import OutputSchema
from .ToolDefinition import ToolDefinition
from .User import User, to_gql_user

Expand Down Expand Up @@ -112,7 +113,12 @@ def to_gql_prompt_version(
else:
tools = []
output_schema = (
to_gql_output_schema_from_pydantic(prompt_version.output_schema)
OutputSchema(
definition=denormalize_output_schema(
prompt_version.output_schema,
prompt_version.model_provider,
)
)
if prompt_version.output_schema is not None
else None
)
Expand Down
Loading
Loading