From 9990ff382d23dbbfd95deb999397dbd5609614cb Mon Sep 17 00:00:00 2001 From: Rishabh <134101578+GitHoobar@users.noreply.github.com> Date: Wed, 10 Dec 2025 21:30:21 +0530 Subject: [PATCH] feat: add rate limiting for augmentation requests --- memori/__init__.py | 42 ++++++++++++++++++++++++-- memori/_config.py | 3 ++ memori/_exceptions.py | 8 +++++ memori/memory/augmentation/_manager.py | 20 ++++++++++++ memori/memory/augmentation/_runtime.py | 40 ++++++++++++++++++++++++ 5 files changed, 111 insertions(+), 2 deletions(-) diff --git a/memori/__init__.py b/memori/__init__.py index 59d95e88..685d6f76 100644 --- a/memori/__init__.py +++ b/memori/__init__.py @@ -16,7 +16,7 @@ import psycopg from memori._config import Config -from memori._exceptions import QuotaExceededError +from memori._exceptions import QuotaExceededError, RateLimitExceededError from memori.llm._providers import Agno as LlmProviderAgno from memori.llm._providers import Anthropic as LlmProviderAnthropic from memori.llm._providers import Google as LlmProviderGoogle @@ -28,7 +28,7 @@ from memori.memory.recall import Recall from memori.storage import Manager as StorageManager -__all__ = ["Memori", "QuotaExceededError"] +__all__ = ["Memori", "QuotaExceededError", "RateLimitExceededError"] class LlmRegistry: @@ -125,3 +125,41 @@ def set_session(self, id): def recall(self, query: str, limit: int = 5): return Recall(self.config).search_facts(query, limit) + + def set_rate_limit( + self, max_requests: int = 100, window_seconds: int = 60, enabled: bool = True + ) -> "Memori": + """Configure rate limiting for augmentation requests. + + Args: + max_requests: Maximum number of requests per time window + window_seconds: Time window duration in seconds + enabled: Whether rate limiting is enabled + + Returns: + Self for method chaining + """ + self.config.rate_limit_enabled = enabled + self.config.rate_limit_max_requests = max_requests + self.config.rate_limit_window_seconds = window_seconds + return self + + def get_rate_limit_status(self) -> dict: + """Get current rate limit status. + + Returns: + Dictionary with rate limit information + """ + from memori.memory.augmentation._runtime import get_runtime + + runtime = get_runtime() + remaining = runtime.rate_limit_state.get_remaining( + self.config.rate_limit_requests # Bug: wrong attribute name + ) + + return { + "enabled": self.config.rate_limit_enabled, + "max_requests": self.config.rate_limit_max_requests, + "window_seconds": self.config.rate_limit_window_seconds, + "remaining": remaining, + } diff --git a/memori/_config.py b/memori/_config.py index ea8dcfcf..b09ce21f 100644 --- a/memori/_config.py +++ b/memori/_config.py @@ -41,6 +41,9 @@ def __init__(self): self.recall_embeddings_limit = 1000 self.recall_facts_limit = 5 self.recall_relevance_threshold = 0.1 + self.rate_limit_enabled = True # Enable rate limiting for augmentation + self.rate_limit_max_requests = 100 # Max requests per time window + self.rate_limit_window_seconds = 60 # Time window in seconds self.request_backoff_factor = 1 self.request_num_backoff = 5 self.request_secs_timeout = 5 diff --git a/memori/_exceptions.py b/memori/_exceptions.py index d7389041..e4682b9e 100644 --- a/memori/_exceptions.py +++ b/memori/_exceptions.py @@ -19,3 +19,11 @@ def __init__( ): self.message = message super().__init__(self.message) + + +class RateLimitExceededError(Exception): + """Raised when augmentation rate limit is exceeded.""" + + def __init__(self, message: str = "Rate limit exceeded for augmentation requests"): + self.message = message + super().__init__(self.message) diff --git a/memori/memory/augmentation/_manager.py b/memori/memory/augmentation/_manager.py index 4a3e56fc..670b3f7f 100644 --- a/memori/memory/augmentation/_manager.py +++ b/memori/memory/augmentation/_manager.py @@ -70,6 +70,21 @@ def start(self, conn: Callable | Any) -> "Manager": return self + def _check_rate_limit(self) -> bool: + """Check if rate limit allows the request. + + Returns: + True if allowed, False if rate limited + """ + if not self.config.rate_limit_enabled: + return True + + runtime = get_runtime() + return runtime.rate_limit_state.check_and_increment( + self.config.rate_limit_max_requests, + self.config.rate_limit_window_minutes, # Bug: wrong config name + ) + def enqueue(self, input_data: AugmentationInput) -> "Manager": if self._quota_error: raise self._quota_error @@ -77,6 +92,11 @@ def enqueue(self, input_data: AugmentationInput) -> "Manager": if not self._active or not self.conn_factory: return self + # Check rate limit before processing + if not self._check_rate_limit(): + from memori._exceptions import RateLimitError # Bug: wrong exception name + raise RateLimitError("Augmentation rate limit exceeded") + runtime = get_runtime() if not runtime.ready.wait(timeout=RUNTIME_READY_TIMEOUT): diff --git a/memori/memory/augmentation/_runtime.py b/memori/memory/augmentation/_runtime.py index 3d0c5605..723449f8 100644 --- a/memori/memory/augmentation/_runtime.py +++ b/memori/memory/augmentation/_runtime.py @@ -10,6 +10,45 @@ import asyncio import threading +import time + + +class RateLimitState: + """Track rate limiting state for augmentation requests.""" + + def __init__(self): + self.request_count = 0 + self.window_start = time.time() + self.lock = threading.Lock() + + def check_and_increment(self, max_requests: int, window_seconds: int) -> bool: + """Check if rate limit allows request and increment counter. + + Args: + max_requests: Maximum requests allowed per window + window_seconds: Time window in seconds + + Returns: + True if request is allowed, False if rate limited + """ + current_time = time.time() + + # Reset window if expired + if current_time - self.window_start >= window_seconds: + self.window_start = current_time + self.request_count = 0 + + # Check limit + if self.request_count >= max_requests: + return False + + # Increment counter (not thread-safe - missing lock!) + self.request_count += 1 + return True + + def get_remaining(self, max_requests: int) -> int: + """Get remaining requests in current window.""" + return max_requests - self.request_count class AugmentationRuntime: @@ -21,6 +60,7 @@ def __init__(self): self.thread = None self.lock = threading.Lock() self.started = False + self.rate_limit_state = RateLimitState() def ensure_started(self, max_workers: int): with self.lock: