Skip to content

Commit

Permalink
Refactor to remove skip_system from LLMModel.run_prompt (#680)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Nov 13, 2024
1 parent 4069d38 commit 9b53817
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 31 deletions.
2 changes: 1 addition & 1 deletion paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def litellm_get_search_query(
result = await model.run_prompt(
prompt=search_prompt,
data={"question": question, "count": count},
skip_system=True,
system_prompt=None,
)
search_query = result.text
queries = [s for s in search_query.split("\n") if len(s) > 3] # noqa: PLR2004
Expand Down
4 changes: 2 additions & 2 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ async def aadd( # noqa: PLR0912
result = await llm_model.run_prompt(
prompt=parse_config.citation_prompt,
data={"text": texts[0].text},
skip_system=True, # skip system because it's too hesitant to answer
system_prompt=None, # skip system because it's too hesitant to answer
)
citation = result.text
if (
Expand Down Expand Up @@ -313,7 +313,7 @@ async def aadd( # noqa: PLR0912
result = await llm_model.run_prompt(
prompt=parse_config.structured_citation_prompt,
data={"citation": citation},
skip_system=True,
system_prompt=None,
)
# This code below tries to isolate the JSON
# based on observed messages from LLMs
Expand Down
30 changes: 11 additions & 19 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,18 +328,15 @@ async def run_prompt(
data: dict,
callbacks: list[Callable] | None = None,
name: str | None = None,
skip_system: bool = False,
system_prompt: str = default_system_prompt,
system_prompt: str | None = default_system_prompt,
) -> LLMResult:
if self.llm_type is None:
self.llm_type = self.infer_llm_type()
if self.llm_type == "chat":
return await self._run_chat(
prompt, data, callbacks, name, skip_system, system_prompt
)
return await self._run_chat(prompt, data, callbacks, name, system_prompt)
if self.llm_type == "completion":
return await self._run_completion(
prompt, data, callbacks, name, skip_system, system_prompt
prompt, data, callbacks, name, system_prompt
)
raise ValueError(f"Unknown llm_type {self.llm_type!r}.")

Expand All @@ -349,8 +346,7 @@ async def _run_chat(
data: dict,
callbacks: list[Callable] | None = None,
name: str | None = None,
skip_system: bool = False,
system_prompt: str = default_system_prompt,
system_prompt: str | None = default_system_prompt,
) -> LLMResult:
"""Run a chat prompt.
Expand All @@ -359,20 +355,18 @@ async def _run_chat(
data: Keys for the input variables that will be formatted into prompt.
callbacks: Optional functions to call with each chunk of the completion.
name: Optional name for the result.
skip_system: Set True to skip the system prompt.
system_prompt: System prompt to use.
system_prompt: System prompt to use, or None/empty string to not use one.
Returns:
Result of the chat.
"""
system_message_prompt = {"role": "system", "content": system_prompt}
human_message_prompt = {"role": "user", "content": prompt}
messages = [
{"role": m["role"], "content": m["content"].format(**data)}
for m in (
[human_message_prompt]
if skip_system
else [system_message_prompt, human_message_prompt]
[{"role": "system", "content": system_prompt}, human_message_prompt]
if system_prompt
else [human_message_prompt]
)
]
result = LLMResult(
Expand Down Expand Up @@ -425,8 +419,7 @@ async def _run_completion(
data: dict,
callbacks: Iterable[Callable] | None = None,
name: str | None = None,
skip_system: bool = False,
system_prompt: str = default_system_prompt,
system_prompt: str | None = default_system_prompt,
) -> LLMResult:
"""Run a completion prompt.
Expand All @@ -435,14 +428,13 @@ async def _run_completion(
data: Keys for the input variables that will be formatted into prompt.
callbacks: Optional functions to call with each chunk of the completion.
name: Optional name for the result.
skip_system: Set True to skip the system prompt.
system_prompt: System prompt to use.
system_prompt: System prompt to use, or None/empty string to not use one.
Returns:
Result of the completion.
"""
formatted_prompt: str = (
prompt if skip_system else system_prompt + "\n\n" + prompt
system_prompt + "\n\n" + prompt if system_prompt else prompt
).format(**data)
result = LLMResult(
model=self.name,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def accum(x) -> None:
completion = await llm.run_prompt(
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum],
)
assert completion.model == "gpt-4o-mini"
Expand All @@ -72,7 +72,7 @@ def accum(x) -> None:
completion = await llm.run_prompt(
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
)
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0
Expand All @@ -85,7 +85,7 @@ async def ac(x) -> None:
completion = await llm.run_prompt(
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum, ac],
)
assert completion.cost > 0
Expand Down
8 changes: 4 additions & 4 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def accum(x) -> None:
completion = await llm.run_prompt(
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum],
)
assert completion.seconds_to_first_token > 0
Expand All @@ -432,7 +432,7 @@ def accum(x) -> None:
assert str(completion) == "".join(outputs)

completion = await llm.run_prompt(
prompt="The {animal} says", data={"animal": "duck"}, skip_system=True
prompt="The {animal} says", data={"animal": "duck"}, system_prompt=None
)
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0
Expand All @@ -453,7 +453,7 @@ def accum(x) -> None:
completion = await llm.run_prompt(
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum],
)
assert completion.seconds_to_first_token > 0
Expand All @@ -464,7 +464,7 @@ def accum(x) -> None:
assert completion.cost > 0

completion = await llm.run_prompt(
prompt="The {animal} says", data={"animal": "duck"}, skip_system=True
prompt="The {animal} says", data={"animal": "duck"}, system_prompt=None
)
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def accum(x) -> None:
3,
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum],
)

Expand All @@ -192,7 +192,7 @@ def accum2(x) -> None:
use_gather=True,
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum2],
)

Expand Down

0 comments on commit 9b53817

Please sign in to comment.