Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Support multi-GPU training via accelerate #5

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
- [train_colbert](../train/train-colbert)
- [train_sparse_embed](../train/train-sparse-embed)
- [train_splade](../train/train-splade)
- [Multi-GPU training via Accelerator](../train/multi-gpu)

## utils

Expand Down
1 change: 1 addition & 0 deletions docs/fine_tune/.pages
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ nav:
- colbert.md
- splade.md
- sparse_embed.md
- multi_gpu.md

72 changes: 72 additions & 0 deletions docs/fine_tune/multi_gpu.md
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Multi-GPU

Neural-Cherche is compatible with multiples GPUs training using [Accelerator](https://huggingface.co/docs/accelerate/package_reference/accelerator). We can train every models of neural-cherche using GPUs. Here is a tutorial.

```python
import torch
from accelerate import Accelerator
from datasets import Dataset
from torch.utils.data import DataLoader

from neural_cherche import models, train

if __name__ == "__main__":
    # We will need to wrap your training loop in a function to avoid multiprocessing issues.
    accelerator = Accelerator()
    save_each_epoch = True

    model = models.SparseEmbed(
        model_name_or_path="distilbert-base-uncased",
        accelerate=True,
        device=accelerator.device,
    ).to(accelerator.device)

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

    # Dataset creation using HuggingFace Datasets library.
    dataset = Dataset.from_dict(
        {
            "anchors": ["anchor 1", "anchor 2", "anchor 3", "anchor 4"],
            "positives": ["positive 1", "positive 2", "positive 3", "positive 4"],
            "negatives": ["negative 1", "negative 2", "negative 3", "negative 4"],
        }
    )

    # Convert your dataset to a DataLoader.
    data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Wrap model, optimizer, and dataloader in accelerator.
    model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

    for epoch in range(2):
        for batch in enumerate(data_loader):
            # Batch is a triple like (anchors, positives, negatives)
            anchors, positives, negatives = (
                batch["anchors"],
                batch["positives"],
                batch["negatives"],
            )

            loss = train.train_sparse_embed(
                model=model,
                optimizer=optimizer,
                anchor=anchors,
                positive=positives,
                negative=negatives,
                threshold_flops=30,
                accelerator=accelerator,
            )

        if accelerator.is_main_process and save_each_epoch:
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(
                "checkpoint/epoch" + str(epoch),
            )

    # Save at the end of the training loop
    # We check to make sure that only the main process will export the model
    if accelerator.is_main_process:
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained("checkpoint")

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a clear example on how to create the dataset using HuggingFace Datasets

Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Multi-GPU (Partial)

Neural-Cherche is working towards being fully compatible with multiples GPUs training using [Accelerator](https://huggingface.co/docs/accelerate/package_reference/accelerator). At the moment, there is partial compatibility, and we can train every models of neural-cherche using GPUs in most circumstances, although it's not yet fully supported. Here is a tutorial.

```python
import torch
from accelerate import Accelerator
from datasets import Dataset
from torch.utils.data import DataLoader

from neural_cherche import models, train

if __name__ == "__main__":
# We will need to wrap your training loop in a function to avoid multiprocessing issues.
accelerator = Accelerator()
save_each_epoch = True

model = models.SparseEmbed(
model_name_or_path="distilbert-base-uncased",
accelerate=True,
device=accelerator.device,
).to(accelerator.device)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

# Dataset creation using HuggingFace Datasets library.
dataset = Dataset.from_dict(
{
"anchors": ["anchor 1", "anchor 2", "anchor 3", "anchor 4"],
"positives": ["positive 1", "positive 2", "positive 3", "positive 4"],
"negatives": ["negative 1", "negative 2", "negative 3", "negative 4"],
}
)

# Convert your dataset to a DataLoader.
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Wrap model, optimizer, and dataloader in accelerator.
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

for epoch in range(2):
for batch in enumerate(data_loader):
# Batch is a triple like (anchors, positives, negatives)
anchors, positives, negatives = (
batch["anchors"],
batch["positives"],
batch["negatives"],
)

loss = train.train_sparse_embed(
model=model,
optimizer=optimizer,
anchor=anchors,
positive=positives,
negative=negatives,
threshold_flops=30,
accelerator=accelerator,
)

if accelerator.is_main_process and save_each_epoch:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
"checkpoint/epoch" + str(epoch),
)

# Save at the end of the training loop
# We check to make sure that only the main process will export the model
if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained("checkpoint")
```
58 changes: 58 additions & 0 deletions neural_cherche/models/base.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got some troubles with position_ids extra parameters with DistilBERT pre-trained checkpoint but not with all-mpnet-base-v2 pre-trained checkpoint so I think it would be cool to keep the legacy code and add an accelerate attribute to models.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os
from abc import ABC, abstractmethod

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModelForMaskedLM, AutoTokenizer


class Base(ABC, torch.nn.Module):
    """Base class from which all models inherit.

    Parameters
    ----------
    model_name_or_path
        Path to the model or the model name.
    device
        Device to use for the model. CPU or CUDA.
    extra_files_to_load
        List of extra files to load.
    accelerate
        Use HuggingFace Accelerate.
    kwargs
        Additional parameters to the model.
    """

    def __init__(
        self,
        model_name_or_path: str,
        device: str = None,
        extra_files_to_load: list[str] = [],
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        """Initialize the model."""
        super(Base, self).__init__()

        if device is not None:
            self.device = device

        elif torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"

        self.accelerate = accelerate

        os.environ["TRANSFORMERS_CACHE"] = "."
        self.model = AutoModelForMaskedLM.from_pretrained(
            model_name_or_path, cache_dir="./", **kwargs
        ).to(self.device)

        # Download linear layer if exists
        for file in extra_files_to_load:
            try:
                _ = hf_hub_download(model_name_or_path, filename=file, cache_dir=".")
            except:
                pass

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, device=self.device, cache_dir="./", **kwargs
        )

        self.model.config.output_hidden_states = True

        if os.path.exists(model_name_or_path):
            # Local checkpoint
            self.model_folder = model_name_or_path
        else:
            # HuggingFace checkpoint
            model_folder = os.path.join(
                f"models--{model_name_or_path}".replace("/", "--"), "snapshots"
            )
            snapshot = os.listdir(model_folder)[-1]
            self.model_folder = os.path.join(model_folder, snapshot)

        self.query_pad_token = self.tokenizer.mask_token
        self.original_pad_token = self.tokenizer.pad_token

    def _encode_accelerate(self, texts: list[str], **kwargs) -> tuple[torch.Tensor]:
        """Encode sentences with multiples gpus.

        Parameters
        ----------
        texts
            List of sentences to encode.

        References
        ----------
        [Accelerate issue.](https://github.com/huggingface/accelerate/issues/97)
        """
        encoded_input = self.tokenizer(texts, return_tensors="pt", **kwargs).to(
            self.device
        )

        position_ids = (
            torch.arange(0, encoded_input["input_ids"].size(1))
            .expand((len(texts), -1))
            .to(self.device)
        )

        output = self.model(**encoded_input, position_ids=position_ids)
        return output.logits, output.hidden_states[-1]

    def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        """Encode sentences.

        Parameters
        ----------
        texts
            List of sentences to encode.
        """
        if self.accelerate:
            return self._encode_accelerate(texts, **kwargs)

        encoded_input = self.tokenizer.batch_encode_plus(
            texts, return_tensors="pt", **kwargs
        )

        if self.device != "cpu":
            encoded_input = {
                key: value.to(self.device) for key, value in encoded_input.items()
            }

        output = self.model(**encoded_input)
        return output.logits, output.hidden_states[-1]

    @abstractmethod
    def forward(self, *args, **kwargs):
        """Pytorch forward method."""
        pass

    @abstractmethod
    def encode(self, *args, **kwargs):
        """Encode documents."""
        pass

    @abstractmethod
    def scores(self, *args, **kwars):
        """Compute scores."""
        pass

    @abstractmethod
    def save_pretrained(self, path: str):
        """Save model the model."""
        pass

    def save_tokenizer_accelerate(self, path: str) -> None:
        """Save tokenizer when using accelerate."""
        tokenizer_config = {
            k: v for k, v in self.tokenizer.__dict__.items() if k != "device"
        }
        tokenizer_config_file = os.path.join(path, "tokenizer_config.json")
        with open(tokenizer_config_file, "w", encoding="utf-8") as file:
            json.dump(tokenizer_config, file, ensure_ascii=False, indent=4)

        # dump vocab
        self.tokenizer.save_vocabulary(path)

        # save special tokens
        special_tokens_file = os.path.join(path, "special_tokens_map.json")
        with open(special_tokens_file, "w", encoding="utf-8") as file:
            json.dump(
                self.tokenizer.special_tokens_map,
                file,
                ensure_ascii=False,
                indent=4,
            )

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the base class updated with a new save_tokenizer_accelerate and accelerate attribute

Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from abc import ABC, abstractmethod

Expand All @@ -15,6 +16,10 @@ class Base(ABC, torch.nn.Module):
Path to the model or the model name.
device
Device to use for the model. CPU or CUDA.
extra_files_to_load
List of extra files to load.
accelerate
Use HuggingFace Accelerate.
kwargs
Additional parameters to the model.
"""
Expand All @@ -24,6 +29,7 @@ def __init__(
model_name_or_path: str,
device: str = None,
extra_files_to_load: list[str] = [],
accelerate: bool = False,
query_prefix: str = "[Q] ",
document_prefix: str = "[D] ",
**kwargs,
Expand All @@ -42,6 +48,8 @@ def __init__(
else:
self.device = "cpu"

self.accelerate = accelerate

os.environ["TRANSFORMERS_CACHE"] = "."
self.model = AutoModelForMaskedLM.from_pretrained(
model_name_or_path, cache_dir="./", **kwargs
Expand Down Expand Up @@ -74,6 +82,31 @@ def __init__(
self.query_pad_token = self.tokenizer.mask_token
self.original_pad_token = self.tokenizer.pad_token

def _encode_accelerate(self, texts: list[str], **kwargs) -> tuple[torch.Tensor]:
"""Encode sentences with multiples gpus.

Parameters
----------
texts
List of sentences to encode.

References
----------
[Accelerate issue.](https://github.com/huggingface/accelerate/issues/97)
"""
encoded_input = self.tokenizer(texts, return_tensors="pt", **kwargs).to(
self.device
)

position_ids = (
torch.arange(0, encoded_input["input_ids"].size(1))
.expand((len(texts), -1))
.to(self.device)
)

output = self.model(**encoded_input, position_ids=position_ids)
return output.logits, output.hidden_states[-1]

def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
"""Encode sentences.

Expand All @@ -82,6 +115,9 @@ def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tenso
texts
List of sentences to encode.
"""
if self.accelerate:
return self._encode_accelerate(texts, **kwargs)

encoded_input = self.tokenizer.batch_encode_plus(
texts, return_tensors="pt", **kwargs
)
Expand Down Expand Up @@ -113,3 +149,25 @@ def scores(self, *args, **kwars):
def save_pretrained(self, path: str):
"""Save model the model."""
pass

def save_tokenizer_accelerate(self, path: str) -> None:
"""Save tokenizer when using accelerate."""
tokenizer_config = {
k: v for k, v in self.tokenizer.__dict__.items() if k != "device"
}
tokenizer_config_file = os.path.join(path, "tokenizer_config.json")
with open(tokenizer_config_file, "w", encoding="utf-8") as file:
json.dump(tokenizer_config, file, ensure_ascii=False, indent=4)

# dump vocab
self.tokenizer.save_vocabulary(path)

# save special tokens
special_tokens_file = os.path.join(path, "special_tokens_map.json")
with open(special_tokens_file, "w", encoding="utf-8") as file:
json.dump(
self.tokenizer.special_tokens_map,
file,
ensure_ascii=False,
indent=4,
)
16 changes: 11 additions & 5 deletions neural_cherche/models/colbert.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os

import torch

from .. import utils
from .base import Base

__all__ = ["ColBERT"]


class ColBERT(Base):
    """ColBERT model.

    Parameters
    ----------
    model_name_or_path
        Path to the model or the model name.
    embedding_size
        Size of the embeddings in output of ColBERT model.
    device
        Device to use for the model. CPU or CUDA.
    accelerate
        Use HuggingFace Accelerate.
    kwargs
        Additional parameters to the SentenceTransformer model.

    Examples
    --------
    >>> from neural_cherche import models
    >>> import torch

    >>> _ = torch.manual_seed(42)

    >>> queries = ["Berlin", "Paris", "London"]

    >>> documents = [
    ...     "Berlin is the capital of Germany",
    ...     "Paris is the capital of France and France is in Europe",
    ...     "London is the capital of England",
    ... ]

    >>> encoder = models.ColBERT(
    ...     model_name_or_path="sentence-transformers/all-mpnet-base-v2",
    ...     embedding_size=128,
    ...     max_length_query=32,
    ...     max_length_document=350,
    ... )

    >>> scores = encoder.scores(
    ...    queries=queries,
    ...    documents=documents,
    ... )

    >>> scores
    tensor([22.9325, 19.8296, 20.8019])

    >>> _ = encoder.save_pretrained("checkpoint", accelerate=False)

    >>> encoder = models.ColBERT(
    ...     model_name_or_path="checkpoint",
    ...     embedding_size=64,
    ...     device="cpu",
    ... )

    >>> scores = encoder.scores(
    ...    queries=queries,
    ...    documents=documents,
    ... )

    >>> scores
    tensor([22.9325, 19.8296, 20.8019])

    >>> embeddings = encoder(
    ...     texts=queries,
    ...     query_mode=True
    ... )

    >>> embeddings["embeddings"].shape
    torch.Size([3, 32, 128])

    >>> embeddings = encoder(
    ...     texts=queries,
    ...     query_mode=False
    ... )

    >>> embeddings["embeddings"].shape
    torch.Size([3, 350, 128])

    """

    def __init__(
        self,
        model_name_or_path: str,
        embedding_size: int = 128,
        device: str = None,
        max_length_query: int = 32,
        max_length_document: int = 350,
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        """Initialize the model."""
        super(ColBERT, self).__init__(
            model_name_or_path=model_name_or_path,
            device=device,
            extra_files_to_load=["linear.pt", "metadata.json"],
            accelerate=accelerate,
            **kwargs,
        )

        self.max_length_query = max_length_query
        self.max_length_document = max_length_document
        self.embedding_size = embedding_size

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            linear = torch.load(
                os.path.join(self.model_folder, "linear.pt"), map_location=self.device
            )
            self.embedding_size = linear["weight"].shape[0]
            in_features = linear["weight"].shape[1]
        else:
            with torch.no_grad():
                _, embeddings = self._encode(texts=["test"])
                in_features = embeddings.shape[2]

        self.linear = torch.nn.Linear(
            in_features=in_features,
            out_features=self.embedding_size,
            bias=False,
            device=self.device,
        )

        if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
            with open(os.path.join(self.model_folder, "metadata.json"), "r") as f:
                metadata = json.load(f)
            self.max_length_document = metadata["max_length_document"]
            self.max_length_query = metadata["max_length_query"]

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            self.linear.load_state_dict(linear)

    def encode(
        self,
        texts: list[str],
        truncation: bool = True,
        add_special_tokens: bool = False,
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Encode documents

        Parameters
        ----------
        texts
            List of sentences to encode.
        truncation
            Truncate the inputs.
        add_special_tokens
            Add special tokens.
        max_length
            Maximum length of the inputs.
        """
        with torch.no_grad():
            embeddings = self(
                texts=texts,
                truncation=truncation,
                add_special_tokens=add_special_tokens,
                query_mode=query_mode,
                **kwargs,
            )
        return embeddings

    def forward(
        self,
        texts: list[str],
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Pytorch forward method.

        Parameters
        ----------
        texts
            List of sentences to encode.
        query_mode
            Wether to encode query or not.
        """
        suffix = "[Q] " if query_mode else "[D] "

        texts = [suffix + text for text in texts]

        self.tokenizer.pad_token = (
            self.query_pad_token if query_mode else self.original_pad_token
        )

        kwargs = {
            "truncation": True,
            "padding": "max_length",
            "max_length": self.max_length_query
            if query_mode
            else self.max_length_document,
            "add_special_tokens": True,
            **kwargs,
        }

        _, embeddings = self._encode(texts=texts, **kwargs)

        return {
            "embeddings": torch.nn.functional.normalize(
                self.linear(embeddings), p=2, dim=2
            )
        }

    def scores(
        self,
        queries: list[str],
        documents: list[str],
        batch_size: int = 2,
        tqdm_bar: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Score queries and documents.

        Parameters
        ----------
        queries
            List of queries.
        documents
            List of documents.
        batch_size
            Batch size.
        truncation
            Truncate the inputs.
        add_special_tokens
            Add special tokens.
        tqdm_bar
            Show tqdm bar.
        """
        list_scores = []

        for batch_queries, batch_documents in zip(
            utils.batchify(
                X=queries,
                batch_size=batch_size,
                desc="Computing scores.",
                tqdm_bar=tqdm_bar,
            ),
            utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
        ):
            queries_embeddings = self.encode(
                texts=batch_queries,
                query_mode=True,
                **kwargs,
            )

            documents_embeddings = self.encode(
                texts=batch_documents,
                query_mode=False,
                **kwargs,
            )

            late_interactions = torch.einsum(
                "bsh,bth->bst",
                queries_embeddings["embeddings"],
                documents_embeddings["embeddings"],
            )

            late_interactions = torch.max(late_interactions, axis=2).values.sum(axis=1)

            list_scores.append(late_interactions)

        return torch.cat(list_scores, dim=0)

    def save_pretrained(self, path: str) -> "ColBERT":
        """Save model the model.

        Parameters
        ----------
        path
            Path to save the model.
        """
        self.model.save_pretrained(path)
        torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
        self.tokenizer.pad_token = self.original_pad_token
        with open(os.path.join(path, "metadata.json"), "w") as f:
            json.dump(
                {
                    "max_length_query": self.max_length_query,
                    "max_length_document": self.max_length_document,
                },
                f,
            )
        if self.accelerate:
            self.save_tokenizer_accelerate(path=path)
        else:
            self.tokenizer.save_pretrained(path)
        return self

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Colbert with the call to save_tokenizer_accelerate parent class :)

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class ColBERT(Base):
Size of the embeddings in output of ColBERT model.
device
Device to use for the model. CPU or CUDA.
accelerate
Use HuggingFace Accelerate.
kwargs
Additional parameters to the SentenceTransformer model.

Expand All @@ -43,7 +45,6 @@ class ColBERT(Base):
... embedding_size=128,
... max_length_query=32,
... max_length_document=350,
... device="mps",
... )

>>> scores = encoder.scores(
Expand All @@ -52,9 +53,9 @@ class ColBERT(Base):
... )

>>> scores
tensor([20.2148, 16.7599, 18.2901], device='mps:0')
tensor([22.9325, 19.8296, 20.8019])

>>> _ = encoder.save_pretrained("checkpoint")
>>> _ = encoder.save_pretrained("checkpoint", accelerate=False)

>>> encoder = models.ColBERT(
... model_name_or_path="checkpoint",
Expand All @@ -68,7 +69,7 @@ class ColBERT(Base):
... )

>>> scores
tensor([20.2148, 16.7599, 18.2901])
tensor([22.9325, 19.8296, 20.8019])

>>> embeddings = encoder(
... texts=queries,
Expand All @@ -95,6 +96,7 @@ def __init__(
device: str = None,
max_length_query: int = 32,
max_length_document: int = 350,
accelerate: bool = False,
query_prefix: str = "[Q] ",
document_prefix: str = "[D] ",
**kwargs,
Expand All @@ -104,6 +106,7 @@ def __init__(
model_name_or_path=model_name_or_path,
device=device,
extra_files_to_load=["linear.pt", "metadata.json"],
accelerate=accelerate,
query_prefix=query_prefix,
document_prefix=document_prefix,
**kwargs,
Expand Down Expand Up @@ -285,7 +288,6 @@ def save_pretrained(self, path: str) -> "ColBERT":
self.model.save_pretrained(path)
torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
self.tokenizer.pad_token = self.original_pad_token
self.tokenizer.save_pretrained(path)
with open(os.path.join(path, "metadata.json"), "w") as f:
json.dump(
{
Expand All @@ -296,4 +298,8 @@ def save_pretrained(self, path: str) -> "ColBERT":
},
f,
)
if self.accelerate:
self.save_tokenizer_accelerate(path=path)
else:
self.tokenizer.save_pretrained(path)
return self
13 changes: 11 additions & 2 deletions neural_cherche/models/sparse_embed.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os

import torch

from .. import utils

__all__ = ["SparseEmbed"]

from .splade import Splade


class SparseEmbed(Splade):
    """SparseEmbed model.

    Parameters
    ----------
    model_name_or_path
        Path to the model or the model name. It should be a SentenceTransformer model.
    embedding_size
        Size of the embeddings in output of SparsEmbed model.
    kwargs
        Additional parameters to the pre-trained model.

    Examples
    --------
    >>> from neural_cherche import models
    >>> import torch

    >>> _ = torch.manual_seed(42)

    >>> device = "mps"

    >>> model = models.SparseEmbed(
    ...     model_name_or_path="distilbert-base-uncased",
    ...     device=device,
    ... )

    >>> queries_embeddings = model.encode(
    ...     ["Sports", "Music"],
    ... )

    >>> queries_embeddings["activations"].shape
    torch.Size([2, 128])

    >>> queries_embeddings["sparse_activations"].shape
    torch.Size([2, 30522])

    >>> queries_embeddings["embeddings"].shape
    torch.Size([2, 128, 128])

    >>> documents_embeddings = model.encode(
    ...    ["Music is great.", "Sports is great."],
    ...    query_mode=False,
    ... )

    >>> documents_embeddings["activations"].shape
    torch.Size([2, 256])

    >>> documents_embeddings["sparse_activations"].shape
    torch.Size([2, 30522])

    >>> documents_embeddings["embeddings"].shape
    torch.Size([2, 256, 128])

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=1,
    ... )
    tensor([64.2330, 54.0180], device='mps:0')

    >>> _ = model.save_pretrained("checkpoint")

    >>> model = models.SparseEmbed(
    ...     model_name_or_path="checkpoint",
    ...     device="cpu",
    ... )

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=2,
    ... )
    tensor([64.2330, 54.0180])

    References
    ----------
    1. [SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://dl.acm.org/doi/pdf/10.1145/3539618.3592065)

    """

    def __init__(
        self,
        model_name_or_path: str = None,
        embedding_size: int = 128,
        max_length_query: int = 128,
        max_length_document: int = 256,
        device: str = None,
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        super(SparseEmbed, self).__init__(
            model_name_or_path=model_name_or_path,
            device=device,
            extra_files_to_load=["linear.pt", "metadata.json"],
            accelerate=accelerate,
            **kwargs,
        )

        self.embedding_size = embedding_size

        self.softmax = torch.nn.Softmax(dim=2).to(self.device)

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            linear = torch.load(
                os.path.join(self.model_folder, "linear.pt"), map_location=self.device
            )
            self.embedding_size = linear["weight"].shape[0]
            in_features = linear["weight"].shape[1]
        else:
            with torch.no_grad():
                _, embeddings = self._encode(texts=["test"])
                in_features = embeddings.shape[2]

        self.linear = torch.nn.Linear(
            in_features=in_features,
            out_features=self.embedding_size,
            bias=False,
            device=self.device,
        )

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            self.linear.load_state_dict(linear)

        if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
            with open(os.path.join(self.model_folder, "metadata.json"), "r") as file:
                metadata = json.load(file)

            max_length_query = metadata["max_length_query"]
            max_length_document = metadata["max_length_document"]

        self.max_length_query = max_length_query
        self.max_length_document = max_length_document

    def forward(
        self,
        texts: list[str],
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Pytorch forward method.

        Parameters
        ----------
        texts
            List of documents to encode.
        query_mode
            Whether to encode queries or documents.
        """
        suffix = "[Q] " if query_mode else "[D] "

        texts = [suffix + text for text in texts]

        self.tokenizer.pad_token = (
            self.query_pad_token if query_mode else self.original_pad_token
        )

        k_tokens = self.max_length_query if query_mode else self.max_length_document

        logits, embeddings = self._encode(
            texts=texts,
            truncation=True,
            padding="max_length",
            max_length=k_tokens,
            add_special_tokens=True,
            **kwargs,
        )

        activations = self._update_activations(
            **self._get_activation(logits=logits),
            k_tokens=k_tokens,
        )

        attention = self._get_attention(
            logits=logits,
            activations=activations["activations"],
        )

        embeddings = torch.bmm(
            attention,
            embeddings,
        )

        return {
            "embeddings": self.relu(self.linear(embeddings)),
            "sparse_activations": activations["sparse_activations"],
            "activations": activations["activations"],
        }

    def _get_attention(
        self, logits: torch.Tensor, activations: torch.Tensor
    ) -> torch.Tensor:
        """Extract attention scores from MLM logits based on activated tokens."""
        attention = logits.gather(
            dim=2,
            index=torch.stack(
                [
                    torch.stack([token for _ in range(logits.shape[1])])
                    for token in activations
                ]
            ),
        )

        return self.softmax(attention)

    def save_pretrained(
        self,
        path: str,
    ):
        """Save model the model."""
        self.model.save_pretrained(path)
        self.tokenizer.pad_token = self.original_pad_token

        if self.accelerate:
            self.save_tokenizer_accelerate(path)
        else:
            self.tokenizer.save_pretrained(path)
        torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
        with open(os.path.join(path, "metadata.json"), "w") as file:
            json.dump(
                fp=file,
                obj={
                    "max_length_query": self.max_length_query,
                    "max_length_document": self.max_length_document,
                },
                indent=4,
            )

        return self

    def scores(
        self,
        queries: list[str],
        documents: list[str],
        batch_size: int = 32,
        tqdm_bar: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Compute similarity scores between queries and documents."""
        dense_scores = []

        for batch_queries, batch_documents in zip(
            utils.batchify(
                X=queries,
                batch_size=batch_size,
                desc="Computing scores.",
                tqdm_bar=tqdm_bar,
            ),
            utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
        ):
            queries_embeddings = self.encode(
                texts=batch_queries,
                query_mode=True,
                **kwargs,
            )

            documents_embeddings = self.encode(
                texts=batch_documents,
                query_mode=False,
                **kwargs,
            )

            dense_scores.append(
                utils.pairs_dense_scores(
                    queries_activations=queries_embeddings["activations"],
                    documents_activations=documents_embeddings["activations"],
                    queries_embeddings=queries_embeddings["embeddings"],
                    documents_embeddings=documents_embeddings["embeddings"],
                )
            )

        return torch.cat(dense_scores, dim=0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparseEmbed with the call to save_tokenizer_accelerate parent class :)

Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
max_length_query: int = 128,
max_length_document: int = 256,
device: str = None,
accelerate: bool = False,
query_prefix: str = "[Q] ",
document_prefix: str = "[D] ",
**kwargs,
Expand All @@ -105,6 +106,7 @@ def __init__(
model_name_or_path=model_name_or_path,
device=device,
extra_files_to_load=["linear.pt", "metadata.json"],
accelerate=accelerate,
query_prefix=query_prefix,
document_prefix=document_prefix,
**kwargs,
Expand Down Expand Up @@ -218,11 +220,18 @@ def _get_attention(

return self.softmax(attention)

def save_pretrained(self, path: str):
def save_pretrained(
self,
path: str,
):
"""Save model the model."""
self.model.save_pretrained(path)
self.tokenizer.pad_token = self.original_pad_token
self.tokenizer.save_pretrained(path)

if self.accelerate:
self.save_tokenizer_accelerate(path)
else:
self.tokenizer.save_pretrained(path)
torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
with open(os.path.join(path, "metadata.json"), "w") as file:
json.dump(
Expand Down
26 changes: 16 additions & 10 deletions neural_cherche/models/splade.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os
import string

import torch

from .. import utils
from .base import Base

__all__ = ["Splade"]


class Splade(Base):
    """SpladeV1 model.

    Parameters
    ----------
    tokenizer
        HuggingFace Tokenizer.
    model
        HuggingFace AutoModelForMaskedLM.
    kwargs
        Additional parameters to the SentenceTransformer model.

    Examples
    --------
    >>> from neural_cherche import models
    >>> import torch

    >>> _ = torch.manual_seed(42)

    >>> model = models.Splade(
    ...     model_name_or_path="distilbert-base-uncased",
    ...     device="mps",
    ... )

    >>> queries_activations = model.encode(
    ...     ["Sports", "Music"],
    ... )

    >>> documents_activations = model.encode(
    ...    ["Music is great.", "Sports is great."],
    ...    query_mode=False,
    ... )

    >>> queries_activations["sparse_activations"].shape
    torch.Size([2, 30522])

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=1
    ... )
    tensor([318.1384, 271.8006], device='mps:0')

    >>> _ = model.save_pretrained("checkpoint")

    >>> model = models.Splade(
    ...     model_name_or_path="checkpoint",
    ...     device="mps",
    ... )

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=1
    ... )
    tensor([318.1384, 271.8006], device='mps:0')

    References
    ----------
    1. [SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking](https://arxiv.org/abs/2107.05720)

    """

    def __init__(
        self,
        model_name_or_path: str = None,
        device: str = None,
        max_length_query: int = 128,
        max_length_document: int = 256,
        extra_files_to_load: list[str] = ["metadata.json"],
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        super(Splade, self).__init__(
            model_name_or_path=model_name_or_path,
            device=device,
            extra_files_to_load=extra_files_to_load,
            accelerate=accelerate,
            **kwargs,
        )

        self.relu = torch.nn.ReLU().to(self.device)

        if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
            with open(os.path.join(self.model_folder, "metadata.json"), "r") as file:
                metadata = json.load(file)

            max_length_query = metadata["max_length_query"]
            max_length_document = metadata["max_length_document"]

        self.max_length_query = max_length_query
        self.max_length_document = max_length_document

    def encode(
        self,
        texts: list[str],
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Encode documents

        Parameters
        ----------
        texts
            List of documents to encode.
        truncation
            Whether to truncate the documents.
        padding
            Whether to pad the documents.
        max_length
            Maximum length of the documents.
        """
        with torch.no_grad():
            return self(
                texts=texts,
                query_mode=query_mode,
                **kwargs,
            )

    def decode(
        self,
        sparse_activations: torch.Tensor,
        clean_up_tokenization_spaces: bool = False,
        skip_special_tokens: bool = True,
        k_tokens: int = 96,
    ) -> list[str]:
        """Decode activated tokens ids where activated value > 0.

        Parameters
        ----------
        sparse_activations
            Activated tokens.
        clean_up_tokenization_spaces
            Whether to clean up the tokenization spaces.
        skip_special_tokens
            Whether to skip special tokens.
        k_tokens
            Number of tokens to keep.
        """
        activations = self._filter_activations(
            sparse_activations=sparse_activations, k_tokens=k_tokens
        )

        # Decode
        return [
            " ".join(
                activation.translate(str.maketrans("", "", string.punctuation)).split()
            )
            for activation in self.tokenizer.batch_decode(
                activations,
                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                skip_special_tokens=skip_special_tokens,
            )
        ]

    def forward(
        self,
        texts: list[str],
        query_mode: bool,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Pytorch forward method.

        Parameters
        ----------
        texts
            List of documents to encode.
        query_mode
            Whether to encode queries or documents.
        """
        suffix = "[Q] " if query_mode else "[D] "

        texts = [suffix + text for text in texts]

        self.tokenizer.pad_token = (
            self.query_pad_token if query_mode else self.original_pad_token
        )

        k_tokens = self.max_length_query if query_mode else self.max_length_document

        logits, _ = self._encode(
            texts=texts,
            truncation=True,
            padding="max_length",
            max_length=k_tokens,
            add_special_tokens=True,
            **kwargs,
        )

        activations = self._get_activation(logits=logits)

        activations = self._update_activations(
            **activations,
            k_tokens=k_tokens,
        )

        return {"sparse_activations": activations["sparse_activations"]}

    def save_pretrained(
        self,
        path: str,
    ):
        """Save model the model.

        Parameters
        ----------
        path
            Path to save the model.

        """
        self.model.save_pretrained(path)
        self.tokenizer.pad_token = self.original_pad_token

        if self.accelerate:
            self.save_tokenizer_accelerate(path)
        else:
            self.tokenizer.save_pretrained(path)

        with open(os.path.join(path, "metadata.json"), "w") as file:
            json.dump(
                fp=file,
                obj={
                    "max_length_query": self.max_length_query,
                    "max_length_document": self.max_length_document,
                },
                indent=4,
            )

        return self

    def scores(
        self,
        queries: list[str],
        documents: list[str],
        batch_size: int = 32,
        tqdm_bar: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Compute similarity scores between queries and documents.

        Parameters
        ----------
        queries
            List of queries.
        documents
            List of documents.
        batch_size
            Batch size.
        tqdm_bar
            Show a progress bar.
        """
        sparse_scores = []

        for batch_queries, batch_documents in zip(
            utils.batchify(
                X=queries,
                batch_size=batch_size,
                desc="Computing scores.",
                tqdm_bar=tqdm_bar,
            ),
            utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
        ):
            queries_embeddings = self.encode(
                batch_queries,
                query_mode=True,
                **kwargs,
            )

            documents_embeddings = self.encode(
                batch_documents,
                query_mode=False,
                **kwargs,
            )

            sparse_scores.append(
                torch.sum(
                    queries_embeddings["sparse_activations"]
                    * documents_embeddings["sparse_activations"],
                    axis=1,
                )
            )

        return torch.cat(sparse_scores, dim=0)

    def _get_activation(self, logits: torch.Tensor) -> dict[str, torch.Tensor]:
        """Returns activated tokens."""
        return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)}

    def _filter_activations(
        self, sparse_activations: torch.Tensor, k_tokens: int
    ) -> list[torch.Tensor]:
        """Among the set of activations, select the ones with a score > 0."""
        scores, activations = torch.topk(input=sparse_activations, k=k_tokens, dim=-1)
        return [
            torch.index_select(
                activation, dim=-1, index=torch.nonzero(score, as_tuple=True)[0]
            )
            for score, activation in zip(scores, activations)
        ]

    def _update_activations(
        self, sparse_activations: torch.Tensor, k_tokens: int
    ) -> torch.Tensor:
        """Returns activated tokens."""
        activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices
        zero_tensor = torch.zeros_like(sparse_activations, dtype=int)
        updated_sparse_activations = sparse_activations * zero_tensor.scatter(
            dim=1, index=activations.long(), value=1
        )

        return {
            "activations": activations,
            "sparse_activations": updated_sparse_activations,
        }

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splade with the call to save_tokenizer_accelerate parent class :)

Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
max_length_query: int = 128,
max_length_document: int = 256,
extra_files_to_load: list[str] = ["metadata.json"],
accelerate: bool = False,
query_prefix: str = "[Q] ",
document_prefix: str = "[D] ",
**kwargs,
Expand All @@ -88,6 +89,7 @@ def __init__(
model_name_or_path=model_name_or_path,
device=device,
extra_files_to_load=extra_files_to_load,
accelerate=accelerate,
query_prefix=query_prefix,
document_prefix=document_prefix,
**kwargs,
Expand Down Expand Up @@ -212,7 +214,10 @@ def forward(

return {"sparse_activations": activations["sparse_activations"]}

def save_pretrained(self, path: str):
def save_pretrained(
self,
path: str,
):
"""Save model the model.

Parameters
Expand All @@ -223,7 +228,11 @@ def save_pretrained(self, path: str):
"""
self.model.save_pretrained(path)
self.tokenizer.pad_token = self.original_pad_token
self.tokenizer.save_pretrained(path)

if self.accelerate:
self.save_tokenizer_accelerate(path)
else:
self.tokenizer.save_pretrained(path)

with open(os.path.join(path, "metadata.json"), "w") as file:
json.dump(
Expand Down Expand Up @@ -314,15 +323,12 @@ def _update_activations(
) -> torch.Tensor:
"""Returns activated tokens."""
activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices

# Set value of max sparse_activations which are not in top k to 0.
sparse_activations = sparse_activations * torch.zeros(
(sparse_activations.shape[0], sparse_activations.shape[1]),
dtype=int,
device=self.device,
).scatter_(dim=1, index=activations.long(), value=1)
zero_tensor = torch.zeros_like(sparse_activations, dtype=int)
updated_sparse_activations = sparse_activations * zero_tensor.scatter(
dim=1, index=activations.long(), value=1
)

return {
"activations": activations,
"sparse_activations": sparse_activations,
"sparse_activations": updated_sparse_activations,
}
6 changes: 5 additions & 1 deletion neural_cherche/train/train_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def train_colbert(
positive: list[str],
negative: list[str],
in_batch_negatives: bool = False,
accelerator=None,
**kwargs,
):
"""Compute the ranking loss and the flops loss for a single step.
Expand Down Expand Up @@ -98,7 +99,10 @@ def train_colbert(

loss = losses.Ranking()(**scores)

loss.backward()
if accelerator:
accelerator.backward(loss)
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()

Expand Down
6 changes: 5 additions & 1 deletion neural_cherche/train/train_sparse_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def train_sparse_embed(
dense_loss_weight: float = 1.0,
in_batch_negatives: bool = False,
threshold_flops: float = 30,
accelerator=None,
**kwargs,
):
"""Compute the ranking loss and the flops loss for a single step.
Expand Down Expand Up @@ -147,7 +148,10 @@ def train_sparse_embed(
+ flops_loss_weight * flops_loss
)

loss.backward()
if accelerator:
accelerator.backward(loss)
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()

Expand Down
Loading