Skip to content

Commit

Permalink
Embeddings index checkpointing, closes #695
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Feb 19, 2025
1 parent 8a3ed27 commit a1a7d34
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 18 deletions.
12 changes: 7 additions & 5 deletions src/python/txtai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,21 @@ def score(self, documents):
if self.isweighted():
self.scoring.index(Stream(self)(documents))

def index(self, documents, reindex=False):
def index(self, documents, reindex=False, checkpoint=None):
"""
Builds an embeddings index. This method overwrites an existing index.
Args:
documents: iterable of (id, data, tags), (id, data) or data
reindex: if this is a reindex operation in which case database creation is skipped, defaults to False
checkpoint: optional checkpoint directory, enables indexing restart
"""

# Initialize index
self.initindex(reindex)

# Create transform and stream
transform = Transform(self, Action.REINDEX if reindex else Action.INDEX)
transform = Transform(self, Action.REINDEX if reindex else Action.INDEX, checkpoint)
stream = Stream(self, Action.REINDEX if reindex else Action.INDEX)

with tempfile.NamedTemporaryFile(mode="wb", suffix=".npy") as buffer:
Expand Down Expand Up @@ -153,23 +154,24 @@ def index(self, documents, reindex=False):
if self.graph:
self.graph.index(Search(self, indexonly=True), Ids(self), self.batchsimilarity)

def upsert(self, documents):
def upsert(self, documents, checkpoint=None):
"""
Runs an embeddings upsert operation. If the index exists, new data is
appended to the index, existing data is updated. If the index doesn't exist,
this method runs a standard index operation.
Args:
documents: iterable of (id, data, tags), (id, data) or data
checkpoint: optional checkpoint directory, enables indexing restart
"""

# Run standard insert if index doesn't exist or it has no records
if not self.count():
self.index(documents)
self.index(documents, checkpoint=checkpoint)
return

# Create transform and stream
transform = Transform(self, Action.UPSERT)
transform = Transform(self, Action.UPSERT, checkpoint=checkpoint)
stream = Stream(self, Action.UPSERT)

with tempfile.NamedTemporaryFile(mode="wb", suffix=".npy") as buffer:
Expand Down
11 changes: 7 additions & 4 deletions src/python/txtai/embeddings/index/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ class Transform:
Executes a transform. Processes a stream of documents, loads batches into enabled data stores and vectorizes documents.
"""

def __init__(self, embeddings, action):
def __init__(self, embeddings, action, checkpoint=None):
"""
Creates a new transform.
Args:
embeddings: embeddings instance
action: index action
checkpoint: optional checkpoint directory, enables indexing restart
"""

self.embeddings = embeddings
self.action = action
self.checkpoint = checkpoint

# Alias embeddings attributes
self.config = embeddings.config
Expand Down Expand Up @@ -91,7 +93,7 @@ def vectors(self, documents, buffer):
"""

# Consume stream and transform documents to vectors
ids, dimensions, batches, stream = self.model.index(self.stream(documents), self.batch)
ids, dimensions, batches, stream = self.model.index(self.stream(documents), self.batch, self.checkpoint)

# Check that embeddings are available and load as a memmap
embeddings = None
Expand All @@ -108,8 +110,9 @@ def vectors(self, documents, buffer):
embeddings[x : x + batch.shape[0]] = batch
x += batch.shape[0]

# Remove temporary file
os.remove(stream)
# Remove temporary file (if checkpointing is disabled)
if not self.checkpoint:
os.remove(stream)

return (ids, dimensions, embeddings)

Expand Down
1 change: 1 addition & 0 deletions src/python/txtai/vectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
from .litellm import LiteLLM
from .llama import LlamaCpp
from .m2v import Model2Vec
from .recovery import Recovery
from .sbert import STVectors
from .words import WordVectors
62 changes: 55 additions & 7 deletions src/python/txtai/vectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
Vectors module
"""

import json
import os
import tempfile
import uuid

import numpy as np

from ..pipeline import Tokenizer

from .recovery import Recovery


class Vectors:
"""
Expand Down Expand Up @@ -100,39 +105,44 @@ def load(self, path):

return model

def index(self, documents, batchsize=500):
def index(self, documents, batchsize=500, checkpoint=None):
"""
Converts a list of documents to a temporary file with embeddings arrays. Returns a tuple of document ids,
number of dimensions and temporary file with embeddings.
Args:
documents: list of (id, data, tags)
batchsize: index batch size
checkpoint: optional checkpoint directory, enables indexing restart
Returns:
(ids, dimensions, stream)
"""

ids, dimensions, batches, stream = [], None, 0, None

# Generate recovery config if checkpoint is set
vectorsid = self.vectorsid() if checkpoint else None
recovery = Recovery(checkpoint, vectorsid) if checkpoint else None

# Convert all documents to embedding arrays, stream embeddings to disk to control memory usage
with tempfile.NamedTemporaryFile(mode="wb", suffix=".npy", delete=False) as output:
with self.spool(checkpoint, vectorsid) as output:
stream = output.name
batch = []
for document in documents:
batch.append(document)

if len(batch) == batchsize:
# Convert batch to embeddings
uids, dimensions = self.batch(batch, output)
uids, dimensions = self.batch(batch, output, recovery)
ids.extend(uids)
batches += 1

batch = []

# Final batch
if batch:
uids, dimensions = self.batch(batch, output)
uids, dimensions = self.batch(batch, output, recovery)
ids.extend(uids)
batches += 1

Expand Down Expand Up @@ -180,13 +190,50 @@ def batchtransform(self, documents, category=None):

return self.vectorize(documents)

def batch(self, documents, output):
def vectorsid(self):
"""
Generates vectors uid for this vectors instance.
Returns:
vectors uid
"""

# Select config options that determine uniqueness
select = ["path", "method", "tokenizer", "maxlength", "tokenize", "instructions", "dimensionality", "quantize"]
config = {k: v for k, v in self.config.items() if k in select}
config.update(self.config.get("vectors", {}))

# Generate a deterministic UUID
return str(uuid.uuid5(uuid.NAMESPACE_DNS, json.dumps(config, sort_keys=True)))

def spool(self, checkpoint, vectorsid):
"""
Opens a spool file for queuing generated vectors.
Args:
checkpoint: optional checkpoint directory, enables indexing restart
vectorsid: vectors uid for current configuration
Returns:
vectors spool file
"""

# Spool to vectors checkpoint file
if checkpoint:
os.makedirs(checkpoint, exist_ok=True)
return open(f"{checkpoint}/{vectorsid}", "wb")

# Spool to temporary file
return tempfile.NamedTemporaryFile(mode="wb", suffix=".npy", delete=False)

def batch(self, documents, output, recovery):
"""
Builds a batch of embeddings.
Args:
documents: list of documents used to build embeddings
output: output temp file to store embeddings
recovery: optional recovery instance
Returns:
(ids, dimensions) list of ids and number of dimensions in embeddings
Expand All @@ -197,8 +244,9 @@ def batch(self, documents, output):
documents = [self.prepare(data, "data") for _, data, _ in documents]
dimensions = None

# Build embeddings
embeddings = self.vectorize(documents)
# Attempt to read embeddings from a recovery file
embeddings = recovery() if recovery else None
embeddings = self.vectorize(documents) if embeddings is None else embeddings
if embeddings is not None:
dimensions = embeddings.shape[1]
np.save(output, embeddings)
Expand Down
58 changes: 58 additions & 0 deletions src/python/txtai/vectors/recovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Recovery module
"""

import os
import shutil

import numpy as np


class Recovery:
"""
Vector embeddings recovery. This class handles streaming embeddings from a vector checkpoint file.
"""

def __init__(self, checkpoint, vectorsid):
"""
Creates a Recovery instance.
Args:
checkpoint: checkpoint directory
vectorsid: vectors uid for current configuration
"""

self.spool, self.path = None, None

# Get unique file id
path = f"{checkpoint}/{vectorsid}"
if os.path.exists(path):
# Generate recovery path
self.path = f"{checkpoint}/recovery"

# Copy current checkpoint to recovery
shutil.copyfile(path, self.path)

# Open file an return
# pylint: disable=R1732
self.spool = open(self.path, "rb")

def __call__(self):
"""
Reads and returns the next batch of embeddings.
Returns
batch of embeddings
"""

try:
return np.load(self.spool) if self.spool else None
except EOFError:
# End of spool file, cleanup
self.spool.close()
os.remove(self.path)

# Clear parameters
self.spool, self.path = None, None

return None
2 changes: 1 addition & 1 deletion src/python/txtai/vectors/words.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def encode(self, data):

return np.array(embeddings, dtype=np.float32)

def index(self, documents, batchsize=500):
def index(self, documents, batchsize=500, checkpoint=None):
# Derive number of parallel processes
parallel = self.config.get("parallel", True)
parallel = os.cpu_count() if parallel and isinstance(parallel, bool) else int(parallel)
Expand Down
18 changes: 18 additions & 0 deletions test/python/testdatabase/testrdbms.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ def testAutoId(self):
result = embeddings.search(self.data[4], 1)[0]
self.assertEqual(len(result["id"]), 36)

def testCheckpoint(self):
"""
Test embeddings index checkpoints
"""

# Checkpoint directory
checkpoint = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.checkpoint")

# Save embeddings checkpoint
self.embeddings.index(self.data, checkpoint=checkpoint)

# Reindex with checkpoint
self.embeddings.index(self.data, checkpoint=checkpoint)

# Search for best match
result = self.embeddings.search("feel good story", 1)[0]
self.assertEqual(result["text"], self.data[4])

def testColumns(self):
"""
Test custom text/object columns
Expand Down
22 changes: 21 additions & 1 deletion test/python/testvectors/testvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
Vectors module tests
"""

import os
import tempfile
import unittest

import numpy as np

from txtai.vectors import Vectors
from txtai.vectors import Vectors, Recovery


class TestVectors(unittest.TestCase):
Expand Down Expand Up @@ -46,3 +48,21 @@ def testNormalize(self):
# Test both data arrays are the same and changed from original
self.assertTrue(np.allclose(data1, data2))
self.assertFalse(np.allclose(data1, original))

def testRecovery(self):
"""
Test vectors recovery failure
"""

# Checkpoint directory
checkpoint = os.path.join(tempfile.gettempdir(), "recovery")
os.makedirs(checkpoint, exist_ok=True)

# Create empty file
# pylint: disable=R1732
f = open(os.path.join(checkpoint, "id"), "w", encoding="utf-8")
f.close()

# Create the recovery instance with an empty checkpoint file
recovery = Recovery(checkpoint, "id")
self.assertIsNone(recovery())

0 comments on commit a1a7d34

Please sign in to comment.