From 9b53817561b4e5c561d7c89b927671bf73ee255f Mon Sep 17 00:00:00 2001 From: James Braza Date: Wed, 13 Nov 2024 14:04:00 -0800 Subject: [PATCH] Refactor to remove `skip_system` from `LLMModel.run_prompt` (#680) --- paperqa/agents/helpers.py | 2 +- paperqa/docs.py | 4 ++-- paperqa/llms.py | 30 +++++++++++------------------- tests/test_llms.py | 6 +++--- tests/test_paperqa.py | 8 ++++---- tests/test_rate_limiter.py | 4 ++-- 6 files changed, 23 insertions(+), 31 deletions(-) diff --git a/paperqa/agents/helpers.py b/paperqa/agents/helpers.py index 933a861b..062e367f 100644 --- a/paperqa/agents/helpers.py +++ b/paperqa/agents/helpers.py @@ -61,7 +61,7 @@ async def litellm_get_search_query( result = await model.run_prompt( prompt=search_prompt, data={"question": question, "count": count}, - skip_system=True, + system_prompt=None, ) search_query = result.text queries = [s for s in search_query.split("\n") if len(s) > 3] # noqa: PLR2004 diff --git a/paperqa/docs.py b/paperqa/docs.py index ea7fa607..51a79794 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -276,7 +276,7 @@ async def aadd( # noqa: PLR0912 result = await llm_model.run_prompt( prompt=parse_config.citation_prompt, data={"text": texts[0].text}, - skip_system=True, # skip system because it's too hesitant to answer + system_prompt=None, # skip system because it's too hesitant to answer ) citation = result.text if ( @@ -313,7 +313,7 @@ async def aadd( # noqa: PLR0912 result = await llm_model.run_prompt( prompt=parse_config.structured_citation_prompt, data={"citation": citation}, - skip_system=True, + system_prompt=None, ) # This code below tries to isolate the JSON # based on observed messages from LLMs diff --git a/paperqa/llms.py b/paperqa/llms.py index ac2092f5..13f2424a 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -328,18 +328,15 @@ async def run_prompt( data: dict, callbacks: list[Callable] | None = None, name: str | None = None, - skip_system: bool = False, - system_prompt: str = default_system_prompt, + system_prompt: str | None = default_system_prompt, ) -> LLMResult: if self.llm_type is None: self.llm_type = self.infer_llm_type() if self.llm_type == "chat": - return await self._run_chat( - prompt, data, callbacks, name, skip_system, system_prompt - ) + return await self._run_chat(prompt, data, callbacks, name, system_prompt) if self.llm_type == "completion": return await self._run_completion( - prompt, data, callbacks, name, skip_system, system_prompt + prompt, data, callbacks, name, system_prompt ) raise ValueError(f"Unknown llm_type {self.llm_type!r}.") @@ -349,8 +346,7 @@ async def _run_chat( data: dict, callbacks: list[Callable] | None = None, name: str | None = None, - skip_system: bool = False, - system_prompt: str = default_system_prompt, + system_prompt: str | None = default_system_prompt, ) -> LLMResult: """Run a chat prompt. @@ -359,20 +355,18 @@ async def _run_chat( data: Keys for the input variables that will be formatted into prompt. callbacks: Optional functions to call with each chunk of the completion. name: Optional name for the result. - skip_system: Set True to skip the system prompt. - system_prompt: System prompt to use. + system_prompt: System prompt to use, or None/empty string to not use one. Returns: Result of the chat. """ - system_message_prompt = {"role": "system", "content": system_prompt} human_message_prompt = {"role": "user", "content": prompt} messages = [ {"role": m["role"], "content": m["content"].format(**data)} for m in ( - [human_message_prompt] - if skip_system - else [system_message_prompt, human_message_prompt] + [{"role": "system", "content": system_prompt}, human_message_prompt] + if system_prompt + else [human_message_prompt] ) ] result = LLMResult( @@ -425,8 +419,7 @@ async def _run_completion( data: dict, callbacks: Iterable[Callable] | None = None, name: str | None = None, - skip_system: bool = False, - system_prompt: str = default_system_prompt, + system_prompt: str | None = default_system_prompt, ) -> LLMResult: """Run a completion prompt. @@ -435,14 +428,13 @@ async def _run_completion( data: Keys for the input variables that will be formatted into prompt. callbacks: Optional functions to call with each chunk of the completion. name: Optional name for the result. - skip_system: Set True to skip the system prompt. - system_prompt: System prompt to use. + system_prompt: System prompt to use, or None/empty string to not use one. Returns: Result of the completion. """ formatted_prompt: str = ( - prompt if skip_system else system_prompt + "\n\n" + prompt + system_prompt + "\n\n" + prompt if system_prompt else prompt ).format(**data) result = LLMResult( model=self.name, diff --git a/tests/test_llms.py b/tests/test_llms.py index 9cf827bd..69bd65c8 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -59,7 +59,7 @@ def accum(x) -> None: completion = await llm.run_prompt( prompt="The {animal} says", data={"animal": "duck"}, - skip_system=True, + system_prompt=None, callbacks=[accum], ) assert completion.model == "gpt-4o-mini" @@ -72,7 +72,7 @@ def accum(x) -> None: completion = await llm.run_prompt( prompt="The {animal} says", data={"animal": "duck"}, - skip_system=True, + system_prompt=None, ) assert completion.seconds_to_first_token == 0 assert completion.seconds_to_last_token > 0 @@ -85,7 +85,7 @@ async def ac(x) -> None: completion = await llm.run_prompt( prompt="The {animal} says", data={"animal": "duck"}, - skip_system=True, + system_prompt=None, callbacks=[accum, ac], ) assert completion.cost > 0 diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index d4095923..74d56a29 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -423,7 +423,7 @@ def accum(x) -> None: completion = await llm.run_prompt( prompt="The {animal} says", data={"animal": "duck"}, - skip_system=True, + system_prompt=None, callbacks=[accum], ) assert completion.seconds_to_first_token > 0 @@ -432,7 +432,7 @@ def accum(x) -> None: assert str(completion) == "".join(outputs) completion = await llm.run_prompt( - prompt="The {animal} says", data={"animal": "duck"}, skip_system=True + prompt="The {animal} says", data={"animal": "duck"}, system_prompt=None ) assert completion.seconds_to_first_token == 0 assert completion.seconds_to_last_token > 0 @@ -453,7 +453,7 @@ def accum(x) -> None: completion = await llm.run_prompt( prompt="The {animal} says", data={"animal": "duck"}, - skip_system=True, + system_prompt=None, callbacks=[accum], ) assert completion.seconds_to_first_token > 0 @@ -464,7 +464,7 @@ def accum(x) -> None: assert completion.cost > 0 completion = await llm.run_prompt( - prompt="The {animal} says", data={"animal": "duck"}, skip_system=True + prompt="The {animal} says", data={"animal": "duck"}, system_prompt=None ) assert completion.seconds_to_first_token == 0 assert completion.seconds_to_last_token > 0 diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py index 1c299e49..0b73773a 100644 --- a/tests/test_rate_limiter.py +++ b/tests/test_rate_limiter.py @@ -165,7 +165,7 @@ def accum(x) -> None: 3, prompt="The {animal} says", data={"animal": "duck"}, - skip_system=True, + system_prompt=None, callbacks=[accum], ) @@ -192,7 +192,7 @@ def accum2(x) -> None: use_gather=True, prompt="The {animal} says", data={"animal": "duck"}, - skip_system=True, + system_prompt=None, callbacks=[accum2], )