Skip to content

Commit

Permalink
fix: doc_masks on encodings device (#156)
Browse files Browse the repository at this point in the history
* fix: doc_masks on encodings device

* version bump

* chore: isort
  • Loading branch information
bclavie authored Feb 23, 2024
1 parent 9a6d1f9 commit 2e5f77e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "RAGatouille"
version = "0.0.7post4"
version = "0.0.7post5"
description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts."
authors = ["Benjamin Clavie <ben@clavie.eu>"]
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion ragatouille/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.7post4"
__version__ = "0.0.7post5"
from .RAGPretrainedModel import RAGPretrainedModel
from .RAGTrainer import RAGTrainer

Expand Down
6 changes: 4 additions & 2 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,9 @@ def _encode_index_free_documents(
embedded_docs = self.inference_ckpt.docFromText(
documents, bsize=bsize, showprogress=verbose
)[0]
doc_mask = torch.full(embedded_docs.shape[:2], -float("inf"))
doc_mask = torch.full(embedded_docs.shape[:2], -float("inf")).to(
embedded_docs.device
)
return embedded_docs, doc_mask

def rank(
Expand Down Expand Up @@ -771,7 +773,7 @@ def encode(
- doc_masks.shape[1],
),
-float("inf"),
).to(device=encodings.device),
).to(device=doc_masks.device),
],
dim=1,
)
Expand Down

0 comments on commit 2e5f77e

Please sign in to comment.