Skip to content

Commit

Permalink
[Bart] Move CLS rep extraction from EOS tokens to head classes
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Dec 21, 2023
1 parent c921726 commit 1c9ecca
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 29 deletions.
48 changes: 31 additions & 17 deletions src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
)
from transformers.utils import ModelOutput

from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition
from ..composition import (
AdapterCompositionBlock,
BatchSplit,
Parallel,
adjust_tensors_for_parallel,
parse_heads_from_composition,
)
from ..context import AdapterSetup, ForwardContext
from ..loading import PredictionHeadLoader
from ..methods.modeling import Activation_Function_Class
Expand Down Expand Up @@ -105,6 +111,21 @@ def get_output_embeddings(self):
def get_label_names(self):
return ["labels"]

def _get_cls_output(self, outputs, **kwargs):
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
elif kwargs.get("get_cls_from_eos_tokens", False):
x = outputs[0] # last hidden state
eos_mask = kwargs.get("eos_mask")
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_output = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_output = outputs[0][:, 0]

return cls_output


class ClassificationHead(PredictionHead):
def __init__(
Expand Down Expand Up @@ -134,10 +155,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
loss = None
labels = kwargs.pop("labels", None)
Expand Down Expand Up @@ -205,10 +223,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
loss = None
labels = kwargs.pop("labels", None)
Expand Down Expand Up @@ -271,10 +286,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=None, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
logits = logits.view(-1, self.config["num_choices"])
loss = None
Expand Down Expand Up @@ -476,10 +488,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
loss = None
labels = kwargs.pop("labels", None)
Expand Down Expand Up @@ -800,6 +809,9 @@ def forward_head(
cls_output (torch.Tensor, optional): The classification output of the model.
attention_mask (torch.Tensor, optional): The attention mask of the model.
return_dict (bool): Whether or not to return a ``ModelOutput`` instead of a plain tuple.
get_cls_from_eos_tokens (bool):
If set to True, retrieve classifier token representations from the last <eos> token in the sequence.
Setting to True requires `eos_mask` to be passed as well.
**kwargs: Additional keyword arguments passed to the forward pass of the head.
"""
used_head_modules = self._get_used_heads(head_name)
Expand Down Expand Up @@ -846,10 +858,12 @@ def _get_head_input(outputs, cls_out, batch):
)
head_outputs = []
labels = kwargs.pop("labels", None)
eos_mask = kwargs.pop("eos_mask", None)
for i, head in enumerate(self.active_head):
head_module = self.heads[head]
batch_idx = range(sum(self.active_head.batch_sizes[:i]), sum(self.active_head.batch_sizes[: i + 1]))
kwargs["labels"] = labels[batch_idx] if labels is not None else None
kwargs["eos_mask"] = eos_mask[batch_idx] if eos_mask is not None else None
head_inputs, head_cls_input = _get_head_input(all_outputs, cls_output, batch_idx)
# head_attention = attention_mask[batch_idx] if attention_mask is not None else None
head_output = head_module(head_inputs, head_cls_input, attention_mask, return_dict, **kwargs)
Expand Down
15 changes: 3 additions & 12 deletions src/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...composition import adjust_tensors_for_parallel
from ...heads import (
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
Expand Down Expand Up @@ -102,23 +101,15 @@ def forward(
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_representation = x

head_outputs = self.forward_head(
outputs,
head_name=head,
cls_output=cls_representation,
attention_mask=attention_mask,
return_dict=return_dict,
get_cls_from_eos_tokens=True,
# `get_cls_from_eos_tokens` requires passing eos mask
eos_mask=input_ids.eq(self.config.eos_token_id) if input_ids is not None else None,
**kwargs,
)

Expand Down

0 comments on commit 1c9ecca

Please sign in to comment.