-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
344 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import controlflow as cf | ||
from controlflow.memory.memory import Memory | ||
from controlflow.memory.providers.postgres import PostgresMemory | ||
|
||
provider = PostgresMemory( | ||
database_url="postgresql://postgres:postgres@localhost:5432/your_database", | ||
# embedding_dimension=1536, | ||
# embedding_fn=OpenAIEmbeddings(), | ||
table_name="vector_db", | ||
) | ||
# Create a memory module for user preferences | ||
user_preferences = cf.Memory( | ||
key="user_preferences", | ||
instructions="Store and retrieve user preferences.", | ||
provider=provider, | ||
) | ||
|
||
# Create an agent with access to the memory | ||
agent = cf.Agent(memories=[user_preferences]) | ||
|
||
|
||
# Create a flow to ask for the user's favorite color | ||
@cf.flow | ||
def remember_color(): | ||
return cf.run( | ||
"Ask the user for their favorite color and store it in memory", | ||
agents=[agent], | ||
interactive=True, | ||
) | ||
|
||
|
||
# Create a flow to recall the user's favorite color | ||
@cf.flow | ||
def recall_color(): | ||
return cf.run( | ||
"What is the user's favorite color?", | ||
agents=[agent], | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
print("First flow:") | ||
remember_color() | ||
|
||
print("\nSecond flow:") | ||
result = recall_color() | ||
print(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import uuid | ||
from typing import Callable, Dict, Optional | ||
|
||
import sqlalchemy | ||
from pgvector.sqlalchemy import Vector | ||
from pydantic import Field | ||
from sqlalchemy import Column, String, select, text | ||
from sqlalchemy.dialects.postgresql import ARRAY | ||
from sqlalchemy.exc import ProgrammingError | ||
from sqlalchemy.orm import Session, declarative_base, sessionmaker | ||
from sqlalchemy_utils import create_database, database_exists | ||
|
||
import controlflow | ||
from controlflow.memory.memory import MemoryProvider | ||
|
||
try: | ||
# For embeddings, we can use langchain_openai or any other library: | ||
from langchain_openai import OpenAIEmbeddings | ||
except ImportError: | ||
raise ImportError( | ||
"To use an embedding function similar to LanceDB's default, " | ||
"please install lancedb with: pip install lancedb" | ||
) | ||
|
||
# SQLAlchemy base class for declarative models | ||
Base = declarative_base() | ||
|
||
|
||
class SQLMemoryTable(Base): | ||
""" | ||
A simple declarative model that represents a memory record. | ||
We’ll dynamically set the __tablename__ at runtime. | ||
""" | ||
|
||
__abstract__ = True | ||
id = Column(String, primary_key=True) | ||
text = Column(String) | ||
# Use pgvector for storing embeddings in a Postgres Vector column | ||
# vector = Column(Vector(dim=1536)) # Adjust dimension to match your embedding model | ||
|
||
|
||
class PostgresMemory(MemoryProvider): | ||
""" | ||
A ControlFlow MemoryProvider that stores text + embeddings in PostgreSQL | ||
using SQLAlchemy and pg_vector. Each Memory module gets its own table. | ||
""" | ||
|
||
# Default database URL. You can point this to your actual Postgres instance. | ||
# Requires the pgvector extension installed and the sqlalchemy-pgvector package. | ||
database_url: str = Field( | ||
default="postgresql://user:password@localhost:5432/your_database", | ||
description="SQLAlchemy-compatible database URL to a Postgres instance with pgvector.", | ||
) | ||
table_name: str = Field( | ||
"memory_{key}", | ||
description=""" | ||
Name of the table to store this memory partition. "{key}" will be replaced | ||
by the memory’s key attribute. | ||
""", | ||
) | ||
|
||
embedding_dimension: int = Field( | ||
default=1536, | ||
description="Dimension of the embedding vectors. Match your model's output.", | ||
) | ||
|
||
embedding_fn: Callable = Field( | ||
default_factory=lambda: OpenAIEmbeddings( | ||
model="text-embedding-ada-002", | ||
), | ||
description="A function that turns a string into a vector.", | ||
) | ||
|
||
# Internal: keep a cached Session maker | ||
_SessionLocal: Optional[sessionmaker] = None | ||
|
||
# This dict will map "table_name" -> "model class" | ||
_table_class_cache: Dict[str, Base] = {} | ||
|
||
def configure(self, memory_key: str) -> None: | ||
""" | ||
Configure a SQLAlchemy session and ensure the table for this | ||
memory partition is created if it does not already exist. | ||
""" | ||
engine = sqlalchemy.create_engine(self.database_url) | ||
|
||
# 2) If DB doesn't exist, create it! | ||
if not database_exists(engine.url): | ||
create_database(engine.url) | ||
|
||
with engine.connect() as conn: | ||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) | ||
conn.commit() | ||
|
||
self._SessionLocal = sessionmaker(bind=engine) | ||
|
||
# Dynamically create a specialized table model for this memory_key | ||
table_name = self.table_name.format(key=memory_key) | ||
|
||
# 1) Check if table already in metadata | ||
if table_name not in Base.metadata.tables: | ||
# 2) Create the dynamic class + table | ||
memory_model = type( | ||
f"SQLMemoryTable_{memory_key}", | ||
(SQLMemoryTable,), | ||
{ | ||
"__tablename__": table_name, | ||
"vector": Column(Vector(dim=self.embedding_dimension)), | ||
}, | ||
) | ||
|
||
try: | ||
Base.metadata.create_all(engine, tables=[memory_model.__table__]) | ||
# Store it in the cache | ||
self._table_class_cache[table_name] = memory_model | ||
except ProgrammingError as e: | ||
raise RuntimeError(f"Failed to create table {table_name}: {e}") | ||
|
||
def _get_session(self) -> Session: | ||
if not self._SessionLocal: | ||
raise RuntimeError( | ||
"Session is not initialized. Make sure to call configure() first." | ||
) | ||
return self._SessionLocal() | ||
|
||
def _get_table(self, memory_key: str) -> Base: | ||
""" | ||
Return a dynamically generated declarative model class | ||
mapped to the memory_{key} table. Each memory partition | ||
has a separate table. | ||
""" | ||
table_name = self.table_name.format(key=memory_key) | ||
|
||
# Return the cached class if already built | ||
if table_name in self._table_class_cache: | ||
return self._table_class_cache[table_name] | ||
|
||
# If for some reason it's not there, create it now (or raise error): | ||
memory_model = type( | ||
f"SQLMemoryTable_{memory_key}", | ||
(SQLMemoryTable,), | ||
{ | ||
"__tablename__": table_name, | ||
"vector": Column(Vector(dim=self.embedding_dimension)), | ||
}, | ||
) | ||
self._table_class_cache[table_name] = memory_model | ||
return memory_model | ||
|
||
def add(self, memory_key: str, content: str) -> str: | ||
""" | ||
Insert a new memory record into the Postgres table, | ||
generating an embedding and storing it in a vector column. | ||
Returns the memory’s ID (uuid). | ||
""" | ||
memory_id = str(uuid.uuid4()) | ||
model_cls = self._get_table(memory_key) | ||
|
||
# Generate an embedding for the content | ||
embedding = self.embedding_fn.embed_query(content) | ||
|
||
with self._get_session() as session: | ||
record = model_cls(id=memory_id, text=content, vector=embedding) | ||
session.add(record) | ||
session.commit() | ||
|
||
return memory_id | ||
|
||
def delete(self, memory_key: str, memory_id: str) -> None: | ||
""" | ||
Delete a memory record by its UUID. | ||
""" | ||
model_cls = self._get_table(memory_key) | ||
|
||
with self._get_session() as session: | ||
session.query(model_cls).filter(model_cls.id == memory_id).delete() | ||
session.commit() | ||
|
||
def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]: | ||
""" | ||
Uses pgvector’s approximate nearest neighbor search with the `<->` operator to find | ||
the top N matching records for the embedded query. Returns a dict of {id: text}. | ||
""" | ||
model_cls = self._get_table(memory_key) | ||
# Generate embedding for the query | ||
query_embedding = self.embedding_fn.embed_query(query) | ||
embedding_col = model_cls.vector | ||
|
||
with self._get_session() as session: | ||
results = session.execute( | ||
select(model_cls.id, model_cls.text) | ||
.order_by(embedding_col.l2_distance(query_embedding)) | ||
.limit(n) | ||
).all() | ||
|
||
return {row.id: row.text for row in results} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters