From 1508b9e9b3fefad7a6f05d12ecef3a3378b8d8a6 Mon Sep 17 00:00:00 2001 From: Andrew Zhou <44193474+adrwz@users.noreply.github.com> Date: Mon, 11 Mar 2024 13:59:48 -0400 Subject: [PATCH] Add BM25Encoder.update() which updates encoder values with new documents If we want to use this BM25Encoder in production & host it somewhere (or save it to a .pkl), there's no way to update it with new documents. Adding that function. --- pinecone_text/sparse/bm25_encoder.py | 35 ++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/pinecone_text/sparse/bm25_encoder.py b/pinecone_text/sparse/bm25_encoder.py index fe5ed0c..7d82446 100644 --- a/pinecone_text/sparse/bm25_encoder.py +++ b/pinecone_text/sparse/bm25_encoder.py @@ -98,6 +98,41 @@ def fit(self, corpus: List[str]) -> "BM25Encoder": self.avgdl = sum_doc_len / n_docs return self + def update(self, new_corpus: List[str]) -> "BM25Encoder": + """ + Update BM25 by incorporating new documents into the existing model + + Args: + new_corpus: list of new texts to update BM25 with + """ + if self.doc_freq is None or self.n_docs is None or self.avgdl is None: + raise ValueError("BM25 must be fit before updating") + + sum_doc_len = 0 + doc_freq_counter: Counter = Counter() + + for doc in tqdm(new_corpus): + if not isinstance(doc, str): + raise ValueError("new_corpus must be a list of strings") + + indices, tf = self._tf(doc) + if len(indices) == 0: + continue + self.n_docs += 1 + sum_doc_len += sum(tf) + + # Count the number of documents that contain each token + doc_freq_counter.update(indices) + + # Merge the new document frequencies with the existing ones + for idx, freq in doc_freq_counter.items(): + self.doc_freq[idx] = self.doc_freq.get(idx, 0) + freq + + # Update the average document length + self.avgdl = (self.avgdl * (self.n_docs - len(new_corpus)) + sum_doc_len) / self.n_docs + + return self + def encode_documents( self, texts: Union[str, List[str]] ) -> Union[SparseVector, List[SparseVector]]: