Skip to content

Commit

Permalink
fix: moves caching of vector store to LCModelComponent level (#3435)
Browse files Browse the repository at this point in the history
* refactor LCModelComponent to use a cached vector store to prevent multiple embeddings
  • Loading branch information
jordanrfrazier authored Aug 21, 2024
1 parent a700ea0 commit 96ca71d
Show file tree
Hide file tree
Showing 18 changed files with 163 additions and 91 deletions.
78 changes: 68 additions & 10 deletions src/backend/base/langflow/base/vectorstores/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from abc import ABC, ABCMeta, abstractmethod
from functools import wraps
from typing import List, cast

from langchain_core.documents import Document
Expand All @@ -10,7 +12,48 @@
from langflow.schema import Data


class LCVectorStoreComponent(Component):
def check_cached_vector_store(f):
"""
Decorator to check for cached vector stores, and returns them if they exist.
"""

@wraps(f)
def check_cached(self, *args, **kwargs):
if self._cached_vector_store is not None:
return self._cached_vector_store

result = f(self, *args, **kwargs)
self._cached_vector_store = result
return result

check_cached._is_cached_vector_store_checked = True
return check_cached


class EnforceCacheDecoratorMeta(ABCMeta):
"""
Enforces that abstract methods marked with @check_cached_vector_store are implemented with the decorator.
"""

def __init__(cls, name, bases, dct):
for name, value in dct.items():
if hasattr(value, "__isabstractmethod__"):
cls._check_method_decorator(name, cls)
super().__init__(name, bases, dct)

@staticmethod
def _check_method_decorator(name, cls):
method = getattr(cls, name)

# Check if the method has been marked as decorated by `check_cached_vector_store`
if not getattr(method, "_is_cached_vector_store_checked", False):
raise TypeError(f"Concrete implementation of '{name}' must use '@check_cached_vector_store' decorator.")


class LCVectorStoreComponent(Component, ABC, metaclass=EnforceCacheDecoratorMeta):
# Used to ensure a single vector store is built for each run of the flow
_cached_vector_store: VectorStore | None = None

trace_type = "retriever"
outputs = [
Output(
Expand All @@ -32,7 +75,11 @@ class LCVectorStoreComponent(Component):

def _validate_outputs(self):
# At least these three outputs must be defined
required_output_methods = ["build_base_retriever", "search_documents", "build_vector_store"]
required_output_methods = [
"build_base_retriever",
"search_documents",
"build_vector_store",
]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
if method_name not in output_names:
Expand Down Expand Up @@ -75,17 +122,16 @@ def search_with_vector_store(
def cast_vector_store(self) -> VectorStore:
return cast(VectorStore, self.build_vector_store())

def build_vector_store(self) -> VectorStore:
"""
Builds the Vector Store object.c
"""
raise NotImplementedError("build_vector_store method must be implemented.")

def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
"""
Builds the BaseRetriever object.
"""
vector_store = self.build_vector_store()
if self._cached_vector_store is not None:
vector_store = self._cached_vector_store
else:
vector_store = self.build_vector_store()
self._cached_vector_store = vector_store

if hasattr(vector_store, "as_retriever"):
retriever = vector_store.as_retriever(**self.get_retriever_kwargs())
if self.status is None:
Expand All @@ -103,7 +149,11 @@ def search_documents(self) -> List[Data]:
self.status = ""
return []

vector_store = self.build_vector_store()
if self._cached_vector_store is not None:
vector_store = self._cached_vector_store
else:
vector_store = self.build_vector_store()
self._cached_vector_store = vector_store

logger.debug(f"Search input: {search_query}")
logger.debug(f"Search type: {self.search_type}")
Expand All @@ -120,3 +170,11 @@ def get_retriever_kwargs(self):
Get the retriever kwargs. Implementations can override this method to provide custom retriever kwargs.
"""
return {}

@abstractmethod
@check_cached_vector_store
def build_vector_store(self) -> VectorStore:
"""
Builds the Vector Store object.
"""
raise NotImplementedError("build_vector_store method must be implemented.")
40 changes: 36 additions & 4 deletions src/backend/base/langflow/components/retrievers/CohereRerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@
from langchain_cohere import CohereRerank

from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.field_typing import Retriever
from langflow.io import DropdownInput, HandleInput, IntInput, MessageTextInput, MultilineInput, SecretStrInput
from langflow.field_typing import Retriever, VectorStore
from langflow.io import (
DropdownInput,
HandleInput,
IntInput,
MessageTextInput,
MultilineInput,
SecretStrInput,
)
from langflow.schema import Data
from langflow.template.field.base import Output


class CohereRerankComponent(LCVectorStoreComponent):
Expand All @@ -33,13 +41,34 @@ class CohereRerankComponent(LCVectorStoreComponent):
),
SecretStrInput(name="api_key", display_name="API Key"),
IntInput(name="top_n", display_name="Top N", value=3),
MessageTextInput(name="user_agent", display_name="User Agent", value="langflow", advanced=True),
MessageTextInput(
name="user_agent",
display_name="User Agent",
value="langflow",
advanced=True,
),
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
]

outputs = [
Output(
display_name="Retriever",
name="base_retriever",
method="build_base_retriever",
),
Output(
display_name="Search Results",
name="search_results",
method="search_documents",
),
]

def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
cohere_reranker = CohereRerank(
cohere_api_key=self.api_key, model=self.model, top_n=self.top_n, user_agent=self.user_agent
cohere_api_key=self.api_key,
model=self.model,
top_n=self.top_n,
user_agent=self.user_agent,
)
retriever = ContextualCompressionRetriever(base_compressor=cohere_reranker, base_retriever=self.retriever)
return cast(Retriever, retriever)
Expand All @@ -50,3 +79,6 @@ async def search_documents(self) -> List[Data]: # type: ignore
data = self.to_data(documents)
self.status = data
return data

def build_vector_store(self) -> VectorStore:
raise NotImplementedError("Cohere Rerank does not support vector stores.")
19 changes: 18 additions & 1 deletion src/backend/base/langflow/components/retrievers/NvidiaRerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from langchain.retrievers import ContextualCompressionRetriever

from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.field_typing import Retriever
from langflow.field_typing import Retriever, VectorStore
from langflow.io import DropdownInput, HandleInput, MultilineInput, SecretStrInput, StrInput
from langflow.schema import Data
from langflow.schema.dotdict import dotdict
from langflow.template.field.base import Output


class NvidiaRerankComponent(LCVectorStoreComponent):
Expand All @@ -33,6 +34,19 @@ class NvidiaRerankComponent(LCVectorStoreComponent):
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
]

outputs = [
Output(
display_name="Retriever",
name="base_retriever",
method="build_base_retriever",
),
Output(
display_name="Search Results",
name="search_results",
method="search_documents",
),
]

def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "base_url" and field_value:
try:
Expand Down Expand Up @@ -62,3 +76,6 @@ async def search_documents(self) -> List[Data]: # type: ignore
data = self.to_data(documents)
self.status = data
return data

def build_vector_store(self) -> VectorStore:
raise NotImplementedError("NVIDIA Rerank does not support vector stores.")
21 changes: 4 additions & 17 deletions src/backend/base/langflow/components/vectorstores/AstraDB.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from langchain_core.vectorstores import VectorStore
from loguru import logger

from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers import docs_to_data
from langflow.inputs import DictInput, FloatInput
from langflow.io import (
Expand All @@ -24,8 +23,6 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
name = "AstraDB"
icon: str = "AstraDB"

_cached_vectorstore: VectorStore | None = None

inputs = [
StrInput(
name="collection_name",
Expand Down Expand Up @@ -162,11 +159,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
),
]

def _build_vector_store(self):
# cache the vector store to avoid re-initializing and ingest data again
if self._cached_vectorstore:
return self._cached_vectorstore

@check_cached_vector_store
def build_vector_store(self):
try:
from langchain_astradb import AstraDBVectorStore
from langchain_astradb.utils.astradb import SetupMode
Expand Down Expand Up @@ -229,9 +223,6 @@ def _build_vector_store(self):
raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e

self._add_documents_to_vector_store(vector_store)

self._cached_vectorstore = vector_store

return vector_store

def _add_documents_to_vector_store(self, vector_store):
Expand Down Expand Up @@ -272,7 +263,7 @@ def _build_search_args(self):
return args

def search_documents(self) -> list[Data]:
vector_store = self._build_vector_store()
vector_store = self.build_vector_store()

logger.debug(f"Search input: {self.search_input}")
logger.debug(f"Search type: {self.search_type}")
Expand Down Expand Up @@ -303,7 +294,3 @@ def get_retriever_kwargs(self):
"search_type": self._map_search_type(),
"search_kwargs": search_args,
}

def build_vector_store(self):
vector_store = self._build_vector_store()
return vector_store
13 changes: 3 additions & 10 deletions src/backend/base/langflow/components/vectorstores/Cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langchain_community.vectorstores import Cassandra
from loguru import logger

from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.inputs import BoolInput, DictInput, FloatInput
from langflow.io import (
Expand All @@ -25,8 +25,6 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
name = "Cassandra"
icon = "Cassandra"

_cached_vectorstore: Cassandra | None = None

inputs = [
MessageTextInput(
name="database_ref",
Expand Down Expand Up @@ -134,12 +132,8 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
),
]

@check_cached_vector_store
def build_vector_store(self) -> Cassandra:
return self._build_cassandra()

def _build_cassandra(self) -> Cassandra:
if self._cached_vectorstore:
return self._cached_vectorstore
try:
import cassio
from langchain_community.utilities.cassandra import SetupMode
Expand Down Expand Up @@ -215,7 +209,6 @@ def _build_cassandra(self) -> Cassandra:
body_index_options=body_index_options,
setup_mode=setup_mode,
)
self._cached_vectorstore = table
return table

def _map_search_type(self):
Expand All @@ -227,7 +220,7 @@ def _map_search_type(self):
return "similarity"

def search_documents(self) -> List[Data]:
vector_store = self._build_cassandra()
vector_store = self.build_vector_store()

logger.debug(f"Search input: {self.search_query}")
logger.debug(f"Search type: {self.search_type}")
Expand Down
3 changes: 2 additions & 1 deletion src/backend/base/langflow/components/vectorstores/Chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain_chroma.vectorstores import Chroma
from loguru import logger

from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.base.vectorstores.utils import chroma_collection_to_data
from langflow.io import BoolInput, DataInput, DropdownInput, HandleInput, IntInput, StrInput, MultilineInput
from langflow.schema import Data
Expand Down Expand Up @@ -98,6 +98,7 @@ class ChromaVectorStoreComponent(LCVectorStoreComponent):
),
]

@check_cached_vector_store
def build_vector_store(self) -> Chroma:
"""
Builds the Chroma object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from langchain_community.vectorstores import CouchbaseVectorStore

from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import HandleInput, IntInput, StrInput, SecretStrInput, DataInput, MultilineInput
from langflow.schema import Data
Expand Down Expand Up @@ -42,10 +42,8 @@ class CouchbaseVectorStoreComponent(LCVectorStoreComponent):
),
]

@check_cached_vector_store
def build_vector_store(self) -> CouchbaseVectorStore:
return self._build_couchbase()

def _build_couchbase(self) -> CouchbaseVectorStore:
try:
from couchbase.auth import PasswordAuthenticator # type: ignore
from couchbase.cluster import Cluster # type: ignore
Expand Down Expand Up @@ -95,7 +93,7 @@ def _build_couchbase(self) -> CouchbaseVectorStore:
return couchbase_vs

def search_documents(self) -> List[Data]:
vector_store = self._build_couchbase()
vector_store = self.build_vector_store()

if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
docs = vector_store.similarity_search(
Expand Down
3 changes: 2 additions & 1 deletion src/backend/base/langflow/components/vectorstores/FAISS.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langchain_community.vectorstores import FAISS
from loguru import logger

from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, DataInput, HandleInput, IntInput, MultilineInput, StrInput
from langflow.schema import Data
Expand Down Expand Up @@ -57,6 +57,7 @@ class FaissVectorStoreComponent(LCVectorStoreComponent):
),
]

@check_cached_vector_store
def build_vector_store(self) -> FAISS:
"""
Builds the FAISS object.
Expand Down
Loading

0 comments on commit 96ca71d

Please sign in to comment.