Skip to content

Commit

Permalink
fix: getting context window sizes of models without prefixes (#994)
Browse files Browse the repository at this point in the history
* fix: getting context window sizes of models without prefixes

* feat: limit split counts to 1
  • Loading branch information
elisalimli authored May 13, 2024
1 parent 0a2806e commit 4bf763b
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions libs/superagent/app/memory/buffer_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,23 @@
from app.memory.message import BaseMessage

DEFAULT_TOKEN_LIMIT_RATIO = 0.75
DEFAULT_TOKEN_LIMIT = 3000
DEFAULT_TOKEN_LIMIT = 3072


def get_context_window(model: str) -> int:
max_input_tokens = model_cost.get(model, {}).get("max_input_tokens")

# Some models don't have a provider prefix in their name
# But they point to the same model
# Example: claude-3-haiku-20240307 and anthropic/claude-3-haiku-20240307
if not max_input_tokens:
model_parts = model.split("/", 1)
if len(model_parts) > 1:
model_without_prefix = model_parts[1]
max_input_tokens = model_cost.get(model_without_prefix, {}).get(
"max_input_tokens", DEFAULT_TOKEN_LIMIT
)
return max_input_tokens


class BufferMemory(BaseMemory):
Expand All @@ -21,8 +37,9 @@ def __init__(
self.memory_store = memory_store
self.tokenizer_fn = tokenizer_fn
self.model = model
context_window = model_cost.get(self.model, {}).get("max_input_tokens")
self.context_window = max_tokens or context_window * DEFAULT_TOKEN_LIMIT_RATIO
self.context_window = (
max_tokens or get_context_window(model=model) * DEFAULT_TOKEN_LIMIT_RATIO
)

def add_message(self, message: BaseMessage) -> None:
self.memory_store.add_message(message)
Expand Down

0 comments on commit 4bf763b

Please sign in to comment.