Skip to content

Commit

Permalink
chore: FunctionPrompt fix
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Nov 30, 2024
1 parent bd6d631 commit 85efc4f
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 66 deletions.
2 changes: 1 addition & 1 deletion src/llmling/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ async def render_prompt(
"""
try:
prompt = self._prompt_registry[name]
return prompt.format(arguments)
return await prompt.format(arguments)
except KeyError as exc:
msg = f"Prompt not found: {name}"
raise exceptions.LLMLingError(msg) from exc
Expand Down
16 changes: 3 additions & 13 deletions src/llmling/prompts/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,15 @@ def create_prompt_from_callable(
)
)

# Create message template
if template_override:
template = template_override
else:
# Create default template from function signature
arg_list = ", ".join(f"{arg.name}={{{arg.name}}}" for arg in arguments)
template = f"Call {name}({arg_list})"

# Create message template. Will be formatted with function result
template = template_override if template_override else "{result}"
# Create prompt messages
messages = [
PromptMessage(
role="system",
content=MessageContent(
type="text",
content=(
f"Function: {name}\n"
f"Description: {description}\n\n"
"Please provide the required arguments."
),
content=f"Content from {name}:\n",
),
),
PromptMessage(role="user", content=MessageContent(type="text", content=template)),
Expand Down
32 changes: 22 additions & 10 deletions src/llmling/prompts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def validate_arguments(self, provided: dict[str, Any]) -> None:
msg = f"Missing required arguments: {', '.join(missing)}"
raise ValueError(msg)

def format(self, arguments: dict[str, Any] | None = None) -> list[PromptMessage]:
async def format(
self, arguments: dict[str, Any] | None = None
) -> list[PromptMessage]:
"""Format prompt messages with arguments.
Args:
Expand All @@ -87,15 +89,25 @@ def format(self, arguments: dict[str, Any] | None = None) -> list[PromptMessage]
args = arguments or {}
self.validate_arguments(args)

# Add defaults for missing optional arguments
format_args = {}
for arg in self.arguments:
if arg.name in args:
format_args[arg.name] = args[arg.name]
elif arg.default is not None:
format_args[arg.name] = arg.default
elif not arg.required:
format_args[arg.name] = "" # Empty string for optional args
# If this is a function prompt, execute it
if self.metadata.get("source") == "function":
try:
from llmling.utils import calling

import_path = self.metadata["import_path"]
result = await calling.execute_callable(import_path, **args)
format_args = {"result": result}
except Exception as exc: # noqa: BLE001
format_args = {"result": f"Error executing function: {exc}"}
else:
# Add default values for optional arguments
format_args = args.copy() # Make a copy to avoid modifying input
for arg in self.arguments:
if arg.name not in format_args:
if arg.default is not None:
format_args[arg.name] = arg.default
elif not arg.required:
format_args[arg.name] = "" # Empty string for optional args

# Format all messages
formatted_messages = []
Expand Down
2 changes: 1 addition & 1 deletion src/llmling/prompts/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def get_messages(
) -> list[PromptMessage]:
"""Get formatted messages for a prompt."""
prompt = self[name]
return prompt.format(arguments or {})
return await prompt.format(arguments or {})

async def get_completions(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/llmling/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ async def handle_get_prompt(
"""Handle prompts/get request."""
try:
prompt = self.runtime.get_prompt(name)
messages = await self.runtime.render_prompt(name, arguments)
messages = await prompt.format(arguments or {}) # Note: now async
mcp_msgs = [conversions.to_mcp_message(msg) for msg in messages]
return GetPromptResult(description=prompt.description, messages=mcp_msgs)
except exceptions.LLMLingError as exc:
Expand Down
14 changes: 8 additions & 6 deletions tests/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def sample_prompt() -> Prompt:
)


def test_prompt_format():
@pytest.mark.asyncio
async def test_prompt_format():
"""Test prompt message formatting."""
prompt = Prompt(
name="test",
Expand All @@ -127,18 +128,19 @@ def test_prompt_format():
)

# Test with all arguments
messages = prompt.format({"name": "Alice", "age": "30"})
messages = await prompt.format({"name": "Alice", "age": "30"})
assert len(messages) == 2 # noqa: PLR2004
assert messages[0].get_text_content() == "Hello Alice"
assert messages[1].get_text_content() == "Age: 30"

# Test with only required arguments
messages = prompt.format({"name": "Bob"})
messages = await prompt.format({"name": "Bob"})
assert messages[0].get_text_content() == "Hello Bob"
assert messages[1].get_text_content() == "Age: "


def test_prompt_validation():
@pytest.mark.asyncio
async def test_prompt_validation():
"""Test prompt argument validation."""
prompt = Prompt(
name="test",
Expand All @@ -149,8 +151,8 @@ def test_prompt_validation():

# Should raise when missing required argument
with pytest.raises(ValueError, match="Missing required argument"):
prompt.format({})
await prompt.format({})

# Should work with required argument
messages = prompt.format({"required_arg": "value"})
messages = await prompt.format({"required_arg": "value"})
assert messages[0].get_text_content() == "Test value"
76 changes: 42 additions & 34 deletions tests/test_prompts_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def example_function(
Returns:
Processed text
"""
return text
result = text
if style == "detailed":
result = f"{result} (detailed)"
if tags:
result = f"{result} [tags: {', '.join(tags)}]"
return result


async def async_function(
Expand All @@ -42,6 +47,8 @@ async def async_function(
Returns:
Processed content
"""
if mode == "upper":
return content.upper()
return content


Expand Down Expand Up @@ -69,7 +76,7 @@ def test_create_prompt_arguments():
assert args["text"].description
assert "input text to process" in args["text"].description.lower()

# Check style argument - should be text type with enum values
# Check style argument
assert args["style"].required is False
assert args["style"].type_hint is typing.Literal["brief", "detailed"]
assert args["style"].default == "brief"
Expand All @@ -96,26 +103,23 @@ def test_create_prompt_async():
assert "Content to process" in description


def test_prompt_formatting():
"""Test that created prompts can be formatted."""
@pytest.mark.asyncio
async def test_prompt_formatting():
"""Test that created prompts format with function results."""
prompt = create_prompt_from_callable(example_function)

# Format with all arguments
messages = prompt.format({
messages = await prompt.format({
"text": "sample",
"style": "brief",
"style": "detailed",
"tags": ["test"],
})
formatted = messages[1].get_text_content()
assert "text=sample" in formatted
assert "style=brief" in formatted
assert "tags=['test']" in formatted
result = messages[1].get_text_content()
assert result == "sample (detailed) [tags: test]"

# Format with only required arguments
messages = prompt.format({"text": "sample"})
formatted = messages[1].get_text_content()
assert "text=sample" in formatted
assert "style=brief" in formatted # Default value
messages = await prompt.format({"text": "sample"})
assert messages[1].get_text_content() == "sample" # Uses default brief style


def test_create_prompt_overrides():
Expand All @@ -124,27 +128,25 @@ def test_create_prompt_overrides():
example_function,
name_override="custom_name",
description_override="Custom description",
template_override="Custom template: {text}",
template_override="Result: {result}",
)

assert prompt.name == "custom_name"
assert prompt.description == "Custom description"
assert prompt.messages[1].content.content == "Result: {result}" # type: ignore

# Test template override
messages = prompt.format({"text": "test"})
assert messages[1].get_text_content() == "Custom template: test"


def test_create_prompt_from_import_path():
@pytest.mark.asyncio
async def test_create_prompt_from_import_path():
"""Test prompt creation from import path."""
prompt = create_prompt_from_callable("llmling.testing.processors.uppercase_text")

assert prompt.name == "uppercase_text"
assert "Convert text to uppercase" in prompt.description

# Test formatting
messages = prompt.format({"text": "test"})
assert "text=test" in messages[1].get_text_content()
# Test execution
messages = await prompt.format({"text": "test"})
assert messages[1].get_text_content() == "TEST"


def test_create_prompt_invalid_import():
Expand All @@ -153,18 +155,14 @@ def test_create_prompt_invalid_import():
create_prompt_from_callable("nonexistent.module.function")


def test_argument_validation():
@pytest.mark.asyncio
async def test_argument_validation():
"""Test argument validation in created prompts."""
prompt = create_prompt_from_callable(example_function)

# Should fail without required argument
with pytest.raises(ValueError, match="Missing required argument"):
prompt.format({})

# Should work with required argument
messages = prompt.format({"text": "test"})
assert len(messages) == 2 # noqa: PLR2004
assert "text=test" in messages[1].get_text_content()
with pytest.raises(ValueError, match="Missing required arguments"):
await prompt.format({}) # Add await here


def test_system_message():
Expand All @@ -173,9 +171,7 @@ def test_system_message():

system_msg = prompt.messages[0]
assert system_msg.role == "system"
content = system_msg.get_text_content()
assert "Function: example_function" in content
assert "Description: Process text with given style" in content
assert "Content from example_function" in system_msg.get_text_content()


def test_prompt_with_completions():
Expand Down Expand Up @@ -203,3 +199,15 @@ def example_func(
# Check completion function
assert args["other"].completion_function is not None
assert args["other"].completion_function("py") == ["python"]


@pytest.mark.asyncio
async def test_async_function_execution():
"""Test that async functions are properly executed."""
prompt = create_prompt_from_callable(async_function)

messages = await prompt.format({"content": "test", "mode": "upper"})
assert messages[1].get_text_content() == "TEST"

messages = await prompt.format({"content": "test"})
assert messages[1].get_text_content() == "test"
2 changes: 2 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from llmling.prompts.models import ExtendedPromptArgument, Prompt, PromptMessage


@pytest.mark.asyncio
async def test_render_prompt(runtime_config):
"""Test prompt rendering through runtime config."""
prompt = Prompt(
Expand All @@ -27,6 +28,7 @@ async def test_render_prompt_not_found(runtime_config):
await runtime_config.render_prompt("nonexistent")


@pytest.mark.asyncio
async def test_render_prompt_validation_error(runtime_config):
"""Test error handling for invalid arguments."""
prompt = Prompt(
Expand Down

0 comments on commit 85efc4f

Please sign in to comment.