From 676ff74c685c7be9f76ec4419754d1ba9269710e Mon Sep 17 00:00:00 2001 From: Ben Tucker Date: Thu, 2 Jan 2025 15:36:43 -0600 Subject: [PATCH] Add support for custom template types to be provided by plugins - Enhanced the Template class to allow dynamic retrieval of template classes based on type. - Introduced a new hook for registering additional template types. - Updated the load_template function to handle custom template types and raise appropriate errors for unknown types. --- docs/plugins/plugin-hooks.md | 29 ++++++ docs/templates.md | 49 ++++++++++ llm/cli.py | 23 +++-- llm/hookspecs.py | 9 ++ llm/templates.py | 23 ++++- tests/test_templates.py | 178 +++++++++++++++++++++++++++++++++-- 6 files changed, 291 insertions(+), 20 deletions(-) diff --git a/docs/plugins/plugin-hooks.md b/docs/plugins/plugin-hooks.md index 0f38cd64..eb9a2ad8 100644 --- a/docs/plugins/plugin-hooks.md +++ b/docs/plugins/plugin-hooks.md @@ -59,3 +59,32 @@ This demonstrates how to register a model with both sync and async versions, and The {ref}`model plugin tutorial ` describes how to use this hook in detail. Asynchronous models {ref}`are described here `. +(register-template-types)= +## register_template_types() + +This hook allows plugins to register custom template types that can be used in prompt templates. + +```python +from llm import Template, hookimpl + +class CustomTemplate(Template): + type: str = "custom" + + def evaluate(self, input: str, params=None): + # Custom processing here + prompt, system = super().evaluate(input, params) + return f"CUSTOM: {prompt}", system + + def stringify(self): + # Custom string representation for llm templates list + return f"custom template: {self.prompt}" + +@hookimpl +def register_template_types(): + return { + "custom": CustomTemplate + } +``` + +Custom template types can modify how prompts are processed and how they appear in template listings. See {ref}`custom-template-types` for more details. + diff --git a/docs/templates.md b/docs/templates.md index 69121597..69dcb410 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -213,3 +213,52 @@ Example: llm -t roast 'How are you today?' ``` > I'm doing great but with your boring questions, I must admit, I've seen more life in a cemetery. + +(custom-template-types)= +### Custom template types + +Plugins can register custom template types that provide additional functionality. These templates are identified by a `type:` key in their YAML configuration. + +For example, a plugin might provide a custom template type that adds special formatting or processing to the prompts: + +```yaml +type: custom +prompt: Hello $input +system: Be helpful +``` + +Custom template types can customize how they appear in the template list by implementing a `stringify` method. This allows them to provide a more descriptive or formatted representation of their configuration when users run `llm templates list`. + +To create a custom template type in a plugin: + +1. Create a class that inherits from `Template` +2. Set a `type` attribute to identify your template type +3. Override methods like `evaluate` to customize behavior +4. Optionally implement `stringify` to control how the template appears in listings +5. Register your template type using the `register_template_types` hook + +For details on implementing the plugin hook, see {ref}`register_template_types() `. + +Example plugin implementation: + +```python +from llm import Template, hookimpl + +class CustomTemplate(Template): + type: str = "custom" + + def evaluate(self, input: str, params=None): + # Custom processing here + prompt, system = super().evaluate(input, params) + return f"CUSTOM: {prompt}", system + + def stringify(self): + # Custom string representation + return f"custom template: {self.prompt}" + +@hookimpl +def register_template_types(): + return { + "custom": CustomTemplate + } +``` diff --git a/llm/cli.py b/llm/cli.py index f4a9d32c..2e0b1219 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1155,14 +1155,18 @@ def templates_list(): for file in path.glob("*.yaml"): name = file.stem template = load_template(name) - text = [] - if template.system: - text.append(f"system: {template.system}") - if template.prompt: - text.append(f" prompt: {template.prompt}") + if hasattr(template, "stringify"): + text = template.stringify() else: - text = [template.prompt if template.prompt else ""] - pairs.append((name, "".join(text).replace("\n", " "))) + text = [] + if template.system: + text.append(f"system: {template.system}") + if template.prompt: + text.append(f" prompt: {template.prompt}") + else: + text = [template.prompt if template.prompt else ""] + text = "".join(text) + pairs.append((name, text.replace("\n", " "))) try: max_name_len = max(len(p[0]) for p in pairs) except ValueError: @@ -1875,11 +1879,14 @@ def load_template(name): return Template(name=name, prompt=loaded) loaded["name"] = name try: - return Template(**loaded) + template_class = Template.get_template_class(loaded.get("type")) + return template_class(**loaded) except pydantic.ValidationError as ex: msg = "A validation error occurred:\n" msg += render_errors(ex.errors()) raise click.ClickException(msg) + except ValueError as ex: + raise click.ClickException(str(ex)) def get_history(chat_id): diff --git a/llm/hookspecs.py b/llm/hookspecs.py index e7f806be..1073142c 100644 --- a/llm/hookspecs.py +++ b/llm/hookspecs.py @@ -18,3 +18,12 @@ def register_models(register): @hookspec def register_embedding_models(register): "Register additional model instances that can be used for embedding" + + +@hookspec +def register_template_types(): + """Register additional template types that can be used for prompt templates. + + Returns: + dict: A dictionary mapping template type names to template classes + """ diff --git a/llm/templates.py b/llm/templates.py index b540fad1..b8ff9c2a 100644 --- a/llm/templates.py +++ b/llm/templates.py @@ -1,10 +1,12 @@ from pydantic import BaseModel import string -from typing import Optional, Any, Dict, List, Tuple +from typing import Optional, Any, Dict, List, Tuple, Type, Literal +from .plugins import pm class Template(BaseModel): name: str + type: Optional[str] = None prompt: Optional[str] = None system: Optional[str] = None model: Optional[str] = None @@ -13,11 +15,28 @@ class Template(BaseModel): extract: Optional[bool] = None class Config: - extra = "forbid" + extra = "allow" class MissingVariables(Exception): pass + @classmethod + def get_template_class(cls, type: Optional[str]) -> Type["Template"]: + """Get the template class for a given type.""" + if not type: + return cls + + # Get registered template types from plugins + template_types = {} + for hook_result in pm.hook.register_template_types(): + if hook_result: + template_types.update(hook_result) + + if type not in template_types: + raise ValueError(f"Unknown template type: {type}") + + return template_types[type] + def evaluate( self, input: str, params: Optional[Dict[str, Any]] = None ) -> Tuple[Optional[str], Optional[str]]: diff --git a/tests/test_templates.py b/tests/test_templates.py index e66005c4..495565a3 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,13 +1,55 @@ from click.testing import CliRunner +import click import json -from llm import Template -from llm.cli import cli +from llm import Template, hookimpl +from llm.cli import cli, load_template +from llm.plugins import pm import os from unittest import mock import pytest import yaml +class CustomTemplate(Template): + """A custom template type for testing.""" + type: str = "custom" + + def evaluate(self, input: str, params=None): + prompt, system = super().evaluate(input, params) + if prompt: + prompt = f"CUSTOM: {prompt}" + if system: + system = f"CUSTOM: {system}" + return prompt, system + + def stringify(self): + parts = [] + if self.prompt: + parts.append(f"custom prompt: {self.prompt}") + if self.system: + parts.append(f"custom system: {self.system}") + return " ".join(parts) + + +class MockPlugin: + __name__ = "MockPlugin" + + @hookimpl + def register_template_types(self): + return { + "custom": CustomTemplate + } + + +@pytest.fixture +def register_custom_template(monkeypatch): + pm.register(MockPlugin(), name="mock-plugin") + try: + yield + finally: + pm.unregister(name="mock-plugin") + + @pytest.mark.parametrize( "prompt,system,defaults,params,expected_prompt,expected_system,expected_error", ( @@ -42,6 +84,72 @@ def test_template_evaluate( assert system == expected_system +@pytest.mark.parametrize( + "template_yaml,expected_type,expected_prompt,expected_system,expected_error", + ( + ( + """ + type: custom + prompt: Hello $input + system: Be helpful + """, + CustomTemplate, + "CUSTOM: Hello world", + "CUSTOM: Be helpful", + None, + ), + ( + """ + type: unknown + prompt: Hello $input + """, + None, + None, + None, + "Unknown template type: unknown", + ), + ( + "Hello $input", + Template, + "Hello world", + None, + None, + ), + ( + """ + prompt: Hello $input + system: Be helpful + """, + Template, + "Hello world", + "Be helpful", + None, + ), + ), +) +def test_template_types( + register_custom_template, + templates_path, + template_yaml, + expected_type, + expected_prompt, + expected_system, + expected_error, +): + (templates_path / "test.yaml").write_text(template_yaml, "utf-8") + if expected_error: + with pytest.raises(click.ClickException, match=expected_error): + load_template("test") + else: + template = load_template("test") + assert isinstance(template, expected_type) + if expected_type == CustomTemplate: + assert template.type == "custom" + prompt, system = template.evaluate("world") + assert prompt == expected_prompt + assert system == expected_system + + def test_templates_list_no_templates_found(): runner = CliRunner() result = runner.invoke(cli, ["templates", "list"]) @@ -50,6 +158,7 @@ def test_templates_list_no_templates_found(): @pytest.mark.parametrize("args", (["templates", "list"], ["templates"])) +@pytest.mark.usefixtures("register_custom_template") def test_templates_list(templates_path, args): (templates_path / "one.yaml").write_text("template one", "utf-8") (templates_path / "two.yaml").write_text("template two", "utf-8") @@ -63,17 +172,22 @@ def test_templates_list(templates_path, args): "system: summarize this\nprompt: $input", "utf-8" ) (templates_path / "sys.yaml").write_text("system: Summarize this", "utf-8") + (templates_path / "custom.yaml").write_text( + "type: custom\nprompt: Hello $input", "utf-8" + ) runner = CliRunner() result = runner.invoke(cli, args) assert result.exit_code == 0 - assert result.output == ( - "both : system: summarize this prompt: $input\n" - "four : this one has newlines in it\n" - "one : template one\n" - "sys : system: Summarize this\n" - "three : template three is very long template three is very long template thre...\n" - "two : template two\n" - ) + lines = result.output.strip().split("\n") + assert len(lines) == 7 + assert lines[0] == "both : system: summarize this prompt: $input" + assert lines[1] == "custom : custom prompt: Hello $input" + assert lines[2] == "four : this one has newlines in it" + assert lines[3] == "one : template one" + assert lines[4] == "sys : system: Summarize this" + assert lines[5].startswith("three : template three is very long template three is very long template") + assert lines[5].endswith("...") + assert lines[6] == "two : template two" @pytest.mark.parametrize( @@ -192,3 +306,47 @@ def test_template_basic( assert result.exit_code == 1 assert result.output.strip() == expected_error mocked_openai_chat.reset() + + +@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) +@pytest.mark.parametrize( + "template,extra_args,expected_model,expected_input,expected_error", + ( + ( + "type: custom\nprompt: 'Say $hello'", + ["-p", "hello", "Blah"], + "gpt-4o-mini", + "CUSTOM: Say Blah", + None, + ), + ), +) +def test_template_basic_custom( + register_custom_template, + templates_path, + mocked_openai_chat, + template, + extra_args, + expected_model, + expected_input, + expected_error, +): + (templates_path / "template.yaml").write_text(template, "utf-8") + runner = CliRunner() + result = runner.invoke( + cli, + ["--no-stream", "-t", "template", "Input text"] + extra_args, + catch_exceptions=False, + ) + if expected_error is None: + assert result.exit_code == 0 + last_request = mocked_openai_chat.get_requests()[-1] + assert json.loads(last_request.content) == { + "model": expected_model, + "messages": [{"role": "user", "content": expected_input}], + "stream": False, + } + else: + assert result.exit_code == 1 + assert result.output.strip() == expected_error + mocked_openai_chat.reset()