Skip to content

Async pg module w/ connection pooling #398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ cython_debug/
src/controlflow/_version.py
all_code.md
all_docs.md
llm_guides.md
llm_guides.md
55 changes: 55 additions & 0 deletions examples/asyncpg-memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio

import controlflow as cf
from controlflow.memory.async_memory import AsyncMemory
from controlflow.memory.providers.postgres import AsyncPostgresMemory

provider = AsyncPostgresMemory(
database_url="postgresql+psycopg://postgres:postgres@localhost:5432/database",
# embedding_dimension=1536,
# embedding_fn=OpenAIEmbeddings(),
table_name="vector_db_async",
)

# Create a memory module for user preferences
user_preferences = AsyncMemory(
key="user_preferences",
instructions="Store and retrieve user preferences.",
provider=provider,
)

# Create an agent with access to the memory
agent = cf.Agent(memories=[user_preferences])


# Create a flow to ask for the user's favorite color
@cf.flow
async def remember_pet():
return await cf.run_async(
"Ask the user for their favorite animal and store it in memory",
agents=[agent],
interactive=True,
)


# Create a flow to recall the user's favorite color
@cf.flow
async def recall_pet():
return await cf.run_async(
"What is the user's favorite animal?",
agents=[agent],
)


async def main():
print("First flow:")
await remember_pet()

print("\nSecond flow:")
result = await recall_pet()
print(result)
return result


if __name__ == "__main__":
asyncio.run(main())
2 changes: 1 addition & 1 deletion examples/pg-memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from controlflow.memory.providers.postgres import PostgresMemory

provider = PostgresMemory(
database_url="postgresql://postgres:postgres@localhost:5432/your_database",
database_url="postgresql://postgres:postgres@localhost:5432/database",
# embedding_dimension=1536,
# embedding_fn=OpenAIEmbeddings(),
table_name="vector_db",
Expand Down
1 change: 1 addition & 0 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# functions, utilites, and decorators
from .memory import Memory
from .memory.async_memory import AsyncMemory
from .instructions import instructions
from .decorators import flow, task
from .tools import tool
Expand Down
3 changes: 2 additions & 1 deletion src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from controlflow.llm.models import get_model as get_model_from_string
from controlflow.llm.rules import LLMRules
from controlflow.memory import Memory
from controlflow.memory.async_memory import AsyncMemory
from controlflow.tools.tools import (
Tool,
as_lc_tools,
Expand Down Expand Up @@ -82,7 +83,7 @@ class Agent(ControlFlowModel, abc.ABC):
default=False,
description="If True, the agent is given tools for interacting with a human user.",
)
memories: list[Memory] = Field(
memories: list[Memory] | list[AsyncMemory] = Field(
default=[],
description="A list of memory modules for the agent to use.",
)
Expand Down
5 changes: 4 additions & 1 deletion src/controlflow/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import controlflow.utilities
import controlflow.utilities.logging
from controlflow.llm.models import BaseChatModel
from controlflow.memory.async_memory import AsyncMemoryProvider, get_memory_provider
from controlflow.memory.memory import MemoryProvider, get_memory_provider
from controlflow.utilities.general import ControlFlowModel

Expand Down Expand Up @@ -39,7 +40,9 @@ class Defaults(ControlFlowModel):
model: Optional[Any]
history: History
agent: Agent
memory_provider: Optional[Union[MemoryProvider, str]]
memory_provider: (
Optional[Union[MemoryProvider, str]] | Optional[Union[AsyncMemoryProvider, str]]
)

# add more defaults here
def __repr__(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions src/controlflow/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .memory import Memory
from .async_memory import AsyncMemory
149 changes: 149 additions & 0 deletions src/controlflow/memory/async_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import abc
import re
from typing import Dict, List, Optional, Union

from pydantic import Field, field_validator, model_validator

import controlflow
from controlflow.tools.tools import Tool
from controlflow.utilities.general import ControlFlowModel, unwrap
from controlflow.utilities.logging import get_logger

logger = get_logger("controlflow.memory")


def sanitize_memory_key(key: str) -> str:
# Remove any characters that are not alphanumeric or underscore
return re.sub(r"[^a-zA-Z0-9_]", "", key)


class AsyncMemoryProvider(ControlFlowModel, abc.ABC):
async def configure(self, memory_key: str) -> None:
"""Configure the provider for a specific memory."""
pass

@abc.abstractmethod
async def add(self, memory_key: str, content: str) -> str:
"""Create a new memory and return its ID."""
pass

@abc.abstractmethod
async def delete(self, memory_key: str, memory_id: str) -> None:
"""Delete a memory by its ID."""
pass

@abc.abstractmethod
async def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]:
"""Search for n memories using a string query."""
pass


class AsyncMemory(ControlFlowModel):
"""
A memory module is a partitioned collection of memories that are stored in a
vector database, configured by a MemoryProvider.
"""

key: str
instructions: str = Field(
description="Explain what this memory is for and how it should be used."
)
provider: AsyncMemoryProvider = Field(
default_factory=lambda: controlflow.defaults.memory_provider,
validate_default=True,
)

def __hash__(self) -> int:
return id(self)

@field_validator("provider", mode="before")
@classmethod
def validate_provider(
cls, v: Optional[Union[AsyncMemoryProvider, str]]
) -> AsyncMemoryProvider:
if isinstance(v, str):
return get_memory_provider(v)
if v is None:
raise ValueError(
unwrap(
"""
Memory modules require a MemoryProvider to configure the
underlying vector database. No provider was passed as an
argument, and no default value has been configured.

For more information on configuring a memory provider, see
the [Memory
documentation](https://controlflow.ai/patterns/memory), and
please review the [default provider
guide](https://controlflow.ai/guides/default-memory) for
information on configuring a default provider.

Please note that if you are using ControlFlow for the first
time, this error is expected because ControlFlow does not include
vector dependencies by default.
"""
)
)
return v

@field_validator("key")
@classmethod
def validate_key(cls, v: str) -> str:
sanitized = sanitize_memory_key(v)
if sanitized != v:
raise ValueError(
"Memory key must contain only alphanumeric characters and underscores"
)
return sanitized

async def _configure_provider(self):
await self.provider.configure(self.key)
return self

async def add(self, content: str) -> str:
return await self.provider.add(self.key, content)

async def delete(self, memory_id: str) -> None:
await self.provider.delete(self.key, memory_id)

async def search(self, query: str, n: int = 20) -> Dict[str, str]:
return await self.provider.search(self.key, query, n)

def get_tools(self) -> List[Tool]:
return [
Tool.from_function(
self.add,
name=f"store_memory_{self.key}",
description=f'Create a new memory in Memory: "{self.key}".',
),
Tool.from_function(
self.delete,
name=f"delete_memory_{self.key}",
description=f'Delete a memory by its ID from Memory: "{self.key}".',
),
Tool.from_function(
self.search,
name=f"search_memories_{self.key}",
description=f'Search for memories relevant to a string query in Memory: "{self.key}". Returns a dictionary of memory IDs and their contents.',
),
]


def get_memory_provider(provider: str) -> AsyncMemoryProvider:
logger.debug(f"Loading memory provider: {provider}")

# --- async postgres ---

if provider.startswith("async-postgres"):
try:
import sqlalchemy
except ImportError:
raise ImportError(
"""To use async Postgres as a memory provider, please install the `sqlalchemy, `psycopg-pool`,
`psycopg-binary`, and `psycopg` packages."""
)

import controlflow.memory.providers.postgres as postgres_providers

return postgres_providers.AsyncPostgresMemory()
raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.')
4 changes: 3 additions & 1 deletion src/controlflow/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,12 @@ def get_memory_provider(provider: str) -> MemoryProvider:
import sqlalchemy
except ImportError:
raise ImportError(
"To use Postgres as a memory provider, please install the `sqlalchemy` package."
"""To use Postgres as a memory provider, please install the `sqlalchemy, `psycopg-pool`,
`psycopg-binary`, and `psycopg` `psycopg2-binary` packages."""
)

import controlflow.memory.providers.postgres as postgres_providers

return postgres_providers.PostgresMemory()

raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.')
Loading
Loading