Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement support to BatchAPIs to gather evidence #687

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
HybridEmbeddingModel,
LiteLLMEmbeddingModel,
LiteLLMModel,
OpenAIBatchLLMModel,
LLMModel,
LLMResult,
NumpyVectorStore,
Expand All @@ -38,6 +39,7 @@
"LLMResult",
"LiteLLMEmbeddingModel",
"LiteLLMModel",
"OpenAIBatchLLMModel"
"NumpyVectorStore",
"PQASession",
"QueryRequest",
Expand Down
78 changes: 62 additions & 16 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
LLMResult,
PQASession,
Text,
Context,
set_llm_session_ids,
)
from paperqa.utils import (
Expand All @@ -50,6 +51,8 @@
maybe_is_text,
md5sum,
name_in_text,
extract_score,
strip_citations
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -600,23 +603,66 @@ async def aget_evidence(
)

with set_llm_session_ids(session.id):
results = await gather_with_concurrency(
answer_config.max_concurrent_requests,
[
map_fxn_summary(
text=m,
question=session.question,
prompt_runner=prompt_runner,
extra_prompt_data={
"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.formatted_citation}",
},
parser=llm_parse_json if prompt_config.use_json else None,
callbacks=callbacks,
)
if evidence_settings.use_batch_in_summary:
# TODO: Should we implement a `gather_with_batch` function that receives `matches` and return results to keep this dry?
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved

data = [
{"question": session.question,
"citation": m.name + ": " + m.doc.formatted_citation,
"text": m.text} |
{"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.formatted_citation}",
"evidence": m.name}
for m in matches
],
)
]

llm_results = await prompt_runner(
data,
callbacks,
)

results_data = []
scores = []
for r in llm_results:
try:
results_data.append(llm_parse_json(r.text))
scores.append(r.pop("relevance_score"))
# just in case question was present
r.pop("question", None)
except ValueError:
results_data.append({})
scores.append(extract_score(r.text))

results = [
(
Context(
context=strip_citations(llm_result.text),
text=m,
model_extra={},
score=score,
**r,
),
llm_result,
) for r, m, llm_result, score in zip(results_data, matches, llm_results, scores)
]
else:
results = await gather_with_concurrency(
answer_config.max_concurrent_requests,
[
map_fxn_summary(
text=m,
question=session.question,
prompt_runner=prompt_runner,
extra_prompt_data={
"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.formatted_citation}",
},
parser=llm_parse_json if prompt_config.use_json else None,
callbacks=callbacks,
)
for m in matches
],
)

for _, llm_result in results:
session.add_tokens(llm_result)
Expand Down
209 changes: 208 additions & 1 deletion paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
from typing import Any, TypeVar, cast

import litellm

import openai
import json
import os
import tempfile

import numpy as np
import tiktoken
from pydantic import (
Expand Down Expand Up @@ -325,7 +331,7 @@ def count_tokens(self, text: str) -> int:
async def run_prompt(
self,
prompt: str,
data: dict,
data: dict | list[dict[str, str]],
callbacks: list[Callable] | None = None,
name: str | None = None,
system_prompt: str | None = default_system_prompt,
Expand Down Expand Up @@ -753,6 +759,207 @@ def infer_llm_type(self) -> str:
def count_tokens(self, text: str) -> int:
return litellm.token_counter(model=self.name, text=text)

class OpenAIBatchLLMModel(LLMModel):
"""A wrapper around the OpenAI library to use the batch API."""
name: str = "gpt-4o-mini"
config: dict = Field(
default_factory=dict,
description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.",
)

def write_jsonl(self,
data: list[dict[str, str]],
filename: str):

batch_template = {
"custom_id": None,
"method": "POST",
"url": self.config.get('endpoint'),
"body": {
"model": None,
"messages": None,
"max_tokens": None
}
}
with open(filename, "w") as f:
for i, d in enumerate(data):
batch_template["custom_id"] = str(i)
batch_template["body"]["model"] = self.config.get('model')
batch_template["body"]["messages"] = d
batch_template["body"]["max_tokens"] = self.config.get('max_tokens')
f.write(json.dumps(batch_template) + "\n")

@rate_limited
async def acomplete(self):
raise NotImplementedError("Only chat models are supported by openAI batch API.")

@rate_limited
async def acomplete_iter(self):
raise NotImplementedError("Async generator not supported for batch calls and nly chat models are supported by openAI batch API.")

async def _run_chat(
self,
prompt: str,
data: list[dict[str,str]],
callbacks: list[Callable] | None = None,
name: str | None = None,
skip_system: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored out skip_system in #680, can you propagate that change to here?

system_prompt: str = default_system_prompt,
) -> list[LLMResult]:
if callbacks:
sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)]
async_callbacks = [f for f in callbacks if is_coroutine_callable(f)]

system_message_prompt = {"role": "system", "content": system_prompt}
human_message_prompt = {"role": "user", "content": prompt}

batch = []
for d in data:
messages = [
{"role": m["role"], "content": m["content"].format(**d)}
for m in (
[human_message_prompt]
if skip_system
else [system_message_prompt, human_message_prompt]
)
]
batch.append(messages)

start_clock = asyncio.get_running_loop().time()
chunks = await self.achat(batch)
batch_time = asyncio.get_running_loop().time() - start_clock

if callbacks:
for chunk in chunks:
await do_callbacks(
async_callbacks, sync_callbacks, chunk.text, name
)

results = [
LLMResult(
model=self.name,
name=name,
prompt=messages,
prompt_count=chunk.prompt_tokens,
text=chunk.text,
completion_count=chunk.completion_tokens,
seconds_to_first_token=batch_time,
seconds_to_last_token=batch_time,
) for messages, chunk in zip(batch, chunks)
]

return results

@rate_limited
async def achat(self,
messages: list[dict[str, str]]
) -> list[Chunk]:
client = openai.OpenAI()

with tempfile.NamedTemporaryFile(suffix=".jsonl", delete=True) as tmp_file:
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved
tmp_filename = tmp_file.name
self.write_jsonl(messages, tmp_filename)
file = client.files.create(
file=open(tmp_filename, "rb"),
purpose="batch"
)

batch = client.batches.create(
input_file_id=file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={
"description": ""
}
)

while batch.status != "completed":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably want "completed" and "failed" to be OpenAI enums here rather than free strings.

batch = client.batches.retrieve(batch.id)
if batch.status == "failed":
raise Exception("Batch failed. \n\nReason: \n" + "\n".join([k.message for k in batch.errors.data]))
await asyncio.sleep(5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's parameterize this waiting, and maybe make the default longer? like 30 second or 1 min polling?

We should probably add some debug/info logs here to track progress along with maybe a max-timeout which users can set.


responses = client.files.content(batch.output_file_id)
response_lines = responses.read().decode('utf-8').splitlines()
responses = [json.loads(line) for line in response_lines]
sorted_responses = sorted(responses, key=lambda x: int(x["custom_id"])) # The batchAPI doesn't guarantee the order of the responses

chunks = [
Chunk(
text=response["response"]["body"]["choices"][0]["message"]["content"],
prompt_tokens=response["response"]["body"]["usage"]["prompt_tokens"],
completion_tokens=response["response"]["body"]["usage"]["completion_tokens"],
) for response in sorted_responses
]

return chunks

@rate_limited
async def achat_iter(self):
raise NotImplementedError("Async generator not supported for batch calls. Use achat instead.")

def infer_llm_type(self):
self.config['endpoint'] = "/v1/chat/completions"
return "chat"

def count_tokens(self, text: str) -> int:
return len(text) // 4

async def check_rate_limit(self, token_count: float, **kwargs) -> None:
if "rate_limit" in self.config:
await GLOBAL_LIMITER.try_acquire(
("client", self.name),
self.config["rate_limit"].get(self.name, None),
weight=max(int(token_count), 1),
**kwargs,
)


class AnthropicBatchLLMModel(LLMModel):
# TODO: This class is not implemented yet.

@rate_limited
async def acomplete(self):
raise NotImplementedError("Completion models are not supported yet")

@rate_limited
async def acomplete_iter(self):
raise NotImplementedError("Completion models are not supported yet")

async def _run_chat(sellf):
'''Processes the batch and call the chat completion method'''
...

@rate_limited
async def achat(self, messages):
...

@rate_limited
async def achat_iter(self):
raise NotImplementedError("support to callbacks is not implemented yet")

def infer_llm_type(self):
return "chat" #TODO: Support completion models

def count_tokens(self, text: str) -> int:
return len(text) // 4 #TODO: Check if OpenAI has a method for that. Currently it's not being used. The token usage is directly retrieved from the response.

def __getstate__(self):
# Prevent _router from being pickled, SEE: https://stackoverflow.com/a/2345953
state = super().__getstate__()
state["__dict__"] = state["__dict__"].copy()
state["__dict__"].pop("_router", None)
return state

async def check_rate_limit(self, token_count: float, **kwargs) -> None:
if "rate_limit" in self.config:
await GLOBAL_LIMITER.try_acquire(
("client", self.name),
self.config["rate_limit"].get(self.name, None),
weight=max(int(token_count), 1),
**kwargs,
)


def cosine_similarity(a, b):
norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)
Expand Down
25 changes: 24 additions & 1 deletion paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
except ImportError:
HAS_LDP_INSTALLED = False

from paperqa.llms import EmbeddingModel, LiteLLMModel, embedding_model_factory
from paperqa.llms import EmbeddingModel, LiteLLMModel, OpenAIBatchLLMModel, embedding_model_factory
from paperqa.prompts import (
CONTEXT_INNER_PROMPT,
CONTEXT_OUTER_PROMPT,
Expand Down Expand Up @@ -564,6 +564,15 @@ def make_default_litellm_model_list_settings(
]
}

def make_default_openai_batch_llm_settings(
llm: str, temperature: float = 0.0
) -> dict:
return {
"model": llm,
"temperature": temperature,
"max_tokens": 2048,

}

class Settings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")
Expand Down Expand Up @@ -596,6 +605,10 @@ class Settings(BaseSettings):
" router_kwargs key with router kwargs as values."
),
)
use_batch_in_summary: bool = Field(
default=False,
description="Whether to use batch API for LLMs in summarization",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a few words on how the batches are actually formed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you can say something like:

Whether to use batch API for LLMs in summarization, which means multiple messages are sent in one API request.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was updated to:

"Whether to use batch API for LLMs in summarization, "
"which means multiple messages are sent in one API request "
"to the LLM provider's batch API."
"This option is only available for Claude(https://docs.anthropic.com/en/api/creating-message-batches)"
"and OpenAI (https://platform.openai.com/docs/guides/batch) chat models."

)
embedding: str = Field(
default="text-embedding-3-small",
description="Default embedding model for texts",
Expand Down Expand Up @@ -780,6 +793,16 @@ def get_llm(self) -> LiteLLMModel:
)

def get_summary_llm(self) -> LiteLLMModel:
if self.use_batch_in_summary:
# TODO: support other LLM providers as well.
# TODO: Make it fail if we don't support the batchAPI for the LLM being used
return OpenAIBatchLLMModel(
name=self.summary_llm,
config=self.summary_llm_config
or make_default_openai_batch_llm_settings(
self.summary_llm, self.temperature
),
)
return LiteLLMModel(
name=self.summary_llm,
config=self.summary_llm_config
Expand Down
Loading
Loading