-
Notifications
You must be signed in to change notification settings - Fork 86
Add Prompt dataclass with initial methods #562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: staging
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -199,6 +199,28 @@ def upload_custom_scorer( | |||||||||||||||||||||||
payload, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def prompts_insert(self, payload: PromptInsertRequest) -> PromptInsertResponse: | ||||||||||||||||||||||||
return self._request( | ||||||||||||||||||||||||
"POST", | ||||||||||||||||||||||||
url_for("/prompts/insert/"), | ||||||||||||||||||||||||
payload, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def prompts_fetch( | ||||||||||||||||||||||||
self, name: str, commit_id: Optional[str] = None, tag: Optional[str] = None | ||||||||||||||||||||||||
) -> PromptFetchResponse: | ||||||||||||||||||||||||
query_params = {} | ||||||||||||||||||||||||
query_params["name"] = name | ||||||||||||||||||||||||
if commit_id is not None: | ||||||||||||||||||||||||
query_params["commit_id"] = commit_id | ||||||||||||||||||||||||
if tag is not None: | ||||||||||||||||||||||||
query_params["tag"] = tag | ||||||||||||||||||||||||
return self._request( | ||||||||||||||||||||||||
"GET", | ||||||||||||||||||||||||
url_for("/prompts/fetch/"), | ||||||||||||||||||||||||
query_params, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def projects_resolve( | ||||||||||||||||||||||||
self, payload: ResolveProjectNameRequest | ||||||||||||||||||||||||
) -> ResolveProjectNameResponse: | ||||||||||||||||||||||||
|
@@ -410,6 +432,30 @@ async def upload_custom_scorer( | |||||||||||||||||||||||
payload, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
async def prompts_insert( | ||||||||||||||||||||||||
self, payload: PromptInsertRequest | ||||||||||||||||||||||||
) -> PromptInsertResponse: | ||||||||||||||||||||||||
return await self._request( | ||||||||||||||||||||||||
"POST", | ||||||||||||||||||||||||
url_for("/prompts/insert/"), | ||||||||||||||||||||||||
payload, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
async def prompts_fetch( | ||||||||||||||||||||||||
self, name: str, commit_id: Optional[str] = None, tag: Optional[str] = None | ||||||||||||||||||||||||
) -> PromptFetchResponse: | ||||||||||||||||||||||||
query_params = {} | ||||||||||||||||||||||||
query_params["name"] = name | ||||||||||||||||||||||||
if commit_id is not None: | ||||||||||||||||||||||||
query_params["commit_id"] = commit_id | ||||||||||||||||||||||||
if tag is not None: | ||||||||||||||||||||||||
query_params["tag"] = tag | ||||||||||||||||||||||||
Comment on lines
+447
to
+452
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the synchronous version, the construction of
Suggested change
|
||||||||||||||||||||||||
return await self._request( | ||||||||||||||||||||||||
"GET", | ||||||||||||||||||||||||
url_for("/prompts/fetch/"), | ||||||||||||||||||||||||
query_params, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
async def projects_resolve( | ||||||||||||||||||||||||
self, payload: ResolveProjectNameRequest | ||||||||||||||||||||||||
) -> ResolveProjectNameResponse: | ||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,96 @@ | ||||||||||
from typing import List, Optional | ||||||||||
import os | ||||||||||
from judgeval.api import JudgmentSyncClient | ||||||||||
from judgeval.exceptions import JudgmentAPIError | ||||||||||
from dataclasses import dataclass, field | ||||||||||
import re | ||||||||||
from string import Template | ||||||||||
|
||||||||||
|
||||||||||
def push_prompt( | ||||||||||
name: str, | ||||||||||
prompt: str, | ||||||||||
tags: List[str], | ||||||||||
judgment_api_key: str = os.getenv("JUDGMENT_API_KEY") or "", | ||||||||||
organization_id: str = os.getenv("JUDGMENT_ORG_ID") or "", | ||||||||||
) -> tuple[str, Optional[str]]: | ||||||||||
client = JudgmentSyncClient(judgment_api_key, organization_id) | ||||||||||
try: | ||||||||||
r = client.prompts_insert( | ||||||||||
payload={"name": name, "prompt": prompt, "tags": tags} | ||||||||||
) | ||||||||||
return r["commit_id"], r["parent_commit_id"] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||
except JudgmentAPIError as e: | ||||||||||
raise JudgmentAPIError( | ||||||||||
status_code=e.status_code, | ||||||||||
detail=f"Failed to save prompt: {e.detail}", | ||||||||||
response=e.response, | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
def fetch_prompt( | ||||||||||
name: str, | ||||||||||
commit_id: Optional[str] = None, | ||||||||||
tag: Optional[str] = None, | ||||||||||
judgment_api_key: str = os.getenv("JUDGMENT_API_KEY") or "", | ||||||||||
organization_id: str = os.getenv("JUDGMENT_ORG_ID") or "", | ||||||||||
): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function
Suggested change
|
||||||||||
client = JudgmentSyncClient(judgment_api_key, organization_id) | ||||||||||
try: | ||||||||||
prompt_config = client.prompts_fetch(name, commit_id, tag) | ||||||||||
return prompt_config | ||||||||||
except JudgmentAPIError as e: | ||||||||||
raise JudgmentAPIError( | ||||||||||
status_code=e.status_code, | ||||||||||
detail=f"Failed to fetch prompt '{name}': {e.detail}", | ||||||||||
response=e.response, | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
@dataclass | ||||||||||
class Prompt: | ||||||||||
name: str | ||||||||||
prompt: str | ||||||||||
tags: List[str] | ||||||||||
commit_id: str | ||||||||||
parent_commit_id: Optional[str] = None | ||||||||||
_template: Template = field(init=False, repr=False) | ||||||||||
|
||||||||||
def __post_init__(self): | ||||||||||
template_str = re.sub(r"\{\{(\w+)\}\}", r"$\1", self.prompt) | ||||||||||
self._template = Template(template_str) | ||||||||||
|
||||||||||
@classmethod | ||||||||||
def create(cls, name: str, prompt: str, tags: Optional[List[str]] = None): | ||||||||||
if not tags: | ||||||||||
tags = [] | ||||||||||
Comment on lines
+65
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While
Suggested change
|
||||||||||
commit_id, parent_commit_id = push_prompt(name, prompt, tags) | ||||||||||
return cls( | ||||||||||
name=name, | ||||||||||
prompt=prompt, | ||||||||||
tags=tags, | ||||||||||
commit_id=commit_id, | ||||||||||
parent_commit_id=parent_commit_id, | ||||||||||
) | ||||||||||
|
||||||||||
@classmethod | ||||||||||
def get(cls, name: str, commit_id: Optional[str] = None, tag: Optional[str] = None): | ||||||||||
if commit_id is not None and tag is not None: | ||||||||||
raise ValueError( | ||||||||||
"You cannot fetch a prompt by both commit_id and tag at the same time" | ||||||||||
) | ||||||||||
prompt_config = fetch_prompt(name, commit_id, tag) | ||||||||||
return cls( | ||||||||||
name=prompt_config["name"], | ||||||||||
prompt=prompt_config["prompt"], | ||||||||||
tags=prompt_config["tags"], | ||||||||||
commit_id=prompt_config["commit_id"], | ||||||||||
parent_commit_id=prompt_config["parent_commit_id"], | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||
) | ||||||||||
|
||||||||||
def compile(self, **kwargs) -> str: | ||||||||||
try: | ||||||||||
return self._template.substitute(**kwargs) | ||||||||||
except KeyError as e: | ||||||||||
missing_var = str(e).strip("'") | ||||||||||
raise ValueError(f"Missing required variable: {missing_var}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The construction of
query_params
can be made slightly more concise. While the current implementation is correct, initializing the dictionary with the requiredname
parameter and then conditionally adding optional parameters can improve readability.