Skip to content

Commit

Permalink
Add BM25Encoder.update() which updates encoder values with new documents
Browse files Browse the repository at this point in the history
If we want to use this BM25Encoder in production & host it somewhere (or save it to a .pkl), there's no way to update it with new documents. Adding that function.
  • Loading branch information
adrwz authored Mar 11, 2024
1 parent 0eb00a2 commit 1508b9e
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions pinecone_text/sparse/bm25_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,41 @@ def fit(self, corpus: List[str]) -> "BM25Encoder":
self.avgdl = sum_doc_len / n_docs
return self

def update(self, new_corpus: List[str]) -> "BM25Encoder":
"""
Update BM25 by incorporating new documents into the existing model
Args:
new_corpus: list of new texts to update BM25 with
"""
if self.doc_freq is None or self.n_docs is None or self.avgdl is None:
raise ValueError("BM25 must be fit before updating")

sum_doc_len = 0
doc_freq_counter: Counter = Counter()

for doc in tqdm(new_corpus):
if not isinstance(doc, str):
raise ValueError("new_corpus must be a list of strings")

indices, tf = self._tf(doc)
if len(indices) == 0:
continue
self.n_docs += 1
sum_doc_len += sum(tf)

# Count the number of documents that contain each token
doc_freq_counter.update(indices)

# Merge the new document frequencies with the existing ones
for idx, freq in doc_freq_counter.items():
self.doc_freq[idx] = self.doc_freq.get(idx, 0) + freq

# Update the average document length
self.avgdl = (self.avgdl * (self.n_docs - len(new_corpus)) + sum_doc_len) / self.n_docs

return self

def encode_documents(
self, texts: Union[str, List[str]]
) -> Union[SparseVector, List[SparseVector]]:
Expand Down

0 comments on commit 1508b9e

Please sign in to comment.