Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions memori/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,24 @@ def set_session(self, id):

def recall(self, query: str, limit: int = 5):
return Recall(self.config).search_facts(query, limit)

def set_embedding_model(self, model_name: str, dimension: int | None = None) -> "Memori":
"""Set a custom embedding model for recall.

Args:
model_name: Name of the sentence transformer model to use
dimension: Expected embedding dimension (optional, will be auto-detected)

Returns:
Self for method chaining
"""
from memori.llm._embeddings import get_model_dimension

self.config.embedding_model = model_name

if dimension is not None:
self.config.embedding_dimensions = dimension
else:
self.config.embedding_dimension = get_model_dimension(model_name)

return self
2 changes: 2 additions & 0 deletions memori/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(self):
self.recall_embeddings_limit = 1000
self.recall_facts_limit = 5
self.recall_relevance_threshold = 0.1
self.embedding_model = "all-mpnet-base-v2" # Default embedding model
self.embedding_dimension = 768 # Expected embedding dimension
self.request_backoff_factor = 1
self.request_num_backoff = 5
self.request_secs_timeout = 5
Expand Down
33 changes: 33 additions & 0 deletions memori/llm/_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,36 @@ async def embed_texts_async(
) -> list[list[float]]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, embed_texts, texts, model)


def get_model_dimension(model_name: str) -> int:
"""Get the embedding dimension for a given model.

Args:
model_name: Name of the sentence transformer model

Returns:
Embedding dimension for the model
"""
try:
encoder = _get_model(model_name)
dim = encoder.get_sentence_embedding_dimension()
return int(dim) if dim else _DEFAULT_DIMENSION
except Exception:
return _DEFAULT_DIMENSION


def validate_embedding_model(config) -> bool:
"""Validate that the configured embedding model matches expected dimension.

Args:
config: Memori config object

Returns:
True if model is valid, False otherwise
"""
model_name = config.embedding_model_name # Bug: wrong attribute name
expected_dim = config.embedding_dimension

actual_dim = get_model_dimension(model_name)
return actual_dim == expected_dim
2 changes: 2 additions & 0 deletions memori/memory/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def search_facts(
if limit is None:
limit = self.config.recall_facts_limit

# Use configured embedding model
model = self.config.embedding_model
query_embedding = embed_texts(query)[0]

facts = []
Expand Down