From b3531ce899b0a11e563a885d7fcfe3a798e8fe05 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Wed, 17 Jul 2024 15:13:48 +0200 Subject: [PATCH] black --- lightning_ir/bi_encoder/config.py | 16 ++----- lightning_ir/bi_encoder/model.py | 74 ++++++++----------------------- 2 files changed, 22 insertions(+), 68 deletions(-) diff --git a/lightning_ir/bi_encoder/config.py b/lightning_ir/bi_encoder/config.py index 0095da6..0b4f7cd 100644 --- a/lightning_ir/bi_encoder/config.py +++ b/lightning_ir/bi_encoder/config.py @@ -44,9 +44,7 @@ def __init__( attend_to_query_expanded_tokens: bool = False, query_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean", query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, - query_aggregation_function: Literal[ - "sum", "mean", "max", "harmonic_mean" - ] = "sum", + query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum", doc_expansion: bool = False, attend_to_doc_expanded_tokens: bool = False, doc_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean", @@ -116,9 +114,7 @@ def to_dict(self) -> Dict[str, Any]: output.pop("doc_mask_scoring_tokens") return output - def save_pretrained( - self, save_directory: str | PathLike, push_to_hub: bool = False, **kwargs - ): + def save_pretrained(self, save_directory: str | PathLike, push_to_hub: bool = False, **kwargs): with open(os.path.join(save_directory, "mask_scoring_tokens.json"), "w") as f: json.dump( { @@ -133,13 +129,9 @@ def save_pretrained( def get_config_dict( cls, pretrained_model_name_or_path: str | PathLike, **kwargs ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - config_dict, kwargs = super().get_config_dict( - pretrained_model_name_or_path, **kwargs - ) + config_dict, kwargs = super().get_config_dict(pretrained_model_name_or_path, **kwargs) mask_scoring_tokens = None - mask_scoring_tokens_path = os.path.join( - pretrained_model_name_or_path, "mask_scoring_tokens.json" - ) + mask_scoring_tokens_path = os.path.join(pretrained_model_name_or_path, "mask_scoring_tokens.json") if os.path.exists(mask_scoring_tokens_path): with open(mask_scoring_tokens_path) as f: mask_scoring_tokens = json.load(f) diff --git a/lightning_ir/bi_encoder/model.py b/lightning_ir/bi_encoder/model.py index e44f796..9607939 100644 --- a/lightning_ir/bi_encoder/model.py +++ b/lightning_ir/bi_encoder/model.py @@ -47,22 +47,15 @@ def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None: self.query_mask_scoring_input_ids: torch.Tensor | None = None self.doc_mask_scoring_input_ids: torch.Tensor | None = None for sequence in ("query", "doc"): - mask_scoring_tokens = getattr( - self.config, f"{sequence}_mask_scoring_tokens" - ) + mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens") if mask_scoring_tokens is None: continue if mask_scoring_tokens == "punctuation": mask_scoring_tokens = list(punctuation) try: - tokenizer = self.config.__class__.tokenizer_class.from_pretrained( - self.config.name_or_path - ) + tokenizer = self.config.__class__.tokenizer_class.from_pretrained(self.config.name_or_path) except OSError: - raise ValueError( - "Can't use token scoring masking if the checkpoint does not " - "have a tokenizer." - ) + raise ValueError("Can't use token scoring masking if the checkpoint does not " "have a tokenizer.") setattr( self, f"{sequence}_mask_scoring_input_ids", @@ -140,9 +133,7 @@ def _encode( pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None, mask_scoring_input_ids: torch.Tensor | None = None, ) -> BiEncoderEmbedding: - embeddings = self.backbone_forward( - input_ids, attention_mask, token_type_ids - ).last_hidden_state + embeddings = self.backbone_forward(input_ids, attention_mask, token_type_ids).last_hidden_state if self.projection is not None: embeddings = self.projection(embeddings) embeddings = self.sparsification(embeddings, self.config.sparsification) @@ -158,9 +149,7 @@ def _encode( ) return BiEncoderEmbedding(embeddings, scoring_mask) - def query_scoring_mask( - self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None - ) -> torch.Tensor: + def query_scoring_mask(self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None) -> torch.Tensor: return self._scoring_mask( input_ids, attention_mask, @@ -169,9 +158,7 @@ def query_scoring_mask( mask_scoring_input_ids=self.config.query_mask_scoring_input_ids, ) - def doc_scoring_mask( - self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None - ) -> torch.Tensor: + def doc_scoring_mask(self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None) -> torch.Tensor: return self._scoring_mask( input_ids, attention_mask, @@ -203,9 +190,7 @@ def _scoring_mask( scoring_mask = torch.ones(shape, dtype=torch.bool, device=device) scoring_mask = scoring_mask.bool() if mask_scoring_input_ids is not None and input_ids is not None: - ignore_mask = ( - input_ids[..., None].eq(mask_scoring_input_ids.to(device)).any(-1) - ) + ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(device)).any(-1) scoring_mask = scoring_mask & ~ignore_mask return scoring_mask @@ -215,9 +200,7 @@ def score( doc_embeddings: BiEncoderEmbedding, num_docs: Sequence[int] | int | None = None, ) -> torch.Tensor: - scores = self.scoring_function.score( - query_embeddings, doc_embeddings, num_docs=num_docs - ) + scores = self.scoring_function.score(query_embeddings, doc_embeddings, num_docs=num_docs) return scores @@ -230,13 +213,9 @@ def batch( def batch_similarity_function(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: if x.shape[0] <= BATCH_SIZE: return similarity_function(x, y) - out = torch.zeros( - x.shape[0], x.shape[1], y.shape[2], device=x.device, dtype=x.dtype - ) + out = torch.zeros(x.shape[0], x.shape[1], y.shape[2], device=x.device, dtype=x.dtype) for i in range(0, x.shape[0], BATCH_SIZE): - out[i : i + BATCH_SIZE] = similarity_function( - x[i : i + BATCH_SIZE], y[i : i + BATCH_SIZE] - ) + out[i : i + BATCH_SIZE] = similarity_function(x[i : i + BATCH_SIZE], y[i : i + BATCH_SIZE]) return out return batch_similarity_function @@ -253,9 +232,7 @@ def __init__(self, config: BiEncoderConfig) -> None: elif self.config.similarity_function == "dot": self.similarity_function = self.dot_similarity else: - raise ValueError( - f"Unknown similarity function {self.config.similarity_function}" - ) + raise ValueError(f"Unknown similarity function {self.config.similarity_function}") self.query_aggregation_function = self.config.query_aggregation_function def compute_similarity( @@ -270,9 +247,7 @@ def compute_similarity( # doc_tensor = doc_tensor.cuda().half() # TODO compute similarity only for non-masked values - similarity = self.similarity_function( - query_embeddings.embeddings, doc_embeddings.embeddings - ) + similarity = self.similarity_function(query_embeddings.embeddings, doc_embeddings.embeddings) return similarity @staticmethod @@ -300,16 +275,11 @@ def parse_num_docs( if isinstance(num_docs, int): num_docs = [num_docs] * batch_size if isinstance(num_docs, list): - if ( - sum(num_docs) != doc_embeddings.embeddings.shape[0] - or len(num_docs) != batch_size - ): + if sum(num_docs) != doc_embeddings.embeddings.shape[0] or len(num_docs) != batch_size: raise ValueError("Num docs does not match doc embeddings") if num_docs is None: if doc_embeddings.embeddings.shape[0] % batch_size != 0: - raise ValueError( - "Docs are not evenly distributed in batch, but no num_docs provided" - ) + raise ValueError("Docs are not evenly distributed in batch, but no num_docs provided") num_docs = [doc_embeddings.embeddings.shape[0] // batch_size] * batch_size return torch.tensor(num_docs, device=query_embeddings.embeddings.device) @@ -328,17 +298,13 @@ def expand_doc_embeddings( embeddings: BiEncoderEmbedding, num_docs: torch.Tensor, ) -> BiEncoderEmbedding: - return BiEncoderEmbedding( - embeddings.embeddings.unsqueeze(1), embeddings.scoring_mask.unsqueeze(1) - ) + return BiEncoderEmbedding(embeddings.embeddings.unsqueeze(1), embeddings.scoring_mask.unsqueeze(1)) def aggregate( self, scores: torch.Tensor, mask: torch.Tensor | None, - query_aggregation_function: ( - Literal["max", "sum", "mean", "harmonic_mean"] | None - ), + query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"] | None, dim: int, ) -> torch.Tensor: if query_aggregation_function is None: @@ -358,9 +324,7 @@ def aggregate( else: num_non_masked = mask.sum(dim, keepdim=True) if query_aggregation_function == "mean": - return torch.where( - num_non_masked == 0, 0, scores.sum(dim, keepdim=True) / num_non_masked - ) + return torch.where(num_non_masked == 0, 0, scores.sum(dim, keepdim=True) / num_non_masked) if query_aggregation_function == "harmonic_mean": return torch.where( num_non_masked == 0, @@ -380,7 +344,5 @@ def score( doc_embeddings = self.expand_doc_embeddings(doc_embeddings, num_docs_t) similarity = self.compute_similarity(query_embeddings, doc_embeddings) scores = self.aggregate(similarity, doc_embeddings.scoring_mask, "max", -1) - scores = self.aggregate( - scores, query_embeddings.scoring_mask, self.query_aggregation_function, -2 - ) + scores = self.aggregate(scores, query_embeddings.scoring_mask, self.query_aggregation_function, -2) return scores[..., 0, 0]