diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index 469fea0..dc4aa1f 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -11,7 +11,6 @@ from colbert import Indexer, IndexUpdater, Searcher, Trainer from colbert.infra import ColBERTConfig, Run, RunConfig from colbert.modeling.checkpoint import Checkpoint - from ragatouille.models.base import LateInteractionModel # TODO: Move all bsize related calcs to `_set_bsize()` @@ -751,7 +750,7 @@ def encode( - encodings.shape[1], encodings.shape[2], ) - ), + ).to(device=encodings.device), ], dim=1, ) @@ -765,7 +764,7 @@ def encode( - doc_masks.shape[1], ), -float("inf"), - ), + ).to(device=encodings.device), ], dim=1, )