Skip to content

Commit

Permalink
feat: PromptLabel gql interface (#6100)
Browse files Browse the repository at this point in the history
* Add prompt label node and mutations

* Add PromptLabel to node interface

* Add prompts resolver on PromptLabel node

* Add prompt label mutations mixin and build gql schema

* Add prompt labels query

* Delete needless comment

* Address feedback

* Build gql schema
  • Loading branch information
anticorrelator authored and mikeldking committed Jan 24, 2025
1 parent 6a673f6 commit d1eda1d
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 0 deletions.
62 changes: 62 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ input CreateDatasetInput {
metadata: JSON
}

input CreatePromptLabelInput {
name: Identifier!
description: String = null
}

input CreateSpanAnnotationInput {
spanId: GlobalID!
name: String!
Expand Down Expand Up @@ -537,6 +542,10 @@ input DeletePromptInput {
promptId: GlobalID!
}

input DeletePromptLabelInput {
promptLabelId: GlobalID!
}

type DeletePromptMutationPayload {
query: Query!
}
Expand Down Expand Up @@ -1243,6 +1252,11 @@ type Mutation {
patchPrompt(input: PatchPromptInput!): Prompt!
deletePromptVersionTag(input: DeletePromptVersionTagInput!): PromptVersionTagMutationPayload!
setPromptVersionTag(input: SetPromptVersionTagInput!): PromptVersionTagMutationPayload!
createPromptLabel(input: CreatePromptLabelInput!): PromptLabelMutationPayload!
patchPromptLabel(input: PatchPromptLabelInput!): PromptLabelMutationPayload!
deletePromptLabel(input: DeletePromptLabelInput!): PromptLabelMutationPayload!
setPromptLabel(input: SetPromptLabelInput!): PromptLabelMutationPayload!
unsetPromptLabel(input: UnsetPromptLabelInput!): PromptLabelMutationPayload!
createSpanAnnotations(input: [CreateSpanAnnotationInput!]!): SpanAnnotationMutationPayload!
patchSpanAnnotations(input: [PatchAnnotationInput!]!): SpanAnnotationMutationPayload!
deleteSpanAnnotations(input: DeleteAnnotationsInput!): SpanAnnotationMutationPayload!
Expand Down Expand Up @@ -1315,6 +1329,12 @@ input PatchPromptInput {
description: String!
}

input PatchPromptLabelInput {
promptLabelId: GlobalID!
name: Identifier = null
description: String = null
}

input PatchUserInput {
userId: GlobalID!
newRole: UserRoleInput
Expand Down Expand Up @@ -1495,6 +1515,37 @@ type PromptEdge {
node: Prompt!
}

type PromptLabel implements Node {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: Identifier!
description: String
prompts: [Prompt!]!
}

"""A connection to a list of items."""
type PromptLabelConnection {
"""Pagination data for this connection"""
pageInfo: PageInfo!

"""Contains the nodes in this connection"""
edges: [PromptLabelEdge!]!
}

"""An edge in a connection."""
type PromptLabelEdge {
"""A cursor for use in pagination"""
cursor: String!

"""The item at the end of the edge"""
node: PromptLabel!
}

type PromptLabelMutationPayload {
promptLabel: PromptLabel
query: Query!
}

type PromptMessage {
role: PromptMessageRole!
content: [ContentPart!]!
Expand Down Expand Up @@ -1618,6 +1669,7 @@ type Query {
node(id: GlobalID!): Node!
viewer: User
prompts(first: Int = 50, last: Int, after: String, before: String): PromptConnection!
promptLabels(first: Int = 50, last: Int, after: String, before: String): PromptLabelConnection!
clusters(clusters: [ClusterInput!]!): [Cluster!]!
hdbscanClustering(
"""Event ID of the coordinates"""
Expand Down Expand Up @@ -1673,6 +1725,11 @@ type Segments {
totalCounts: DatasetValues!
}

input SetPromptLabelInput {
promptId: GlobalID!
promptLabelId: GlobalID!
}

input SetPromptVersionTagInput {
promptVersionId: GlobalID!
name: Identifier!
Expand Down Expand Up @@ -2110,6 +2167,11 @@ type UMAPPoints {
contextRetrievals: [Retrieval!]!
}

input UnsetPromptLabelInput {
promptId: GlobalID!
promptLabelId: GlobalID!
}

type User implements Node {
"""The Globally Unique ID of this object"""
id: GlobalID!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def upgrade() -> None:
nullable=False,
index=True,
),
sa.UniqueConstraint(
"prompt_label_id",
"prompt_id",
),
)

op.create_table(
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,8 @@ class PromptPromptLabel(Base):
)
prompt: Mapped["Prompt"] = relationship("Prompt", back_populates="prompts_prompt_labels")

__table_args__ = (UniqueConstraint("prompt_label_id", "prompt_id"),)


class PromptVersion(Base):
__tablename__ = "prompt_versions"
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/api/mutations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
from phoenix.server.api.mutations.prompt_label_mutations import PromptLabelMutationMixin
from phoenix.server.api.mutations.prompt_mutations import PromptMutationMixin
from phoenix.server.api.mutations.prompt_version_tag_mutations import PromptVersionTagMutationMixin
from phoenix.server.api.mutations.span_annotations_mutations import SpanAnnotationMutationMixin
Expand All @@ -24,6 +25,7 @@ class Mutation(
ProjectMutationMixin,
PromptMutationMixin,
PromptVersionTagMutationMixin,
PromptLabelMutationMixin,
SpanAnnotationMutationMixin,
TraceAnnotationMutationMixin,
UserMutationMixin,
Expand Down
191 changes: 191 additions & 0 deletions src/phoenix/server/api/mutations/prompt_label_mutations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# file: PromptLabelMutations.py

from typing import Optional

import strawberry
from sqlalchemy import delete
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
from strawberry.relay import GlobalID
from strawberry.types import Info

from phoenix.db import models
from phoenix.db.types.identifier import Identifier as IdentifierModel
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import Conflict, NotFound
from phoenix.server.api.queries import Query
from phoenix.server.api.types.Identifier import Identifier
from phoenix.server.api.types.node import from_global_id_with_expected_type
from phoenix.server.api.types.Prompt import Prompt
from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label


@strawberry.input
class CreatePromptLabelInput:
name: Identifier
description: Optional[str] = None


@strawberry.input
class PatchPromptLabelInput:
prompt_label_id: GlobalID
name: Optional[Identifier] = None
description: Optional[str] = None


@strawberry.input
class DeletePromptLabelInput:
prompt_label_id: GlobalID


@strawberry.input
class SetPromptLabelInput:
prompt_id: GlobalID
prompt_label_id: GlobalID


@strawberry.input
class UnsetPromptLabelInput:
prompt_id: GlobalID
prompt_label_id: GlobalID


@strawberry.type
class PromptLabelMutationPayload:
prompt_label: Optional["PromptLabel"]
query: "Query"


@strawberry.type
class PromptLabelMutationMixin:
@strawberry.mutation
async def create_prompt_label(
self, info: Info[Context, None], input: CreatePromptLabelInput
) -> PromptLabelMutationPayload:
async with info.context.db() as session:
name = IdentifierModel.model_validate(str(input.name))
label_orm = models.PromptLabel(name=name, description=input.description)
session.add(label_orm)

try:
await session.commit()
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
raise Conflict(f"A prompt label named '{name}' already exists.")

return PromptLabelMutationPayload(
prompt_label=to_gql_prompt_label(label_orm),
query=Query(),
)

@strawberry.mutation
async def patch_prompt_label(
self, info: Info[Context, None], input: PatchPromptLabelInput
) -> PromptLabelMutationPayload:
validated_name = IdentifierModel.model_validate(str(input.name)) if input.name else None
async with info.context.db() as session:
label_id = from_global_id_with_expected_type(
input.prompt_label_id, PromptLabel.__name__
)

label_orm = await session.get(models.PromptLabel, label_id)
if not label_orm:
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")

if validated_name is not None:
label_orm.name = validated_name.root
if input.description is not None:
label_orm.description = input.description

try:
await session.commit()
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
raise Conflict("Error patching PromptLabel. Possibly a name conflict?")

return PromptLabelMutationPayload(
prompt_label=to_gql_prompt_label(label_orm),
query=Query(),
)

@strawberry.mutation
async def delete_prompt_label(
self, info: Info[Context, None], input: DeletePromptLabelInput
) -> PromptLabelMutationPayload:
"""
Deletes a PromptLabel (and any crosswalk references).
"""
async with info.context.db() as session:
label_id = from_global_id_with_expected_type(
input.prompt_label_id, PromptLabel.__name__
)
stmt = delete(models.PromptLabel).where(models.PromptLabel.id == label_id)
result = await session.execute(stmt)

if result.rowcount == 0:
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")

await session.commit()

return PromptLabelMutationPayload(
prompt_label=None,
query=Query(),
)

@strawberry.mutation
async def set_prompt_label(
self, info: Info[Context, None], input: SetPromptLabelInput
) -> PromptLabelMutationPayload:
async with info.context.db() as session:
prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
label_id = from_global_id_with_expected_type(
input.prompt_label_id, PromptLabel.__name__
)

crosswalk = models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
session.add(crosswalk)

try:
await session.commit()
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
# The error could be:
# - Unique constraint violation => row already exists
# - Foreign key violation => prompt_id or label_id doesn't exist
raise Conflict("Failed to associate PromptLabel with Prompt.") from e

label_orm = await session.get(models.PromptLabel, label_id)
if not label_orm:
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")

return PromptLabelMutationPayload(
prompt_label=to_gql_prompt_label(label_orm),
query=Query(),
)

@strawberry.mutation
async def unset_prompt_label(
self, info: Info[Context, None], input: UnsetPromptLabelInput
) -> PromptLabelMutationPayload:
"""
Unsets a PromptLabel from a Prompt by removing the row in the crosswalk.
"""
async with info.context.db() as session:
prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
label_id = from_global_id_with_expected_type(
input.prompt_label_id, PromptLabel.__name__
)

stmt = delete(models.PromptPromptLabel).where(
(models.PromptPromptLabel.prompt_id == prompt_id)
& (models.PromptPromptLabel.prompt_label_id == label_id)
)
result = await session.execute(stmt)

if result.rowcount == 0:
raise NotFound(f"No association between prompt={prompt_id} and label={label_id}.")

await session.commit()

label_orm = await session.get(models.PromptLabel, label_id)
return PromptLabelMutationPayload(
prompt_label=to_gql_prompt_label(label_orm) if label_orm else None,
query=Query(),
)
33 changes: 33 additions & 0 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from phoenix.server.api.types.Project import Project
from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
from phoenix.server.api.types.SortDir import SortDir
from phoenix.server.api.types.Span import Span, to_gql_span
Expand Down Expand Up @@ -583,6 +584,15 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
return to_gql_prompt_version(orm_prompt_version)
else:
raise NotFound(f"Unknown prompt version: {id}")
elif type_name == PromptLabel.__name__:
async with info.context.db() as session:
if not (
prompt_label := await session.scalar(
select(models.PromptLabel).where(models.PromptLabel.id == node_id)
)
):
raise NotFound(f"Unknown prompt label: {id}")
return to_gql_prompt_label(prompt_label)
raise NotFound(f"Unknown node type: {type_name}")

@strawberry.field
Expand Down Expand Up @@ -629,6 +639,29 @@ async def prompts(
args=args,
)

@strawberry.field
async def prompt_labels(
self,
info: Info[Context, None],
first: Optional[int] = 50,
last: Optional[int] = UNSET,
after: Optional[CursorString] = UNSET,
before: Optional[CursorString] = UNSET,
) -> Connection[PromptLabel]:
args = ConnectionArgs(
first=first,
after=after if isinstance(after, CursorString) else None,
last=last,
before=before if isinstance(before, CursorString) else None,
)
async with info.context.db() as session:
prompt_labels = await session.stream_scalars(select(models.PromptLabel))
data = [to_gql_prompt_label(prompt_label) async for prompt_label in prompt_labels]
return connection_from_list(
data=data,
args=args,
)

@strawberry.field
def clusters(
self,
Expand Down
Loading

0 comments on commit d1eda1d

Please sign in to comment.