Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoder model implementations in MLX #1914

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions pyserini/encode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,20 @@
# limitations under the License.
#

from ._base import DocumentEncoder, QueryEncoder, JsonlCollectionIterator,\
RepresentationWriter, FaissRepresentationWriter, JsonlRepresentationWriter, PcaEncoder
from ._base import (
DocumentEncoder,
MLXDocumentEncoder,
QueryEncoder,
MLXQueryEncoder,
JsonlCollectionIterator,
RepresentationWriter,
FaissRepresentationWriter,
JsonlRepresentationWriter,
PcaEncoder
)
from ._ance import AnceEncoder, AnceDocumentEncoder, AnceQueryEncoder
from ._auto import AutoQueryEncoder, AutoDocumentEncoder
from ._dpr import DprDocumentEncoder, DprQueryEncoder
from ._dpr import MLXDprDocumentEncoder, DprDocumentEncoder, DprQueryEncoder, MLXDprQueryEncoder
from ._tct_colbert import TctColBertDocumentEncoder, TctColBertQueryEncoder
from ._aggretriever import AggretrieverDocumentEncoder, AggretrieverQueryEncoder
from ._unicoil import UniCoilEncoder, UniCoilDocumentEncoder, UniCoilQueryEncoder
Expand All @@ -27,5 +36,12 @@
from ._splade import SpladeQueryEncoder
from ._slim import SlimQueryEncoder
from ._openai import OpenAIDocumentEncoder, OpenAIQueryEncoder, OPENAI_API_RETRY_DELAY
from ._cosdpr import CosDprEncoder, CosDprDocumentEncoder, CosDprQueryEncoder
from ._cosdpr import (
CosDprEncoder,
MLXCosDprEncoder,
CosDprDocumentEncoder,
MLXCosDprDocumentEncoder,
CosDprQueryEncoder,
MLXCosDprQueryEncoder
)
from ._clip import ClipEncoder, ClipDocumentEncoder
16 changes: 16 additions & 0 deletions pyserini/encode/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import numpy as np
from tqdm import tqdm

# apple silicon
import mlx as mx

class DocumentEncoder:
def encode(self, texts, **kwargs):
Expand All @@ -33,12 +35,26 @@ def _mean_pooling(last_hidden_state, attention_mask):
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask

class MLXDocumentEncoder:
def encode(self, texts, **kwargs):
pass

@staticmethod
def _mean_pooling(last_hidden_state: mx.array, attention_mask: mx.array):
token_embeddings = last_hidden_state
input_mask_expanded = attention_mask.expand_dims(-1).broadcast_to(token_embeddings.shape).astype(mx.float32)
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = mx.clip(input_mask_expanded.sum(axis=1), a_min=1e-9, a_max=None)
return sum_embeddings / sum_mask

class QueryEncoder:
def encode(self, text, **kwargs):
pass

class MLXQueryEncoder:
def encode(self, text, **kwargs):
pass

class PcaEncoder:
def __init__(self, encoder, pca_model_path):
Expand Down
94 changes: 93 additions & 1 deletion pyserini/encode/_cosdpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@
from typing import Optional

import torch
import numpy as np
from transformers import PreTrainedModel, BertConfig, BertModel, BertTokenizer

from pyserini.encode import DocumentEncoder, QueryEncoder
# apple silicon
import mlx as mx

from pyserini.encode import (
DocumentEncoder,
MLXDocumentEncoder,
QueryEncoder,
MLXQueryEncoder
)


class CosDprEncoder(PreTrainedModel):
Expand Down Expand Up @@ -72,6 +81,51 @@ def forward(
return pooled_output


class MLXCosDprEncoder(PreTrainedModel):
config_class = BertConfig
base_model_prefix = 'bert'
load_tf_weights = None

def __init__(self, config: BertConfig):
super().__init__(config)
self.config = config
self.bert = BertModel(config)
self.linear = mx.nn.Linear(config.hidden_size, config.hidden_size)
self.init_weights()

def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (mx.nn.Linear, mx.nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, mx.nn.Linear) and module.bias is not None:
module.bias.data.zero_()

def init_weights(self):
self.bert.init_weights()
self.linear.apply(self._init_weights)

def forward(
self,
input_ids: mx.array,
attention_mask: Optional[mx.array] = None,
):
input_shape = input_ids.size()
if attention_mask is None:
attention_mask = (
mx.ones(input_shape, device=input_ids.default_device())
if input_ids is None
else (input_ids != self.bert.config.pad_token_id)
)
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs.last_hidden_state
pooled_output = sequence_output[:, 0, :]
# Lp normalization
pooled_output = self.linear(pooled_output)
pooled_output = mx.core.linalg.norm(pooled_output, p=2, dim=1)
return pooled_output



class CosDprDocumentEncoder(DocumentEncoder):
def __init__(self, model_name, tokenizer_name=None, device='cuda:0'):
self.device = device
Expand Down Expand Up @@ -114,3 +168,41 @@ def encode(self, query: str, **kwargs):
inputs.to(self.device)
embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy()
return embeddings.flatten()


class MLXCosDprDocumentEncoder(MLXDocumentEncoder):
def __init__(self, model_name, tokenizer_name=None):
self.model = MLXCosDprEncoder.from_pretrained(model_name)
self.model.to(self.default_device())
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name)

def encode(self, texts, titles=None, max_length=256, **kwargs):
if titles is not None:
texts = [f'{title} {text}' for title, text in zip(titles, texts)]
inputs = self.tokenizer(
texts,
max_length=max_length,
padding='longest',
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
return np.array(self.model(inputs["input_ids"]), copy=False)


class MLXCosDprQueryEncoder(MLXQueryEncoder):
def __init__(self, encoder_dir: str, tokenizer_name: str = None, **kwargs):
self.model = MLXCosDprEncoder.from_pretrained(encoder_dir)
self.tokenizer = BertTokenizer.from_pretrained(encoder_dir or tokenizer_name)

def encode(self, query: str, **kwargs):
inputs = self.tokenizer(
query,
add_special_tokens=True,
return_tensors='pt',
truncation='only_first',
padding='longest',
return_token_type_ids=False,
)
embeddings = np.array(self.model(inputs["input_ids"]))
return embeddings.flatten()
59 changes: 57 additions & 2 deletions pyserini/encode/_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,19 @@
# limitations under the License.
#

from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
from transformers import (
DPRContextEncoder,
DPRContextEncoderTokenizer,
DPRQuestionEncoder,
DPRQuestionEncoderTokenizer
)

from pyserini.encode import DocumentEncoder, QueryEncoder
from pyserini.encode import (
DocumentEncoder,
MLXDocumentEncoder,
QueryEncoder,
MLXQueryEncoder
)


class DprDocumentEncoder(DocumentEncoder):
Expand Down Expand Up @@ -62,3 +72,48 @@ def encode(self, query: str, **kwargs):
input_ids.to(self.device)
embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu().numpy()
return embeddings.flatten()


class MLXDprDocumentEncoder(MLXDocumentEncoder):
def __init__(self, model_name, tokenizer_name=None, device='cuda:0'):
self.device = device
self.model = DPRContextEncoder.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(tokenizer_name or model_name)

def encode(self, texts, titles=None, max_length=256, **kwargs):
if titles:
inputs = self.tokenizer(
titles,
text_pair=texts,
max_length=max_length,
padding='longest',
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
else:
inputs = self.tokenizer(
texts,
max_length=max_length,
padding='longest',
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
inputs.to(self.device)
return self.model(inputs["input_ids"]).pooler_output.detach().cpu().numpy()


class MLXDprQueryEncoder(MLXQueryEncoder):
def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'):
self.device = device
self.model = DPRQuestionEncoder.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or model_name)

def encode(self, query: str, **kwargs):
input_ids = self.tokenizer(query, return_tensors='pt')
input_ids.to(self.device)
embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu().numpy()
return embeddings.flatten()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ tiktoken>=0.4.0
pyarrow>=15.0.0
pillow>=10.2.0
pybind11>=2.11.0
mlx>=0.14.1
Empty file added tests/test_dpr_mlx.py
Empty file.