Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
patcher9 authored Jan 9, 2025
2 parents 909bc21 + 59d6c90 commit d6df627
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 13 deletions.
47 changes: 47 additions & 0 deletions examples/pg-memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import controlflow as cf
from controlflow.memory.memory import Memory
from controlflow.memory.providers.postgres import PostgresMemory

provider = PostgresMemory(
database_url="postgresql://postgres:postgres@localhost:5432/your_database",
# embedding_dimension=1536,
# embedding_fn=OpenAIEmbeddings(),
table_name="vector_db",
)
# Create a memory module for user preferences
user_preferences = cf.Memory(
key="user_preferences",
instructions="Store and retrieve user preferences.",
provider=provider,
)

# Create an agent with access to the memory
agent = cf.Agent(memories=[user_preferences])


# Create a flow to ask for the user's favorite color
@cf.flow
def remember_color():
return cf.run(
"Ask the user for their favorite color and store it in memory",
agents=[agent],
interactive=True,
)


# Create a flow to recall the user's favorite color
@cf.flow
def recall_color():
return cf.run(
"What is the user's favorite color?",
agents=[agent],
)


if __name__ == "__main__":
print("First flow:")
remember_color()

print("\nSecond flow:")
result = recall_color()
print(result)
12 changes: 12 additions & 0 deletions src/controlflow/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,16 @@ def get_memory_provider(provider: str) -> MemoryProvider:

return lance_providers.LanceMemory()

# --- Postgres ---
elif provider.startswith("postgres"):
try:
import sqlalchemy
except ImportError:
raise ImportError(
"To use Postgres as a memory provider, please install the `sqlalchemy` package."
)

import controlflow.memory.providers.postgres as postgres_providers

return postgres_providers.PostgresMemory()
raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.')
197 changes: 197 additions & 0 deletions src/controlflow/memory/providers/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import uuid
from typing import Callable, Dict, Optional

import sqlalchemy
from pgvector.sqlalchemy import Vector
from pydantic import Field
from sqlalchemy import Column, String, select, text
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from sqlalchemy_utils import create_database, database_exists

import controlflow
from controlflow.memory.memory import MemoryProvider

try:
# For embeddings, we can use langchain_openai or any other library:
from langchain_openai import OpenAIEmbeddings
except ImportError:
raise ImportError(
"To use an embedding function similar to LanceDB's default, "
"please install lancedb with: pip install lancedb"
)

# SQLAlchemy base class for declarative models
Base = declarative_base()


class SQLMemoryTable(Base):
"""
A simple declarative model that represents a memory record.
We’ll dynamically set the __tablename__ at runtime.
"""

__abstract__ = True
id = Column(String, primary_key=True)
text = Column(String)
# Use pgvector for storing embeddings in a Postgres Vector column
# vector = Column(Vector(dim=1536)) # Adjust dimension to match your embedding model


class PostgresMemory(MemoryProvider):
"""
A ControlFlow MemoryProvider that stores text + embeddings in PostgreSQL
using SQLAlchemy and pg_vector. Each Memory module gets its own table.
"""

# Default database URL. You can point this to your actual Postgres instance.
# Requires the pgvector extension installed and the sqlalchemy-pgvector package.
database_url: str = Field(
default="postgresql://user:password@localhost:5432/your_database",
description="SQLAlchemy-compatible database URL to a Postgres instance with pgvector.",
)
table_name: str = Field(
"memory_{key}",
description="""
Name of the table to store this memory partition. "{key}" will be replaced
by the memory’s key attribute.
""",
)

embedding_dimension: int = Field(
default=1536,
description="Dimension of the embedding vectors. Match your model's output.",
)

embedding_fn: Callable = Field(
default_factory=lambda: OpenAIEmbeddings(
model="text-embedding-ada-002",
),
description="A function that turns a string into a vector.",
)

# Internal: keep a cached Session maker
_SessionLocal: Optional[sessionmaker] = None

# This dict will map "table_name" -> "model class"
_table_class_cache: Dict[str, Base] = {}

def configure(self, memory_key: str) -> None:
"""
Configure a SQLAlchemy session and ensure the table for this
memory partition is created if it does not already exist.
"""
engine = sqlalchemy.create_engine(self.database_url)

# 2) If DB doesn't exist, create it!
if not database_exists(engine.url):
create_database(engine.url)

with engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()

self._SessionLocal = sessionmaker(bind=engine)

# Dynamically create a specialized table model for this memory_key
table_name = self.table_name.format(key=memory_key)

# 1) Check if table already in metadata
if table_name not in Base.metadata.tables:
# 2) Create the dynamic class + table
memory_model = type(
f"SQLMemoryTable_{memory_key}",
(SQLMemoryTable,),
{
"__tablename__": table_name,
"vector": Column(Vector(dim=self.embedding_dimension)),
},
)

try:
Base.metadata.create_all(engine, tables=[memory_model.__table__])
# Store it in the cache
self._table_class_cache[table_name] = memory_model
except ProgrammingError as e:
raise RuntimeError(f"Failed to create table {table_name}: {e}")

def _get_session(self) -> Session:
if not self._SessionLocal:
raise RuntimeError(
"Session is not initialized. Make sure to call configure() first."
)
return self._SessionLocal()

def _get_table(self, memory_key: str) -> Base:
"""
Return a dynamically generated declarative model class
mapped to the memory_{key} table. Each memory partition
has a separate table.
"""
table_name = self.table_name.format(key=memory_key)

# Return the cached class if already built
if table_name in self._table_class_cache:
return self._table_class_cache[table_name]

# If for some reason it's not there, create it now (or raise error):
memory_model = type(
f"SQLMemoryTable_{memory_key}",
(SQLMemoryTable,),
{
"__tablename__": table_name,
"vector": Column(Vector(dim=self.embedding_dimension)),
},
)
self._table_class_cache[table_name] = memory_model
return memory_model

def add(self, memory_key: str, content: str) -> str:
"""
Insert a new memory record into the Postgres table,
generating an embedding and storing it in a vector column.
Returns the memory’s ID (uuid).
"""
memory_id = str(uuid.uuid4())
model_cls = self._get_table(memory_key)

# Generate an embedding for the content
embedding = self.embedding_fn.embed_query(content)

with self._get_session() as session:
record = model_cls(id=memory_id, text=content, vector=embedding)
session.add(record)
session.commit()

return memory_id

def delete(self, memory_key: str, memory_id: str) -> None:
"""
Delete a memory record by its UUID.
"""
model_cls = self._get_table(memory_key)

with self._get_session() as session:
session.query(model_cls).filter(model_cls.id == memory_id).delete()
session.commit()

def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]:
"""
Uses pgvector’s approximate nearest neighbor search with the `<->` operator to find
the top N matching records for the embedded query. Returns a dict of {id: text}.
"""
model_cls = self._get_table(memory_key)
# Generate embedding for the query
query_embedding = self.embedding_fn.embed_query(query)
embedding_col = model_cls.vector

with self._get_session() as session:
results = session.execute(
select(model_cls.id, model_cls.text)
.order_by(embedding_col.l2_distance(query_embedding))
.limit(n)
).all()

return {row.id: row.text for row in results}
24 changes: 16 additions & 8 deletions src/controlflow/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,27 @@ def from_function(
):
name = name or fn.__name__
description = description or fn.__doc__ or ""

signature = inspect.signature(fn)
try:
parameters = TypeAdapter(fn).json_schema()
except PydanticSchemaGenerationError:
raise ValueError(
f'Could not generate a schema for tool "{name}". '
"Tool functions must have type hints that are compatible with Pydantic."
)

# If parameters are provided in kwargs, use those instead of generating them
if "parameters" in kwargs:
parameters = kwargs.pop("parameters") # Custom parameters are respected
else:
try:
parameters = TypeAdapter(fn).json_schema()
except PydanticSchemaGenerationError:
raise ValueError(
f'Could not generate a schema for tool "{name}". '
"Tool functions must have type hints that are compatible with Pydantic."
)

# load parameter descriptions
if include_param_descriptions:
for param in signature.parameters.values():
# ensure we only try to add descriptions for parameters that exist in the schema
if param.name not in parameters.get("properties", {}):
continue

# handle Annotated type hints
if typing.get_origin(param.annotation) is Annotated:
param_description = " ".join(
Expand Down
77 changes: 72 additions & 5 deletions tests/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
import controlflow
from controlflow.agents.agent import Agent
from controlflow.llm.messages import ToolMessage
from controlflow.tools.tools import (
Tool,
handle_tool_call,
tool,
)
from controlflow.tools.tools import Tool, handle_tool_call, tool


@pytest.mark.parametrize("style", ["decorator", "class"])
Expand Down Expand Up @@ -170,6 +166,77 @@ def add(a: int, b: float) -> float:
elif style == "decorator":
tool(add)

def test_custom_parameters(self, style):
"""Test that custom parameters override generated ones."""

def add(a: int, b: float):
return a + b

custom_params = {
"type": "object",
"properties": {
"x": {"type": "number", "description": "Custom parameter"},
"y": {"type": "string"},
},
"required": ["x"],
}

if style == "class":
tool_obj = Tool.from_function(add, parameters=custom_params)
elif style == "decorator":
tool_obj = tool(add, parameters=custom_params)

assert tool_obj.parameters == custom_params
assert "a" not in tool_obj.parameters["properties"]
assert "b" not in tool_obj.parameters["properties"]
assert (
tool_obj.parameters["properties"]["x"]["description"] == "Custom parameter"
)

def test_custom_parameters_with_annotations(self, style):
"""Test that annotations still work with custom parameters if param names match."""

def process(x: Annotated[float, "The x value"], y: str):
return x

custom_params = {
"type": "object",
"properties": {"x": {"type": "number"}, "y": {"type": "string"}},
"required": ["x"],
}

if style == "class":
tool_obj = Tool.from_function(process, parameters=custom_params)
elif style == "decorator":
tool_obj = tool(process, parameters=custom_params)

assert tool_obj.parameters["properties"]["x"]["description"] == "The x value"
assert "description" not in tool_obj.parameters["properties"]["y"]

def test_custom_parameters_ignore_descriptions(self, style):
"""Test that include_param_descriptions=False works with custom parameters."""

def process(x: Annotated[float, "The x value"], y: str):
return x

custom_params = {
"type": "object",
"properties": {"x": {"type": "number"}, "y": {"type": "string"}},
"required": ["x"],
}

if style == "class":
tool_obj = Tool.from_function(
process, parameters=custom_params, include_param_descriptions=False
)
elif style == "decorator":
tool_obj = tool(
process, parameters=custom_params, include_param_descriptions=False
)

assert "description" not in tool_obj.parameters["properties"]["x"]
assert "description" not in tool_obj.parameters["properties"]["y"]


class TestToolFunctions:
def test_non_serializable_return_value(self):
Expand Down

0 comments on commit d6df627

Please sign in to comment.