From 81a28143d3ead7624952588e193286310a092887 Mon Sep 17 00:00:00 2001
From: Roger Yang <80478925+RogerHYang@users.noreply.github.com>
Date: Fri, 24 Jan 2025 08:48:47 -0800
Subject: [PATCH] feat(prompts): REST endpoint to get the latest prompt version
 (#6166)

---
 schemas/openapi.json                          | 64 +++++++++++++++++++
 src/phoenix/server/api/routers/v1/prompts.py  | 24 +++++++
 .../server/api/routers/v1/test_prompts.py     | 14 ++++
 3 files changed, 102 insertions(+)

diff --git a/schemas/openapi.json b/schemas/openapi.json
index c2afb2ea4b..07f4db4f6d 100644
--- a/schemas/openapi.json
+++ b/schemas/openapi.json
@@ -1404,6 +1404,70 @@
           }
         }
       }
+    },
+    "/v1/prompts/{prompt_identifier}/latest": {
+      "get": {
+        "tags": [
+          "prompts"
+        ],
+        "summary": "Get the latest prompt version",
+        "operationId": "getPromptVersionLatest",
+        "parameters": [
+          {
+            "name": "prompt_identifier",
+            "in": "path",
+            "required": true,
+            "schema": {
+              "type": "string",
+              "description": "The identifier of the prompt, i.e. name or ID.",
+              "title": "Prompt Identifier"
+            },
+            "description": "The identifier of the prompt, i.e. name or ID."
+          }
+        ],
+        "responses": {
+          "200": {
+            "description": "Successful Response",
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/GetPromptResponseBody"
+                }
+              }
+            }
+          },
+          "403": {
+            "content": {
+              "text/plain": {
+                "schema": {
+                  "type": "string"
+                }
+              }
+            },
+            "description": "Forbidden"
+          },
+          "404": {
+            "content": {
+              "text/plain": {
+                "schema": {
+                  "type": "string"
+                }
+              }
+            },
+            "description": "Not Found"
+          },
+          "422": {
+            "content": {
+              "text/plain": {
+                "schema": {
+                  "type": "string"
+                }
+              }
+            },
+            "description": "Unprocessable Entity"
+          }
+        }
+      }
     }
   },
   "components": {
diff --git a/src/phoenix/server/api/routers/v1/prompts.py b/src/phoenix/server/api/routers/v1/prompts.py
index 3b5248ff8a..6bb17c8536 100644
--- a/src/phoenix/server/api/routers/v1/prompts.py
+++ b/src/phoenix/server/api/routers/v1/prompts.py
@@ -174,6 +174,30 @@ async def get_prompt_version_by_tag_name(
     return _prompt_version_response_body(prompt_version)
 
 
+@router.get(
+    "/prompts/{prompt_identifier}/latest",
+    operation_id="getPromptVersionLatest",
+    summary="Get the latest prompt version",
+    responses=add_errors_to_responses(
+        [
+            HTTP_404_NOT_FOUND,
+            HTTP_422_UNPROCESSABLE_ENTITY,
+        ]
+    ),
+)
+async def get_prompt_version_by_latest(
+    request: Request,
+    prompt_identifier: str = Path(description="The identifier of the prompt, i.e. name or ID."),
+) -> GetPromptResponseBody:
+    stmt = select(models.PromptVersion).order_by(models.PromptVersion.id.desc()).limit(1)
+    stmt = _filter_by_prompt_identifier(stmt, prompt_identifier)
+    async with request.app.state.db() as session:
+        prompt_version: models.PromptVersion = await session.scalar(stmt)
+        if prompt_version is None:
+            raise HTTPException(HTTP_404_NOT_FOUND)
+    return _prompt_version_response_body(prompt_version)
+
+
 class _PromptId(int): ...
 
 
diff --git a/tests/unit/server/api/routers/v1/test_prompts.py b/tests/unit/server/api/routers/v1/test_prompts.py
index c9518bde90..a79f9e557a 100644
--- a/tests/unit/server/api/routers/v1/test_prompts.py
+++ b/tests/unit/server/api/routers/v1/test_prompts.py
@@ -16,6 +16,20 @@
 
 
 class TestPrompts:
+    async def test_get_latest_prompt_version(
+        self,
+        httpx_client: httpx.AsyncClient,
+        db: DbSessionFactory,
+    ) -> None:
+        prompt, prompt_versions = await self._insert_prompt_versions(db)
+        prompt_version = prompt_versions[-1]
+        prompt_id = str(GlobalID(Prompt.__name__, str(prompt.id)))
+        for prompts_identifier in prompt_id, prompt.name.root:
+            url = f"/v1/prompts/{quote_plus(prompts_identifier)}/latest"
+            assert (response := await httpx_client.get(url)).is_success
+            assert isinstance((data := response.json()["data"]), dict)
+            self._compare_prompt_version(data, prompt_version)
+
     async def test_get_prompt_version_by_prompt_version_id(
         self,
         httpx_client: httpx.AsyncClient,