Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Support multi-GPU training via accelerate #5

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
- [train_colbert](../train/train-colbert)
- [train_sparse_embed](../train/train-sparse-embed)
- [train_splade](../train/train-splade)
- [Multi-GPU training via Accelerator](../train/multi-gpu)

## utils

Expand Down
1 change: 1 addition & 0 deletions docs/fine_tune/.pages
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ nav:
- colbert.md
- splade.md
- sparse_embed.md
- multi_gpu.md

64 changes: 64 additions & 0 deletions docs/fine_tune/multi_gpu.md
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Multi-GPU

Neural-Cherche is compatible with multiples GPUs training using [Accelerator](https://huggingface.co/docs/accelerate/package_reference/accelerator). We can train every models of neural-cherche using GPUs. Here is a tutorial.

```python
import torch
from accelerate import Accelerator
from datasets import Dataset
from torch.utils.data import DataLoader

from neural_cherche import models, train

if __name__ == "__main__":
    # We will need to wrap your training loop in a function to avoid multiprocessing issues.
    accelerator = Accelerator()
    save_each_epoch = True

    model = models.SparseEmbed(
        model_name_or_path="distilbert-base-uncased",
        accelerate=True,
        device=accelerator.device,
    ).to(accelerator.device)

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

    # Dataset creation using HuggingFace Datasets library.
    dataset = Dataset.from_dict(
        {
            "anchors": ["anchor 1", "anchor 2", "anchor 3", "anchor 4"],
            "positives": ["positive 1", "positive 2", "positive 3", "positive 4"],
            "negatives": ["negative 1", "negative 2", "negative 3", "negative 4"],
        }
    )

    # Convert your dataset to a DataLoader.
    data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Wrap model, optimizer, and dataloader in accelerator.
    model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

    for epoch in range(2):
        for batch in enumerate(data_loader):
            # Batch is a triple like (anchors, positives, negatives)
            anchors, positives, negatives = (
                batch["anchors"],
                batch["positives"],
                batch["negatives"],
            )

            loss = train.train_sparse_embed(
                model=model,
                optimizer=optimizer,
                anchor=anchors,
                positive=positives,
                negative=negatives,
                threshold_flops=30,
                accelerator=accelerator,
            )

        if accelerator.is_main_process and save_each_epoch:
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(
                "checkpoint/epoch" + str(epoch),
            )

    # Save at the end of the training loop
    # We check to make sure that only the main process will export the model
    if accelerator.is_main_process:
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained("checkpoint")

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a clear example on how to create the dataset using HuggingFace Datasets

Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Multi-GPU (Accelerator)


Training any of the models on multiple GPU via the accelerator library is simple. You just need to modify the training loop in a few key ways:

```python
from neural_cherche import models, utils, train
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator


# Wrap in main function to avoid multiprocessing issues
if __name__ == "__main__"":
accelerator = Accelerator()
device = accelerator.device
batch_size = 32
epochs = 2
save_on_epoch = True

model = models.SparseEmbed(
model_name_or_path="distilbert-base-uncased",
device=device
).to(device)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

# prepare your dataset -- this example uses a huggingface `datasets` object
...

# Convert the data into a PyTorch dataloader for ease of preparation
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Wrap the model, optimizer, and data loader in the accelerator
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

for epoch in range(epochs):
for batch_id, batch_data in enumerate(data_loader):
# Assuming batch_data is a tuple in the form (anchors, positives, negatives)
anchors, positives, negatives = batch_data

loss = train_sparse_embed(
model=model,
optimizer=optimizer,
anchor=anchors,
positive=positives,
negative=negatives,
threshold_flops=30,
accelerator=accelerator,
)

if accelerator.is_main_process and save_on_epoch:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
"checkpoint/epoch" + str(epoch),
)

# Save at the end of the training loop
# We check to make sure that only the main process will export the model
if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained("checkpoint", accelerator=True)
```
14 changes: 7 additions & 7 deletions neural_cherche/models/base.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got some troubles with position_ids extra parameters with DistilBERT pre-trained checkpoint but not with all-mpnet-base-v2 pre-trained checkpoint so I think it would be cool to keep the legacy code and add an accelerate attribute to models.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os
from abc import ABC, abstractmethod

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModelForMaskedLM, AutoTokenizer


class Base(ABC, torch.nn.Module):
    """Base class from which all models inherit.

    Parameters
    ----------
    model_name_or_path
        Path to the model or the model name.
    device
        Device to use for the model. CPU or CUDA.
    extra_files_to_load
        List of extra files to load.
    accelerate
        Use HuggingFace Accelerate.
    kwargs
        Additional parameters to the model.
    """

    def __init__(
        self,
        model_name_or_path: str,
        device: str = None,
        extra_files_to_load: list[str] = [],
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        """Initialize the model."""
        super(Base, self).__init__()

        if device is not None:
            self.device = device

        elif torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"

        self.accelerate = accelerate

        os.environ["TRANSFORMERS_CACHE"] = "."
        self.model = AutoModelForMaskedLM.from_pretrained(
            model_name_or_path, cache_dir="./", **kwargs
        ).to(self.device)

        # Download linear layer if exists
        for file in extra_files_to_load:
            try:
                _ = hf_hub_download(model_name_or_path, filename=file, cache_dir=".")
            except:
                pass

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, device=self.device, cache_dir="./", **kwargs
        )

        self.model.config.output_hidden_states = True

        if os.path.exists(model_name_or_path):
            # Local checkpoint
            self.model_folder = model_name_or_path
        else:
            # HuggingFace checkpoint
            model_folder = os.path.join(
                f"models--{model_name_or_path}".replace("/", "--"), "snapshots"
            )
            snapshot = os.listdir(model_folder)[-1]
            self.model_folder = os.path.join(model_folder, snapshot)

        self.query_pad_token = self.tokenizer.mask_token
        self.original_pad_token = self.tokenizer.pad_token

    def _encode_accelerate(self, texts: list[str], **kwargs) -> tuple[torch.Tensor]:
        """Encode sentences with multiples gpus.

        Parameters
        ----------
        texts
            List of sentences to encode.

        References
        ----------
        [Accelerate issue.](https://github.com/huggingface/accelerate/issues/97)
        """
        encoded_input = self.tokenizer(texts, return_tensors="pt", **kwargs).to(
            self.device
        )

        position_ids = (
            torch.arange(0, encoded_input["input_ids"].size(1))
            .expand((len(texts), -1))
            .to(self.device)
        )

        output = self.model(**encoded_input, position_ids=position_ids)
        return output.logits, output.hidden_states[-1]

    def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        """Encode sentences.

        Parameters
        ----------
        texts
            List of sentences to encode.
        """
        if self.accelerate:
            return self._encode_accelerate(texts, **kwargs)

        encoded_input = self.tokenizer.batch_encode_plus(
            texts, return_tensors="pt", **kwargs
        )

        if self.device != "cpu":
            encoded_input = {
                key: value.to(self.device) for key, value in encoded_input.items()
            }

        output = self.model(**encoded_input)
        return output.logits, output.hidden_states[-1]

    @abstractmethod
    def forward(self, *args, **kwargs):
        """Pytorch forward method."""
        pass

    @abstractmethod
    def encode(self, *args, **kwargs):
        """Encode documents."""
        pass

    @abstractmethod
    def scores(self, *args, **kwars):
        """Compute scores."""
        pass

    @abstractmethod
    def save_pretrained(self, path: str):
        """Save model the model."""
        pass

    def save_tokenizer_accelerate(self, path: str) -> None:
        """Save tokenizer when using accelerate."""
        tokenizer_config = {
            k: v for k, v in self.tokenizer.__dict__.items() if k != "device"
        }
        tokenizer_config_file = os.path.join(path, "tokenizer_config.json")
        with open(tokenizer_config_file, "w", encoding="utf-8") as file:
            json.dump(tokenizer_config, file, ensure_ascii=False, indent=4)

        # dump vocab
        self.tokenizer.save_vocabulary(path)

        # save special tokens
        special_tokens_file = os.path.join(path, "special_tokens_map.json")
        with open(special_tokens_file, "w", encoding="utf-8") as file:
            json.dump(
                self.tokenizer.special_tokens_map,
                file,
                ensure_ascii=False,
                indent=4,
            )

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the base class updated with a new save_tokenizer_accelerate and accelerate attribute

Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tenso
texts
List of sentences to encode.
"""
encoded_input = self.tokenizer.batch_encode_plus(
texts, return_tensors="pt", **kwargs
encoded_input = self.tokenizer(texts, return_tensors="pt", **kwargs).to(
self.device
)

if self.device != "cpu":
encoded_input = {
key: value.to(self.device) for key, value in encoded_input.items()
}
# Must hardcode position_ids to avoid a bug with accelerate multi-GPU
seq_len = encoded_input["input_ids"].size(1)
position_ids = torch.arange(0, seq_len).expand((len(texts), -1)).to(self.device)

output = self.model(**encoded_input)
# Pass both the inputs and position_ids to the model
output = self.model(**encoded_input, position_ids=position_ids)
return output.logits, output.hidden_states[-1]

@abstractmethod
Expand Down
29 changes: 27 additions & 2 deletions neural_cherche/models/colbert.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os

import torch

from .. import utils
from .base import Base

__all__ = ["ColBERT"]


class ColBERT(Base):
    """ColBERT model.

    Parameters
    ----------
    model_name_or_path
        Path to the model or the model name.
    embedding_size
        Size of the embeddings in output of ColBERT model.
    device
        Device to use for the model. CPU or CUDA.
    accelerate
        Use HuggingFace Accelerate.
    kwargs
        Additional parameters to the SentenceTransformer model.

    Examples
    --------
    >>> from neural_cherche import models
    >>> import torch

    >>> _ = torch.manual_seed(42)

    >>> queries = ["Berlin", "Paris", "London"]

    >>> documents = [
    ...     "Berlin is the capital of Germany",
    ...     "Paris is the capital of France and France is in Europe",
    ...     "London is the capital of England",
    ... ]

    >>> encoder = models.ColBERT(
    ...     model_name_or_path="sentence-transformers/all-mpnet-base-v2",
    ...     embedding_size=128,
    ...     max_length_query=32,
    ...     max_length_document=350,
    ... )

    >>> scores = encoder.scores(
    ...    queries=queries,
    ...    documents=documents,
    ... )

    >>> scores
    tensor([22.9325, 19.8296, 20.8019])

    >>> _ = encoder.save_pretrained("checkpoint", accelerate=False)

    >>> encoder = models.ColBERT(
    ...     model_name_or_path="checkpoint",
    ...     embedding_size=64,
    ...     device="cpu",
    ... )

    >>> scores = encoder.scores(
    ...    queries=queries,
    ...    documents=documents,
    ... )

    >>> scores
    tensor([22.9325, 19.8296, 20.8019])

    >>> embeddings = encoder(
    ...     texts=queries,
    ...     query_mode=True
    ... )

    >>> embeddings["embeddings"].shape
    torch.Size([3, 32, 128])

    >>> embeddings = encoder(
    ...     texts=queries,
    ...     query_mode=False
    ... )

    >>> embeddings["embeddings"].shape
    torch.Size([3, 350, 128])

    """

    def __init__(
        self,
        model_name_or_path: str,
        embedding_size: int = 128,
        device: str = None,
        max_length_query: int = 32,
        max_length_document: int = 350,
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        """Initialize the model."""
        super(ColBERT, self).__init__(
            model_name_or_path=model_name_or_path,
            device=device,
            extra_files_to_load=["linear.pt", "metadata.json"],
            accelerate=accelerate,
            **kwargs,
        )

        self.max_length_query = max_length_query
        self.max_length_document = max_length_document
        self.embedding_size = embedding_size

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            linear = torch.load(
                os.path.join(self.model_folder, "linear.pt"), map_location=self.device
            )
            self.embedding_size = linear["weight"].shape[0]
            in_features = linear["weight"].shape[1]
        else:
            with torch.no_grad():
                _, embeddings = self._encode(texts=["test"])
                in_features = embeddings.shape[2]

        self.linear = torch.nn.Linear(
            in_features=in_features,
            out_features=self.embedding_size,
            bias=False,
            device=self.device,
        )

        if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
            with open(os.path.join(self.model_folder, "metadata.json"), "r") as f:
                metadata = json.load(f)
            self.max_length_document = metadata["max_length_document"]
            self.max_length_query = metadata["max_length_query"]

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            self.linear.load_state_dict(linear)

    def encode(
        self,
        texts: list[str],
        truncation: bool = True,
        add_special_tokens: bool = False,
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Encode documents

        Parameters
        ----------
        texts
            List of sentences to encode.
        truncation
            Truncate the inputs.
        add_special_tokens
            Add special tokens.
        max_length
            Maximum length of the inputs.
        """
        with torch.no_grad():
            embeddings = self(
                texts=texts,
                truncation=truncation,
                add_special_tokens=add_special_tokens,
                query_mode=query_mode,
                **kwargs,
            )
        return embeddings

    def forward(
        self,
        texts: list[str],
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Pytorch forward method.

        Parameters
        ----------
        texts
            List of sentences to encode.
        query_mode
            Wether to encode query or not.
        """
        suffix = "[Q] " if query_mode else "[D] "

        texts = [suffix + text for text in texts]

        self.tokenizer.pad_token = (
            self.query_pad_token if query_mode else self.original_pad_token
        )

        kwargs = {
            "truncation": True,
            "padding": "max_length",
            "max_length": self.max_length_query
            if query_mode
            else self.max_length_document,
            "add_special_tokens": True,
            **kwargs,
        }

        _, embeddings = self._encode(texts=texts, **kwargs)

        return {
            "embeddings": torch.nn.functional.normalize(
                self.linear(embeddings), p=2, dim=2
            )
        }

    def scores(
        self,
        queries: list[str],
        documents: list[str],
        batch_size: int = 2,
        tqdm_bar: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Score queries and documents.

        Parameters
        ----------
        queries
            List of queries.
        documents
            List of documents.
        batch_size
            Batch size.
        truncation
            Truncate the inputs.
        add_special_tokens
            Add special tokens.
        tqdm_bar
            Show tqdm bar.
        """
        list_scores = []

        for batch_queries, batch_documents in zip(
            utils.batchify(
                X=queries,
                batch_size=batch_size,
                desc="Computing scores.",
                tqdm_bar=tqdm_bar,
            ),
            utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
        ):
            queries_embeddings = self.encode(
                texts=batch_queries,
                query_mode=True,
                **kwargs,
            )

            documents_embeddings = self.encode(
                texts=batch_documents,
                query_mode=False,
                **kwargs,
            )

            late_interactions = torch.einsum(
                "bsh,bth->bst",
                queries_embeddings["embeddings"],
                documents_embeddings["embeddings"],
            )

            late_interactions = torch.max(late_interactions, axis=2).values.sum(axis=1)

            list_scores.append(late_interactions)

        return torch.cat(list_scores, dim=0)

    def save_pretrained(self, path: str) -> "ColBERT":
        """Save model the model.

        Parameters
        ----------
        path
            Path to save the model.
        """
        self.model.save_pretrained(path)
        torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
        self.tokenizer.pad_token = self.original_pad_token
        with open(os.path.join(path, "metadata.json"), "w") as f:
            json.dump(
                {
                    "max_length_query": self.max_length_query,
                    "max_length_document": self.max_length_document,
                },
                f,
            )
        if self.accelerate:
            self.save_tokenizer_accelerate(path=path)
        else:
            self.tokenizer.save_pretrained(path)
        return self

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Colbert with the call to save_tokenizer_accelerate parent class :)

Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def scores(

return torch.cat(list_scores, dim=0)

def save_pretrained(self, path: str) -> "ColBERT":
def save_pretrained(self, path: str, accelerator: bool = False) -> "ColBERT":
"""Save model the model.

Parameters
Expand All @@ -279,7 +279,32 @@ def save_pretrained(self, path: str) -> "ColBERT":
self.model.save_pretrained(path)
torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
self.tokenizer.pad_token = self.original_pad_token
self.tokenizer.save_pretrained(path)
if accelerator:
# Workaround an issue with accelerator. Tokenizer has a key "device"
# which is non serialisable, but not removeable with a basic delattr

# dump config
tokenizer_config = {
k: v for k, v in self.tokenizer.__dict__.items() if k != "device"
}
tokenizer_config_file = os.path.join(path, "tokenizer_config.json")
with open(tokenizer_config_file, "w", encoding="utf-8") as file:
json.dump(tokenizer_config, file, ensure_ascii=False, indent=4)

# dump vocab
self.tokenizer.save_vocabulary(path)

# save special tokens
special_tokens_file = os.path.join(path, "special_tokens_map.json")
with open(special_tokens_file, "w", encoding="utf-8") as file:
json.dump(
self.tokenizer.special_tokens_map,
file,
ensure_ascii=False,
indent=4,
)
else:
self.tokenizer.save_pretrained(path)
with open(os.path.join(path, "metadata.json"), "w") as f:
json.dump(
{
Expand Down
33 changes: 31 additions & 2 deletions neural_cherche/models/sparse_embed.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os

import torch

from .. import utils

__all__ = ["SparseEmbed"]

from .splade import Splade


class SparseEmbed(Splade):
    """SparseEmbed model.

    Parameters
    ----------
    model_name_or_path
        Path to the model or the model name. It should be a SentenceTransformer model.
    embedding_size
        Size of the embeddings in output of SparsEmbed model.
    kwargs
        Additional parameters to the pre-trained model.

    Examples
    --------
    >>> from neural_cherche import models
    >>> import torch

    >>> _ = torch.manual_seed(42)

    >>> device = "mps"

    >>> model = models.SparseEmbed(
    ...     model_name_or_path="distilbert-base-uncased",
    ...     device=device,
    ... )

    >>> queries_embeddings = model.encode(
    ...     ["Sports", "Music"],
    ... )

    >>> queries_embeddings["activations"].shape
    torch.Size([2, 128])

    >>> queries_embeddings["sparse_activations"].shape
    torch.Size([2, 30522])

    >>> queries_embeddings["embeddings"].shape
    torch.Size([2, 128, 128])

    >>> documents_embeddings = model.encode(
    ...    ["Music is great.", "Sports is great."],
    ...    query_mode=False,
    ... )

    >>> documents_embeddings["activations"].shape
    torch.Size([2, 256])

    >>> documents_embeddings["sparse_activations"].shape
    torch.Size([2, 30522])

    >>> documents_embeddings["embeddings"].shape
    torch.Size([2, 256, 128])

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=1,
    ... )
    tensor([64.2330, 54.0180], device='mps:0')

    >>> _ = model.save_pretrained("checkpoint")

    >>> model = models.SparseEmbed(
    ...     model_name_or_path="checkpoint",
    ...     device="cpu",
    ... )

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=2,
    ... )
    tensor([64.2330, 54.0180])

    References
    ----------
    1. [SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://dl.acm.org/doi/pdf/10.1145/3539618.3592065)

    """

    def __init__(
        self,
        model_name_or_path: str = None,
        embedding_size: int = 128,
        max_length_query: int = 128,
        max_length_document: int = 256,
        device: str = None,
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        super(SparseEmbed, self).__init__(
            model_name_or_path=model_name_or_path,
            device=device,
            extra_files_to_load=["linear.pt", "metadata.json"],
            accelerate=accelerate,
            **kwargs,
        )

        self.embedding_size = embedding_size

        self.softmax = torch.nn.Softmax(dim=2).to(self.device)

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            linear = torch.load(
                os.path.join(self.model_folder, "linear.pt"), map_location=self.device
            )
            self.embedding_size = linear["weight"].shape[0]
            in_features = linear["weight"].shape[1]
        else:
            with torch.no_grad():
                _, embeddings = self._encode(texts=["test"])
                in_features = embeddings.shape[2]

        self.linear = torch.nn.Linear(
            in_features=in_features,
            out_features=self.embedding_size,
            bias=False,
            device=self.device,
        )

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            self.linear.load_state_dict(linear)

        if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
            with open(os.path.join(self.model_folder, "metadata.json"), "r") as file:
                metadata = json.load(file)

            max_length_query = metadata["max_length_query"]
            max_length_document = metadata["max_length_document"]

        self.max_length_query = max_length_query
        self.max_length_document = max_length_document

    def forward(
        self,
        texts: list[str],
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Pytorch forward method.

        Parameters
        ----------
        texts
            List of documents to encode.
        query_mode
            Whether to encode queries or documents.
        """
        suffix = "[Q] " if query_mode else "[D] "

        texts = [suffix + text for text in texts]

        self.tokenizer.pad_token = (
            self.query_pad_token if query_mode else self.original_pad_token
        )

        k_tokens = self.max_length_query if query_mode else self.max_length_document

        logits, embeddings = self._encode(
            texts=texts,
            truncation=True,
            padding="max_length",
            max_length=k_tokens,
            add_special_tokens=True,
            **kwargs,
        )

        activations = self._update_activations(
            **self._get_activation(logits=logits),
            k_tokens=k_tokens,
        )

        attention = self._get_attention(
            logits=logits,
            activations=activations["activations"],
        )

        embeddings = torch.bmm(
            attention,
            embeddings,
        )

        return {
            "embeddings": self.relu(self.linear(embeddings)),
            "sparse_activations": activations["sparse_activations"],
            "activations": activations["activations"],
        }

    def _get_attention(
        self, logits: torch.Tensor, activations: torch.Tensor
    ) -> torch.Tensor:
        """Extract attention scores from MLM logits based on activated tokens."""
        attention = logits.gather(
            dim=2,
            index=torch.stack(
                [
                    torch.stack([token for _ in range(logits.shape[1])])
                    for token in activations
                ]
            ),
        )

        return self.softmax(attention)

    def save_pretrained(
        self,
        path: str,
    ):
        """Save model the model."""
        self.model.save_pretrained(path)
        self.tokenizer.pad_token = self.original_pad_token

        if self.accelerate:
            self.save_tokenizer_accelerate(path)
        else:
            self.tokenizer.save_pretrained(path)
        torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
        with open(os.path.join(path, "metadata.json"), "w") as file:
            json.dump(
                fp=file,
                obj={
                    "max_length_query": self.max_length_query,
                    "max_length_document": self.max_length_document,
                },
                indent=4,
            )

        return self

    def scores(
        self,
        queries: list[str],
        documents: list[str],
        batch_size: int = 32,
        tqdm_bar: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Compute similarity scores between queries and documents."""
        dense_scores = []

        for batch_queries, batch_documents in zip(
            utils.batchify(
                X=queries,
                batch_size=batch_size,
                desc="Computing scores.",
                tqdm_bar=tqdm_bar,
            ),
            utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
        ):
            queries_embeddings = self.encode(
                texts=batch_queries,
                query_mode=True,
                **kwargs,
            )

            documents_embeddings = self.encode(
                texts=batch_documents,
                query_mode=False,
                **kwargs,
            )

            dense_scores.append(
                utils.pairs_dense_scores(
                    queries_activations=queries_embeddings["activations"],
                    documents_activations=documents_embeddings["activations"],
                    queries_embeddings=queries_embeddings["embeddings"],
                    documents_embeddings=documents_embeddings["embeddings"],
                )
            )

        return torch.cat(dense_scores, dim=0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparseEmbed with the call to save_tokenizer_accelerate parent class :)

Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,40 @@ def _get_attention(

return self.softmax(attention)

def save_pretrained(self, path: str):
def save_pretrained(
self,
path: str,
accelerator: bool = False,
):
"""Save model the model."""
self.model.save_pretrained(path)
self.tokenizer.pad_token = self.original_pad_token
self.tokenizer.save_pretrained(path)
if accelerator:
# Workaround an issue with accelerator. Tokenizer has a key "device"
# which is non serialisable, but not removeable with a basic delattr

# dump config
tokenizer_config = {
k: v for k, v in self.tokenizer.__dict__.items() if k != "device"
}
tokenizer_config_file = os.path.join(path, "tokenizer_config.json")
with open(tokenizer_config_file, "w", encoding="utf-8") as file:
json.dump(tokenizer_config, file, ensure_ascii=False, indent=4)

# dump vocab
self.tokenizer.save_vocabulary(path)

# save special tokens
special_tokens_file = os.path.join(path, "special_tokens_map.json")
with open(special_tokens_file, "w", encoding="utf-8") as file:
json.dump(
self.tokenizer.special_tokens_map,
file,
ensure_ascii=False,
indent=4,
)
else:
self.tokenizer.save_pretrained(path)
torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
with open(os.path.join(path, "metadata.json"), "w") as file:
json.dump(
Expand Down
46 changes: 36 additions & 10 deletions neural_cherche/models/splade.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os
import string

import torch

from .. import utils
from .base import Base

__all__ = ["Splade"]


class Splade(Base):
    """SpladeV1 model.

    Parameters
    ----------
    tokenizer
        HuggingFace Tokenizer.
    model
        HuggingFace AutoModelForMaskedLM.
    kwargs
        Additional parameters to the SentenceTransformer model.

    Examples
    --------
    >>> from neural_cherche import models
    >>> import torch

    >>> _ = torch.manual_seed(42)

    >>> model = models.Splade(
    ...     model_name_or_path="distilbert-base-uncased",
    ...     device="mps",
    ... )

    >>> queries_activations = model.encode(
    ...     ["Sports", "Music"],
    ... )

    >>> documents_activations = model.encode(
    ...    ["Music is great.", "Sports is great."],
    ...    query_mode=False,
    ... )

    >>> queries_activations["sparse_activations"].shape
    torch.Size([2, 30522])

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=1
    ... )
    tensor([318.1384, 271.8006], device='mps:0')

    >>> _ = model.save_pretrained("checkpoint")

    >>> model = models.Splade(
    ...     model_name_or_path="checkpoint",
    ...     device="mps",
    ... )

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=1
    ... )
    tensor([318.1384, 271.8006], device='mps:0')

    References
    ----------
    1. [SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking](https://arxiv.org/abs/2107.05720)

    """

    def __init__(
        self,
        model_name_or_path: str = None,
        device: str = None,
        max_length_query: int = 128,
        max_length_document: int = 256,
        extra_files_to_load: list[str] = ["metadata.json"],
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        super(Splade, self).__init__(
            model_name_or_path=model_name_or_path,
            device=device,
            extra_files_to_load=extra_files_to_load,
            accelerate=accelerate,
            **kwargs,
        )

        self.relu = torch.nn.ReLU().to(self.device)

        if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
            with open(os.path.join(self.model_folder, "metadata.json"), "r") as file:
                metadata = json.load(file)

            max_length_query = metadata["max_length_query"]
            max_length_document = metadata["max_length_document"]

        self.max_length_query = max_length_query
        self.max_length_document = max_length_document

    def encode(
        self,
        texts: list[str],
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Encode documents

        Parameters
        ----------
        texts
            List of documents to encode.
        truncation
            Whether to truncate the documents.
        padding
            Whether to pad the documents.
        max_length
            Maximum length of the documents.
        """
        with torch.no_grad():
            return self(
                texts=texts,
                query_mode=query_mode,
                **kwargs,
            )

    def decode(
        self,
        sparse_activations: torch.Tensor,
        clean_up_tokenization_spaces: bool = False,
        skip_special_tokens: bool = True,
        k_tokens: int = 96,
    ) -> list[str]:
        """Decode activated tokens ids where activated value > 0.

        Parameters
        ----------
        sparse_activations
            Activated tokens.
        clean_up_tokenization_spaces
            Whether to clean up the tokenization spaces.
        skip_special_tokens
            Whether to skip special tokens.
        k_tokens
            Number of tokens to keep.
        """
        activations = self._filter_activations(
            sparse_activations=sparse_activations, k_tokens=k_tokens
        )

        # Decode
        return [
            " ".join(
                activation.translate(str.maketrans("", "", string.punctuation)).split()
            )
            for activation in self.tokenizer.batch_decode(
                activations,
                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                skip_special_tokens=skip_special_tokens,
            )
        ]

    def forward(
        self,
        texts: list[str],
        query_mode: bool,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Pytorch forward method.

        Parameters
        ----------
        texts
            List of documents to encode.
        query_mode
            Whether to encode queries or documents.
        """
        suffix = "[Q] " if query_mode else "[D] "

        texts = [suffix + text for text in texts]

        self.tokenizer.pad_token = (
            self.query_pad_token if query_mode else self.original_pad_token
        )

        k_tokens = self.max_length_query if query_mode else self.max_length_document

        logits, _ = self._encode(
            texts=texts,
            truncation=True,
            padding="max_length",
            max_length=k_tokens,
            add_special_tokens=True,
            **kwargs,
        )

        activations = self._get_activation(logits=logits)

        activations = self._update_activations(
            **activations,
            k_tokens=k_tokens,
        )

        return {"sparse_activations": activations["sparse_activations"]}

    def save_pretrained(
        self,
        path: str,
    ):
        """Save model the model.

        Parameters
        ----------
        path
            Path to save the model.

        """
        self.model.save_pretrained(path)
        self.tokenizer.pad_token = self.original_pad_token

        if self.accelerate:
            self.save_tokenizer_accelerate(path)
        else:
            self.tokenizer.save_pretrained(path)

        with open(os.path.join(path, "metadata.json"), "w") as file:
            json.dump(
                fp=file,
                obj={
                    "max_length_query": self.max_length_query,
                    "max_length_document": self.max_length_document,
                },
                indent=4,
            )

        return self

    def scores(
        self,
        queries: list[str],
        documents: list[str],
        batch_size: int = 32,
        tqdm_bar: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Compute similarity scores between queries and documents.

        Parameters
        ----------
        queries
            List of queries.
        documents
            List of documents.
        batch_size
            Batch size.
        tqdm_bar
            Show a progress bar.
        """
        sparse_scores = []

        for batch_queries, batch_documents in zip(
            utils.batchify(
                X=queries,
                batch_size=batch_size,
                desc="Computing scores.",
                tqdm_bar=tqdm_bar,
            ),
            utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
        ):
            queries_embeddings = self.encode(
                batch_queries,
                query_mode=True,
                **kwargs,
            )

            documents_embeddings = self.encode(
                batch_documents,
                query_mode=False,
                **kwargs,
            )

            sparse_scores.append(
                torch.sum(
                    queries_embeddings["sparse_activations"]
                    * documents_embeddings["sparse_activations"],
                    axis=1,
                )
            )

        return torch.cat(sparse_scores, dim=0)

    def _get_activation(self, logits: torch.Tensor) -> dict[str, torch.Tensor]:
        """Returns activated tokens."""
        return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)}

    def _filter_activations(
        self, sparse_activations: torch.Tensor, k_tokens: int
    ) -> list[torch.Tensor]:
        """Among the set of activations, select the ones with a score > 0."""
        scores, activations = torch.topk(input=sparse_activations, k=k_tokens, dim=-1)
        return [
            torch.index_select(
                activation, dim=-1, index=torch.nonzero(score, as_tuple=True)[0]
            )
            for score, activation in zip(scores, activations)
        ]

    def _update_activations(
        self, sparse_activations: torch.Tensor, k_tokens: int
    ) -> torch.Tensor:
        """Returns activated tokens."""
        activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices
        zero_tensor = torch.zeros_like(sparse_activations, dtype=int)
        updated_sparse_activations = sparse_activations * zero_tensor.scatter(
            dim=1, index=activations.long(), value=1
        )

        return {
            "activations": activations,
            "sparse_activations": updated_sparse_activations,
        }

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splade with the call to save_tokenizer_accelerate parent class :)

Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,11 @@ def forward(

return {"sparse_activations": activations["sparse_activations"]}

def save_pretrained(self, path: str):
def save_pretrained(
self,
path: str,
accelerator: bool = False,
):
"""Save model the model.

Parameters
Expand All @@ -217,7 +221,32 @@ def save_pretrained(self, path: str):
"""
self.model.save_pretrained(path)
self.tokenizer.pad_token = self.original_pad_token
self.tokenizer.save_pretrained(path)
if accelerator:
# Workaround an issue with accelerator. Tokenizer has a key "device"
# which is non serialisable, but not removeable with a basic delattr

# dump config
tokenizer_config = {
k: v for k, v in self.tokenizer.__dict__.items() if k != "device"
}
tokenizer_config_file = os.path.join(path, "tokenizer_config.json")
with open(tokenizer_config_file, "w", encoding="utf-8") as file:
json.dump(tokenizer_config, file, ensure_ascii=False, indent=4)

# dump vocab
self.tokenizer.save_vocabulary(path)

# save special tokens
special_tokens_file = os.path.join(path, "special_tokens_map.json")
with open(special_tokens_file, "w", encoding="utf-8") as file:
json.dump(
self.tokenizer.special_tokens_map,
file,
ensure_ascii=False,
indent=4,
)
else:
self.tokenizer.save_pretrained(path)

with open(os.path.join(path, "metadata.json"), "w") as file:
json.dump(
Expand Down Expand Up @@ -306,15 +335,12 @@ def _update_activations(
) -> torch.Tensor:
"""Returns activated tokens."""
activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices

# Set value of max sparse_activations which are not in top k to 0.
sparse_activations = sparse_activations * torch.zeros(
(sparse_activations.shape[0], sparse_activations.shape[1]),
dtype=int,
device=self.device,
).scatter_(dim=1, index=activations.long(), value=1)
zero_tensor = torch.zeros_like(sparse_activations, dtype=int)
updated_sparse_activations = sparse_activations * zero_tensor.scatter(
dim=1, index=activations.long(), value=1
)

return {
"activations": activations,
"sparse_activations": sparse_activations,
"sparse_activations": updated_sparse_activations,
}
6 changes: 5 additions & 1 deletion neural_cherche/train/train_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def train_colbert(
positive: list[str],
negative: list[str],
in_batch_negatives: bool = False,
accelerator=None,
**kwargs,
):
"""Compute the ranking loss and the flops loss for a single step.
Expand Down Expand Up @@ -98,7 +99,10 @@ def train_colbert(

loss = losses.Ranking()(**scores)

loss.backward()
if accelerator:
accelerator.backward(loss)
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()

Expand Down
6 changes: 5 additions & 1 deletion neural_cherche/train/train_sparse_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def train_sparse_embed(
dense_loss_weight: float = 1.0,
in_batch_negatives: bool = False,
threshold_flops: float = 30,
accelerator=None,
**kwargs,
):
"""Compute the ranking loss and the flops loss for a single step.
Expand Down Expand Up @@ -147,7 +148,10 @@ def train_sparse_embed(
+ flops_loss_weight * flops_loss
)

loss.backward()
if accelerator:
accelerator.backward(loss)
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()

Expand Down
6 changes: 5 additions & 1 deletion neural_cherche/train/train_splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def train_splade(
sparse_loss_weight: float = 1.0,
in_batch_negatives: bool = False,
threshold_flops: float = 30,
accelerator=None,
**kwargs,
):
"""Compute the ranking loss and the flops loss for a single step.
Expand Down Expand Up @@ -117,7 +118,10 @@ def train_splade(

loss = sparse_loss_weight * sparse_loss + flops_loss_weight * flops_loss

loss.backward()
if accelerator:
accelerator.backward(loss)
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()

Expand Down