forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community: retrievers: added capability for using Product Quantizatio…
…n as one of the retriever. (langchain-ai#22424) - [ ] **Community**: "Retrievers: Product Quantization" - [X] This PR adds Product Quantization feature to the retrievers to the Langchain Community. PQ is one of the fastest retrieval methods if the embeddings are rich enough in context due to the concepts of quantization and representation through centroids - **Description:** Adding PQ as one of the retrievers - **Dependencies:** using the package nanopq for this PR - **Twitter handle:** vishnunkumar_ - [X] **Add tests and docs**: If you're adding a new integration, please include - [X] Added unit tests for the same in the retrievers. - [] Will add an example notebook subsequently - [X] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ - done the same --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
- Loading branch information
1 parent
e395322
commit 9c3c13f
Showing
6 changed files
with
306 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "661d5123-8ed2-4504-a846-7df0984e79f9", | ||
"metadata": {}, | ||
"source": [ | ||
"# NanoPQ (Product Quantization)\n", | ||
"\n", | ||
">[Product Quantization algorithm (k-NN)](https://towardsdatascience.com/similarity-search-product-quantization-b2a1a6397701) in brief is a quantization algorithm that helps in compression of database vectors which helps in semantic search when large datasets are involved. In a nutshell, the embedding is split into M subspaces which further goes through clustering. Upon clustering the vectors the centroid vector gets mapped to the vectors present in the each of the clusters of the subspace. \n", | ||
"\n", | ||
"This notebook goes over how to use a retriever that under the hood uses a Product Quantization which has been implemented by the [nanopq](https://github.com/matsui528/nanopq) package." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "68794637-c13b-4145-944f-3b0c2f1258f9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install -qU langchain-community langchain-openai nanopq" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "39ecbf50-4623-4ee6-9c8e-fea5da21767e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings\n", | ||
"from langchain_community.retrievers import NanoPQRetriever" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c1ce742a-5085-408a-a2c2-4bae0f605880", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create New Retriever with Texts" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "6c80020e-bc9e-49e8-8f93-5f75fd823738", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"retriever = NanoPQRetriever.from_texts(\n", | ||
" [\"Great world\", \"great words\", \"world\", \"planets of the world\"],\n", | ||
" SpacyEmbeddings(model_name=\"en_core_web_sm\"),\n", | ||
" clusters=2,\n", | ||
" subspace=2,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "743c26c1-0072-4e46-b41b-c28b3f1737c8", | ||
"metadata": {}, | ||
"source": [ | ||
"## Use Retriever\n", | ||
"\n", | ||
"We can now use the retriever!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "f496de2d-9b8f-4f8b-a30f-279ef199259a", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"M: 2, Ks: 2, metric : <class 'numpy.uint8'>, code_dtype: l2\n", | ||
"iter: 20, seed: 123\n", | ||
"Training the subspace: 0 / 2\n", | ||
"Training the subspace: 1 / 2\n", | ||
"Encoding the subspace: 0 / 2\n", | ||
"Encoding the subspace: 1 / 2\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[Document(page_content='world'),\n", | ||
" Document(page_content='Great world'),\n", | ||
" Document(page_content='great words'),\n", | ||
" Document(page_content='planets of the world')]" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"retriever.invoke(\"earth\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "617202a7-e3a6-49a8-b807-4b4d771159d5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,3 +89,4 @@ upstash-ratelimit>=1.1.0,<2 | |
vdms==0.0.20 | ||
xata>=1.0.0a7,<2 | ||
xmltodict>=0.13.0,<0.14 | ||
nanopq==0.2.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
libs/community/langchain_community/retrievers/nanopq.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from __future__ import annotations | ||
|
||
import concurrent.futures | ||
from typing import Any, Iterable, List, Optional | ||
|
||
import numpy as np | ||
from langchain_core.callbacks import CallbackManagerForRetrieverRun | ||
from langchain_core.documents import Document | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.retrievers import BaseRetriever | ||
|
||
|
||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: | ||
""" | ||
Create an index of embeddings for a list of contexts. | ||
Args: | ||
contexts: List of contexts to embed. | ||
embeddings: Embeddings model to use. | ||
Returns: | ||
Index of embeddings. | ||
""" | ||
with concurrent.futures.ThreadPoolExecutor() as executor: | ||
return np.array(list(executor.map(embeddings.embed_query, contexts))) | ||
|
||
|
||
class NanoPQRetriever(BaseRetriever): | ||
"""`NanoPQ retriever.""" | ||
|
||
embeddings: Embeddings | ||
"""Embeddings model to use.""" | ||
index: Any | ||
"""Index of embeddings.""" | ||
texts: List[str] | ||
"""List of texts to index.""" | ||
metadatas: Optional[List[dict]] = None | ||
"""List of metadatas corresponding with each text.""" | ||
k: int = 4 | ||
"""Number of results to return.""" | ||
relevancy_threshold: Optional[float] = None | ||
"""Threshold for relevancy.""" | ||
subspace: int = 4 | ||
"""No of subspaces to be created, should be a multiple of embedding shape""" | ||
clusters: int = 128 | ||
"""No of clusters to be created""" | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
arbitrary_types_allowed = True | ||
|
||
@classmethod | ||
def from_texts( | ||
cls, | ||
texts: List[str], | ||
embeddings: Embeddings, | ||
metadatas: Optional[List[dict]] = None, | ||
**kwargs: Any, | ||
) -> NanoPQRetriever: | ||
index = create_index(texts, embeddings) | ||
return cls( | ||
embeddings=embeddings, | ||
index=index, | ||
texts=texts, | ||
metadatas=metadatas, | ||
**kwargs, | ||
) | ||
|
||
@classmethod | ||
def from_documents( | ||
cls, | ||
documents: Iterable[Document], | ||
embeddings: Embeddings, | ||
**kwargs: Any, | ||
) -> NanoPQRetriever: | ||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) | ||
return cls.from_texts( | ||
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs | ||
) | ||
|
||
def _get_relevant_documents( | ||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | ||
) -> List[Document]: | ||
try: | ||
from nanopq import PQ | ||
except ImportError: | ||
raise ImportError( | ||
"Could not import nanopq, please install with `pip install " "nanopq`." | ||
) | ||
|
||
query_embeds = np.array(self.embeddings.embed_query(query)) | ||
try: | ||
pq = PQ(M=self.subspace, Ks=self.clusters, verbose=True).fit( | ||
self.index.astype("float32") | ||
) | ||
except AssertionError: | ||
error_message = ( | ||
"Received params: training_sample={training_sample}, " | ||
"n_cluster={n_clusters}, subspace={subspace}, " | ||
"embedding_shape={embedding_shape}. Issue with the combination. " | ||
"Please retrace back to find the exact error" | ||
).format( | ||
training_sample=self.index.shape[0], | ||
n_clusters=self.clusters, | ||
subspace=self.subspace, | ||
embedding_shape=self.index.shape[1], | ||
) | ||
raise RuntimeError(error_message) | ||
|
||
index_code = pq.encode(vecs=self.index.astype("float32")) | ||
dt = pq.dtable(query=query_embeds.astype("float32")) | ||
dists = dt.adist(codes=index_code) | ||
|
||
sorted_ix = np.argsort(dists) | ||
|
||
top_k_results = [ | ||
Document( | ||
page_content=self.texts[row], | ||
metadata=self.metadatas[row] if self.metadatas else {}, | ||
) | ||
for row in sorted_ix[0 : self.k] | ||
] | ||
|
||
return top_k_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import pytest | ||
from langchain_core.documents import Document | ||
|
||
from langchain_community.embeddings import FakeEmbeddings | ||
from langchain_community.retrievers import NanoPQRetriever | ||
|
||
|
||
class TestNanoPQRetriever: | ||
@pytest.mark.requires("nanopq") | ||
def test_from_texts(self) -> None: | ||
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] | ||
pq_retriever = NanoPQRetriever.from_texts( | ||
texts=input_texts, embeddings=FakeEmbeddings(size=100) | ||
) | ||
assert len(pq_retriever.texts) == 3 | ||
|
||
@pytest.mark.requires("nanopq") | ||
def test_from_documents(self) -> None: | ||
input_docs = [ | ||
Document(page_content="I have a pen.", metadata={"page": 1}), | ||
Document(page_content="Do you have a pen?", metadata={"page": 2}), | ||
Document(page_content="I have a bag.", metadata={"page": 3}), | ||
] | ||
pq_retriever = NanoPQRetriever.from_documents( | ||
documents=input_docs, embeddings=FakeEmbeddings(size=100) | ||
) | ||
assert pq_retriever.texts == [ | ||
"I have a pen.", | ||
"Do you have a pen?", | ||
"I have a bag.", | ||
] | ||
assert pq_retriever.metadatas == [{"page": 1}, {"page": 2}, {"page": 3}] | ||
|
||
@pytest.mark.requires("nanopq") | ||
def invalid_subspace_error(self) -> None: | ||
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] | ||
pq_retriever = NanoPQRetriever.from_texts( | ||
texts=input_texts, embeddings=FakeEmbeddings(size=43) | ||
) | ||
with pytest.raises(RuntimeError): | ||
pq_retriever.invoke("I have") |