Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Colbert MLX #1866

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyserini/encode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
#

from ._base import DocumentEncoder, QueryEncoder, JsonlCollectionIterator,\
RepresentationWriter, FaissRepresentationWriter, JsonlRepresentationWriter, PcaEncoder
RepresentationWriter, FaissRepresentationWriter, JsonlRepresentationWriter, PcaEncoder, \
MlxDocumentEncoder
from ._ance import AnceEncoder, AnceDocumentEncoder, AnceQueryEncoder
from ._auto import AutoQueryEncoder, AutoDocumentEncoder
from ._dpr import DprDocumentEncoder, DprQueryEncoder
from ._tct_colbert import TctColBertDocumentEncoder, TctColBertQueryEncoder
from ._tct_colbert import TctColBertDocumentEncoder, TctColBertQueryEncoder, \
MlxTctColBertDocumentEncoder
from ._aggretriever import AggretrieverDocumentEncoder, AggretrieverQueryEncoder
from ._unicoil import UniCoilEncoder, UniCoilDocumentEncoder, UniCoilQueryEncoder
from ._cached_data import CachedDataQueryEncoder
Expand Down
32 changes: 20 additions & 12 deletions pyserini/encode/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pyserini.encode import DprDocumentEncoder, TctColBertDocumentEncoder, AnceDocumentEncoder, AggretrieverDocumentEncoder, AutoDocumentEncoder, CosDprDocumentEncoder, ClipDocumentEncoder
from pyserini.encode import UniCoilDocumentEncoder
from pyserini.encode import OpenAIDocumentEncoder, OPENAI_API_RETRY_DELAY
from pyserini.encode import MlxTctColBertDocumentEncoder


encoder_class_map = {
Expand All @@ -34,9 +35,10 @@
"cosdpr": CosDprDocumentEncoder,
"auto": AutoDocumentEncoder,
"clip": ClipDocumentEncoder,
"mlx_tct_colbert": MlxTctColBertDocumentEncoder,
}

def init_encoder(encoder, encoder_class, device, pooling, l2_norm, prefix, multimodal):
def init_encoder(encoder, encoder_class, device, pooling, l2_norm, prefix, multimodal, use_mlx):
_encoder_class = encoder_class

# determine encoder_class
Expand All @@ -57,15 +59,20 @@ def init_encoder(encoder, encoder_class, device, pooling, l2_norm, prefix, multi
encoder_class = AutoDocumentEncoder

# prepare arguments to encoder class
kwargs = dict(model_name=encoder, device=device)
if (_encoder_class == "sentence-transformers") or ("sentence-transformers" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=True))
if (_encoder_class == "contriever") or ("contriever" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=False))
if (_encoder_class == "auto"):
kwargs.update(dict(pooling=pooling, l2_norm=l2_norm, prefix=prefix))
if (_encoder_class == "clip") or ("clip" in encoder):
kwargs.update(dict(l2_norm=True, prefix=prefix, multimodal=multimodal))
if use_mlx:
kwargs = dict(model_name=encoder)
else:
kwargs = dict(model_name=encoder, device=device)
if (_encoder_class == "sentence-transformers") or ("sentence-transformers" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=True))
if (_encoder_class == "contriever") or ("contriever" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=False))
if (_encoder_class == "auto"):
kwargs.update(dict(pooling=pooling, l2_norm=l2_norm, prefix=prefix))
if (_encoder_class == "clip") or ("clip" in encoder):
kwargs.update(dict(l2_norm=True, prefix=prefix, multimodal=multimodal))

print(f'Initializing encoder: {encoder_class.__name__} with kwargs: {kwargs}')
return encoder_class(**kwargs)


Expand Down Expand Up @@ -112,8 +119,9 @@ def parse_args(parser, commands):

encoder_parser = commands.add_parser('encoder')
encoder_parser.add_argument('--encoder', type=str, help='encoder name or path', required=True)
encoder_parser.add_argument('--use-mlx', action='store_true', default=False)
encoder_parser.add_argument('--encoder-class', type=str, required=False, default=None,
choices=["dpr", "bpr", "tct_colbert", "ance", "sentence-transformers", "openai-api", "auto"],
choices=["dpr", "bpr", "tct_colbert", "ance", "sentence-transformers", "openai-api", "auto", "mlx_tct_colbert"],
help='which query encoder class to use. `default` would infer from the args.encoder')
encoder_parser.add_argument('--fields', help='fields to encode', nargs='+', default=['text'], required=False)
encoder_parser.add_argument('--multimodal', action='store_true', default=False)
Expand All @@ -132,7 +140,7 @@ def parse_args(parser, commands):

args = parse_args(parser, commands)
delimiter = args.input.delimiter.replace("\\n", "\n") # argparse would add \ prior to the passed '\n\n'
encoder = init_encoder(args.encoder.encoder, args.encoder.encoder_class, device=args.encoder.device, pooling=args.encoder.pooling, l2_norm=args.encoder.l2_norm, prefix=args.encoder.prefix, multimodal=args.encoder.multimodal)
encoder = init_encoder(args.encoder.encoder, args.encoder.encoder_class, device=args.encoder.device, pooling=args.encoder.pooling, l2_norm=args.encoder.l2_norm, prefix=args.encoder.prefix, multimodal=args.encoder.multimodal, use_mlx=args.encoder.use_mlx)
if args.output.to_faiss:
embedding_writer = FaissRepresentationWriter(args.output.embeddings, dimension=args.encoder.dimension)
else:
Expand Down
20 changes: 20 additions & 0 deletions pyserini/encode/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import numpy as np
from tqdm import tqdm

import mlx.core as mx


class DocumentEncoder:
def encode(self, texts, **kwargs):
Expand Down Expand Up @@ -55,6 +57,24 @@ def encode(self, text, **kwargs):
embeddings = self.pca_mat.apply_py(embeddings)
return embeddings

class MlxDocumentEncoder:
def encode(self, texts, **kwargs):
pass

@staticmethod
def _mean_pooling(last_hidden_state: mx.array, attention_mask: mx.array):
token_embeddings = last_hidden_state
input_mask_expanded = mx.expand_dims(attention_mask, -1)
input_mask_expanded = mx.broadcast_to(input_mask_expanded, token_embeddings.shape).astype(mx.float32)
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = mx.clip(input_mask_expanded.sum(axis=1), 1e-9, None)
return sum_embeddings / sum_mask


class MlxQueryEncoder:
def encode(self, text, **kwargs):
pass


class JsonlCollectionIterator:
def __init__(self, collection_path: str, fields=None, docid_field=None, delimiter="\n"):
Expand Down
40 changes: 37 additions & 3 deletions pyserini/encode/_tct_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
# limitations under the License.
#

import os
from typing import Optional, List

import numpy as np
import torch
if torch.cuda.is_available():
from torch.cuda.amp import autocast
from transformers import BertModel, BertTokenizer, BertTokenizerFast

from pyserini.encode import DocumentEncoder, QueryEncoder
from transformers import BertModel, BertTokenizer, BertTokenizerFast, BertConfig
from onnxruntime import ExecutionMode, SessionOptions, InferenceSession
import mlx.core as mx
from mlx_transformers.models.bert import BertModel as MlxBertModel

from pyserini.encode import DocumentEncoder, QueryEncoder, MlxDocumentEncoder


class TctColBertDocumentEncoder(DocumentEncoder):
Expand Down Expand Up @@ -89,3 +94,32 @@ def encode(self, query: str, **kwargs):
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.detach().cpu().numpy()
return np.average(embeddings[:, 4:, :], axis=-2).flatten()


class MlxTctColBertDocumentEncoder(MlxDocumentEncoder):
def __init__(self, model_name: str, tokenizer_name: Optional[str]=None):

self.config = BertConfig.from_pretrained(model_name)
self.model = MlxBertModel(self.config)
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or model_name)

self.model.from_pretrained(model_name, huggingface_model_architecture="BertModel")

def encode(self, texts: List[str], titles: Optional[List[str]]=None, fp16: bool=False, max_length: int=512, **kwargs):
if titles is not None:
texts = [f'[CLS] [D] {title} {text}' for title, text in zip(titles, texts)]
else:
texts = ['[CLS] [D] ' + text for text in texts]
inputs = self.tokenizer(
texts,
max_length=max_length,
padding="longest",
truncation=True,
add_special_tokens=False,
return_tensors='np'
)

inputs = {key: mx.array(v) for key, v in inputs.items()}
outputs = self.model(**inputs)
embeddings = self._mean_pooling(outputs.last_hidden_state[:, 4:, :], inputs['attention_mask'][:, 4:])
return np.array(embeddings)
3 changes: 2 additions & 1 deletion pyserini/search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ._deprecated import SimpleSearcher, ImpactSearcher, SimpleFusionSearcher

from .faiss import DenseSearchResult, PRFDenseSearchResult, FaissSearcher, BinaryDenseSearcher, QueryEncoder, \
DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AggretrieverQueryEncoder, AutoQueryEncoder, ClipQueryEncoder
DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AggretrieverQueryEncoder, AutoQueryEncoder, ClipQueryEncoder, MlxTctColBertQueryEncoder
from .faiss import AnceEncoder
from .faiss import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf
from .faiss import OpenAIQueryEncoder
Expand Down Expand Up @@ -52,6 +52,7 @@
'BprQueryEncoder',
'DkrrDprQueryEncoder',
'TctColBertQueryEncoder',
'MlxTctColBertQueryEncoder',
'AnceEncoder',
'AnceQueryEncoder',
'AggretrieverQueryEncoder',
Expand Down
4 changes: 2 additions & 2 deletions pyserini/search/faiss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#

from ._searcher import DenseSearchResult, PRFDenseSearchResult, FaissSearcher, BinaryDenseSearcher, QueryEncoder, \
DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AggretrieverQueryEncoder, OpenAIQueryEncoder, \
DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AggretrieverQueryEncoder, OpenAIQueryEncoder, MlxTctColBertQueryEncoder, \
AutoQueryEncoder, ClipQueryEncoder

from ._model import AnceEncoder
Expand All @@ -24,4 +24,4 @@
__all__ = ['DenseSearchResult', 'PRFDenseSearchResult', 'FaissSearcher', 'BinaryDenseSearcher', 'QueryEncoder',
'DprQueryEncoder', 'BprQueryEncoder', 'DkrrDprQueryEncoder', 'TctColBertQueryEncoder', 'AnceEncoder',
'AnceQueryEncoder', 'AggretrieverQueryEncoder', 'AutoQueryEncoder', 'DenseVectorAveragePrf', 'DenseVectorRocchioPrf', 'DenseVectorAncePrf',
'OpenAIQueryEncoder', 'ClipQueryEncoder']
'OpenAIQueryEncoder', 'ClipQueryEncoder', 'MlxTctColBertQueryEncoder']
34 changes: 20 additions & 14 deletions pyserini/search/faiss/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pyserini.search import FaissSearcher, BinaryDenseSearcher, TctColBertQueryEncoder, QueryEncoder, \
DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, AnceQueryEncoder, AggretrieverQueryEncoder, DenseVectorAveragePrf, \
DenseVectorRocchioPrf, DenseVectorAncePrf, OpenAIQueryEncoder, ClipQueryEncoder
DenseVectorRocchioPrf, DenseVectorAncePrf, OpenAIQueryEncoder, ClipQueryEncoder, MlxTctColBertQueryEncoder

from pyserini.encode import PcaEncoder, CosDprQueryEncoder, AutoQueryEncoder
from pyserini.query_iterator import get_query_iterator, TopicsFormat
Expand All @@ -41,7 +41,7 @@ def define_dsearch_args(parser):
help="Path to Faiss index or name of prebuilt index.")
parser.add_argument('--encoder-class', type=str, metavar='which query encoder class to use. `default` would infer from the args.encoder',
required=False,
choices=["dkrr", "dpr", "bpr", "tct_colbert", "ance", "sentence", "contriever", "auto", "aggretriever", "openai-api", "cosdpr"],
choices=["dkrr", "dpr", "bpr", "tct_colbert", "ance", "sentence", "contriever", "auto", "aggretriever", "openai-api", "cosdpr", "mlx_tct_colbert"],
default=None,
help='which query encoder class to use. `default` would infer from the args.encoder')
parser.add_argument('--encoder', type=str, metavar='path to query encoder checkpoint or encoder name',
Expand Down Expand Up @@ -89,9 +89,10 @@ def define_dsearch_args(parser):
help='The path or name to ANCE-PRF model checkpoint')
parser.add_argument('--ef-search', type=int, metavar='efSearch for HNSW index', required=False, default=None,
help="Set efSearch for HNSW index")
parser.add_argument('--use-mlx', action='store_true', default=False)


def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, encoded_queries, device, max_length, pooling, l2_norm, prefix, multimodal=False):
def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, encoded_queries, device, max_length, pooling, l2_norm, prefix, multimodal=False, use_mlx=False):
encoded_queries_map = {
'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset',
'dpr-nq-dev': 'dpr_multi-nq-dev',
Expand All @@ -115,6 +116,7 @@ def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, enco
"openai-api": OpenAIQueryEncoder,
"auto": AutoQueryEncoder,
"clip": ClipQueryEncoder,
"mlx_tct_colbert": MlxTctColBertQueryEncoder,
}

if encoder:
Expand All @@ -138,17 +140,21 @@ def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, enco
encoder_class = AutoQueryEncoder

# prepare arguments to encoder class
kwargs = dict(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device, prefix=prefix)
if (_encoder_class == "sentence") or ("sentence" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=True))
if (_encoder_class == "contriever") or ("contriever" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=False))
if (_encoder_class == "openai-api") or ("openai" in encoder):
kwargs.update(dict(max_length=max_length))
if (_encoder_class == "auto"):
kwargs.update(dict(pooling=pooling, l2_norm=l2_norm, prefix=prefix))
if (_encoder_class == "clip") or ("clip" in encoder):
kwargs.update(dict(l2_norm=True, prefix=prefix, multimodal=multimodal))
# prepare arguments to encoder class
if use_mlx:
kwargs = dict(encoder_dir=encoder)
else:
kwargs = dict(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device, prefix=prefix)
if (_encoder_class == "sentence") or ("sentence" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=True))
if (_encoder_class == "contriever") or ("contriever" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=False))
if (_encoder_class == "openai-api") or ("openai" in encoder):
kwargs.update(dict(max_length=max_length))
if (_encoder_class == "auto"):
kwargs.update(dict(pooling=pooling, l2_norm=l2_norm, prefix=prefix))
if (_encoder_class == "clip") or ("clip" in encoder):
kwargs.update(dict(l2_norm=True, prefix=prefix, multimodal=multimodal))
return encoder_class(**kwargs)

if encoded_queries:
Expand Down
37 changes: 35 additions & 2 deletions pyserini/search/faiss/_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
import openai
import tiktoken

from transformers import (AutoModel, AutoTokenizer, BertModel, BertTokenizer, BertTokenizerFast,
from transformers import (AutoModel, AutoTokenizer, BertModel, BertConfig, BertTokenizer, BertTokenizerFast,
DPRQuestionEncoder, DPRQuestionEncoderTokenizer, RobertaTokenizer)
from transformers.file_utils import is_faiss_available, requires_backends
from mlx_transformers.models.bert import BertModel as MlxBertModel

from pyserini.util import (download_encoded_queries, download_prebuilt_index,
get_dense_indexes_info, get_sparse_index)
Expand All @@ -43,6 +44,9 @@
from ...encode import PcaEncoder, CosDprQueryEncoder, ClipEncoder
from ...encode._aggretriever import BERTAggretrieverEncoder, DistlBERTAggretrieverEncoder

import mlx.core as mx
from mlx_transformers.models.bert import BertModel as MlxBertModel

if is_faiss_available():
import faiss

Expand Down Expand Up @@ -175,6 +179,36 @@ def encode(self, query: str):
else:
return super().encode(query)

class MlxTctColBertQueryEncoder(QueryEncoder):

def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
encoded_query_dir: str = None, **kwargs):
super().__init__(encoded_query_dir)
if encoder_dir:
self.config = BertConfig.from_pretrained(encoder_dir)
self.model = MlxBertModel(self.config)
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or encoder_dir)
self.model.from_pretrained(encoder_dir, huggingface_model_architecture="BertModel")
self.has_model = True
if (not self.has_model) and (not self.has_encoded_query):
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')

def encode(self, query: str):
if self.has_model:
max_length = 36 # hardcode for now
inputs = self.tokenizer(
'[CLS] [Q] ' + query + '[MASK]' * max_length,
max_length=max_length,
truncation=True,
add_special_tokens=False,
return_tensors='np'
)
inputs = {key: mx.array(v) for key, v in inputs.items()}
outputs = self.model(**inputs)
embeddings = np.array(outputs.last_hidden_state)
return np.average(embeddings[:, 4:, :], axis=-2).flatten()
else:
return super().encode(query)

class DprQueryEncoder(QueryEncoder):

Expand Down Expand Up @@ -414,7 +448,6 @@ def encode(self, query: str):
else:
return super().encode(query)


@dataclass
class DenseSearchResult:
docid: str
Expand Down
4 changes: 3 additions & 1 deletion pyserini/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
from urllib.error import HTTPError, URLError
from urllib.request import urlretrieve

import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import PreTrainedModel

from pyserini.encoded_query_info import QUERY_INFO
from pyserini.encoded_corpus_info import CORPUS_INFO
Expand Down Expand Up @@ -280,4 +282,4 @@ def download_evaluation_script(evaluation_name, force=False, verbose=True, mirro
def get_sparse_index(index_name):
if index_name not in FAISS_INDEX_INFO:
raise ValueError(f'Unrecognized index name {index_name}')
return FAISS_INDEX_INFO[index_name]["texts"]
return FAISS_INDEX_INFO[index_name]["texts"]
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ pyyaml
openai>=1.0.0
tiktoken>=0.4.0
pyarrow>=15.0.0
pillow>=10.2.0
pillow>=10.2.0
mlx>=0.10.0
mlx-transformers==0.1.0
Loading