Skip to content

Commit

Permalink
Merge pull request #8 from zzstoatzz/more-concurrent
Browse files Browse the repository at this point in the history
async utils
  • Loading branch information
zzstoatzz authored Nov 4, 2024
2 parents e2b8c0b + 949b51c commit e4e5b7a
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 63 deletions.
23 changes: 18 additions & 5 deletions examples/refresh_tpuf/refresh_namespace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# /// script
# dependencies = [
# "prefect",
# "raggy[tpuf]@git+https://github.com/zzstoatzz/raggy@improve-ingest",
# "raggy[tpuf]@git+https://github.com/zzstoatzz/raggy",
# "trafilatura",
# ]
# ///
Expand Down Expand Up @@ -100,11 +100,12 @@ async def refresh_tpuf_namespace(
namespace_loaders: list[Loader],
reset: bool = False,
batch_size: int = 100,
max_concurrent: int = 8,
):
"""Flow updating vectorstore with info from the Prefect community."""
documents: list[Document] = [
doc
for future in run_loader.map(quote(namespace_loaders)) # type: ignore
for future in run_loader.map(quote(namespace_loaders))
for doc in future.result()
]

Expand All @@ -115,20 +116,32 @@ async def refresh_tpuf_namespace(
await task(tpuf.reset)()
print(f"RESETTING: Deleted all documents from tpuf ns {namespace!r}.")

await task(tpuf.upsert_batched)(documents=documents, batch_size=batch_size)
await task(tpuf.upsert_batched)(
documents=documents, batch_size=batch_size, max_concurrent=max_concurrent
)

print(f"Updated tpuf ns {namespace!r} with {len(documents)} documents.")


@flow(name="Refresh Namespaces", log_prints=True)
async def refresh_tpuf(reset: bool = False, batch_size: int = 100):
async def refresh_tpuf(
reset: bool = False, batch_size: int = 100, test_mode: bool = False
):
for namespace, namespace_loaders in loaders.items():
if test_mode:
namespace = f"TESTING-{namespace}"
await refresh_tpuf_namespace(
namespace, namespace_loaders, reset=reset, batch_size=batch_size
)


if __name__ == "__main__":
import asyncio
import sys

if len(sys.argv) > 1:
test_mode = sys.argv[1] != "prod"
else:
test_mode = True

asyncio.run(refresh_tpuf(reset=True))
asyncio.run(refresh_tpuf(reset=True, test_mode=test_mode))
11 changes: 9 additions & 2 deletions src/raggy/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Annotated

from jinja2 import Environment, Template
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from raggy.utilities.ids import generate_prefixed_uuid
from raggy.utilities.text import count_tokens, extract_keywords, hash_text, split_text
Expand Down Expand Up @@ -32,11 +32,18 @@ class Document(BaseModel):
text: str = Field(..., description="Document text content.")

embedding: list[float] | None = Field(default=None)
metadata: DocumentMetadata = Field(default_factory=DocumentMetadata)
metadata: DocumentMetadata | dict = Field(default_factory=DocumentMetadata)

tokens: int | None = Field(default=None)
keywords: list[str] = Field(default_factory=list)

@field_validator("metadata", mode="before")
@classmethod
def ensure_metadata(cls, v):
if isinstance(v, dict):
return DocumentMetadata(**v)
return v

@model_validator(mode="after")
def ensure_tokens(self):
if self.tokens is None:
Expand Down
60 changes: 31 additions & 29 deletions src/raggy/loaders/web.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import re
from typing import Callable, Self
from urllib.parse import urljoin
Expand All @@ -11,12 +10,11 @@
import raggy
from raggy.documents import Document, document_to_excerpts
from raggy.loaders.base import Loader, MultiLoader
from raggy.utilities.asyncutils import run_concurrent_tasks
from raggy.utilities.collections import batched

user_agent = UserAgent()

URL_CONCURRENCY = asyncio.Semaphore(30)


def ensure_http(url):
if not url.startswith(("http://", "https://")):
Expand All @@ -30,7 +28,6 @@ async def sitemap_search(sitemap_url) -> list[str]:
response.raise_for_status()

soup = BeautifulSoup(response.content, "xml")

return [loc.text for loc in soup.find_all("loc")]


Expand All @@ -51,28 +48,33 @@ class URLLoader(WebLoader):
"""

source_type: str = "url"

urls: list[str] = Field(default_factory=list)

async def load(self) -> list[Document]:
headers = await self.get_headers()
async with AsyncClient(
headers=headers, timeout=30, follow_redirects=True
) as client:
documents = await asyncio.gather(
*[self.load_url(u, client) for u in self.urls], return_exceptions=True

async def load_url_task(url):
try:
return await self.load_url(url, client)
except Exception as e:
self.logger.error(e)
return None

documents = await run_concurrent_tasks(
[lambda u=url: load_url_task(u) for url in self.urls], max_concurrent=30
)

final_documents = []
for d in documents:
if isinstance(d, Exception):
self.logger.error(d)
elif d is not None:
final_documents.extend(await document_to_excerpts(d)) # type: ignore
if d is not None:
final_documents.extend(await document_to_excerpts(d))
return final_documents

async def load_url(self, url, client) -> Document | None:
async with URL_CONCURRENCY:
response = await client.get(url, follow_redirects=True)
response = await client.get(url, follow_redirects=True)

if not response.status_code == 200:
self.logger.warning_style(
Expand All @@ -84,16 +86,17 @@ async def load_url(self, url, client) -> Document | None:
meta_refresh = soup.find(
"meta", attrs={"http-equiv": re.compile(r"refresh", re.I)}
)
if meta_refresh:
refresh_content = meta_refresh.get("content")
redirect_url_match = re.search(r"url=([\S]+)", refresh_content, re.I)
if redirect_url_match:
redirect_url = redirect_url_match.group(1)
# join base url with relative url
redirect_url = urljoin(str(response.url), redirect_url)
# Now ensure the URL includes the protocol
redirect_url = ensure_http(redirect_url)
response = await client.get(redirect_url, follow_redirects=True)
if meta_refresh and isinstance(meta_refresh, BeautifulSoup.Tag):
content = meta_refresh.get("content", "")
if isinstance(content, str):
redirect_url_match = re.search(r"url=([\S]+)", content, re.I)
if redirect_url_match:
redirect_url = redirect_url_match.group(1)
# join base url with relative url
redirect_url = urljoin(str(response.url), redirect_url)
# Now ensure the URL includes the protocol
redirect_url = ensure_http(redirect_url)
response = await client.get(redirect_url, follow_redirects=True)

document = await self.response_to_document(response)
if document:
Expand All @@ -103,6 +106,7 @@ async def load_url(self, url, client) -> Document | None:
return document

async def response_to_document(self, response: Response) -> Document:
"""Convert an HTTP response to a Document."""
return Document(
text=await self.get_document_text(response),
metadata=dict(
Expand All @@ -128,17 +132,15 @@ async def get_document_text(self, response: Response) -> str:

class SitemapLoader(URLLoader):
"""A loader that loads URLs from a sitemap.
Attributes:
include: A list of strings or regular expressions. Only URLs that match one of these will be included.
exclude: A list of strings or regular expressions. URLs that match one of these will be excluded.
url_loader: The loader to use for loading the URLs.
Examples:
Load all URLs from a sitemap:
```python
from raggy.loaders.web import SitemapLoader
loader = SitemapLoader(urls=["https://askmarvin.ai/sitemap.xml"])
loader = SitemapLoader(urls=["https://controlflow.ai/sitemap.xml"])
documents = await loader.load()
print(documents)
```
Expand All @@ -147,11 +149,12 @@ class SitemapLoader(URLLoader):
include: list[str | re.Pattern] = Field(default_factory=list)
exclude: list[str | re.Pattern] = Field(default_factory=list)
url_loader: URLLoader = Field(default_factory=HTMLLoader)

url_processor: Callable[[str], str] = lambda x: x # noqa: E731

async def _get_loader(self: Self) -> MultiLoader:
urls = await asyncio.gather(*[self.load_sitemap(url) for url in self.urls])
urls = await run_concurrent_tasks(
[lambda u=url: self.load_sitemap(u) for url in self.urls], max_concurrent=5
)
return MultiLoader(
loaders=[
type(self.url_loader)(urls=url_batch, headers=await self.get_headers()) # type: ignore
Expand All @@ -169,7 +172,6 @@ async def load_sitemap(self, url: str) -> list[str]:
def is_included(url: str) -> bool:
if not self.include:
return True

return any(
(isinstance(i, str) and i in url)
or (isinstance(i, re.Pattern) and re.search(i, url))
Expand Down
4 changes: 3 additions & 1 deletion src/raggy/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ class Settings(BaseSettings):
extra="allow",
validate_assignment=True,
)

max_concurrent_tasks: int = Field(
default=50, gt=3, description="The maximum number of concurrent tasks to run."
)
html_parser: Callable[[str], str] = default_html_parser

log_level: str = Field(
Expand Down
43 changes: 28 additions & 15 deletions src/raggy/utilities/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from typing import Any, Callable, TypeVar

import anyio
from anyio import create_task_group, to_thread

from raggy import settings

T = TypeVar("T")

Expand All @@ -21,23 +24,33 @@ async def run_sync_in_worker_thread(
__fn: Callable[..., T], *args: Any, **kwargs: Any
) -> T:
"""Runs a sync function in a new worker thread so that the main thread's event loop
is not blocked
is not blocked."""
call = partial(__fn, *args, **kwargs)
return await to_thread.run_sync(
call, cancellable=True, limiter=get_thread_limiter()
)

Unlike the anyio function, this defaults to a cancellable thread and does not allow
passing arguments to the anyio function so users can pass kwargs to their function.

Note that cancellation of threads will not result in interrupted computation, the
thread may continue running — the outcome will just be ignored.
async def run_concurrent_tasks(
tasks: list[Callable],
max_concurrent: int = settings.max_concurrent_tasks,
):
"""Run multiple tasks concurrently with a limit on concurrent execution.
Args:
__fn: The function to run in a worker thread
*args: Positional arguments to pass to the function
**kwargs: Keyword arguments to pass to the function
Returns:
The result of the function
tasks: List of async callables to execute
max_concurrent: Maximum number of tasks to run concurrently
"""
call = partial(__fn, *args, **kwargs)
return await anyio.to_thread.run_sync(
call, cancellable=True, limiter=get_thread_limiter()
)
semaphore = anyio.Semaphore(max_concurrent)
results = []

async def _run_task(task: Callable):
async with semaphore:
result = await task()
results.append(result)

async with create_task_group() as tg:
for task in tasks:
tg.start_soon(_run_task, task)

return results
10 changes: 8 additions & 2 deletions src/raggy/vectorstores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

from raggy.documents import Document as RaggyDocument
from raggy.documents import DocumentMetadata
from raggy.settings import settings
from raggy.utilities.asyncutils import run_sync_in_worker_thread
from raggy.utilities.embeddings import create_openai_embeddings
Expand Down Expand Up @@ -90,7 +91,10 @@ async def add(self, documents: list[RaggyDocument]) -> list[ChromaDocument]:
ids = [doc.id for doc in unique_documents]
texts = [doc.text for doc in unique_documents]
metadatas = [
doc.metadata.model_dump(exclude_none=True) for doc in unique_documents
doc.metadata.model_dump(exclude_none=True)
if isinstance(doc.metadata, DocumentMetadata)
else None
for doc in unique_documents
]

embeddings = await create_openai_embeddings(texts)
Expand Down Expand Up @@ -145,7 +149,9 @@ async def upsert(self, documents: list[RaggyDocument]) -> list[ChromaDocument]:
ids=[document.id for document in documents],
documents=[document.text for document in documents],
metadatas=[
document.metadata.model_dump(exclude_none=True) or None
document.metadata.model_dump(exclude_none=True)
if isinstance(document.metadata, DocumentMetadata)
else None
for document in documents
],
embeddings=await create_openai_embeddings(
Expand Down
32 changes: 23 additions & 9 deletions src/raggy/vectorstores/tpuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from turbopuffer.vectors import VectorResult

from raggy.documents import Document
from raggy.utilities.asyncutils import run_sync_in_worker_thread
from raggy.utilities.asyncutils import run_concurrent_tasks, run_sync_in_worker_thread
from raggy.utilities.embeddings import create_openai_embeddings
from raggy.utilities.text import slice_tokens
from raggy.vectorstores.base import Vectorstore
Expand All @@ -27,8 +27,10 @@ class TurboPufferSettings(BaseSettings):
extra="ignore",
)

api_key: SecretStr
default_namespace: str = "raggy"
api_key: SecretStr = Field(
default=..., description="The API key for the TurboPuffer instance."
)
default_namespace: str = Field(default="raggy")

@model_validator(mode="after")
def set_api_key(self):
Expand Down Expand Up @@ -151,20 +153,32 @@ async def upsert_batched(
self,
documents: Iterable[Document],
batch_size: int = 100,
max_concurrent: int = 25,
):
"""Upsert documents in batches to avoid memory issues with large datasets.
"""Upsert documents in batches concurrently.
Args:
documents: Iterable of documents to upsert
batch_size: Maximum number of documents to upsert in each batch
batch_size: Maximum number of documents per batch
max_concurrent: Maximum number of concurrent upsert operations
"""
document_list = list(documents)
total_docs = len(document_list)
batches = [
document_list[i : i + batch_size]
for i in range(0, len(document_list), batch_size)
]

for i in range(0, total_docs, batch_size):
batch = document_list[i : i + batch_size]
async def process_batch(batch: list[Document], batch_num: int):
await self.upsert(documents=batch)
print(f"Upserted batch {i//batch_size + 1} ({len(batch)} documents)")
print(
f"Upserted batch {batch_num + 1}/{len(batches)} ({len(batch)} documents)"
)

tasks = [
lambda b=batch, i=i: process_batch(b, i) for i, batch in enumerate(batches)
]

await run_concurrent_tasks(tasks, max_concurrent=max_concurrent)


async def query_namespace(
Expand Down

0 comments on commit e4e5b7a

Please sign in to comment.