Skip to content

Commit

Permalink
Don't throw error if ForwardContext is not available (adapter-hub#730)
Browse files Browse the repository at this point in the history
Fixes in this PR:
- Fix 'NoneType' error if context not available
- Fix Llama grad checkpointing training.
- Add quantized training test.

Still existing limitation:
- ForwardContext does not work with gradient checkpointing, ie methods
such as reft or prefix tuning don't work with gradient checkpointing
currently.
  • Loading branch information
dainis-boumber committed Aug 30, 2024
1 parent 0071915 commit 29916f8
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 50 deletions.
35 changes: 8 additions & 27 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Mapping, NamedTuple, Optional, Union
from typing import List, Mapping, NamedTuple, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -94,28 +94,6 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:

return False

def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool:
# add new adapter
if self.add_adapter(adapter_name, self.layer_idx):
# average weights
avg_state_dict = {}
for name, weight in input_adapters.items():
if name in self.adapters:
module = self.adapters[name]
for k, v in module.state_dict().items():
if k in avg_state_dict:
avg_state_dict[k] += weight * v
else:
avg_state_dict[k] = weight * v
else:
self.delete_adapter(adapter_name) # clean up before raising error
raise ValueError("Adapter {} not found.".format(name))
# load averaged weights
self.adapters[adapter_name].load_state_dict(avg_state_dict)
return True

return False

def add_fusion_layer(self, adapter_names: Union[List, str]):
"""See BertModel.add_fusion_layer"""
adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",")
Expand Down Expand Up @@ -226,13 +204,15 @@ def mean(self, states: List[BottleneckState], weights: torch.Tensor) -> Bottlene
def compose_single(self, adapter_setup: str, state: BottleneckState, lvl: int = 0) -> BottleneckState:
adapter_layer = self.adapters[adapter_setup]
context = ForwardContext.get_context()
output_gating = context.output_adapter_gating_scores if context is not None else False
layer_output = adapter_layer(
state.hidden_states,
residual_input=state.adapter_residual,
output_gating=context.output_adapter_gating_scores,
output_gating=output_gating,
)
hidden_states, up = layer_output[0], layer_output[2]
self._store_gating_score(adapter_setup, layer_output[-1])
if output_gating:
self._store_gating_score(adapter_setup, layer_output[-1])

return state._replace(hidden_states=hidden_states, bottleneck_up=up, last=adapter_setup)

Expand Down Expand Up @@ -268,14 +248,15 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0
up_list = torch.stack([state.bottleneck_up for state in children_states])
up_list = up_list.permute(1, 2, 0, 3)

output_fusion_attns = context.output_adapter_fusion_attentions if context is not None else False
fusion_output = self.adapter_fusion_layer[adapter_setup.name](
query,
up_list,
up_list,
state.adapter_residual,
output_attentions=context.output_adapter_fusion_attentions,
output_attentions=output_fusion_attns,
)
if context.output_adapter_fusion_attentions:
if output_fusion_attns:
hidden_states = fusion_output[0]
self._store_fusion_attentions(adapter_setup.name, fusion_output[-1])
else:
Expand Down
84 changes: 64 additions & 20 deletions src/adapters/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from transformers.cache_utils import Cache, StaticCache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
Expand Down Expand Up @@ -57,6 +58,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
Expand Down Expand Up @@ -94,8 +96,16 @@ def forward(
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
# >>> END AH Changes <<<

past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
Expand Down Expand Up @@ -133,7 +143,7 @@ def forward(

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, -1)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
Expand All @@ -158,12 +168,13 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` make"
" sure to use `sdpa` in the mean time, and open an issue at"
" https://github.com/huggingface/transformers"
)

output_attentions = False
Expand All @@ -188,7 +199,16 @@ def forward(
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
# >>> END AH Changes <<<

cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
Expand Down Expand Up @@ -220,7 +240,7 @@ def forward(
# in fp32. (LlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
if input_dtype == torch.float32 or key_states.dtype == torch.float32:
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
Expand All @@ -230,20 +250,28 @@ def forward(
target_dtype = self.q_proj.weight.dtype

logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
"The input hidden states seems to be silently casted in float32, this might be related to the fact"
" you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)

if not output_attentions:
Expand All @@ -264,12 +292,16 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does"
" not support `output_attentions=True`. Falling back to the manual attention implementation, but"
" specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This"
' warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
Expand All @@ -279,6 +311,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)

bsz, q_len, _ = hidden_states.size()
Expand All @@ -298,7 +331,16 @@ def forward(
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
# >>> END AH Changes <<<

cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
Expand Down Expand Up @@ -329,8 +371,8 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
Expand All @@ -343,7 +385,7 @@ def forward(
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = attn_output.view(bsz, q_len, -1)

attn_output = self.o_proj(attn_output)

Expand All @@ -356,10 +398,11 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -395,6 +438,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.attention_adapters(hidden_states, residual, None)
Expand Down
7 changes: 6 additions & 1 deletion src/adapters/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __init__(
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
):
if model is not None:
model_quantized = getattr(model, "is_quantized", False)
model.is_quantized = False
super().__init__(
model,
args,
Expand All @@ -55,6 +58,8 @@ def __init__(
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
if model is not None:
model.is_quantized = model_quantized

if adapter_names is not None:
self.model.set_active_adapters(adapter_names)
Expand Down Expand Up @@ -250,4 +255,4 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra


class Seq2SeqAdapterTrainer(AdapterTrainer, Seq2SeqTrainer):
pass
pass
Loading

0 comments on commit 29916f8

Please sign in to comment.