diff --git a/src/colin/llm/prompts/classify.md b/src/colin/llm/prompts/classify.md index 6870374..3404384 100644 --- a/src/colin/llm/prompts/classify.md +++ b/src/colin/llm/prompts/classify.md @@ -16,7 +16,7 @@ Classify the content into one or more of the labels above. Return a list of appl Classify the content into exactly one of the labels above. {% endif %} -{% if previous_output %} +{% if previous_output is not none %} ## Previous Output The following classification was made previously. If the content hasn't changed meaningfully, you may respond with UseExisting to maintain stability. diff --git a/src/colin/llm/prompts/complete.md b/src/colin/llm/prompts/complete.md index 585af40..6346cdb 100644 --- a/src/colin/llm/prompts/complete.md +++ b/src/colin/llm/prompts/complete.md @@ -1,6 +1,6 @@ {{ body }} -{% if previous_output %} +{% if previous_output is not none %} ## Previous Output (for reference) {{ previous_output }} diff --git a/src/colin/llm/prompts/extract.md b/src/colin/llm/prompts/extract.md index 82e27b5..ab6a58d 100644 --- a/src/colin/llm/prompts/extract.md +++ b/src/colin/llm/prompts/extract.md @@ -6,7 +6,7 @@ You are extracting specific information from content. ## Task Extract: {{ prompt }} -{% if previous_output %} +{% if previous_output is not none %} ## Previous Output The following was extracted previously. If the content hasn't changed meaningfully, you may respond with UseExisting to maintain stability. @@ -14,4 +14,8 @@ The following was extracted previously. If the content hasn't changed meaningful {% endif %} ## Response -Provide the extracted information. If previous output exists and is still valid, respond with UseExisting instead. +{% if previous_output is not none %} +Provide the extracted information. If previous output is still valid, respond with UseExisting instead. +{% else %} +Provide the extracted information. +{% endif %} diff --git a/src/colin/providers/llm.py b/src/colin/providers/llm.py index c94d99e..3693440 100644 --- a/src/colin/providers/llm.py +++ b/src/colin/providers/llm.py @@ -184,6 +184,11 @@ async def _extract( result = await agent.run(full_prompt) output_text = str(result.output) + # Handle UseExisting signal: LLM returned the literal string + # "UseExisting" to indicate previous output is still valid + if output_text.strip() == "UseExisting" and previous_output is not None: + output_text = previous_output + # Record LLM call for tracking (only on actual execution, not cache hit) if compile_ctx: compile_ctx.add_llm_call( @@ -424,6 +429,11 @@ async def _complete( result = await agent.run(full_prompt) output_text = str(result.output) + # Handle UseExisting signal: LLM returned the literal string + # "UseExisting" to indicate previous output is still valid + if output_text.strip() == "UseExisting" and previous_output is not None: + output_text = previous_output + # Record LLM call for tracking (only on actual execution, not cache hit) if compile_ctx: compile_ctx.add_llm_call( diff --git a/tests/providers/test_llm_provider.py b/tests/providers/test_llm_provider.py index 6c138f7..1470d5f 100644 --- a/tests/providers/test_llm_provider.py +++ b/tests/providers/test_llm_provider.py @@ -5,7 +5,7 @@ from pydantic_ai.models.function import FunctionModel from colin.api.project import ProjectConfig -from colin.compiler.cache import hash_args, set_compile_context +from colin.compiler.cache import _serialize_value, hash_args, set_compile_context from colin.compiler.context import CompileContext from colin.models import DocumentMeta, LLMCall, Manifest from colin.providers.llm import LLMProvider @@ -636,7 +636,6 @@ def capture_prompt(messages, info): provider = LLMProvider(model=FunctionModel(capture_prompt)) # Manifest with previous successful extract call - from colin.compiler.cache import _serialize_value old_content = "Old content that was extracted from" prompt = "key points" @@ -696,6 +695,304 @@ def capture_prompt(messages, info): assert "## Previous Output" in captured_prompts[0] assert "Previous extraction result" in captured_prompts[0] + async def test_extract_use_existing_returns_previous_output(self, tmp_path) -> None: + """Test that when LLM returns 'UseExisting', the previous output is used instead.""" + + def return_use_existing(messages, info): + return ModelResponse(parts=[TextPart(content="UseExisting")]) + + provider = LLMProvider(model=FunctionModel(return_use_existing)) + + old_content = "Old content" + prompt = "summarize" + serialized = _serialize_value(old_content) + input_hash = hash_args((serialized, prompt), {}) + position_id = "extract_1" + call_id = f"llm.extract:{position_id}:{input_hash}" + + manifest = Manifest() + doc_uri = "project://test.md" + doc_meta = DocumentMeta( + uri=doc_uri, + source_hash="abc123", + llm_calls={ + call_id: LLMCall( + call_id=call_id, + position_id=position_id, + config_hash=provider._config_hash, + input_hash=input_hash, + output_hash="out_hash", + output="The document discusses Python testing.", + model="test", + ) + }, + ) + manifest.set_document(doc_uri, doc_meta) + + project_provider = ProjectProvider(base_path=tmp_path) + config = ProjectConfig( + name="test", + project_root=tmp_path, + model_path=tmp_path / "models", + output_path=tmp_path / "output", + manifest_path=tmp_path / ".colin" / "manifest.json", + ) + compile_ctx = CompileContext( + manifest=manifest, + document_uri=doc_uri, + project_provider=project_provider, + config=config, + ) + + set_compile_context(compile_ctx) + try: + result = await provider._extract( + "New content", + prompt, + _position_id=position_id, + ) + finally: + set_compile_context(None) + + # Should return the previous output, not the literal "UseExisting" + assert result == "The document discusses Python testing." + assert result != "UseExisting" + + async def test_extract_use_existing_with_whitespace(self, tmp_path) -> None: + """Test that 'UseExisting' with surrounding whitespace is handled.""" + + def return_use_existing_padded(messages, info): + return ModelResponse(parts=[TextPart(content=" UseExisting \n")]) + + provider = LLMProvider(model=FunctionModel(return_use_existing_padded)) + + old_content = "Old content" + prompt = "summarize" + serialized = _serialize_value(old_content) + input_hash = hash_args((serialized, prompt), {}) + position_id = "extract_1" + call_id = f"llm.extract:{position_id}:{input_hash}" + + manifest = Manifest() + doc_uri = "project://test.md" + doc_meta = DocumentMeta( + uri=doc_uri, + source_hash="abc123", + llm_calls={ + call_id: LLMCall( + call_id=call_id, + position_id=position_id, + config_hash=provider._config_hash, + input_hash=input_hash, + output_hash="out_hash", + output="Previous result", + model="test", + ) + }, + ) + manifest.set_document(doc_uri, doc_meta) + + project_provider = ProjectProvider(base_path=tmp_path) + config = ProjectConfig( + name="test", + project_root=tmp_path, + model_path=tmp_path / "models", + output_path=tmp_path / "output", + manifest_path=tmp_path / ".colin" / "manifest.json", + ) + compile_ctx = CompileContext( + manifest=manifest, + document_uri=doc_uri, + project_provider=project_provider, + config=config, + ) + + set_compile_context(compile_ctx) + try: + result = await provider._extract( + "New content", + prompt, + _position_id=position_id, + ) + finally: + set_compile_context(None) + + assert result == "Previous result" + + async def test_extract_use_existing_without_previous_output_passes_through( + self, tmp_path + ) -> None: + """Test that 'UseExisting' without previous output is returned as-is.""" + + def return_use_existing(messages, info): + return ModelResponse(parts=[TextPart(content="UseExisting")]) + + provider = LLMProvider(model=FunctionModel(return_use_existing)) + + # Empty manifest — no previous output + manifest = Manifest() + doc_uri = "project://test.md" + + project_provider = ProjectProvider(base_path=tmp_path) + config = ProjectConfig( + name="test", + project_root=tmp_path, + model_path=tmp_path / "models", + output_path=tmp_path / "output", + manifest_path=tmp_path / ".colin" / "manifest.json", + ) + compile_ctx = CompileContext( + manifest=manifest, + document_uri=doc_uri, + project_provider=project_provider, + config=config, + ) + + set_compile_context(compile_ctx) + try: + result = await provider._extract( + "Some content", + "summarize", + _position_id="extract_1", + ) + finally: + set_compile_context(None) + + # No previous output, so UseExisting passes through as-is + assert result == "UseExisting" + + async def test_complete_use_existing_returns_previous_output(self, tmp_path) -> None: + """Test that when LLM returns 'UseExisting' for complete, previous output is used.""" + + def return_use_existing(messages, info): + return ModelResponse(parts=[TextPart(content="UseExisting")]) + + provider = LLMProvider(model=FunctionModel(return_use_existing)) + + prompt = "Write a haiku about spring" + input_hash = hash_args((prompt,), {}) + position_id = "llm_1_5" + call_id = f"llm.complete:{position_id}:{input_hash}" + + manifest = Manifest() + doc_uri = "project://test.md" + doc_meta = DocumentMeta( + uri=doc_uri, + source_hash="abc123", + llm_calls={ + call_id: LLMCall( + call_id=call_id, + position_id=position_id, + config_hash=provider._config_hash, + input_hash=input_hash, + output_hash="out_hash", + output=( + "Cherry blossoms fall\n" + "Gentle breeze carries petals\n" + "Spring has come at last" + ), + model="test", + ) + }, + ) + manifest.set_document(doc_uri, doc_meta) + + project_provider = ProjectProvider(base_path=tmp_path) + config = ProjectConfig( + name="test", + project_root=tmp_path, + model_path=tmp_path / "models", + output_path=tmp_path / "output", + manifest_path=tmp_path / ".colin" / "manifest.json", + ) + compile_ctx = CompileContext( + manifest=manifest, + document_uri=doc_uri, + project_provider=project_provider, + config=config, + ) + + set_compile_context(compile_ctx) + try: + result = await provider._complete( + prompt, + _position_id=position_id, + ) + finally: + set_compile_context(None) + + assert ( + result == "Cherry blossoms fall\nGentle breeze carries petals\nSpring has come at last" + ) + assert result != "UseExisting" + + async def test_extract_use_existing_records_previous_output_in_llm_call(self, tmp_path) -> None: + """Test that when UseExisting is resolved, the recorded LLM call stores + the resolved previous output, not the literal 'UseExisting' string.""" + + def return_use_existing(messages, info): + return ModelResponse(parts=[TextPart(content="UseExisting")]) + + provider = LLMProvider(model=FunctionModel(return_use_existing)) + + old_content = "Old content" + prompt = "summarize" + serialized = _serialize_value(old_content) + input_hash = hash_args((serialized, prompt), {}) + position_id = "extract_1" + call_id = f"llm.extract:{position_id}:{input_hash}" + + manifest = Manifest() + doc_uri = "project://test.md" + doc_meta = DocumentMeta( + uri=doc_uri, + source_hash="abc123", + llm_calls={ + call_id: LLMCall( + call_id=call_id, + position_id=position_id, + config_hash=provider._config_hash, + input_hash=input_hash, + output_hash="out_hash", + output="Previous summary", + model="test", + ) + }, + ) + manifest.set_document(doc_uri, doc_meta) + + project_provider = ProjectProvider(base_path=tmp_path) + config = ProjectConfig( + name="test", + project_root=tmp_path, + model_path=tmp_path / "models", + output_path=tmp_path / "output", + manifest_path=tmp_path / ".colin" / "manifest.json", + ) + compile_ctx = CompileContext( + manifest=manifest, + document_uri=doc_uri, + project_provider=project_provider, + config=config, + ) + + set_compile_context(compile_ctx) + try: + await provider._extract( + "New content", + prompt, + _position_id=position_id, + ) + finally: + set_compile_context(None) + + # The new LLM call should record "Previous summary", not "UseExisting" + new_call_id = ( + f"llm.extract:{position_id}:{hash_args((_serialize_value('New content'), prompt), {})}" + ) + recorded_call = compile_ctx.llm_calls[new_call_id] + assert recorded_call.output == "Previous summary" + async def test_classify_receives_previous_output_with_position_id(self, tmp_path) -> None: """Test that _classify receives previous_output when _position_id is provided.""" captured_prompts: list[str] = [] @@ -708,7 +1005,6 @@ def capture_prompt(messages, info): provider = LLMProvider(model=FunctionModel(capture_prompt)) # Manifest with previous successful classify call - from colin.compiler.cache import _serialize_value old_content = "Old content" labels = ["positive", "negative"]