diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 008b1825..2346f0b0 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -15,6 +15,8 @@ HybridEmbeddingModel, LiteLLMEmbeddingModel, LiteLLMModel, + OpenAIBatchLLMModel, + AnthropicBatchLLMModel, LLMModel, LLMResult, NumpyVectorStore, @@ -38,6 +40,8 @@ "LLMResult", "LiteLLMEmbeddingModel", "LiteLLMModel", + "OpenAIBatchLLMModel", + "AnthropicBatchLLMModel", "NumpyVectorStore", "PQASession", "QueryRequest", diff --git a/paperqa/core.py b/paperqa/core.py index 5ceb0060..786e26dd 100644 --- a/paperqa/core.py +++ b/paperqa/core.py @@ -115,3 +115,50 @@ async def map_fxn_summary( ), llm_result, ) + +async def gather_with_batch( + matches: list[Text], + question: str, + prompt_runner: PromptRunner | None, + extra_prompt_data: dict[str, str] | None = None, + parser: Callable[[str], dict[str, Any]] | None = None, + callbacks: list[Callable[[str], None]] | None = None, + ) -> list[tuple[Context, LLMResult]]: + """Gathers a batch of results for a given text.""" + data = [ + {"question": question, + "citation": m.name + ": " + m.doc.formatted_citation, + "text": m.text} | + extra_prompt_data or {} + for m in matches + ] + + llm_results = await prompt_runner( + data, + callbacks, + ) + + results_data = [] + scores = [] + for r in llm_results: + try: + results_data.append(parser(r.text)) + scores.append(r.pop("relevance_score")) + # just in case question was present + r.pop("question", None) + except: + results_data.append({}) + scores.append(extract_score(r.text)) + + return [ + ( + Context( + context=strip_citations(llm_result.text), + text=m, + model_extra={}, + score=score, + **r, + ), + llm_result, + ) for r, m, llm_result, score in zip(results_data, matches, llm_results, scores) + ] \ No newline at end of file diff --git a/paperqa/docs.py b/paperqa/docs.py index 5d2e1b45..0f5d8832 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -22,7 +22,11 @@ ) from paperqa.clients import DEFAULT_CLIENTS, DocMetadataClient -from paperqa.core import llm_parse_json, map_fxn_summary +from paperqa.core import ( + llm_parse_json, + map_fxn_summary, + gather_with_batch +) from paperqa.llms import ( EmbeddingModel, LLMModel, @@ -40,6 +44,7 @@ LLMResult, PQASession, Text, + Context, set_llm_session_ids, ) from paperqa.utils import ( @@ -50,6 +55,8 @@ maybe_is_text, md5sum, name_in_text, + extract_score, + strip_citations ) logger = logging.getLogger(__name__) @@ -600,23 +607,35 @@ async def aget_evidence( ) with set_llm_session_ids(session.id): - results = await gather_with_concurrency( - answer_config.max_concurrent_requests, - [ - map_fxn_summary( - text=m, - question=session.question, - prompt_runner=prompt_runner, - extra_prompt_data={ - "summary_length": answer_config.evidence_summary_length, - "citation": f"{m.name}: {m.doc.formatted_citation}", - }, - parser=llm_parse_json if prompt_config.use_json else None, - callbacks=callbacks, - ) - for m in matches - ], - ) + if evidence_settings.use_batch_in_summary: + results = await gather_with_batch( + matches = matches, + question = session.question, + prompt_runner=prompt_runner, + extra_prompt_data={ + "summary_length": answer_config.evidence_summary_length, + }, + parser=llm_parse_json if prompt_config.use_json else None, + callbacks=callbacks, + ) + else: + results = await gather_with_concurrency( + answer_config.max_concurrent_requests, + [ + map_fxn_summary( + text=m, + question=session.question, + prompt_runner=prompt_runner, + extra_prompt_data={ + "summary_length": answer_config.evidence_summary_length, + "citation": f"{m.name}: {m.doc.formatted_citation}", + }, + parser=llm_parse_json if prompt_config.use_json else None, + callbacks=callbacks, + ) + for m in matches + ], + ) for _, llm_result in results: session.add_tokens(llm_result) diff --git a/paperqa/llms.py b/paperqa/llms.py index 13f2424a..c7744645 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -19,6 +19,12 @@ from typing import Any, TypeVar, cast import litellm + +import json +import os +import tempfile +import logging + import numpy as np import tiktoken from pydantic import ( @@ -35,6 +41,8 @@ from paperqa.types import Embeddable, LLMResult from paperqa.utils import is_coroutine_callable +logger = logging.getLogger(__name__) + PromptRunner = Callable[ [dict, list[Callable[[str], None]] | None, str | None], Awaitable[LLMResult], @@ -69,6 +77,24 @@ class EmbeddingModes(StrEnum): QUERY = "query" +class OpenAIBatchStatus(StrEnum): + COMPLETE = "completed" + PROGRESS = "in_progress" + SUCESS = "completed" + FAILURE = "failed" + EXPIRE = "expired" + CANCEL = "cancelled" + + +class AnthropicBatchStatus(StrEnum): + COMPLETE = "ended" + PROGRESS = "in_progress" + SUCESS = "succeeded" + FAILURE = "errored" + EXPIRE = "expired" + CANCEL = "canceled" + + # Estimate from OpenAI's FAQ # https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them CHARACTERS_PER_TOKEN_ASSUMPTION: float = 4.0 @@ -325,7 +351,7 @@ def count_tokens(self, text: str) -> int: async def run_prompt( self, prompt: str, - data: dict, + data: dict | list[dict[str, str]], callbacks: list[Callable] | None = None, name: str | None = None, system_prompt: str | None = default_system_prompt, @@ -754,6 +780,292 @@ def count_tokens(self, text: str) -> int: return litellm.token_counter(model=self.name, text=text) +class OpenAIBatchLLMModel(LLMModel): + """A wrapper around the OpenAI library to use the batch API.""" + name: str = "gpt-4o-mini" + config: dict = Field( + default_factory=dict, + description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", + ) + status: OpenAIBatchStatus = Field( + default=OpenAIBatchStatus, + description="Statuses used to report the status of the API request.", + ) + + def write_jsonl(self, + data: list[dict[str, str]], + filename: str): + + batch_template = { + "custom_id": None, + "method": "POST", + "url": self.config.get('endpoint'), + "body": { + "model": None, + "messages": None, + "max_tokens": None + } + } + with open(filename, "w") as f: + for i, d in enumerate(data): + batch_template["custom_id"] = str(i) + batch_template["body"]["model"] = self.config.get('model') + batch_template["body"]["messages"] = d + batch_template["body"]["max_tokens"] = self.config.get('max_tokens') + f.write(json.dumps(batch_template) + "\n") + + @rate_limited + async def acomplete(self): + raise NotImplementedError("Only chat models are supported by openAI batch API.") + + @rate_limited + async def acomplete_iter(self): + raise NotImplementedError("Async generator not supported for batch calls and nly chat models are supported by openAI batch API.") + + async def _run_chat( + self, + prompt: str, + data: list[dict[str,str]], + callbacks: list[Callable] | None = None, + name: str | None = None, + system_prompt: str = default_system_prompt, + ) -> list[LLMResult]: + if callbacks: + sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] + async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] + + human_message_prompt = {"role": "user", "content": prompt} + + batch = [] + for d in data: + messages = [ + {"role": m["role"], "content": m["content"].format(**d)} + for m in ( + [{"role": "system", "content": system_prompt}, human_message_prompt] + if system_prompt + else [human_message_prompt] + ) + ] + batch.append(messages) + + start_clock = asyncio.get_running_loop().time() + chunks = await self.achat(batch) + batch_time = asyncio.get_running_loop().time() - start_clock + + if callbacks: + for chunk in chunks: + await do_callbacks( + async_callbacks, sync_callbacks, chunk.text, name + ) + + results = [ + LLMResult( + model=self.name, + name=name, + prompt=messages, + prompt_count=chunk.prompt_tokens, + text=chunk.text, + completion_count=chunk.completion_tokens, + seconds_to_first_token=batch_time, + seconds_to_last_token=batch_time, + ) for messages, chunk in zip(batch, chunks) + ] + + return results + + @rate_limited + async def achat(self, + messages: list[dict[str, str]] + ) -> list[Chunk]: + try: + import openai + except ImportError as exc: + raise ImportError( + "Please install paper-qa[batch] to use" + " OpenAIBatchLLMModel." + ) + + client = openai.AsyncOpenAI() + + with tempfile.NamedTemporaryFile(suffix=".jsonl") as tmp_file: + tmp_filename = tmp_file.name + self.write_jsonl(messages, tmp_filename) + file = await client.files.create( + file=open(tmp_filename, "rb"), + purpose="batch" + ) + + batch = await client.batches.create( + input_file_id=file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": "" + } + ) + + start_clock = asyncio.get_running_loop().time() + while batch.status != self.status.COMPLETE: + batch = await client.batches.retrieve(batch.id) + if batch.status == self.status.FAILURE: + raise Exception("Batch failed. \n\nReason: \n" + "\n".join([k.message for k in batch.errors.data])) + elif batch.status == self.status.CANCEL: + raise Exception("Batch was cancelled.") + + batch_time = asyncio.get_running_loop().time() - start_clock + if batch_time > self.config.get('batch_summary_timelimit'): + raise Exception("Batch took too long to complete.") + + logger.info(f"Summary batch status: {batch.status} | Time elapsed: {batch_time}") + await asyncio.sleep(self.config.get('batch_polling_interval')) + + responses = await client.files.content(batch.output_file_id) + response_lines = responses.read().decode('utf-8').splitlines() + responses = [json.loads(line) for line in response_lines] + sorted_responses = sorted(responses, key=lambda x: int(x["custom_id"])) # The batchAPI doesn't guarantee the order of the responses + + chunks = [ + Chunk( + text=response["response"]["body"]["choices"][0]["message"]["content"], + prompt_tokens=response["response"]["body"]["usage"]["prompt_tokens"], + completion_tokens=response["response"]["body"]["usage"]["completion_tokens"], + ) for response in sorted_responses + ] + + return chunks + + @rate_limited + async def achat_iter(self): + raise NotImplementedError("Async generator not supported for batch calls. Use achat instead.") + + def infer_llm_type(self): + self.config['endpoint'] = "/v1/chat/completions" + return "chat" + + def count_tokens(self, text: str) -> int: + return len(text) // 4 + + async def check_rate_limit(self, token_count: float, **kwargs) -> None: + if "rate_limit" in self.config: + await GLOBAL_LIMITER.try_acquire( + ("client", self.name), + self.config["rate_limit"].get(self.name, None), + weight=max(int(token_count), 1), + **kwargs, + ) + + +class AnthropicBatchLLMModel(LLMModel): + """A wrapper around the anthropic library to use the batch API.""" + name: str = "claude-3-5-sonnet-20241022" + config: dict = Field( + default_factory=dict, + description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", + ) + status: AnthropicBatchStatus = Field( + default=AnthropicBatchStatus, + description="Statuses used to report the status of the API request.", + ) + + @rate_limited + async def acomplete(self): + raise NotImplementedError("Completion models are not supported yet") + + @rate_limited + async def acomplete_iter(self): + raise NotImplementedError("Completion models are not supported yet") + + async def _run_chat( + self, + prompt: str, + data: list[dict[str,str]], + callbacks: list[Callable] | None = None, + name: str | None = None, + system_prompt: str = default_system_prompt, + ) -> list[LLMResult]: + if callbacks: + sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] + async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] + + human_message_prompt = {"role": "user", "content": prompt} + + batch = [] + for d in data: + messages = [ + {"role": m["role"], "content": m["content"].format(**d)} + for m in ( + [{"role": "system", "content": system_prompt}, human_message_prompt] + if system_prompt + else [human_message_prompt] + ) + ] + batch.append(messages) + + start_clock = asyncio.get_running_loop().time() + chunks = await self.achat(batch) + batch_time = asyncio.get_running_loop().time() - start_clock + + @rate_limited + async def achat(self, messages: list[dict[str, str]]) -> list[Chunk]: + try: + import anthropic + from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming + from anthropic.types.beta.messages.batch_create_params import Request + except ImportError as exc: + raise ImportError( + "Please install paper-qa[batch] to use" + " AnthropicBatchLLMModel." + ) + + client = anthropic.Anthropic() + + requests = [ + Request( + custom_id=str(i), + params=MessageCreateParamsNonStreaming( + model=self.config.get('model'), + max_tokens=self.config.get('max_tokens'), + messages=m + ) + ) for i, m in enumerate(messages) + ] + + batch = client.beta.messages.batches.create( + requests=requests + ) + + while batch.processing_status != self.status.COMPLETE: + batch = client.beta.messages.batches.retrieve(batch.id) + print(batch.processing_status) + await asyncio.sleep(5) + + responses = client.beta.messages.batches.results(batch.id) + + + # TODO: [WIP] Extract the completions from response. But I am having a bad time waiting for the API to return the results. + return + + + @rate_limited + async def achat_iter(self): + raise NotImplementedError("support to callbacks is not implemented yet") + + def infer_llm_type(self): + return "chat" + + def count_tokens(self, text: str) -> int: + return len(text) // 4 + + async def check_rate_limit(self, token_count: float, **kwargs) -> None: + if "rate_limit" in self.config: + await GLOBAL_LIMITER.try_acquire( + ("client", self.name), + self.config["rate_limit"].get(self.name, None), + weight=max(int(token_count), 1), + **kwargs, + ) + + def cosine_similarity(a, b): norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1) return a @ b.T / norm_product diff --git a/paperqa/settings.py b/paperqa/settings.py index 2a0b457a..f0789d72 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -40,7 +40,13 @@ except ImportError: HAS_LDP_INSTALLED = False -from paperqa.llms import EmbeddingModel, LiteLLMModel, embedding_model_factory +from paperqa.llms import ( + EmbeddingModel, + LiteLLMModel, + OpenAIBatchLLMModel, + AnthropicBatchLLMModel, + embedding_model_factory +) from paperqa.prompts import ( CONTEXT_INNER_PROMPT, CONTEXT_OUTER_PROMPT, @@ -564,6 +570,15 @@ def make_default_litellm_model_list_settings( ] } +def make_default_openai_batch_llm_settings( + llm: str, temperature: float = 0.0 +) -> dict: + return { + "model": llm, + "temperature": temperature, + "max_tokens": 2048, + + } class Settings(BaseSettings): model_config = SettingsConfigDict(extra="ignore") @@ -596,6 +611,24 @@ class Settings(BaseSettings): " router_kwargs key with router kwargs as values." ), ) + use_batch_in_summary: bool = Field( + default=False, + description=( + "Whether to use batch API for LLMs in summarization, " + "which means multiple messages are sent in one API request " + "to the LLM provider's batch API." + "This option is only available for Claude(https://docs.anthropic.com/en/api/creating-message-batches)" + "and OpenAI (https://platform.openai.com/docs/guides/batch) chat models." + ), + ) + batch_summary_timelimit: int = Field( + default=24*60*60, + description="Time limit for batch summarization in seconds", + ) + batch_polling_interval: int = Field( + default=30, + description="Polling interval for batch summarization in seconds", + ) embedding: str = Field( default="text-embedding-3-small", description="Default embedding model for texts", @@ -780,6 +813,33 @@ def get_llm(self) -> LiteLLMModel: ) def get_summary_llm(self) -> LiteLLMModel: + if self.use_batch_in_summary: + import openai + client = openai.OpenAI() + openai_models = [k.id for _, k in enumerate(client.models.list().data) + if k.owned_by in ['system', "openai"]] + if self.summary_llm.startswith("claude-"): + return AnthropicBatchLLMModel( + name=self.summary_llm, + config=self.summary_llm_config + or make_default_openai_batch_llm_settings( + self.summary_llm, self.temperature + ), + ) + elif self.summary_llm in openai_models: + return OpenAIBatchLLMModel( + name=self.summary_llm, + config=self.summary_llm_config + or make_default_openai_batch_llm_settings( + self.summary_llm, self.temperature + ), + ) + else: + raise NotImplementedError( + "`use_batch_in_summary` is set to True, but the summary LLM is not supported" + "for batch processing.\nEither use a Claude or an OpenAI chat model or set " + "`use_batch_in_summary` to False." + ) return LiteLLMModel( name=self.summary_llm, config=self.summary_llm_config diff --git a/pyproject.toml b/pyproject.toml index 41e07aaf..9a30ca6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,10 @@ typing = [ zotero = [ "pyzotero", ] +batch = [ + "openai", + "anthropic", +] [project.scripts] pqa = "paperqa.agents:main" diff --git a/tests/test_llms.py b/tests/test_llms.py index 69bd65c8..ec531a24 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -10,6 +10,8 @@ HybridEmbeddingModel, LiteLLMEmbeddingModel, LiteLLMModel, + OpenAIBatchLLMModel, + AnthropicBatchLLMModel, SentenceTransformerEmbeddingModel, SparseEmbeddingModel, embedding_model_factory, @@ -158,6 +160,114 @@ def test_pickling(self, tmp_path: pathlib.Path) -> None: assert llm.config == rehydrated_llm.config assert llm.router.deployment_names == rehydrated_llm.router.deployment_names +class TestOpenAIBatchLLMModel: + @pytest.fixture(scope="class") + def config(self, request) -> dict[str, Any]: + model_name = request.param + return { + "model": model_name, + "temperature": 0.0, + "max_tokens": 64, + "batch_summary_timelimit": 24*60*60, + "batch_polling_interval": 5, + } + + # @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON])# , "body"]) + @pytest.mark.parametrize( + "config",[ + pytest.param("gpt-4o-mini", id="chat-model"), + pytest.param("gpt-3.5-turbo-instruct", id="completion-model") + ], indirect=True + ) + @pytest.mark.asyncio + async def test_run_prompt(self, config: dict[str, Any], request) -> None: + llm = OpenAIBatchLLMModel(name=config['model'], config=config) + + outputs = [] + def accum(x) -> None: + outputs.append(x) + + async def ac(x) -> None: + pass + + data = [ + {"animal": "duck"}, + {"animal": "dog"}, + {"animal": "cat"} + ] + + if request.node.name == "test_run_prompt[completion-model]": + with pytest.raises(Exception) as e_info: + completion = await llm.run_prompt( + prompt="The {animal} says", + data=data, + ) + assert "Batch failed" in str(e_info.value) + assert "not supported" in str(e_info.value) + + if request.node.name == "test_run_prompt[chat-model]": + completion = await llm.run_prompt( + prompt="The {animal} says", + data=data, + callbacks=[accum, ac], + ) + + assert all([completion[k].model == config['model'] for k, _ in enumerate(data)]) + assert all([completion[k].seconds_to_first_token > 0 for k, _ in enumerate(data)]) + assert all([completion[k].prompt_count > 0 for k, _ in enumerate(data)]) + assert all([completion[k].completion_count > 0 for k, _ in enumerate(data)]) + assert all([completion[k].completion_count <= config['max_tokens'] for k, _ in enumerate(data)]) + assert sum([completion[k].cost for k, _ in enumerate(data)]) > 0 + assert all([str(completion[k]) == outputs[k] for k, _ in enumerate(data)]) + + @pytest.mark.parametrize( + "config",[ + pytest.param("gpt-4o-mini"), + ], indirect=True + ) + def test_pickling(self, tmp_path: pathlib.Path, config: dict[str,Any]) -> None: + pickle_path = tmp_path / "llm_model.pickle" + llm = OpenAIBatchLLMModel( + name="gpt-4o-mini", + config=config, + ) + with pickle_path.open("wb") as f: + pickle.dump(llm, f) + with pickle_path.open("rb") as f: + rehydrated_llm = pickle.load(f) + assert llm.name == rehydrated_llm.name + assert llm.config == rehydrated_llm.config + +class TestAnthropicBatchLLMModel: + @pytest.fixture(scope="class") + def config(self, request) -> dict[str, Any]: + model_name = request.param + return { + "model": model_name, + "temperature": 0.0, + "max_tokens": 64, + } + + @pytest.mark.vcr + @pytest.mark.asyncio + @pytest.mark.parametrize( + "config",[ + pytest.param("claude-3-haiku-20240307", id="chat-model"), + ], indirect=True + ) + async def test_run_prompt(self, config: dict[str, Any], request) -> None: + llm = AnthropicBatchLLMModel(name=config['model'], config=config) + + data = [ + {"animal": "duck"}, + {"animal": "dog"}, + {"animal": "cat"} + ] + + completion = await llm.run_prompt( + prompt="The {animal} says", + data=data, + ) @pytest.mark.asyncio async def test_embedding_model_factory_sentence_transformer() -> None: diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 9dc8dcf3..f5642104 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -505,10 +505,17 @@ async def test_docs_lifecycle(subtests: SubTests, stub_data_dir: Path) -> None: assert docs.texts assert all(t not in docs.texts_index for t in docs.texts) - -def test_evidence(docs_fixture) -> None: +@pytest.mark.parametrize("use_batch", [ + pytest.param(True, id="using-batch"), + pytest.param(False, id="not-using-batch") + ] + ) +def test_evidence(docs_fixture, use_batch) -> None: debug_settings = Settings.from_name("debug") - evidence = docs_fixture.get_evidence( + debug_settings.use_batch_in_summary = use_batch + if use_batch: + debug_settings.summary_llm = "gpt-3.5-turbo" + evidence = docs_fixture.get_evidence( PQASession(question="What does XAI stand for?"), settings=debug_settings, ).contexts