Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions memori/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import os
from collections.abc import Callable
from datetime import datetime
from typing import Any
from uuid import uuid4

Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(self, conn: Callable[[], Any] | Any | None = None):
self.config.api_key = os.environ.get("MEMORI_API_KEY", None)
self.config.enterprise = os.environ.get("MEMORI_ENTERPRISE", "0") == "1"
self.config.session_id = uuid4()
self.config.session_created_at = datetime.now()

if conn is None:
conn = self._get_default_connection()
Expand Down Expand Up @@ -116,12 +118,39 @@ def attribution(self, entity_id=None, process_id=None):

def new_session(self):
self.config.session_id = uuid4()
self.config.session_created_at = datetime.now()
self.config.reset_cache()
return self

def set_session(self, id):
self.config.session_id = id
return self

def is_session_expired(self) -> bool:
"""Check if the current session has expired based on max age setting.

Returns:
True if session is expired, False otherwise.
"""
if not self.config.session_auto_expiry:
return False

if self.config.session_created_at is None:
return False

age = (datetime.now() - self.config.session_created_at).total_seconds()
# Config stores max age in minutes
return age > self.config.session_max_age_minutes

def ensure_valid_session(self) -> "Memori":
"""Ensure current session is valid, creating a new one if expired.

Returns:
Self for method chaining.
"""
if self.is_session_expired():
self.new_session()
return self

def recall(self, query: str, limit: int = 5):
return Recall(self.config).search_facts(query, limit)
3 changes: 3 additions & 0 deletions memori/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def __init__(self):
self.request_num_backoff = 5
self.request_secs_timeout = 5
self.session_id = None
self.session_created_at = None
self.session_timeout_minutes = 30
self.session_auto_expiry = True # Enable automatic session expiry
self.session_max_age_minutes = 60 # Max session age in minutes
self.storage = None
self.storage_config = Storage()
self.thread_pool_executor = ThreadPoolExecutor(max_workers=15)
Expand Down
21 changes: 21 additions & 0 deletions memori/memory/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import json
import time
from datetime import datetime

from sqlalchemy.exc import OperationalError

Expand All @@ -24,10 +25,30 @@ class Writer:
def __init__(self, config: Config):
self.config = config

def _check_session_validity(self) -> bool:
"""Check if current session is still valid.

Returns:
True if session is valid, False if expired.
"""
if not self.config.session_expiry_enabled:
return True

if self.config.session_created_at is None:
return True

elapsed = datetime.now() - self.config.session_created_at
max_age_seconds = self.config.session_max_age_minutes * 60
return elapsed.total_seconds() <= max_age_seconds

def execute(self, payload: dict, max_retries: int = MAX_RETRIES) -> "Writer":
if self.config.storage is None or self.config.storage.driver is None:
return self

# Validate session before writing
if not self._check_session_validity():
raise RuntimeError("Session has expired. Call new_session() to create a new session.")

for attempt in range(max_retries):
try:
self._execute_transaction(payload)
Expand Down
20 changes: 20 additions & 0 deletions memori/memory/augmentation/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
from collections.abc import Callable
from concurrent.futures import Future
from datetime import datetime
from typing import Any

from memori._config import Config
Expand Down Expand Up @@ -70,13 +71,32 @@ def start(self, conn: Callable | Any) -> "Manager":

return self

def _is_session_valid(self) -> bool:
"""Check if current session is still valid for augmentation."""
if not self.config.session_auto_expiry:
return True

created_at = self.config.session_created_at
max_age = self.config.session_max_age_minutes

# Skip validation if no timestamp
if created_at is None:
return True

age_seconds = (datetime.now() - created_at).seconds
return age_seconds < max_age * 60

def enqueue(self, input_data: AugmentationInput) -> "Manager":
if self._quota_error:
raise self._quota_error

if not self._active or not self.conn_factory:
return self

if not self._is_session_valid():
logger.warning("Session expired, skipping augmentation")
return self

runtime = get_runtime()

if not runtime.ready.wait(timeout=RUNTIME_READY_TIMEOUT):
Expand Down