Skip to content

Commit

Permalink
fix(prompts): output schema (#6194)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored and mikeldking committed Feb 6, 2025
1 parent f9f72a3 commit 6da69ac
Show file tree
Hide file tree
Showing 9 changed files with 704 additions and 106 deletions.
26 changes: 23 additions & 3 deletions schemas/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -2286,15 +2286,34 @@
},
"PromptOutputSchema": {
"properties": {
"definition": {
"type": {
"type": "string",
"const": "output-schema-v1",
"title": "Type"
},
"name": {
"type": "string",
"title": "Name"
},
"description": {
"type": "string",
"title": "Description"
},
"schema": {
"type": "object",
"title": "Schema"
},
"extra_parameters": {
"type": "object",
"title": "Definition"
"title": "Extra Parameters"
}
},
"additionalProperties": false,
"type": "object",
"required": [
"definition"
"type",
"name",
"extra_parameters"
],
"title": "PromptOutputSchema"
},
Expand Down Expand Up @@ -2357,6 +2376,7 @@
}
},
"type": "array",
"minItems": 1,
"title": "Tools"
}
},
Expand Down
8 changes: 6 additions & 2 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from phoenix.server.api.helpers.prompts.models import (
PromptChatTemplateV1,
PromptOutputSchema,
PromptOutputSchemaWrapper,
PromptStringTemplateV1,
PromptTemplate,
PromptTemplateWrapper,
Expand Down Expand Up @@ -165,7 +166,7 @@ def process_result_value(


class _PromptOutputSchema(TypeDecorator[PromptOutputSchema]):
# See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
cache_ok = True
impl = JSON_

Expand All @@ -177,7 +178,10 @@ def process_bind_param(
def process_result_value(
self, value: Optional[dict[str, Any]], _: Dialect
) -> Optional[PromptOutputSchema]:
return PromptOutputSchema.model_validate(value) if value is not None else None
if value is None:
return None
wrapped_schema = PromptOutputSchemaWrapper.model_validate({"schema": value})
return wrapped_schema.schema_


class ExperimentRunOutput(TypedDict, total=False):
Expand Down
122 changes: 110 additions & 12 deletions src/phoenix/server/api/helpers/prompts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,12 @@ class PromptStringTemplateV1(PromptModel):


class PromptTemplateWrapper(PromptModel):
template: PromptTemplate

"""
Discriminated union types don't have pydantic methods such as
`model_validate`, so a wrapper around the union type is needed.
"""

class PromptOutputSchema(PromptModel):
definition: JSONSchemaObjectDefinition
template: PromptTemplate


class PromptFunctionToolV1(PromptModel):
Expand All @@ -154,19 +155,116 @@ class PromptFunctionToolV1(PromptModel):
extra_parameters: dict[str, Any]


PromptTool: TypeAlias = Annotated[Union[PromptFunctionToolV1], Field(..., discriminator="type")]


class PromptToolsV1(PromptModel):
type: Literal["tools-v1"]
tools: list[Annotated[Union[PromptFunctionToolV1], Field(..., discriminator="type")]]
tools: Annotated[list[PromptTool], Field(..., min_length=1)]


class PromptOpenAIJSONSchema(PromptModel):
"""
Based on https://github.com/openai/openai-python/blob/d16e6edde5a155626910b5758a0b939bfedb9ced/src/openai/types/shared/response_format_json_schema.py#L13
"""

name: str
description: str = UNDEFINED
schema_: JSONSchemaObjectDefinition = Field(
...,
alias="schema", # an alias is used to avoid conflict with the pydantic schema class method
)
strict: Optional[bool] = UNDEFINED


class PromptOpenAIOutputSchema(PromptModel):
"""
Based on https://github.com/openai/openai-python/blob/d16e6edde5a155626910b5758a0b939bfedb9ced/src/openai/types/shared/response_format_json_schema.py#L40
"""

json_schema: PromptOpenAIJSONSchema
type: Literal["json_schema"]


class PromptOutputSchema(PromptModel):
type: Literal["output-schema-v1"]
name: str
description: str = UNDEFINED
schema_: JSONSchemaObjectDefinition = Field(
default=UNDEFINED,
alias="schema", # an alias is used to avoid conflict with the pydantic schema class method
)
extra_parameters: dict[str, Any]


class PromptOutputSchemaWrapper(PromptModel):
"""
Discriminated union types don't have pydantic methods such as
`model_validate`, so a wrapper around the union type is needed.
"""

schema_: Annotated[
Union[PromptOutputSchema],
Field(
...,
discriminator="type",
alias="schema", # avoid conflict with the pydantic schema class method
),
]


def _openai_to_prompt_output_schema(
schema: PromptOpenAIOutputSchema,
) -> PromptOutputSchema:
jsonschema = schema.json_schema
extra_parameters = {}
if (strict := jsonschema.strict) is not UNDEFINED:
extra_parameters["strict"] = strict
return PromptOutputSchema(
type="output-schema-v1",
name=jsonschema.name,
description=jsonschema.description,
schema=jsonschema.schema_,
extra_parameters=extra_parameters,
)


def _prompt_to_openai_output_schema(
output_schema: PromptOutputSchema,
) -> PromptOpenAIOutputSchema:
assert output_schema.type == "output-schema-v1"
name = output_schema.name
description = output_schema.description
schema = output_schema.schema_
extra_parameters = output_schema.extra_parameters
strict = extra_parameters.get("strict", UNDEFINED)
return PromptOpenAIOutputSchema(
type="json_schema",
json_schema=PromptOpenAIJSONSchema(
name=name,
description=description,
schema=schema,
strict=strict,
),
)


def normalize_output_schema(
output_schema: dict[str, Any], model_provider: str
) -> PromptOutputSchema:
if model_provider.lower() == "openai":
openai_output_schema = PromptOpenAIOutputSchema.model_validate(output_schema)
return _openai_to_prompt_output_schema(openai_output_schema)
raise ValueError(f"Unsupported model provider: {model_provider}")


def _get_tool_definition_model(
model_provider: str,
) -> Optional[Union[type["OpenAIToolDefinition"], type["AnthropicToolDefinition"]]]:
def denormalize_output_schema(
output_schema: PromptOutputSchema, model_provider: str
) -> dict[str, Any]:
if model_provider.lower() == "openai":
return OpenAIToolDefinition
if model_provider.lower() == "anthropic":
return AnthropicToolDefinition
return None
openai_output_schema = _prompt_to_openai_output_schema(output_schema)
return openai_output_schema.model_dump()
raise ValueError(f"Unsupported model provider: {model_provider}")


# OpenAI tool definitions
Expand Down
1 change: 1 addition & 0 deletions src/phoenix/server/api/input_types/PromptVersionInput.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ToolDefinitionInput:

@strawberry.input
class OutputSchemaInput:
version: strawberry.Private[str] = "openai-output-schema-v1"
definition: JSON


Expand Down
27 changes: 20 additions & 7 deletions src/phoenix/server/api/mutations/prompt_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from phoenix.db.types.identifier import Identifier as IdentifierModel
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
from phoenix.server.api.helpers.prompts.models import PromptOutputSchema, normalize_tools
from phoenix.server.api.helpers.prompts.models import (
normalize_output_schema,
normalize_tools,
)
from phoenix.server.api.input_types.PromptVersionInput import (
ChatPromptVersionInput,
to_pydantic_prompt_chat_template_v1,
Expand Down Expand Up @@ -82,11 +85,16 @@ async def create_chat_prompt(
input_prompt_version = input.prompt_version
tool_definitions = [tool.definition for tool in input_prompt_version.tools]
try:
tools = normalize_tools(tool_definitions, input_prompt_version.model_provider)
tools = (
normalize_tools(tool_definitions, input_prompt_version.model_provider)
if tool_definitions
else None
)
template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
output_schema = (
PromptOutputSchema.model_validate(
strawberry.asdict(input_prompt_version.output_schema)
normalize_output_schema(
input_prompt_version.output_schema.definition,
input_prompt_version.model_provider,
)
if input_prompt_version.output_schema
else None
Expand Down Expand Up @@ -135,11 +143,16 @@ async def create_chat_prompt_version(
input_prompt_version = input.prompt_version
tool_definitions = [tool.definition for tool in input.prompt_version.tools]
try:
tools = normalize_tools(tool_definitions, input.prompt_version.model_provider)
tools = (
normalize_tools(tool_definitions, input_prompt_version.model_provider)
if tool_definitions
else None
)
template = to_pydantic_prompt_chat_template_v1(input.prompt_version.template)
output_schema = (
PromptOutputSchema.model_validate(
strawberry.asdict(input_prompt_version.output_schema)
normalize_output_schema(
input_prompt_version.output_schema.definition,
input_prompt_version.model_provider,
)
if input_prompt_version.output_schema
else None
Expand Down
6 changes: 0 additions & 6 deletions src/phoenix/server/api/types/OutputSchema.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import strawberry
from strawberry.scalars import JSON

from phoenix.server.api.helpers.prompts.models import PromptOutputSchema


@strawberry.type
class OutputSchema:
"""A JSON schema definition used to guide an LLM's output"""

definition: JSON


def to_gql_output_schema_from_pydantic(pydantic_output_schema: PromptOutputSchema) -> OutputSchema:
return OutputSchema(**pydantic_output_schema.model_dump())
10 changes: 8 additions & 2 deletions src/phoenix/server/api/types/PromptVersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from phoenix.server.api.helpers.prompts.models import (
PromptTemplateFormat,
PromptTemplateType,
denormalize_output_schema,
denormalize_tools,
)
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
Expand All @@ -21,7 +22,7 @@
to_gql_template_from_orm,
)

from .OutputSchema import OutputSchema, to_gql_output_schema_from_pydantic
from .OutputSchema import OutputSchema
from .ToolDefinition import ToolDefinition
from .User import User, to_gql_user

Expand Down Expand Up @@ -112,7 +113,12 @@ def to_gql_prompt_version(
else:
tools = []
output_schema = (
to_gql_output_schema_from_pydantic(prompt_version.output_schema)
OutputSchema(
definition=denormalize_output_schema(
prompt_version.output_schema,
prompt_version.model_provider,
)
)
if prompt_version.output_schema is not None
else None
)
Expand Down
Loading

0 comments on commit 6da69ac

Please sign in to comment.