Skip to content

Commit

Permalink
Use seq. classification head in T5 tests.
Browse files Browse the repository at this point in the history
Move used heads retrieval to new method.
  • Loading branch information
calpt committed Sep 16, 2023
1 parent 0a3b96b commit 5065d27
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 86 deletions.
43 changes: 26 additions & 17 deletions src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,27 @@ def delete_head(self, head_name: str):
if self.active_head == head_name:
self.active_head = None

def _get_used_heads(self, head_name: str = None):
if head_name:
used_heads = [head_name]
# together with context, check if we have heads at all to allow for models without heads
elif len(self.heads) > 0 and AdapterSetup.get_context_head_setup():
used_heads = AdapterSetup.get_context_head_setup()
if isinstance(used_heads, str):
used_heads = [used_heads]
elif self._active_heads:
used_heads = self._active_heads
else:
return []

head_modules = []
for head in used_heads:
if head not in self.heads:
raise ValueError("Unknown head_name '{}'".format(head))
head_modules.append(self.heads[head])

return head_modules

def forward_head(
self, all_outputs, head_name=None, cls_output=None, attention_mask=None, return_dict=False, **kwargs
):
Expand All @@ -750,16 +771,8 @@ def forward_head(
return_dict (bool): Whether or not to return a ``ModelOutput`` instead of a plain tuple.
**kwargs: Additional keyword arguments passed to the forward pass of the head.
"""
if head_name:
used_heads = [head_name]
# together with context, check if we have heads at all to allow for models without heads
elif len(self.heads) > 0 and AdapterSetup.get_context_head_setup():
used_heads = AdapterSetup.get_context_head_setup()
if isinstance(used_heads, str):
used_heads = [used_heads]
elif self._active_heads:
used_heads = self._active_heads
else:
used_head_modules = self._get_used_heads(head_name)
if len(used_head_modules) == 0:
logger.debug("No prediction head is used.")
return all_outputs

Expand Down Expand Up @@ -787,9 +800,6 @@ def _get_head_input(outputs, cls_out, batch):
if inv_adapter:
kwargs["invertible_adapter"] = inv_adapter

for head in used_heads:
if head not in self.heads:
raise ValueError("Unknown head_name '{}'".format(head))
if isinstance(self.active_head, BatchSplit):
if sum(self.active_head.batch_sizes) != all_outputs[0].size()[0]:
raise ValueError(
Expand Down Expand Up @@ -830,14 +840,13 @@ def _get_head_input(outputs, cls_out, batch):
else None
)
return_output = MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
elif len(used_heads) > 1:
elif len(used_head_modules) > 1:
head_outputs = []
for head in used_heads:
head_module = self.heads[head]
for head_module in used_head_modules:
head_outputs.append(head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs))
return_output = MultiHeadOutput(head_outputs=head_outputs)
else:
head_module = self.heads[used_heads[0]]
head_module = used_head_modules[0]
return_output = head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs)

if isinstance(return_output, ModelOutput):
Expand Down
32 changes: 17 additions & 15 deletions src/adapters/models/t5/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,14 @@ def forward(
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
if decoder_input_ids is None and decoder_inputs_embeds is None:
# Check if we're using a LM head
if labels is not None and any([isinstance(head, Seq2SeqLMHead) for head in self._get_used_heads(head)]):
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
else:
# decoder_input_ids from input_ids if no decoder_input_ids are provided
decoder_input_ids = self._shift_right(input_ids)

model_output = self.transformer(
input_ids=input_ids,
Expand Down Expand Up @@ -121,18 +126,15 @@ def forward(
else:
cls_representation = sequence_output

if head or self.active_head:
kwargs["labels"] = labels
head_outputs = self.forward_head(
model_output,
head_name=head,
cls_output=cls_representation,
return_dict=return_dict,
**kwargs,
)
return head_outputs
else:
return model_output
kwargs["labels"] = labels
head_outputs = self.forward_head(
model_output,
head_name=head,
cls_output=cls_representation,
return_dict=return_dict,
**kwargs,
)
return head_outputs

# Copied from T5ForConditionalGeneration
def prepare_inputs_for_generation(
Expand Down
18 changes: 11 additions & 7 deletions src/adapters/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def forward(
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")

if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings")
inputs_embeds = self.embed_tokens(input_ids)

batch_size, seq_length = input_shape
Expand All @@ -301,7 +302,8 @@ def forward(
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length

if use_cache is True:
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")

if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
Expand Down Expand Up @@ -330,6 +332,13 @@ def forward(
else:
encoder_extended_attention_mask = None

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
Expand Down Expand Up @@ -369,11 +378,6 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
4 changes: 2 additions & 2 deletions tests_adapters/composition/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def run_parallel_training_equivalent_to_single(self, adapter_config):
dataset = []
for i in range(3):
input_data = self.get_input_samples(config=model.config)
if isinstance(model, T5AdapterModel) or isinstance(model, BertGenerationAdapterModel):
if isinstance(model, BertGenerationAdapterModel):
input_data["labels"] = torch.randint(0, 2, (3, 64))
else:
input_data["labels"] = torch.randint(0, 2, (3, 1))
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_parallel_training_single_forward_pass(self):
self.assertTrue(torch.equal(v, state_dict[k.replace(b1, b2)]))

input_data = self.get_input_samples(config=model.config)
if isinstance(model, T5AdapterModel) or isinstance(model, BertGenerationAdapterModel):
if isinstance(model, BertGenerationAdapterModel):
input_data["labels"] = torch.randint(0, 2, (3, 64), device=torch_device)
else:
input_data["labels"] = torch.randint(0, 2, (3, 1), device=torch_device)
Expand Down
4 changes: 2 additions & 2 deletions tests_adapters/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
SeqBnInvConfig,
)
from adapters.heads.language_modeling import CausalLMHead
from transformers import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
from transformers import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, CLIPConfig
from transformers.testing_utils import require_torch, torch_device

from .base import AdapterMethodBaseTestMixin, create_twin_models
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_get_adapter(self):
n_layers = len(list(model.iter_layers()))
if model.config.is_encoder_decoder:
n_prefix_layers = 3
elif model.config.is_composition:
elif model.config.is_composition or isinstance(model.config, CLIPConfig):
n_prefix_layers = 2
else:
n_prefix_layers = 1
Expand Down
3 changes: 2 additions & 1 deletion tests_adapters/methods/test_prefix_tuning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from adapters import ADAPTER_MODEL_MAPPING, AutoAdapterModel, PrefixTuningConfig
from transformers import CLIPConfig
from transformers.testing_utils import require_torch, torch_device

from .base import AdapterMethodBaseTestMixin
Expand All @@ -24,7 +25,7 @@ def test_get_prefix_tuning(self):
model = self.get_model()
if model.config.is_encoder_decoder:
n_prefix_layers = 3
elif model.config.is_composition:
elif model.config.is_composition or isinstance(model.config, CLIPConfig):
n_prefix_layers = 2
else:
n_prefix_layers = 1
Expand Down
43 changes: 1 addition & 42 deletions tests_adapters/test_t5.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import unittest

from datasets import load_dataset

from transformers import AutoTokenizer, T5Config
from transformers import T5Config
from transformers.testing_utils import require_torch

from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
Expand Down Expand Up @@ -38,45 +36,6 @@ class T5AdapterTestBase(AdapterTestBase):
)
tokenizer_name = "t5-base"

def dataset(self, tokenizer=None):
# setup tokenizer
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

def preprocess_function(examples):
inputs = examples["document"]
targets = examples["summary"]
inputs = ["Summarize: " + inp for inp in inputs]
model_inputs = tokenizer(inputs, padding=True, truncation=True)

# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, padding=True, truncation=True)

# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]

model_inputs["labels"] = labels["input_ids"]
return model_inputs

data_args = {
"task_name": "xsum",
"path": "./hf_transformers/tests/fixtures/tests_samples/xsum/sample.json",
}
dataset = load_dataset("json", data_files=data_args["path"])
train_dataset = dataset["train"]
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
desc="Running tokenizer on train dataset",
)
return train_dataset


@require_torch
class T5AdapterTest(
Expand Down

0 comments on commit 5065d27

Please sign in to comment.