Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Jul 17, 2024
1 parent 0da2569 commit b3531ce
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 68 deletions.
16 changes: 4 additions & 12 deletions lightning_ir/bi_encoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def __init__(
attend_to_query_expanded_tokens: bool = False,
query_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean",
query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
query_aggregation_function: Literal[
"sum", "mean", "max", "harmonic_mean"
] = "sum",
query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum",
doc_expansion: bool = False,
attend_to_doc_expanded_tokens: bool = False,
doc_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean",
Expand Down Expand Up @@ -116,9 +114,7 @@ def to_dict(self) -> Dict[str, Any]:
output.pop("doc_mask_scoring_tokens")
return output

def save_pretrained(
self, save_directory: str | PathLike, push_to_hub: bool = False, **kwargs
):
def save_pretrained(self, save_directory: str | PathLike, push_to_hub: bool = False, **kwargs):
with open(os.path.join(save_directory, "mask_scoring_tokens.json"), "w") as f:
json.dump(
{
Expand All @@ -133,13 +129,9 @@ def save_pretrained(
def get_config_dict(
cls, pretrained_model_name_or_path: str | PathLike, **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
config_dict, kwargs = super().get_config_dict(
pretrained_model_name_or_path, **kwargs
)
config_dict, kwargs = super().get_config_dict(pretrained_model_name_or_path, **kwargs)
mask_scoring_tokens = None
mask_scoring_tokens_path = os.path.join(
pretrained_model_name_or_path, "mask_scoring_tokens.json"
)
mask_scoring_tokens_path = os.path.join(pretrained_model_name_or_path, "mask_scoring_tokens.json")
if os.path.exists(mask_scoring_tokens_path):
with open(mask_scoring_tokens_path) as f:
mask_scoring_tokens = json.load(f)
Expand Down
74 changes: 18 additions & 56 deletions lightning_ir/bi_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,15 @@ def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None:
self.query_mask_scoring_input_ids: torch.Tensor | None = None
self.doc_mask_scoring_input_ids: torch.Tensor | None = None
for sequence in ("query", "doc"):
mask_scoring_tokens = getattr(
self.config, f"{sequence}_mask_scoring_tokens"
)
mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens")
if mask_scoring_tokens is None:
continue
if mask_scoring_tokens == "punctuation":
mask_scoring_tokens = list(punctuation)
try:
tokenizer = self.config.__class__.tokenizer_class.from_pretrained(
self.config.name_or_path
)
tokenizer = self.config.__class__.tokenizer_class.from_pretrained(self.config.name_or_path)
except OSError:
raise ValueError(
"Can't use token scoring masking if the checkpoint does not "
"have a tokenizer."
)
raise ValueError("Can't use token scoring masking if the checkpoint does not " "have a tokenizer.")
setattr(
self,
f"{sequence}_mask_scoring_input_ids",
Expand Down Expand Up @@ -140,9 +133,7 @@ def _encode(
pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None,
mask_scoring_input_ids: torch.Tensor | None = None,
) -> BiEncoderEmbedding:
embeddings = self.backbone_forward(
input_ids, attention_mask, token_type_ids
).last_hidden_state
embeddings = self.backbone_forward(input_ids, attention_mask, token_type_ids).last_hidden_state
if self.projection is not None:
embeddings = self.projection(embeddings)
embeddings = self.sparsification(embeddings, self.config.sparsification)
Expand All @@ -158,9 +149,7 @@ def _encode(
)
return BiEncoderEmbedding(embeddings, scoring_mask)

def query_scoring_mask(
self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None
) -> torch.Tensor:
def query_scoring_mask(self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None) -> torch.Tensor:
return self._scoring_mask(
input_ids,
attention_mask,
Expand All @@ -169,9 +158,7 @@ def query_scoring_mask(
mask_scoring_input_ids=self.config.query_mask_scoring_input_ids,
)

def doc_scoring_mask(
self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None
) -> torch.Tensor:
def doc_scoring_mask(self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None) -> torch.Tensor:
return self._scoring_mask(
input_ids,
attention_mask,
Expand Down Expand Up @@ -203,9 +190,7 @@ def _scoring_mask(
scoring_mask = torch.ones(shape, dtype=torch.bool, device=device)
scoring_mask = scoring_mask.bool()
if mask_scoring_input_ids is not None and input_ids is not None:
ignore_mask = (
input_ids[..., None].eq(mask_scoring_input_ids.to(device)).any(-1)
)
ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(device)).any(-1)
scoring_mask = scoring_mask & ~ignore_mask
return scoring_mask

Expand All @@ -215,9 +200,7 @@ def score(
doc_embeddings: BiEncoderEmbedding,
num_docs: Sequence[int] | int | None = None,
) -> torch.Tensor:
scores = self.scoring_function.score(
query_embeddings, doc_embeddings, num_docs=num_docs
)
scores = self.scoring_function.score(query_embeddings, doc_embeddings, num_docs=num_docs)
return scores


Expand All @@ -230,13 +213,9 @@ def batch(
def batch_similarity_function(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
if x.shape[0] <= BATCH_SIZE:
return similarity_function(x, y)
out = torch.zeros(
x.shape[0], x.shape[1], y.shape[2], device=x.device, dtype=x.dtype
)
out = torch.zeros(x.shape[0], x.shape[1], y.shape[2], device=x.device, dtype=x.dtype)
for i in range(0, x.shape[0], BATCH_SIZE):
out[i : i + BATCH_SIZE] = similarity_function(
x[i : i + BATCH_SIZE], y[i : i + BATCH_SIZE]
)
out[i : i + BATCH_SIZE] = similarity_function(x[i : i + BATCH_SIZE], y[i : i + BATCH_SIZE])
return out

return batch_similarity_function
Expand All @@ -253,9 +232,7 @@ def __init__(self, config: BiEncoderConfig) -> None:
elif self.config.similarity_function == "dot":
self.similarity_function = self.dot_similarity
else:
raise ValueError(
f"Unknown similarity function {self.config.similarity_function}"
)
raise ValueError(f"Unknown similarity function {self.config.similarity_function}")
self.query_aggregation_function = self.config.query_aggregation_function

def compute_similarity(
Expand All @@ -270,9 +247,7 @@ def compute_similarity(
# doc_tensor = doc_tensor.cuda().half()

# TODO compute similarity only for non-masked values
similarity = self.similarity_function(
query_embeddings.embeddings, doc_embeddings.embeddings
)
similarity = self.similarity_function(query_embeddings.embeddings, doc_embeddings.embeddings)
return similarity

@staticmethod
Expand Down Expand Up @@ -300,16 +275,11 @@ def parse_num_docs(
if isinstance(num_docs, int):
num_docs = [num_docs] * batch_size
if isinstance(num_docs, list):
if (
sum(num_docs) != doc_embeddings.embeddings.shape[0]
or len(num_docs) != batch_size
):
if sum(num_docs) != doc_embeddings.embeddings.shape[0] or len(num_docs) != batch_size:
raise ValueError("Num docs does not match doc embeddings")
if num_docs is None:
if doc_embeddings.embeddings.shape[0] % batch_size != 0:
raise ValueError(
"Docs are not evenly distributed in batch, but no num_docs provided"
)
raise ValueError("Docs are not evenly distributed in batch, but no num_docs provided")
num_docs = [doc_embeddings.embeddings.shape[0] // batch_size] * batch_size
return torch.tensor(num_docs, device=query_embeddings.embeddings.device)

Expand All @@ -328,17 +298,13 @@ def expand_doc_embeddings(
embeddings: BiEncoderEmbedding,
num_docs: torch.Tensor,
) -> BiEncoderEmbedding:
return BiEncoderEmbedding(
embeddings.embeddings.unsqueeze(1), embeddings.scoring_mask.unsqueeze(1)
)
return BiEncoderEmbedding(embeddings.embeddings.unsqueeze(1), embeddings.scoring_mask.unsqueeze(1))

def aggregate(
self,
scores: torch.Tensor,
mask: torch.Tensor | None,
query_aggregation_function: (
Literal["max", "sum", "mean", "harmonic_mean"] | None
),
query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"] | None,
dim: int,
) -> torch.Tensor:
if query_aggregation_function is None:
Expand All @@ -358,9 +324,7 @@ def aggregate(
else:
num_non_masked = mask.sum(dim, keepdim=True)
if query_aggregation_function == "mean":
return torch.where(
num_non_masked == 0, 0, scores.sum(dim, keepdim=True) / num_non_masked
)
return torch.where(num_non_masked == 0, 0, scores.sum(dim, keepdim=True) / num_non_masked)
if query_aggregation_function == "harmonic_mean":
return torch.where(
num_non_masked == 0,
Expand All @@ -380,7 +344,5 @@ def score(
doc_embeddings = self.expand_doc_embeddings(doc_embeddings, num_docs_t)
similarity = self.compute_similarity(query_embeddings, doc_embeddings)
scores = self.aggregate(similarity, doc_embeddings.scoring_mask, "max", -1)
scores = self.aggregate(
scores, query_embeddings.scoring_mask, self.query_aggregation_function, -2
)
scores = self.aggregate(scores, query_embeddings.scoring_mask, self.query_aggregation_function, -2)
return scores[..., 0, 0]

0 comments on commit b3531ce

Please sign in to comment.