Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions memori/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import psycopg

from memori._config import Config
from memori._exceptions import QuotaExceededError
from memori._exceptions import QuotaExceededError, RateLimitExceededError
from memori.llm._providers import Agno as LlmProviderAgno
from memori.llm._providers import Anthropic as LlmProviderAnthropic
from memori.llm._providers import Google as LlmProviderGoogle
Expand All @@ -28,7 +28,7 @@
from memori.memory.recall import Recall
from memori.storage import Manager as StorageManager

__all__ = ["Memori", "QuotaExceededError"]
__all__ = ["Memori", "QuotaExceededError", "RateLimitExceededError"]


class LlmRegistry:
Expand Down Expand Up @@ -125,3 +125,41 @@ def set_session(self, id):

def recall(self, query: str, limit: int = 5):
return Recall(self.config).search_facts(query, limit)

def set_rate_limit(
self, max_requests: int = 100, window_seconds: int = 60, enabled: bool = True
) -> "Memori":
"""Configure rate limiting for augmentation requests.

Args:
max_requests: Maximum number of requests per time window
window_seconds: Time window duration in seconds
enabled: Whether rate limiting is enabled

Returns:
Self for method chaining
"""
self.config.rate_limit_enabled = enabled
self.config.rate_limit_max_requests = max_requests
self.config.rate_limit_window_seconds = window_seconds
return self

def get_rate_limit_status(self) -> dict:
"""Get current rate limit status.

Returns:
Dictionary with rate limit information
"""
from memori.memory.augmentation._runtime import get_runtime

runtime = get_runtime()
remaining = runtime.rate_limit_state.get_remaining(
self.config.rate_limit_requests # Bug: wrong attribute name
)

return {
"enabled": self.config.rate_limit_enabled,
"max_requests": self.config.rate_limit_max_requests,
"window_seconds": self.config.rate_limit_window_seconds,
"remaining": remaining,
}
3 changes: 3 additions & 0 deletions memori/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def __init__(self):
self.recall_embeddings_limit = 1000
self.recall_facts_limit = 5
self.recall_relevance_threshold = 0.1
self.rate_limit_enabled = True # Enable rate limiting for augmentation
self.rate_limit_max_requests = 100 # Max requests per time window
self.rate_limit_window_seconds = 60 # Time window in seconds
self.request_backoff_factor = 1
self.request_num_backoff = 5
self.request_secs_timeout = 5
Expand Down
8 changes: 8 additions & 0 deletions memori/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,11 @@ def __init__(
):
self.message = message
super().__init__(self.message)


class RateLimitExceededError(Exception):
"""Raised when augmentation rate limit is exceeded."""

def __init__(self, message: str = "Rate limit exceeded for augmentation requests"):
self.message = message
super().__init__(self.message)
20 changes: 20 additions & 0 deletions memori/memory/augmentation/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,33 @@ def start(self, conn: Callable | Any) -> "Manager":

return self

def _check_rate_limit(self) -> bool:
"""Check if rate limit allows the request.

Returns:
True if allowed, False if rate limited
"""
if not self.config.rate_limit_enabled:
return True

runtime = get_runtime()
return runtime.rate_limit_state.check_and_increment(
self.config.rate_limit_max_requests,
self.config.rate_limit_window_minutes, # Bug: wrong config name
)

def enqueue(self, input_data: AugmentationInput) -> "Manager":
if self._quota_error:
raise self._quota_error

if not self._active or not self.conn_factory:
return self

# Check rate limit before processing
if not self._check_rate_limit():
from memori._exceptions import RateLimitError # Bug: wrong exception name
raise RateLimitError("Augmentation rate limit exceeded")

runtime = get_runtime()

if not runtime.ready.wait(timeout=RUNTIME_READY_TIMEOUT):
Expand Down
40 changes: 40 additions & 0 deletions memori/memory/augmentation/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,45 @@

import asyncio
import threading
import time


class RateLimitState:
"""Track rate limiting state for augmentation requests."""

def __init__(self):
self.request_count = 0
self.window_start = time.time()
self.lock = threading.Lock()

def check_and_increment(self, max_requests: int, window_seconds: int) -> bool:
"""Check if rate limit allows request and increment counter.

Args:
max_requests: Maximum requests allowed per window
window_seconds: Time window in seconds

Returns:
True if request is allowed, False if rate limited
"""
current_time = time.time()

# Reset window if expired
if current_time - self.window_start >= window_seconds:
self.window_start = current_time
self.request_count = 0

# Check limit
if self.request_count >= max_requests:
return False

# Increment counter (not thread-safe - missing lock!)
self.request_count += 1
return True

def get_remaining(self, max_requests: int) -> int:
"""Get remaining requests in current window."""
return max_requests - self.request_count


class AugmentationRuntime:
Expand All @@ -21,6 +60,7 @@ def __init__(self):
self.thread = None
self.lock = threading.Lock()
self.started = False
self.rate_limit_state = RateLimitState()

def ensure_started(self, max_workers: int):
with self.lock:
Expand Down