diff --git a/app/schema.graphql b/app/schema.graphql index 2eea5c20f8..17ab25cb5f 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -286,6 +286,11 @@ input CreateDatasetInput { metadata: JSON } +input CreatePromptLabelInput { + name: Identifier! + description: String = null +} + input CreateSpanAnnotationInput { spanId: GlobalID! name: String! @@ -537,6 +542,10 @@ input DeletePromptInput { promptId: GlobalID! } +input DeletePromptLabelInput { + promptLabelId: GlobalID! +} + type DeletePromptMutationPayload { query: Query! } @@ -1243,6 +1252,11 @@ type Mutation { patchPrompt(input: PatchPromptInput!): Prompt! deletePromptVersionTag(input: DeletePromptVersionTagInput!): PromptVersionTagMutationPayload! setPromptVersionTag(input: SetPromptVersionTagInput!): PromptVersionTagMutationPayload! + createPromptLabel(input: CreatePromptLabelInput!): PromptLabelMutationPayload! + patchPromptLabel(input: PatchPromptLabelInput!): PromptLabelMutationPayload! + deletePromptLabel(input: DeletePromptLabelInput!): PromptLabelMutationPayload! + setPromptLabel(input: SetPromptLabelInput!): PromptLabelMutationPayload! + unsetPromptLabel(input: UnsetPromptLabelInput!): PromptLabelMutationPayload! createSpanAnnotations(input: [CreateSpanAnnotationInput!]!): SpanAnnotationMutationPayload! patchSpanAnnotations(input: [PatchAnnotationInput!]!): SpanAnnotationMutationPayload! deleteSpanAnnotations(input: DeleteAnnotationsInput!): SpanAnnotationMutationPayload! @@ -1315,6 +1329,12 @@ input PatchPromptInput { description: String! } +input PatchPromptLabelInput { + promptLabelId: GlobalID! + name: Identifier = null + description: String = null +} + input PatchUserInput { userId: GlobalID! newRole: UserRoleInput @@ -1495,6 +1515,37 @@ type PromptEdge { node: Prompt! } +type PromptLabel implements Node { + """The Globally Unique ID of this object""" + id: GlobalID! + name: Identifier! + description: String + prompts: [Prompt!]! +} + +"""A connection to a list of items.""" +type PromptLabelConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + + """Contains the nodes in this connection""" + edges: [PromptLabelEdge!]! +} + +"""An edge in a connection.""" +type PromptLabelEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: PromptLabel! +} + +type PromptLabelMutationPayload { + promptLabel: PromptLabel + query: Query! +} + type PromptMessage { role: PromptMessageRole! content: [ContentPart!]! @@ -1618,6 +1669,7 @@ type Query { node(id: GlobalID!): Node! viewer: User prompts(first: Int = 50, last: Int, after: String, before: String): PromptConnection! + promptLabels(first: Int = 50, last: Int, after: String, before: String): PromptLabelConnection! clusters(clusters: [ClusterInput!]!): [Cluster!]! hdbscanClustering( """Event ID of the coordinates""" @@ -1673,6 +1725,11 @@ type Segments { totalCounts: DatasetValues! } +input SetPromptLabelInput { + promptId: GlobalID! + promptLabelId: GlobalID! +} + input SetPromptVersionTagInput { promptVersionId: GlobalID! name: Identifier! @@ -2110,6 +2167,11 @@ type UMAPPoints { contextRetrievals: [Retrieval!]! } +input UnsetPromptLabelInput { + promptId: GlobalID! + promptLabelId: GlobalID! +} + type User implements Node { """The Globally Unique ID of this object""" id: GlobalID! diff --git a/src/phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py b/src/phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py index cd8d069171..33a8b5d55b 100644 --- a/src/phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py +++ b/src/phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py @@ -98,6 +98,10 @@ def upgrade() -> None: nullable=False, index=True, ), + sa.UniqueConstraint( + "prompt_label_id", + "prompt_id", + ), ) op.create_table( diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 86d9cb02cc..e9fb11910d 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -897,6 +897,8 @@ class PromptPromptLabel(Base): ) prompt: Mapped["Prompt"] = relationship("Prompt", back_populates="prompts_prompt_labels") + __table_args__ = (UniqueConstraint("prompt_label_id", "prompt_id"),) + class PromptVersion(Base): __tablename__ = "prompt_versions" diff --git a/src/phoenix/server/api/mutations/__init__.py b/src/phoenix/server/api/mutations/__init__.py index aa2d46754e..48511ce8b1 100644 --- a/src/phoenix/server/api/mutations/__init__.py +++ b/src/phoenix/server/api/mutations/__init__.py @@ -8,6 +8,7 @@ from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin +from phoenix.server.api.mutations.prompt_label_mutations import PromptLabelMutationMixin from phoenix.server.api.mutations.prompt_mutations import PromptMutationMixin from phoenix.server.api.mutations.prompt_version_tag_mutations import PromptVersionTagMutationMixin from phoenix.server.api.mutations.span_annotations_mutations import SpanAnnotationMutationMixin @@ -24,6 +25,7 @@ class Mutation( ProjectMutationMixin, PromptMutationMixin, PromptVersionTagMutationMixin, + PromptLabelMutationMixin, SpanAnnotationMutationMixin, TraceAnnotationMutationMixin, UserMutationMixin, diff --git a/src/phoenix/server/api/mutations/prompt_label_mutations.py b/src/phoenix/server/api/mutations/prompt_label_mutations.py new file mode 100644 index 0000000000..f89d6bae65 --- /dev/null +++ b/src/phoenix/server/api/mutations/prompt_label_mutations.py @@ -0,0 +1,191 @@ +# file: PromptLabelMutations.py + +from typing import Optional + +import strawberry +from sqlalchemy import delete +from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError +from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped] +from strawberry.relay import GlobalID +from strawberry.types import Info + +from phoenix.db import models +from phoenix.db.types.identifier import Identifier as IdentifierModel +from phoenix.server.api.context import Context +from phoenix.server.api.exceptions import Conflict, NotFound +from phoenix.server.api.queries import Query +from phoenix.server.api.types.Identifier import Identifier +from phoenix.server.api.types.node import from_global_id_with_expected_type +from phoenix.server.api.types.Prompt import Prompt +from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label + + +@strawberry.input +class CreatePromptLabelInput: + name: Identifier + description: Optional[str] = None + + +@strawberry.input +class PatchPromptLabelInput: + prompt_label_id: GlobalID + name: Optional[Identifier] = None + description: Optional[str] = None + + +@strawberry.input +class DeletePromptLabelInput: + prompt_label_id: GlobalID + + +@strawberry.input +class SetPromptLabelInput: + prompt_id: GlobalID + prompt_label_id: GlobalID + + +@strawberry.input +class UnsetPromptLabelInput: + prompt_id: GlobalID + prompt_label_id: GlobalID + + +@strawberry.type +class PromptLabelMutationPayload: + prompt_label: Optional["PromptLabel"] + query: "Query" + + +@strawberry.type +class PromptLabelMutationMixin: + @strawberry.mutation + async def create_prompt_label( + self, info: Info[Context, None], input: CreatePromptLabelInput + ) -> PromptLabelMutationPayload: + async with info.context.db() as session: + name = IdentifierModel.model_validate(str(input.name)) + label_orm = models.PromptLabel(name=name, description=input.description) + session.add(label_orm) + + try: + await session.commit() + except (PostgreSQLIntegrityError, SQLiteIntegrityError): + raise Conflict(f"A prompt label named '{name}' already exists.") + + return PromptLabelMutationPayload( + prompt_label=to_gql_prompt_label(label_orm), + query=Query(), + ) + + @strawberry.mutation + async def patch_prompt_label( + self, info: Info[Context, None], input: PatchPromptLabelInput + ) -> PromptLabelMutationPayload: + validated_name = IdentifierModel.model_validate(str(input.name)) if input.name else None + async with info.context.db() as session: + label_id = from_global_id_with_expected_type( + input.prompt_label_id, PromptLabel.__name__ + ) + + label_orm = await session.get(models.PromptLabel, label_id) + if not label_orm: + raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found") + + if validated_name is not None: + label_orm.name = validated_name.root + if input.description is not None: + label_orm.description = input.description + + try: + await session.commit() + except (PostgreSQLIntegrityError, SQLiteIntegrityError): + raise Conflict("Error patching PromptLabel. Possibly a name conflict?") + + return PromptLabelMutationPayload( + prompt_label=to_gql_prompt_label(label_orm), + query=Query(), + ) + + @strawberry.mutation + async def delete_prompt_label( + self, info: Info[Context, None], input: DeletePromptLabelInput + ) -> PromptLabelMutationPayload: + """ + Deletes a PromptLabel (and any crosswalk references). + """ + async with info.context.db() as session: + label_id = from_global_id_with_expected_type( + input.prompt_label_id, PromptLabel.__name__ + ) + stmt = delete(models.PromptLabel).where(models.PromptLabel.id == label_id) + result = await session.execute(stmt) + + if result.rowcount == 0: + raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found") + + await session.commit() + + return PromptLabelMutationPayload( + prompt_label=None, + query=Query(), + ) + + @strawberry.mutation + async def set_prompt_label( + self, info: Info[Context, None], input: SetPromptLabelInput + ) -> PromptLabelMutationPayload: + async with info.context.db() as session: + prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__) + label_id = from_global_id_with_expected_type( + input.prompt_label_id, PromptLabel.__name__ + ) + + crosswalk = models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id) + session.add(crosswalk) + + try: + await session.commit() + except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e: + # The error could be: + # - Unique constraint violation => row already exists + # - Foreign key violation => prompt_id or label_id doesn't exist + raise Conflict("Failed to associate PromptLabel with Prompt.") from e + + label_orm = await session.get(models.PromptLabel, label_id) + if not label_orm: + raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found") + + return PromptLabelMutationPayload( + prompt_label=to_gql_prompt_label(label_orm), + query=Query(), + ) + + @strawberry.mutation + async def unset_prompt_label( + self, info: Info[Context, None], input: UnsetPromptLabelInput + ) -> PromptLabelMutationPayload: + """ + Unsets a PromptLabel from a Prompt by removing the row in the crosswalk. + """ + async with info.context.db() as session: + prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__) + label_id = from_global_id_with_expected_type( + input.prompt_label_id, PromptLabel.__name__ + ) + + stmt = delete(models.PromptPromptLabel).where( + (models.PromptPromptLabel.prompt_id == prompt_id) + & (models.PromptPromptLabel.prompt_label_id == label_id) + ) + result = await session.execute(stmt) + + if result.rowcount == 0: + raise NotFound(f"No association between prompt={prompt_id} and label={label_id}.") + + await session.commit() + + label_orm = await session.get(models.PromptLabel, label_id) + return PromptLabelMutationPayload( + prompt_label=to_gql_prompt_label(label_orm) if label_orm else None, + query=Query(), + ) diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index ca6b4bd874..ca06902dfa 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -61,6 +61,7 @@ from phoenix.server.api.types.Project import Project from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm +from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version from phoenix.server.api.types.SortDir import SortDir from phoenix.server.api.types.Span import Span, to_gql_span @@ -583,6 +584,15 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: return to_gql_prompt_version(orm_prompt_version) else: raise NotFound(f"Unknown prompt version: {id}") + elif type_name == PromptLabel.__name__: + async with info.context.db() as session: + if not ( + prompt_label := await session.scalar( + select(models.PromptLabel).where(models.PromptLabel.id == node_id) + ) + ): + raise NotFound(f"Unknown prompt label: {id}") + return to_gql_prompt_label(prompt_label) raise NotFound(f"Unknown node type: {type_name}") @strawberry.field @@ -629,6 +639,29 @@ async def prompts( args=args, ) + @strawberry.field + async def prompt_labels( + self, + info: Info[Context, None], + first: Optional[int] = 50, + last: Optional[int] = UNSET, + after: Optional[CursorString] = UNSET, + before: Optional[CursorString] = UNSET, + ) -> Connection[PromptLabel]: + args = ConnectionArgs( + first=first, + after=after if isinstance(after, CursorString) else None, + last=last, + before=before if isinstance(before, CursorString) else None, + ) + async with info.context.db() as session: + prompt_labels = await session.stream_scalars(select(models.PromptLabel)) + data = [to_gql_prompt_label(prompt_label) async for prompt_label in prompt_labels] + return connection_from_list( + data=data, + args=args, + ) + @strawberry.field def clusters( self, diff --git a/src/phoenix/server/api/types/PromptLabel.py b/src/phoenix/server/api/types/PromptLabel.py new file mode 100644 index 0000000000..5e15c6e375 --- /dev/null +++ b/src/phoenix/server/api/types/PromptLabel.py @@ -0,0 +1,41 @@ +from typing import Optional + +import strawberry +from sqlalchemy import select +from strawberry.relay import Node, NodeID +from strawberry.types import Info + +from phoenix.db import models +from phoenix.server.api.context import Context +from phoenix.server.api.types.Identifier import Identifier +from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm + + +@strawberry.type +class PromptLabel(Node): + id_attr: NodeID[int] + name: Identifier + description: Optional[str] = None + + @strawberry.field + async def prompts(self, info: Info[Context, None]) -> list[Prompt]: + async with info.context.db() as session: + statement = ( + select(models.Prompt) + .join( + models.PromptPromptLabel, models.Prompt.id == models.PromptPromptLabel.prompt_id + ) + .where(models.PromptPromptLabel.prompt_label_id == self.id_attr) + ) + return [ + to_gql_prompt_from_orm(prompt_orm) + async for prompt_orm in await session.stream_scalars(statement) + ] + + +def to_gql_prompt_label(label_orm: models.PromptLabel) -> PromptLabel: + return PromptLabel( + id_attr=label_orm.id, + name=Identifier(label_orm.name), + description=label_orm.description, + )