Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latest LiteLLM for development and using pytest-recording #25

Merged
merged 11 commits into from
Sep 11, 2024
Merged
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ repos:
args:
- --word-list=.secrets.allowlist
- --exclude-files=.secrets.baseline$
exclude: tests/cassettes
- repo: https://github.com/jumanjihouse/pre-commit-hooks
rev: 3.0.0
hooks:
Expand All @@ -62,7 +63,7 @@ repos:
additional_dependencies:
- "validate-pyproject-schema-store[all]>=2024.08.19" # For Ruff renaming RUF025 to C420
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.4.6
rev: 0.4.9
hooks:
- id: uv-lock
- repo: https://github.com/pre-commit/mirrors-mypy
Expand All @@ -73,7 +74,7 @@ repos:
- fastapi>=0.109 # Match pyproject.toml
- fhaviary>=0.6 # Match pyproject.toml
- httpx
- litellm>=1.40.9,<=1.40.12 # Match pyproject.toml
- litellm>=1.42.1 # Match pyproject.toml
- numpy
- pydantic~=2.0 # Match pyproject.toml
- tenacity
Expand Down
2 changes: 2 additions & 0 deletions .secrets.allowlist
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
authorization
x-api-key
2 changes: 0 additions & 2 deletions ldp/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
LLMModel,
LLMResult,
MultipleCompletionLLMModel,
process_llm_config,
sum_logprobs,
validate_json_completion,
)
Expand Down Expand Up @@ -35,7 +34,6 @@
"append_to_sys",
"prepend_sys",
"prepend_sys_and_append_sys",
"process_llm_config",
"sum_logprobs",
"validate_json_completion",
]
30 changes: 5 additions & 25 deletions ldp/llms/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,6 @@ def get_supported_openai_params(self) -> list[str] | None:
return litellm.get_supported_openai_params(self.model)


def process_llm_config(llm_config: dict) -> dict:
jamesbraza marked this conversation as resolved.
Show resolved Hide resolved
"""Remove model_type and try to set max_tokens."""
result = llm_config.copy()
result.pop("model_type", None)

if result.get("max_tokens", -1) == -1: # Either max_tokens is missing or it's -1
model = llm_config["model"]
# these are estimates - should probably do something better in the future.
if model.startswith("gpt-4") or (
model.startswith("gpt-3.5") and "0125" in model
):
result["max_tokens"] = 4000
elif "rrr" not in model:
result["max_tokens"] = 2500

return result


def sum_logprobs(choice: litellm.utils.Choices) -> float | None:
"""Calculate the sum of the log probabilities of an LLM completion (a Choices object).

Expand Down Expand Up @@ -154,22 +136,20 @@ async def achat(
self, messages: Iterable[Message], **kwargs
) -> litellm.ModelResponse:
return await litellm.acompletion(
messages=[m.model_dump(exclude_none=True, by_alias=True) for m in messages],
**(process_llm_config(self.config) | kwargs),
messages=[m.model_dump(by_alias=True) for m in messages],
**(self.config | kwargs),
)

async def achat_iter(self, messages: Iterable[Message], **kwargs) -> AsyncGenerator:
return cast(
AsyncGenerator,
await litellm.acompletion(
messages=[
m.model_dump(exclude_none=True, by_alias=True) for m in messages
],
**(process_llm_config(self.config) | kwargs),
messages=[m.model_dump(by_alias=True) for m in messages],
stream=True,
stream_options={
"include_usage": True, # Included to get prompt token counts
},
**(self.config | kwargs),
),
)

Expand Down Expand Up @@ -225,7 +205,7 @@ async def call( # noqa: C901, PLR0915
chat_kwargs["response_format"] = {"type": "json_object"}

# add static configuration to kwargs
chat_kwargs = process_llm_config(self.config) | chat_kwargs
chat_kwargs = self.config | chat_kwargs
n = chat_kwargs.get("n", 1) # number of completions
if n < 1:
raise ValueError("Number of completions (n) must be >= 1.")
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ tests-root = "tests"
check-filenames = true
check-hidden = true
ignore-words-list = "astroid,ser"
skip = "tests/cassettes/*"

[tool.mypy]
# Type-checks the interior of functions without type annotations.
Expand Down Expand Up @@ -199,6 +200,8 @@ disable = [
"too-many-return-statements", # Rely on ruff PLR0911 for this
"too-many-statements", # Rely on ruff PLR0915 for this
"ungrouped-imports", # Rely on ruff I001 for this
"unidiomatic-typecheck", # Rely on ruff E721 for this
"unreachable", # Rely on mypy unreachable for this
"unspecified-encoding", # Don't care to enforce this
"unsubscriptable-object", # Buggy, SEE: https://github.com/PyCQA/pylint/issues/3637
"unsupported-membership-test", # Buggy, SEE: https://github.com/pylint-dev/pylint/issues/3045
Expand Down Expand Up @@ -400,13 +403,14 @@ dev-dependencies = [
"fhaviary[xml]",
"ipython>=8", # Pin to keep recent
"ldp[monitor,nn,server,typing,visualization]",
"litellm>=1.40.9,<=1.40.12", # Pin lower for get_supported_openai_params not requiring custom LLM, upper for https://github.com/BerriAI/litellm/issues/4032
"litellm>=1.42.1", # Pin lower for UnsupportedParamsError fix
"mypy>=1.8", # Pin for mutable-override
"pre-commit~=3.4", # Pin to keep recent
"pydantic~=2.9", # Pydantic 2.9 changed JSON schema exports 'allOf', so ensure tests match
"pylint-pydantic",
"pylint>=3.2", # Pin to keep recent
"pytest-asyncio",
"pytest-recording",
"pytest-rerunfailures",
"pytest-subtests",
"pytest-sugar",
Expand Down
5 changes: 5 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pathlib
from enum import StrEnum


Expand All @@ -6,3 +7,7 @@ class CILLMModelNames(StrEnum):

ANTHROPIC = "claude-3-haiku-20240307" # Cheap and not Anthropic's cutting edge
OPENAI = "gpt-4o-mini-2024-07-18" # Cheap and not OpenAI's cutting edge


TESTS_DIR = pathlib.Path(__file__).parent
CASSETTES_DIR = TESTS_DIR / "cassettes"
Loading