diff --git a/.gitignore b/.gitignore index 9ad896e97..e9155ff15 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__ /appwrapper.yaml *.egg-info/ build/ +*.log diff --git a/README.md b/README.md index 9da39662d..3a579e203 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ labels = [[-100, -100, -100, 4, 5, 0], [-100, -100, 8, 0]] # this will throw a warning saying that the model is of gpt_bigcode class # ignore the warning -model = GPTBaseForCausalLM.from_pretrained(, use_padding_free_transformer=True).cuda() +model = GPTBaseForCausalLM.from_pretrained().cuda() with enable_kernels([Kernel.flash_attention_2]): loss = model(input_ids=input_ids, labels=labels).loss diff --git a/configs/distillation-example.yml b/configs/distillation-example.yml index 05a697969..ed80a4c70 100644 --- a/configs/distillation-example.yml +++ b/configs/distillation-example.yml @@ -23,7 +23,6 @@ model_args: model_class: AutoModelForCausalLM model_name: ibm/PowerLM-3b efficient_initialization: false - use_padding_free_transformer: false teacher_args: model_class: AutoModelForCausalLM diff --git a/configs/finetuning-example.yml b/configs/finetuning-example.yml index 2c99e9005..74bf993c9 100644 --- a/configs/finetuning-example.yml +++ b/configs/finetuning-example.yml @@ -25,7 +25,6 @@ model_args: # padding free transformer needs a gpt_base model. # To convert granite models to this class and convert back after training, # take a look at the readme of this repo - use_padding_free_transformer: false random_args: # for replication of experiment (however, flash attention is non-deterministic so replication generally won't work) diff --git a/configs/pretraining-examples/dense/pretrain-1.yml b/configs/pretraining-examples/dense/pretrain-1.yml index 17d127514..7bb861374 100644 --- a/configs/pretraining-examples/dense/pretrain-1.yml +++ b/configs/pretraining-examples/dense/pretrain-1.yml @@ -120,7 +120,6 @@ model_args: intermediate_size: 3072 add_bias: true position_embedding_type: learned_absolute - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/pretraining-examples/dense/pretrain-2.yml b/configs/pretraining-examples/dense/pretrain-2.yml index 371efc271..63b8dd9d3 100644 --- a/configs/pretraining-examples/dense/pretrain-2.yml +++ b/configs/pretraining-examples/dense/pretrain-2.yml @@ -125,7 +125,6 @@ model_args: intermediate_size: 3072 add_bias: true position_embedding_type: learned_absolute - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/pretraining-examples/dense/pretrain-3.yml b/configs/pretraining-examples/dense/pretrain-3.yml index ff138a2cb..2a10695c0 100644 --- a/configs/pretraining-examples/dense/pretrain-3.yml +++ b/configs/pretraining-examples/dense/pretrain-3.yml @@ -138,7 +138,6 @@ model_args: intermediate_size: 3072 add_bias: true position_embedding_type: learned_absolute - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/pretraining-examples/dense/pretrain-tpu.yml b/configs/pretraining-examples/dense/pretrain-tpu.yml index 57660cd9a..e5fdba16e 100644 --- a/configs/pretraining-examples/dense/pretrain-tpu.yml +++ b/configs/pretraining-examples/dense/pretrain-tpu.yml @@ -139,7 +139,6 @@ model_args: intermediate_size: 3072 add_bias: true position_embedding_type: learned_absolute - # use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/research/cross-layer-attention/base.yml b/configs/research/cross-layer-attention/base.yml index 86dd56916..2688a968c 100644 --- a/configs/research/cross-layer-attention/base.yml +++ b/configs/research/cross-layer-attention/base.yml @@ -249,7 +249,6 @@ model_args: activation_function: swiglu intermediate_size: 8192 efficient_initialization: false - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/research/cross-layer-attention/cla.yml b/configs/research/cross-layer-attention/cla.yml index 867e7c232..a5ee1f63c 100644 --- a/configs/research/cross-layer-attention/cla.yml +++ b/configs/research/cross-layer-attention/cla.yml @@ -282,7 +282,6 @@ model_args: activation_function: swiglu intermediate_size: 8192 efficient_initialization: false - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/1b-base.yml b/configs/research/ladder-residual/1b-base.yml index 29cee68f1..9ddee7075 100644 --- a/configs/research/ladder-residual/1b-base.yml +++ b/configs/research/ladder-residual/1b-base.yml @@ -278,7 +278,6 @@ model_args: activation_function: swiglu intermediate_size: 4096 efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/1b-ladder.yml b/configs/research/ladder-residual/1b-ladder.yml index f3fb1f56a..66f787d69 100644 --- a/configs/research/ladder-residual/1b-ladder.yml +++ b/configs/research/ladder-residual/1b-ladder.yml @@ -278,7 +278,6 @@ model_args: activation_function: swiglu intermediate_size: 4096 efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/1b-parallel.yml b/configs/research/ladder-residual/1b-parallel.yml index 4959b0378..c843e57fe 100644 --- a/configs/research/ladder-residual/1b-parallel.yml +++ b/configs/research/ladder-residual/1b-parallel.yml @@ -278,7 +278,6 @@ model_args: activation_function: swiglu intermediate_size: 4096 efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/3b-base.yml b/configs/research/ladder-residual/3b-base.yml index be1bccb68..50ddee10c 100644 --- a/configs/research/ladder-residual/3b-base.yml +++ b/configs/research/ladder-residual/3b-base.yml @@ -238,7 +238,6 @@ model_args: - mlp_type: MLP activation_function: swiglu efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/3b-ladder.yml b/configs/research/ladder-residual/3b-ladder.yml index 81b485e73..10c2fac95 100644 --- a/configs/research/ladder-residual/3b-ladder.yml +++ b/configs/research/ladder-residual/3b-ladder.yml @@ -238,7 +238,6 @@ model_args: - mlp_type: MLP activation_function: swiglu efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/3b-parallel.yml b/configs/research/ladder-residual/3b-parallel.yml index fb14fc5b8..9db78544d 100644 --- a/configs/research/ladder-residual/3b-parallel.yml +++ b/configs/research/ladder-residual/3b-parallel.yml @@ -238,7 +238,6 @@ model_args: - mlp_type: MLP activation_function: swiglu efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/lm_engine/arguments.py b/lm_engine/arguments.py index 73720cbbc..1a3bc6985 100644 --- a/lm_engine/arguments.py +++ b/lm_engine/arguments.py @@ -48,8 +48,6 @@ class ModelArgs(BaseArgs): model_class: str = None # trust remote code for models that are not directly supported by HuggingFace yet trust_remote_code: bool = False - # whether to use padding free transformer: https://huggingface.co/blog/mayank-mishra/padding-free-transformer - use_padding_free_transformer: bool = False # use lower memory to initialize model efficient_initialization: bool = False # whether to reset attention masks for pretraining diff --git a/lm_engine/data/__init__.py b/lm_engine/data/__init__.py index 8dc4f9fad..8c8ad82a7 100644 --- a/lm_engine/data/__init__.py +++ b/lm_engine/data/__init__.py @@ -134,7 +134,6 @@ def get_finetuning_dataloader( use_output=use_output, loss_mask=args.training_parameters.loss_mask, eos_token_id=tokenizer.eos_token_id, - use_padding_free_transformer=args.model_args.use_padding_free_transformer, pad_to_multiple_of=ProcessGroupManager.get_tensor_parallel_world_size(), ), ) diff --git a/lm_engine/data/utils.py b/lm_engine/data/utils.py index f686acd32..26d185964 100644 --- a/lm_engine/data/utils.py +++ b/lm_engine/data/utils.py @@ -8,7 +8,52 @@ import torch from ..enums import LossMask -from ..hf_models import convert_padding_free_lists_to_tensors + + +def _check_list_type(list_of_list: list[list[int | float]] | None, error_message: str) -> None: + if list_of_list is None: + return + + assert isinstance(list_of_list, list), error_message + assert isinstance(list_of_list[0], list), error_message + + +def _flatten_and_convert_to_tensors(x: list[int], device: torch.device) -> torch.Tensor: + y = [] + for sequence in x: + y.extend(sequence) + + return torch.tensor(y, device=device) + + +def _convert_padding_free_lists_to_tensors( + input_ids: list[list[int]] | None = None, + position_ids: list[list[int]] | None = None, + labels: list[list[int]] | None = None, + device: torch.device = None, +) -> tuple[torch.Tensor | int]: + + # check input types are correct + error_message = "{variable} should be of type List[List[{dtype}]]" + _check_list_type(input_ids, error_message.format(variable="input_ids", dtype="int")) + _check_list_type(position_ids, error_message.format(variable="position_ids", dtype="int")) + _check_list_type(labels, error_message.format(variable="labels", dtype="int")) + + # prepare inputs for the model + seqlens = torch.tensor([0] + [len(x) for x in input_ids], device=device) + cu_seqlens = seqlens.cumsum(dim=-1).to(torch.int32) + max_seqlen = seqlens.max().item() + + if position_ids is None: + position_ids = [list(range(len(x))) for x in input_ids] + position_ids = _flatten_and_convert_to_tensors(position_ids, device) + + input_ids = _flatten_and_convert_to_tensors(input_ids, device) + + if labels is not None: + labels = _flatten_and_convert_to_tensors(labels, device) + + return input_ids, position_ids, labels, cu_seqlens, max_seqlen def collate_fn( @@ -16,7 +61,6 @@ def collate_fn( use_output: bool, loss_mask: LossMask, eos_token_id: int, - use_padding_free_transformer: bool, labels_mask_value: int = -100, pad_to_multiple_of: int = 1, device: torch.device = None, @@ -38,64 +82,40 @@ def collate_fn( device = torch.cuda.current_device() if device is None else device - if use_padding_free_transformer: - input_ids = inputs - attention_mask = None - - if loss_mask == LossMask.output_only: - labels = [ - [labels_mask_value] * (len(array_in) - len(array_out)) + array_out - for array_in, array_out in zip(inputs, outputs) - ] - elif loss_mask == LossMask.no_mask: - labels = inputs - else: - raise ValueError(f"unexpected loss_mask ({loss_mask})") - - tokens_to_add = 0 - if pad_to_multiple_of > 1: - total_tokens = sum([len(array) for array in input_ids]) - tokens_to_add = (math.ceil(total_tokens / pad_to_multiple_of) * pad_to_multiple_of) - total_tokens - - # we pad the last example in the batch on the right - # NOTE this can be done since the attention is causal - input_ids[-1].extend([eos_token_id] * tokens_to_add) - labels[-1].extend([labels_mask_value] * tokens_to_add) - - input_ids, position_ids, _, labels, cu_seqlens, max_seqlen = convert_padding_free_lists_to_tensors( - input_ids=input_ids, labels=labels, device=device - ) - - result = { - "input_ids": input_ids, - "position_ids": position_ids, - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - } - if labels is not None: - result["labels"] = labels + input_ids = inputs + + if loss_mask == LossMask.output_only: + labels = [ + [labels_mask_value] * (len(array_in) - len(array_out)) + array_out + for array_in, array_out in zip(inputs, outputs) + ] + elif loss_mask == LossMask.no_mask: + labels = inputs else: - max_length = max(list(map(len, inputs))) - if pad_to_multiple_of > 1: - max_length = math.ceil(max_length / pad_to_multiple_of) * pad_to_multiple_of - - input_ids = [[eos_token_id] * (max_length - len(array)) + array for array in inputs] - attention_mask = [[0] * (max_length - len(array)) + [1] * len(array) for array in inputs] - - if outputs is not None: - if loss_mask == LossMask.output_only: - labels = [[labels_mask_value] * (max_length - len(array)) + array for array in outputs] - elif loss_mask == LossMask.no_mask: - labels = inputs - else: - raise ValueError(f"unexpected loss_mask ({loss_mask})") - - result = { - "input_ids": torch.tensor(input_ids, device=device), - "attention_mask": torch.tensor(attention_mask, device=device), - } - if labels is not None: - result["labels"] = torch.tensor(labels, device=device) + raise ValueError(f"unexpected loss_mask ({loss_mask})") + + tokens_to_add = 0 + if pad_to_multiple_of > 1: + total_tokens = sum([len(array) for array in input_ids]) + tokens_to_add = (math.ceil(total_tokens / pad_to_multiple_of) * pad_to_multiple_of) - total_tokens + + # we pad the last example in the batch on the right + # NOTE this can be done since the attention is causal + input_ids[-1].extend([eos_token_id] * tokens_to_add) + labels[-1].extend([labels_mask_value] * tokens_to_add) + + input_ids, position_ids, _, labels, cu_seqlens, max_seqlen = _convert_padding_free_lists_to_tensors( + input_ids=input_ids, labels=labels, device=device + ) + + result = { + "input_ids": input_ids, + "position_ids": position_ids, + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + } + if labels is not None: + result["labels"] = labels return result diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index 72a0e2133..ef7933bee 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -2,8 +2,10 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +from .cache import disable_generation_cache from .config import CommonConfig from .loss import get_autoregressive_language_modeling_loss, is_aux_loss_zero +from .mask import AttentionMaskInfo from .mixins import CausalLMOutputWithPast, PipelineParallelInput, PipelineParallelOutput from .model_conversion import export_to_huggingface, import_from_huggingface from .models import ( @@ -30,7 +32,6 @@ ) from .register_hf import get_model_parallel_class, is_custom_model, register_model_classes from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts -from .utils import convert_padding_free_lists_to_tensors, disable_generation_cache register_model_classes() diff --git a/lm_engine/hf_models/cache/__init__.py b/lm_engine/hf_models/cache/__init__.py index 7bbb5a15e..4b3b9431f 100644 --- a/lm_engine/hf_models/cache/__init__.py +++ b/lm_engine/hf_models/cache/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Iterable +from typing import Any, Iterable import torch @@ -53,3 +53,22 @@ def get_seq_length(self, layer_idx: int = 0) -> int: def reorder_cache(self, beam_idx: torch.Tensor) -> None: for cache in self.cache: cache.reorder_cache(beam_idx) + + +_IS_GENERATION_CACHE_ENABLED: bool = True + + +class disable_generation_cache: + def __enter__(self) -> Any: + global _IS_GENERATION_CACHE_ENABLED + self.original = _IS_GENERATION_CACHE_ENABLED + + _IS_GENERATION_CACHE_ENABLED = False + + def __exit__(self, exception_type, exception_value, exception_traceback) -> Any: + global _IS_GENERATION_CACHE_ENABLED + _IS_GENERATION_CACHE_ENABLED = self.original + + +def is_generation_cache_enabled() -> bool: + return _IS_GENERATION_CACHE_ENABLED diff --git a/lm_engine/hf_models/loss.py b/lm_engine/hf_models/loss.py index c3cd0a257..d76b0443b 100644 --- a/lm_engine/hf_models/loss.py +++ b/lm_engine/hf_models/loss.py @@ -14,6 +14,7 @@ from ..enums import Kernel from ..kernels import is_kernel_allowed from ..utils import ProcessGroupManager, is_xma_available +from .mask import AttentionMaskInfo if is_xma_available(): @@ -23,10 +24,9 @@ def get_autoregressive_language_modeling_loss( lm_logits: torch.Tensor, labels: torch.Tensor, + attention_mask_info: AttentionMaskInfo, hidden_states: torch.Tensor | None = None, vocab_weight: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - use_padding_free_transformer: bool = False, reduction: str = "mean", shift_logits_and_labels: bool = True, tensor_parallel_enabled: bool = False, @@ -40,15 +40,13 @@ def get_autoregressive_language_modeling_loss( labels = labels[..., 1:] - if use_padding_free_transformer: - if shift_logits_and_labels: - assert cu_seqlens is not None + if shift_logits_and_labels: + cu_seqlens = attention_mask_info.get_cu_seqlens() + if cu_seqlens is not None: # this is needed so that the last token of current example doesn't predict first token of next example drop_loss_positions = cu_seqlens[1:-1] - 1 labels[drop_loss_positions] = -100 - else: - assert cu_seqlens is None if is_kernel_allowed(Kernel.fused_linear_cross_entropy): assert lm_logits is None diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py new file mode 100644 index 000000000..9f5f191a5 --- /dev/null +++ b/lm_engine/hf_models/mask.py @@ -0,0 +1,254 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from ..enums import Kernel +from ..kernels import is_kernel_allowed +from ..utils import is_fma_available + + +if is_fma_available(): + from fma import KernelBackend, pack_sequence, unpack_sequence + + +_ERROR_MESSAGE = "code is not supposed to reach here" + + +# NOTE using dataclass here since pydantic doesn't work with torch.compile +@dataclass +class AttentionMaskInfo: + batch_size: int | None = None + cu_seqlens: torch.Tensor | None = None + max_seqlen: int | None = None + attention_mask: torch.Tensor | None = None + device: torch.device | None = None + mask_value: torch.Tensor | None = None + causal_mask: torch.Tensor | None = None + _mask_value: torch.Tensor | None = None + + def __post_init__(self) -> None: + self._has_cu_seqlens = self.cu_seqlens is not None + self._has_attention_mask = self.attention_mask is not None + + if self.batch_size is not None: + assert self.max_seqlen is not None + assert not self.has_cu_seqlens() + assert not self.has_attention_mask() + elif self.cu_seqlens is not None: + assert self.batch_size is None + assert self.max_seqlen is not None + assert not self.has_attention_mask() + + self.device = self.cu_seqlens.device + elif self.has_attention_mask(): + assert self.batch_size is None + assert not self.has_cu_seqlens() + assert self.max_seqlen is None + + self.device = self.attention_mask.device + + assert self.device is not None + + def has_cu_seqlens(self) -> bool: + return self._has_cu_seqlens + + def has_attention_mask(self) -> bool: + return self._has_attention_mask + + def has_padding(self) -> bool: + return self.has_cu_seqlens() or self.has_attention_mask() + + def get_batch_size(self) -> int: + if self.batch_size is None: + if self.has_cu_seqlens(): + self.batch_size = self.cu_seqlens.size(0) - 1 + elif self.has_attention_mask(): + self.batch_size = self.attention_mask.size(0) + else: + raise NotImplementedError(_ERROR_MESSAGE) + + return self.batch_size + + def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor | None: + if self.has_cu_seqlens(): + return self.cu_seqlens + + if return_none_allowed: + return None + + if self.has_attention_mask(): + seqlens = self.attention_mask.sum(dim=-1, dtype=torch.int32) + self.cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) + self.max_seqlen = seqlens.max().item() + else: + B = self.get_batch_size() + S = self.get_max_seqlen() + + self.cu_seqlens = torch.arange(0, B * S + 1, S, dtype=torch.int32, device=self.device) + + return self.cu_seqlens + + def get_max_seqlen(self, return_none_allowed: bool = True) -> int | None: + if self.has_cu_seqlens(): + assert self.max_seqlen is not None + return self.max_seqlen + + if return_none_allowed: + return None + + if self.max_seqlen is None: + # this will cache the max_seqlen but causes synchronization with CPU + self.get_cu_seqlens(False) + + if self.max_seqlen is None: + raise NotImplementedError(_ERROR_MESSAGE) + + return self.max_seqlen + + def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | None: + if self.has_attention_mask(): + return self.attention_mask + + if return_none_allowed: + return None + + B = self.get_batch_size() + S = self.get_max_seqlen() + + if self.has_cu_seqlens(): + self.attention_mask = self.unpack_sequence( + inputs=torch.ones_like(self.get_cu_seqlens(), device=self.device, dtype=torch.int32), + output_shape=(B, S), + ) + else: + self.attention_mask = torch.ones(B, S, device=self.device, dtype=torch.int32) + + return self.attention_mask + + def get_position_ids(self) -> torch.Tensor: + if self.has_cu_seqlens() or self.has_attention_mask(): + attention_mask = self.get_attention_mask(False) + position_ids = attention_mask.cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + else: + B = self.get_batch_size() + S = self.get_max_seqlen(False) + + position_ids = torch.arange(0, S, device=self.device) + position_ids = position_ids[None, ...].expand(B, -1) + + position_ids = self.pack_sequence(position_ids) + + return position_ids + + def get_causal_mask( + self, query_length: int, return_none_allowed: bool = True, dtype: torch.dtype | None = None + ) -> torch.Tensor | None: + if self.causal_mask is not None: + return self.causal_mask + + if self.has_cu_seqlens() or self.has_attention_mask(): + attention_mask = self.get_attention_mask() + elif return_none_allowed: + return None + + Q = query_length + K = attention_mask.size(1) + L = K - Q + + if Q > 1: + causal_mask = torch.empty((Q, K), dtype=torch.bool, device=self.device) + causal_mask[:, L:] = torch.tril(torch.ones(Q, K, dtype=torch.bool, device=self.device)) + + if L > 0: + causal_mask[:, :L] = True + + causal_mask = causal_mask[None, ...] + causal_mask = causal_mask & attention_mask[:, None, ...].to(torch.bool) + elif Q == 1: + causal_mask = attention_mask[:, None, ...].to(dtype=torch.bool, device=self.device) + else: + raise NotImplementedError(_ERROR_MESSAGE) + + causal_mask = causal_mask[:, None, ...] + causal_mask = torch.where(causal_mask, ~causal_mask, self._get_mask_value(attention_mask.device, dtype)) + + # this is needed to prevent NaN since SDPA + # see issue: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = causal_mask * ~torch.all( + causal_mask == self._get_mask_value(self.device, dtype), dim=-1, keepdim=True + ) + + self.causal_mask = causal_mask + + return self.causal_mask + + def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: + if inputs is None: + return None + + is_tensor = isinstance(inputs, torch.Tensor) + if is_tensor: + inputs = [inputs] + + if self.has_cu_seqlens() or self.has_attention_mask(): + kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch + cu_seqlens = self.get_cu_seqlens(False) + + inputs = pack_sequence( + inputs=inputs, + cu_seqlens=cu_seqlens, + total_tokens=cu_seqlens[-1].item(), + kernel_backend_forward=kernel_backend, + kernel_backend_backward=kernel_backend, + ) + else: + inputs = [i.flatten(0, 1) for i in inputs] + + if is_tensor: + inputs = inputs[0] + + return inputs + + def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: + if inputs is None: + return None + + is_tensor = isinstance(inputs, torch.Tensor) + if is_tensor: + inputs = [inputs] + + B = self.get_batch_size() + S = self.get_max_seqlen(False) + + if self.has_cu_seqlens() or self.has_attention_mask(): + kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch + + inputs = unpack_sequence( + inputs=inputs, + cu_seqlens=self.get_cu_seqlens(False), + batch_size=B, + sequence_length=S, + kernel_backend_forward=kernel_backend, + kernel_backend_backward=kernel_backend, + ) + else: + inputs = [i.reshape(B, S, *i.size()[1:]) for i in inputs] + + if is_tensor: + inputs = inputs[0] + + return inputs + + def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: + self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + return self.mask_value diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 3be5b9568..e392270a2 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -8,12 +8,10 @@ import torch.nn as nn from transformers import GenerationConfig, PreTrainedModel -from ....enums import Kernel -from ....kernels import is_kernel_allowed from ...cache import GenerationCache from ...config import CommonConfig +from ...mask import AttentionMaskInfo from ...modeling_utils import ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function -from ...utils import convert_padding_free_lists_to_tensors, is_generation_cache_enabled from ..modeling_outputs import BaseModelOutputWithPast from .layer import Block @@ -37,54 +35,13 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixi assert self.config_class is not None self.generation_config = GenerationConfig.from_model_config(self.config) - self.use_padding_free_transformer = kwargs.get("use_padding_free_transformer", False) self._tied_word_embeddings = config.tie_word_embeddings - self._has_mamba2 = any([block.sequence_mixer_type == "mamba2" for block in self.config.sequence_mixer_blocks]) def _init_weights(self, module: nn.Module) -> None: if hasattr(module, "reset_parameters"): module.reset_parameters() - # FIXME typing - def prepare_inputs_for_model( - self, - input_ids: torch.Tensor | list[list[int]] | None, - position_ids: torch.Tensor | list[list[int]] | None, - labels: torch.Tensor | list[list[int]] | None, - cu_seqlens: torch.Tensor | None, - max_seqlen: int | None, - past_key_values: tuple[tuple[torch.Tensor]], - attention_mask: torch.Tensor | None, - use_cache: bool, - ) -> tuple[torch.Tensor]: - if self.use_padding_free_transformer: - if isinstance(input_ids, list): - # this is managed internally - error_message = ( - "{variable} should not be passed for flash attention when using List[List[int]] " - "input types attention mask logic is handled internally" - ) - assert cu_seqlens is None, error_message.format(variable="cu_seqlens") - assert max_seqlen is None, error_message.format(variable="max_seqlen") - assert attention_mask is None, error_message.format(variable="attention_mask") - - input_ids, position_ids, labels, cu_seqlens, max_seqlen = convert_padding_free_lists_to_tensors( - input_ids=input_ids, position_ids=position_ids, labels=labels, device=torch.cuda.current_device() - ) - else: - assert ( - cu_seqlens is not None - ), "cu_seqlens needs to be specified when using tensor inputs with padding_free transformer" - assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert max_seqlen is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" - - if use_cache or past_key_values is not None: - raise NotImplementedError("KV caching is not supported with padding_free transformer") - - return input_ids, position_ids, labels, cu_seqlens, max_seqlen - class BaseModelMixin(PreTrainedModelMixin): mask_value = None @@ -106,12 +63,7 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.embedding_dropout = ( nn.Identity() if config.embedding_dropout == 0 else nn.Dropout(config.embedding_dropout) ) - self.h = nn.ModuleList( - [ - self.layer_class(config, use_padding_free_transformer=self.use_padding_free_transformer, layer_idx=i) - for i in range(config.num_layers) - ] - ) + self.h = nn.ModuleList([self.layer_class(config, layer_idx=i) for i in range(config.num_layers)]) self.ln_f = get_normalization_function( config.normalization_function, self.embed_dim, eps=config.layer_norm_epsilon ) @@ -126,53 +78,20 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: def forward( self, - input_ids: torch.Tensor | None = None, + input_ids: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, - use_cache: bool | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> BaseModelOutputWithPast: - ( - use_cache, - hidden_states, - causal_mask, - position_ids, - rope_cos_sin, - past_key_values, - ) = self._prepare_a_bunch_of_stuff( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - if is_generation_cache_enabled(): - past_key_values = ( - GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values - ) - - mamba_mask = None - mamba_mask_computed = False - - for sequence_mixer_type, block in zip(self.sequence_mixer_block_types, self.h): - is_linear_layer = sequence_mixer_type in ["mamba2", "rnn", "gru"] - - if is_linear_layer and not mamba_mask_computed: - mamba_mask = self._get_mamba_mask(attention_mask, past_key_values) - mamba_mask_computed = True + hidden_states = self._get_initial_hidden_state(input_ids, position_ids) + rope_cos_sin = self._get_rope_cos_sin(attention_mask_info, position_ids, dtype=hidden_states.dtype) - hidden_states = block( + for block in self.h: + hidden_states: torch.Tensor = block( hidden_states, + attention_mask_info=attention_mask_info, past_key_values=past_key_values, - attention_mask=mamba_mask if is_linear_layer else causal_mask, rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) hidden_states = self.ln_f(hidden_states) @@ -195,55 +114,14 @@ def _get_position_ids( return position_ids def _get_rope_cos_sin( - self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype + self, attention_mask_info: AttentionMaskInfo, position_ids: torch.Tensor, dtype: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: if self.position_embedding_type == "rope": - cos, sin = self.rope(key_length, dtype=dtype) + cos, sin = self.rope(attention_mask_info.get_max_seqlen(False), dtype=dtype) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) return cos, sin - def _prepare_causal_attention_mask( - self, - attention_mask: torch.Tensor | None, - batch_size: int, - query_length: int, - key_length: int, - device: torch.device, - ) -> torch.Tensor: - past_length = key_length - query_length - - if query_length > 1: - # (query_length, key_length) - causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device) - causal_mask[:, past_length:] = torch.tril( - torch.ones(query_length, query_length, dtype=torch.bool, device=device) - ) - - if past_length > 0: - causal_mask[:, :past_length] = True - - # (query_length, key_length) -> (1, query_length, key_length) - causal_mask = causal_mask.unsqueeze(0) - - if attention_mask is None: - # (1, query_length, key_length) -> (batch_size, query_length, key_length) - causal_mask = causal_mask.expand(batch_size, -1, -1) - else: - # (1, query_length, key_length) & (batch_size, 1, key_length) -> (batch_size, query_length, key_length) - causal_mask = causal_mask & attention_mask.unsqueeze(1).to(torch.bool) - else: - if attention_mask is None: - # (batch_size, query_length, key_length) - causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device) - else: - # (batch_size, query_length, key_length) - causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device) - - causal_mask = causal_mask.unsqueeze(1) - - return causal_mask - def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor: hidden_state = self.wte(input_ids) @@ -257,67 +135,6 @@ def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch return hidden_state - def _prepare_a_bunch_of_stuff( - self, - input_ids: torch.Tensor | None = None, - past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - use_cache: bool | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - ) -> tuple[bool, torch.Tensor, torch.Tensor, torch.Tensor | None, GenerationCache | None]: - if use_cache is None: - use_cache = False if self.use_padding_free_transformer else self.config.use_cache - - input_shape = input_ids.size() - - # special handling for padding free transformer with list inputs - if self.use_padding_free_transformer: - # for flash attention, there is no padding and we do packing - # so, input_ids is of shape (s1 + s2 + ... + sb) - batch_size = cu_seqlens.shape[0] - 1 - else: - batch_size = input_shape[0] - - if self.use_padding_free_transformer: - assert position_ids is not None, ( - "GPTBaseModel needs position_ids from outside when using flash attention with List[List[int]] " - "inputs" - ) - - past_length = None - query_length = None - key_length = None - if self.use_padding_free_transformer: - key_length = max_seqlen.item() if isinstance(max_seqlen, torch.Tensor) else max_seqlen - else: - past_length = 0 if past_key_values is None else past_key_values.get_seq_length() - query_length = input_shape[-1] - key_length = past_length + query_length - - if position_ids is None: - position_ids = self._get_position_ids( - attention_mask, past_length, query_length, key_length, input_ids.device - ) - - hidden_states = self._get_initial_hidden_state(input_ids, position_ids) - - rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) - - attention_mask = self._get_maybe_causal_mask( - attention_mask, batch_size, query_length, key_length, hidden_states.dtype, input_ids.device - ) - - return ( - use_cache, - hidden_states, - attention_mask, - position_ids, - rope_cos_sin, - past_key_values, - ) - def _setup_positional_encoding(self) -> None: max_position_embeddings = self.config.max_position_embeddings @@ -343,42 +160,6 @@ def _setup_positional_encoding(self) -> None: else: raise NotImplementedError() - def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: - # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: - self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) - return self.mask_value - - def _get_maybe_causal_mask( - self, - attention_mask: torch.Tensor | None, - batch_size: int, - query_length: int, - key_length: int, - dtype: torch.dtype, - device: torch.device, - ) -> torch.Tensor: - if not (is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3)): - # we use the causal/non-causal argument of SDPA for attention in this case - if attention_mask is not None: - attention_mask = self._prepare_causal_attention_mask( - attention_mask, batch_size, query_length, key_length, device - ) - - attention_mask = torch.where( - attention_mask, - ~attention_mask, - self._get_mask_value(attention_mask.device, dtype), - ) - - # this is needed to prevent NaN since SDPA - # see issue: https://github.com/pytorch/pytorch/issues/110213 - attention_mask = attention_mask * ~torch.all( - attention_mask == self._get_mask_value(attention_mask.device, dtype), dim=-1, keepdim=True - ) - - return attention_mask - def _get_mamba_mask( self, attention_mask: torch.Tensor | None, past_key_values: GenerationCache ) -> torch.Tensor | None: diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index e854fce65..ff84dfc15 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -9,13 +9,12 @@ from ...cache import GenerationCache from ...config import CommonConfig +from ...mask import AttentionMaskInfo from ...modeling_utils import get_mlp_block, get_normalization_function, get_sequence_mixer class Block(nn.Module): - def __init__( - self, config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int | None = None - ) -> Block: + def __init__(self, config: CommonConfig, layer_idx: int | None = None) -> Block: super().__init__() hidden_size = config.hidden_size @@ -25,83 +24,59 @@ def __init__( self.ln_1 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) - self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) + self.sequence_mixer = get_sequence_mixer(config, True, layer_idx) self.ln_2 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) - self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx - ) + self.mlp_block = get_mlp_block(config, layer_idx=layer_idx) def forward( self, - hidden_states: torch.Tensor, + x: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - - hidden_states = self._sequence_mixer_forward( - hidden_states=hidden_states, - past_key_values=past_key_values, - attention_mask=attention_mask, - rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + r = x + x = self.ln_1(x) + + x = self._sequence_mixer_forward( + x=x, past_key_values=past_key_values, attention_mask_info=attention_mask_info, rope_cos_sin=rope_cos_sin ) if self.m_residual is not None: - hidden_states = hidden_states * self.m_residual + x = x * self.m_residual - hidden_states = hidden_states + residual + x = x + r - residual = hidden_states - hidden_states = self.ln_2(hidden_states) + r = x + x = self.ln_2(x) - hidden_states = self.mlp_block(hidden_states) + x = self.mlp_block(x) if self.m_residual is not None: - hidden_states = hidden_states * self.m_residual + x = x * self.m_residual - hidden_states = hidden_states + residual + x = x + r - return hidden_states + return x def _sequence_mixer_forward( self, - hidden_states: torch.Tensor, + x: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: if self.sequence_mixer_type in ["softmax_attention", "multihead_latent_attention"]: - hidden_states = self.sequence_mixer( - hidden_states, - past_key_values=past_key_values, - attention_mask=attention_mask, - rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + x = self.sequence_mixer( + x, past_key_values=past_key_values, attention_mask_info=attention_mask_info, rope_cos_sin=rope_cos_sin ) elif self.sequence_mixer_type in ["causal_convolution", "mamba2"]: - hidden_states = self.sequence_mixer( - hidden_states, cache_params=past_key_values, attention_mask=attention_mask - ) + x = self.sequence_mixer(x, cache_params=past_key_values, attention_mask=attention_mask) elif self.sequence_mixer_type in ["gru", "rnn"]: - hidden_states = self.sequence_mixer( - hidden_states, - cache_params=past_key_values, - attention_mask=attention_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) + x = self.sequence_mixer(x=x, attention_mask_info=attention_mask_info, cache_params=past_key_values) else: raise ValueError(f"unexpected sequence_mixer_type ({self.sequence_mixer_type})") - return hidden_states + return x diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index ad4883430..5e535927e 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -10,9 +10,10 @@ from ....enums import Kernel from ....kernels import is_kernel_allowed -from ...cache import GenerationCache +from ...cache import GenerationCache, is_generation_cache_enabled from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero +from ...mask import AttentionMaskInfo from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .base import PreTrainedModelMixin @@ -56,42 +57,40 @@ def set_output_embeddings(self, new_embeddings: ParameterizedLinear) -> None: def forward( self, - input_ids: torch.Tensor | list[list[int]] | None = None, + input_ids: torch.Tensor | None = None, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | list[list[int]] | None = None, - inputs_embeds: torch.Tensor | list[list[float]] | None = None, - labels: torch.Tensor | list[list[int]] | None = None, + position_ids: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + labels: torch.Tensor | None = None, use_cache: bool | None = None, return_dict: bool = True, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, + attention_mask_info: AttentionMaskInfo | None = None, reduction: str = "mean", ) -> CausalLMOutputWithPast: assert return_dict assert inputs_embeds is None - input_ids, position_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - position_ids=position_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - ) - clear_aux_loss() + if attention_mask_info is None: + attention_mask_info = self._get_attention_mask_info(x=input_ids, attention_mask=attention_mask) + + if position_ids is None: + position_ids = attention_mask_info.get_position_ids() + + if is_generation_cache_enabled(): + past_key_values = ( + GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values + ) + + input_ids = attention_mask_info.pack_sequence(input_ids) + transformer_outputs: BaseModelOutputWithPast = self.transformer( - input_ids, + input_ids=input_ids, + attention_mask_info=attention_mask_info, past_key_values=past_key_values, - attention_mask=attention_mask, position_ids=position_ids, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) hidden_states = transformer_outputs.last_hidden_state @@ -118,18 +117,22 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width + labels = attention_mask_info.pack_sequence(labels) + loss = get_autoregressive_language_modeling_loss( lm_logits=lm_logits, labels=labels, + attention_mask_info=attention_mask_info, hidden_states=None, vocab_weight=None, - cu_seqlens=cu_seqlens, - use_padding_free_transformer=self.use_padding_free_transformer, reduction=reduction, shift_logits_and_labels=True, tensor_parallel_enabled=False, ) + lm_logits = attention_mask_info.unpack_sequence(lm_logits) + hidden_states = attention_mask_info.unpack_sequence(hidden_states) + aux_loss = get_aux_loss() if loss is not None and not is_aux_loss_zero(aux_loss): @@ -161,8 +164,6 @@ def generate( top_p: float | None = None, **kwargs, ) -> torch.Tensor: - assert not self.use_padding_free_transformer - has_attention_mask = attention_mask is not None min_tokens_to_keep = 1 @@ -258,3 +259,14 @@ def generate( ) return generated_tokens + + def _get_attention_mask_info(self, x: torch.Tensor, attention_mask: torch.Tensor | None) -> AttentionMaskInfo: + kwargs = {} + if attention_mask is None: + kwargs["batch_size"] = x.size(0) + kwargs["max_seqlen"] = x.size(1) + kwargs["device"] = x.device + else: + kwargs["attention_mask"] = attention_mask + + return AttentionMaskInfo(**kwargs) diff --git a/lm_engine/hf_models/mixins/dense_TP/base.py b/lm_engine/hf_models/mixins/dense_TP/base.py index 8943f8f97..ca1e841d8 100644 --- a/lm_engine/hf_models/mixins/dense_TP/base.py +++ b/lm_engine/hf_models/mixins/dense_TP/base.py @@ -8,11 +8,10 @@ import torch.nn as nn from ....utils import ProcessGroupManager, divide_if_divisible -from ...cache import GenerationCache +from ...cache import GenerationCache, is_generation_cache_enabled from ...config import CommonConfig from ...modeling_utils import RoPE, YaRNScaledRoPE from ...modeling_utils_TP import Dropout_TP, Embedding_TP, get_normalization_function_TP -from ...utils import is_generation_cache_enabled from ..dense import BaseModelMixin, PreTrainedModelMixin from ..modeling_outputs import BaseModelOutputWithPast from .layer import Block_TP diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index 802f431d1..4cb078c48 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -7,7 +7,7 @@ from .moe import MoE, ParameterizedExperts -def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int) -> MLP | MoE: +def get_mlp_block(config: CommonConfig, layer_idx: int) -> MLP | MoE: block = config.mlp_blocks[layer_idx] mlp_type = block.mlp_type @@ -33,7 +33,6 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye normalized_topk=block.normalized_topk, num_experts=block.num_experts, num_experts_per_tok=block.num_experts_per_tok, - use_padding_free_transformer=use_padding_free_transformer, ) else: raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})") diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index efa133c9e..0982ccd81 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -167,13 +167,11 @@ def __init__( initializer_range: float, m_width: float, num_layers: int, - use_padding_free_transformer: bool, ) -> MoE: super().__init__() self.num_experts = num_experts self.top_k = num_experts_per_tok - self.use_padding_free_transformer = use_padding_free_transformer self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.shared_intermediate_size = shared_intermediate_size @@ -247,9 +245,7 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_proj_shared.weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if not self.use_padding_free_transformer: - batch_size, sequence_length, _ = hidden_states.shape - + original_shape = hidden_states.size() hidden_states = hidden_states.view(-1, self.hidden_size) router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states) @@ -263,9 +259,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: del moe_output - if not self.use_padding_free_transformer: - hidden_states = hidden_states.reshape(batch_size, sequence_length, self.hidden_size) - + hidden_states = hidden_states.reshape(*original_shape) hidden_states = self.dropout(hidden_states) aux_loss = ( diff --git a/lm_engine/hf_models/modeling_utils/position_embedding/rope.py b/lm_engine/hf_models/modeling_utils/position_embedding/rope.py index 3aa8b21f3..86426193a 100644 --- a/lm_engine/hf_models/modeling_utils/position_embedding/rope.py +++ b/lm_engine/hf_models/modeling_utils/position_embedding/rope.py @@ -13,12 +13,7 @@ class RoPE(nn.Module): - def __init__( - self, - head_dim: int, - max_position_embeddings: int = 2048, - base: int = 10000, - ) -> RoPE: + def __init__(self, head_dim: int, max_position_embeddings: int = 2048, base: int = 10000) -> RoPE: super().__init__() self.head_dim = head_dim diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index f15066f56..acd03fdf7 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -19,12 +19,7 @@ SEQUENCE_MIXER_TYPE = Attention | CausalConvolution | GRU | Mamba2 | MultiHeadLatentAttention | RNN -def get_sequence_mixer( - config: CommonConfig, - causal: bool, - use_padding_free_transformer: bool, - layer_idx: int, -) -> SEQUENCE_MIXER_TYPE: +def get_sequence_mixer(config: CommonConfig, causal: bool, layer_idx: int) -> SEQUENCE_MIXER_TYPE: block = config.sequence_mixer_blocks[layer_idx] sequence_mixer_type = block.sequence_mixer_type @@ -42,7 +37,6 @@ def get_sequence_mixer( init_method=config.init_method, num_layers=config.num_layers, layer_idx=layer_idx, - use_padding_free_transformer=use_padding_free_transformer, ) elif sequence_mixer_type in ["rnn", "gru"]: return (GRU if sequence_mixer_type == "gru" else RNN)( @@ -59,7 +53,6 @@ def get_sequence_mixer( scaling_factor=block.scaling_factor, num_layers=config.num_layers, layer_idx=layer_idx, - use_padding_free_transformer=use_padding_free_transformer, ) elif sequence_mixer_type == "mamba2": return Mamba2( @@ -101,7 +94,6 @@ def get_sequence_mixer( num_layers=config.num_layers, causal=True, layer_idx=layer_idx, - use_padding_free_transformer=use_padding_free_transformer, normalization_function=block.normalization_function, layer_norm_epsilon=config.layer_norm_epsilon, ) @@ -124,11 +116,6 @@ def get_sequence_mixer( ) if sequence_mixer_type == "softmax_attention": - return Attention( - **sequence_mixer_kwargs, - qkv_bias=block.qkv_bias, - softmax_dropout=block.softmax_dropout, - use_padding_free_transformer=use_padding_free_transformer, - ) + return Attention(**sequence_mixer_kwargs, qkv_bias=block.qkv_bias, softmax_dropout=block.softmax_dropout) else: raise ValueError(f"unexpected sequence_mixer_type ({sequence_mixer_type})") diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 32ec1868c..85a9d6a6e 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -14,6 +14,7 @@ from ....kernels import is_kernel_allowed, wait_for_ACT from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...parameter import mark_parameter_as_mup_learning_rate from ..linear import ParameterizedLinear from ..position_embedding import apply_rotary_pos_emb @@ -84,7 +85,6 @@ def __init__( num_layers: int, causal: bool, layer_idx: int, - use_padding_free_transformer: bool, ) -> Attention: super().__init__() @@ -94,7 +94,6 @@ def __init__( self.num_key_value_heads = num_key_value_heads self.add_bias = add_bias self.qkv_bias = qkv_bias - self.use_padding_free_transformer = use_padding_free_transformer self.sliding_window = sliding_window self.head_dim = divide_if_divisible( @@ -138,97 +137,61 @@ def __init__( def forward( self, - hidden_states: torch.Tensor, + x: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: - use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) - use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) - accelerator = Accelerator.get_accelerator() + T = x.size(0) - if self.use_padding_free_transformer: - assert use_flash_attention_2 or use_flash_attention_3 - assert past_key_values is None + x = self.c_attn(x) + x = x.view(T, self.num_key_value_heads, -1) - total_q = hidden_states.shape[0] - input_shape = (total_q, self.num_key_value_heads, -1) - output_shape = (total_q, -1, self.head_dim) - else: - batch_size, query_length = hidden_states.shape[:-1] - - input_shape = (batch_size, query_length, self.num_key_value_heads, -1) - output_shape = (batch_size, query_length, -1, self.head_dim) - - hidden_states = self.c_attn(hidden_states) - hidden_states = hidden_states.view(*input_shape) - - query, key, value = hidden_states.split( + q, k, v = x.split( ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 ) - query = query.reshape(*output_shape) - - if not self.use_padding_free_transformer: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) + q = q.reshape(T, -1, self.head_dim) if self.position_embedding_type == "rope": - query = apply_rotary_pos_emb(query, rope_cos_sin) - key = apply_rotary_pos_emb(key, rope_cos_sin) + q = apply_rotary_pos_emb(q, rope_cos_sin) + k = apply_rotary_pos_emb(k, rope_cos_sin) if past_key_values is not None: - key, value = past_key_values.update(key_states=key, value_states=value, layer_idx=self.layer_idx) + k, v = past_key_values.update(key_states=k, value_states=v, layer_idx=self.layer_idx) - if use_flash_attention_2 or use_flash_attention_3: - assert accelerator == Accelerator.cuda + if is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3): + q, k, v = [wait_for_ACT(i, wait_in_forward=True, wait_in_backward=False) for i in (q, k, v)] - if self.use_padding_free_transformer: - output_shape = (-1, self.hidden_size) - else: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - output_shape = (batch_size, query_length, -1) - - query = wait_for_ACT(query, wait_in_forward=True, wait_in_backward=False) - key = wait_for_ACT(key, wait_in_forward=True, wait_in_backward=False) - value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) - - hidden_states = flash_attention( - query=query, - key=key, - value=value, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - attention_mask=attention_mask, - use_padding_free_transformer=self.use_padding_free_transformer, + x = flash_attention( + q=q, + k=k, + v=v, + attention_mask_info=attention_mask_info, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, sliding_window=self.sliding_window, ) - del query, key, value - - hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) - hidden_states = hidden_states.view(*output_shape) + x = wait_for_ACT(x, wait_in_forward=False, wait_in_backward=True) else: assert self.sliding_window is None - if accelerator == Accelerator.tpu: + q, k, v = attention_mask_info.unpack_sequence([q, k, v]) + q, k, v = [i.transpose(1, 2) for i in (q, k, v)] + + attention_mask = attention_mask_info.get_causal_mask(query_length=q.size(-2), dtype=q.dtype) + + if Accelerator.get_accelerator() == Accelerator.tpu: assert attention_mask is None assert self.softmax_dropout_p == 0 - hidden_states = flash_attention_tpu( - query, - key, - value, - causal=self.causal if attention_mask is None else False, + x = flash_attention_tpu( + q, + k, + v, + causal=self.causal, sm_scale=( 1 / math.sqrt(self.head_dim) if self.attention_multiplier is None @@ -236,10 +199,10 @@ def forward( ), ) else: - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, + x = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, attn_mask=attention_mask, dropout_p=self.softmax_dropout_p if self.training else 0, is_causal=self.causal if attention_mask is None else False, @@ -247,13 +210,11 @@ def forward( enable_gqa=True, ) - del query, key, value - - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.transpose(1, 2) - hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) + x = x.transpose(1, 2) + x = attention_mask_info.pack_sequence(x) - hidden_states = self.c_proj(hidden_states) - hidden_states = self.dropout(hidden_states) + x = x.flatten(-2, -1) + x = self.c_proj(x) + x = self.dropout(x) - return hidden_states + return x diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py index 6b8de7b21..99dd87cf7 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py @@ -147,13 +147,9 @@ def __init__( init_method: str, num_layers: int, layer_idx: int, - use_padding_free_transformer: bool, ) -> CausalConvolution: super().__init__() - if use_padding_free_transformer: - raise NotImplementedError() - self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index 14020ac3f..7c9418979 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -12,10 +12,10 @@ from ....utils import divide_if_divisible, is_xma_available from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence if is_xma_available(): @@ -38,7 +38,6 @@ def __init__( scaling_factor: float | None, num_layers: int, layer_idx: int, - use_padding_free_transformer: bool, ) -> GRU: super().__init__() @@ -48,7 +47,6 @@ def __init__( self.num_heads = num_heads self.gradient_clipping = gradient_clipping self.layer_idx = layer_idx - self.use_padding_free_transformer = use_padding_free_transformer self.state_head_dim = divide_if_divisible(self.state_size, self.num_heads, "") std = initializer_range @@ -56,13 +54,7 @@ def __init__( std /= math.sqrt(m_width) self.state_weight_std = std - self.input_projection = ParameterizedLinear( - self.input_size, - 4 * self.state_size, - bias=add_bias, - std=std, - ) - + self.input_projection = ParameterizedLinear(self.input_size, 4 * self.state_size, bias=add_bias, std=std) self.state_weight = nn.Parameter(torch.empty(3 * self.num_heads, self.state_head_dim, self.state_head_dim)) std = initializer_range / math.sqrt(2 * num_layers) @@ -82,73 +74,53 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) def forward( - self, - input: torch.Tensor, - cache_params: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, + self, x: torch.Tensor, attention_mask_info: AttentionMaskInfo, cache_params: GenerationCache | None = None ) -> torch.Tensor: - if self.use_padding_free_transformer: - assert cache_params is None - assert attention_mask is None - else: - assert cu_seqlens is None - assert max_seqlen is None + x = self.input_projection(x) + x, g = x.split((3 * self.state_size, self.state_size), dim=-1) - batch_size, sequence_length = input.size()[:2] - - if attention_mask is not None: - cu_seqlens, max_seqlen = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask) - input = pack_sequence(inputs=input, cu_seqlens=cu_seqlens) - - input_state = None if cache_params is None else cache_params.get_cache(self.layer_idx) + if self.scaling_factor != 1: + x = x * self.scaling_factor - input = self.input_projection(input) - input, gate = input.split((3 * self.state_size, self.state_size), dim=-1) + x, x_forget, x_reset = x.chunk(3, dim=-1) + x, x_forget, x_reset = [i.view(T, self.num_heads, self.state_head_dim) for i in (x, x_forget, x_reset)] weight = self.state_weight - if self.scaling_factor != 1: - input = input * self.scaling_factor weight = weight * self.scaling_factor - input, forget_input, reset_input = input.chunk(3, dim=-1) weight, forget_weight, reset_weight = weight.chunk(3, dim=0) - input, forget_input, reset_input = [ - i.view(*input.size()[:-1], self.num_heads, self.state_head_dim) for i in (input, forget_input, reset_input) - ] + cu_seqlens = None if attention_mask_info.is_ragged() else attention_mask_info.get_cu_seqlens() + max_seqlen = None if attention_mask_info.is_ragged() else attention_mask_info.get_max_seqlen() - input = gru( - input=input, + x = gru( + input=attention_mask_info.unpack_sequence(x), weight=weight, - forget_input=forget_input, + forget_input=x_forget, forget_weight=forget_weight, - reset_input=reset_input, + reset_input=x_reset, reset_weight=reset_weight, - input_state=input_state, + input_state=None if cache_params is None else cache_params.get_cache(self.layer_idx), gradient_clipping=self.gradient_clipping, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) - if not self.use_padding_free_transformer and attention_mask is not None: - input = unpack_sequence( - inputs=input, cu_seqlens=cu_seqlens, output_shape=(batch_size, sequence_length, *input.size()[1:]) - ) - if cache_params is not None: - cache_params.update(state=input[:, -1], num_tokens_added=input.size(1), layer_idx=self.layer_idx) - - input = input.view(*input.size()[:-2], -1) - - input = input * F.silu(gate) - input = self.norm(input) - - input = self.output_projection(input) - - return input + if cu_seqlens is None: + cache_params.update(state=x[:, -1], num_tokens_added=input.size(1), layer_idx=self.layer_idx) + else: + cache_params.update( + state=x[cu_seqlens[1:] - 1], num_tokens_added=cu_seqlens[1:], layer_idx=self.layer_idx + ) + + x = x.flatten(-2, -1) + x = x * F.silu(g) + x = self.norm(x) + x = self.output_projection(x) + + return x @torch.no_grad() def reset_parameters(self) -> None: diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py index b9a9f414e..eabf4bb03 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py @@ -38,7 +38,6 @@ def __init__( num_layers: int, causal: bool, layer_idx: int, - use_padding_free_transformer: bool, normalization_function: str, layer_norm_epsilon: float = 1e-5, ) -> MultiHeadLatentAttention: @@ -49,7 +48,6 @@ def __init__( self.num_heads = num_attention_heads self.head_dim = head_dim self.add_bias = add_bias - self.use_padding_free_transformer = use_padding_free_transformer self.query_compression_size = query_compression_size self.key_value_compression_size = key_value_compression_size self.position_embedding_type = position_embedding_type @@ -116,13 +114,6 @@ def forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, ) -> torch.Tensor: - use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) - use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) - - if self.use_padding_free_transformer: - assert use_flash_attention_2 or use_flash_attention_3 - assert past_key_values is None - query = self.query_down_projection(hidden_states) query = self.query_ln(query) @@ -146,47 +137,32 @@ def forward( key = self.key_up_projection(key) value = self.value_up_projection(value) - if use_flash_attention_2 or use_flash_attention_3: - if self.use_padding_free_transformer: - total_q = query.shape[0] - - query = query.view(total_q, self.num_heads, -1) - key = key.view(total_q, self.num_heads, -1) - value = value.view(total_q, self.num_heads, -1) - - output_shape = (-1, self.hidden_size) - else: - batch_size, query_length = query.shape[:-1] - key_length = key.shape[1] + if is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3): + T = query.size(0) - query = query.view(batch_size, query_length, self.num_heads, -1) - key = key.view(batch_size, key_length, self.num_heads, -1) - value = value.view(batch_size, key_length, self.num_heads, -1) - - output_shape = (batch_size, query_length, -1) + query = query.view(T, self.num_heads, -1) + key = key.view(T, self.num_heads, -1) + value = value.view(T, self.num_heads, -1) query = wait_for_ACT(query, wait_in_forward=True, wait_in_backward=False) key = wait_for_ACT(key, wait_in_forward=True, wait_in_backward=False) value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) hidden_states = flash_attention( - query=query, - key=key, - value=value, + q=query, + k=key, + v=value, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask, - use_padding_free_transformer=self.use_padding_free_transformer, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, sliding_window=self.sliding_window, ) - del query, key, value - hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(-1, self.hidden_size) else: assert self.sliding_window is None @@ -208,8 +184,6 @@ def forward( enable_gqa=True, ) - del query, key, value - batch_size = hidden_states.shape[0] hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 114d0b187..b3423a4b3 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -12,10 +12,10 @@ from ....utils import divide_if_divisible, is_xma_available from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence if is_xma_available(): @@ -38,7 +38,6 @@ def __init__( scaling_factor: float | None, num_layers: int, layer_idx: int, - use_padding_free_transformer: bool, ) -> RNN: super().__init__() @@ -48,7 +47,6 @@ def __init__( self.num_heads = num_heads self.gradient_clipping = gradient_clipping self.layer_idx = layer_idx - self.use_padding_free_transformer = use_padding_free_transformer self.state_head_dim = divide_if_divisible(self.state_size, self.num_heads, "") std = initializer_range @@ -76,64 +74,46 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) def forward( - self, - input: torch.Tensor, - cache_params: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, + self, x: torch.Tensor, attention_mask_info: AttentionMaskInfo, cache_params: GenerationCache | None = None ) -> torch.Tensor: - if self.use_padding_free_transformer: - assert cache_params is None - assert attention_mask is None - else: - assert cu_seqlens is None - assert max_seqlen is None - - batch_size, sequence_length = input.size()[:2] - - if attention_mask is not None: - cu_seqlens, max_seqlen = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask) - input = pack_sequence(inputs=input, cu_seqlens=cu_seqlens) + T = x.size(0) - input_state = None if cache_params is None else cache_params.get_cache(self.layer_idx) + x = self.input_projection(x) + x, g = x.chunk(2, dim=-1) + x = x.view(T, self.num_heads, self.state_head_dim) - input = self.input_projection(input) - input, gate = input.chunk(2, dim=-1) - - input = input.view(*input.size()[:-1], self.num_heads, self.state_head_dim) + if self.scaling_factor != 1: + x = x * self.scaling_factor weight = self.state_weight - if self.scaling_factor != 1: - input = input * self.scaling_factor weight = weight * self.scaling_factor - input = rnn( - input=input, + has_padding = attention_mask_info.has_padding() + + x, s = rnn( + input=x if has_padding else attention_mask_info.unpack_sequence(x), weight=weight, - input_state=input_state, + input_state=None if cache_params is None else cache_params.get_cache(self.layer_idx), gradient_clipping=self.gradient_clipping, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + cu_seqlens=attention_mask_info.get_cu_seqlens(), + max_seqlen=attention_mask_info.get_max_seqlen(), ) - if not self.use_padding_free_transformer and attention_mask is not None: - input = unpack_sequence( - inputs=input, cu_seqlens=cu_seqlens, output_shape=(batch_size, sequence_length, *input.size()[1:]) - ) + if not has_padding: + x = attention_mask_info.pack_sequence(x) if cache_params is not None: - cache_params.update(state=input[:, -1], num_tokens_added=input.size(1), layer_idx=self.layer_idx) - - input = input.view(*input.size()[:-2], -1) - - input = input * F.silu(gate) - input = self.norm(input) + cache_params.update( + state=s, num_tokens_added=attention_mask_info.get_batch_size(), layer_idx=self.layer_idx + ) - input = self.output_projection(input) + x = x.flatten(-2, -1) + x = x * F.silu(g) + x = self.norm(x) + x = self.output_projection(x) - return input + return x @torch.no_grad() def reset_parameters(self) -> None: diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py new file mode 100644 index 000000000..36a27a7d0 --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py @@ -0,0 +1,104 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +import torch + +from ....enums import Kernel +from ....kernels import is_kernel_allowed +from ....utils import is_flash_attention_2_available, is_flash_attention_3_available +from ...mask import AttentionMaskInfo + + +if is_flash_attention_2_available(): + from flash_attn import flash_attn_func as flash_attention_2 + from flash_attn import flash_attn_varlen_func as flash_attention_2_varlen + +if is_flash_attention_3_available(): + from flash_attn_interface import flash_attn_func as flash_attention_3 + from flash_attn_interface import flash_attn_varlen_func as flash_attention_3_varlen + + +def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask_info: AttentionMaskInfo, + causal: bool, + dropout: float = 0, + softmax_scale: float | None = None, + sliding_window: int | None = None, + softcap: float = 0, +) -> torch.Tensor: + use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) + + if use_flash_attention_3: + assert dropout == 0 + + window_size = (-1, -1) + if sliding_window is not None and k.size(1) > sliding_window: + window_size = (sliding_window, sliding_window) + + assert q.dim() == 3 + assert k.dim() == 3 + assert v.dim() == 3 + + if attention_mask_info.has_padding(): + assert sliding_window is None + + cu_seqlens = attention_mask_info.get_cu_seqlens(False) + max_seqlen = attention_mask_info.get_max_seqlen(False) + + if use_flash_attention_3: + x, _ = flash_attention_3_varlen( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + x = flash_attention_2_varlen( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + q, k, v = attention_mask_info.unpack_sequence([q, k, v]) + + if use_flash_attention_3: + x, _ = flash_attention_3( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + else: + x = flash_attention_2( + q=q, + k=k, + v=v, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + + x = attention_mask_info.pack_sequence(x) + + return x diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/__init__.py deleted file mode 100644 index 3b7c41a83..000000000 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from .flash_attention_utils import flash_attention -from .packing import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py deleted file mode 100644 index b9f43b198..000000000 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py +++ /dev/null @@ -1,177 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch - -from .....enums import Kernel -from .....kernels import is_kernel_allowed -from .....utils import is_flash_attention_2_available, is_flash_attention_3_available -from .packing import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence - - -if is_flash_attention_2_available(): - from flash_attn import flash_attn_func as flash_attention_2 - from flash_attn import flash_attn_varlen_func as flash_attention_2_varlen - -if is_flash_attention_3_available(): - from flash_attn_interface import flash_attn_func as flash_attention_3 - from flash_attn_interface import flash_attn_varlen_func as flash_attention_3_varlen - - -def unpad_input( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - cu_seqlens_k, max_seqlen_k = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask) - batch_size, kv_seq_len = key.size()[:2] - - if query_length == kv_seq_len: - query, key, value = pack_sequence(inputs=(query, key, value), cu_seqlens=cu_seqlens_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_q = max_seqlen_k - else: - key, value = pack_sequence(inputs=(key, value), cu_seqlens=cu_seqlens_k) - - if query_length == 1: - # There is a memcpy here, that is very bad. - cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query.device) - query = query.squeeze(1) - key, value = pack_sequence(inputs=(key, value), cu_seqlens=cu_seqlens_k) - max_seqlen_q = 1 - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - cu_seqlens_q, max_seqlen_q = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask) - query = pack_sequence(inputs=query, cu_seqlens=cu_seqlens_q) - - return query, key, value, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k - - -def flash_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - cu_seqlens: torch.Tensor | None, - max_seqlen: int | None, - use_padding_free_transformer: bool, - causal: bool, - dropout: float = 0, - softmax_scale: float | None = None, - sliding_window: int | None = None, - softcap: float = 0, -) -> torch.Tensor: - use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) - use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) - - if use_flash_attention_3: - assert dropout == 0 - - assert use_flash_attention_3 or use_flash_attention_2, "enable flash_attention_2 or flash_attention_3" - - if use_padding_free_transformer: - assert use_flash_attention_3 or use_flash_attention_2 - - window_size = (-1, -1) - if sliding_window is not None and key.size(1) > sliding_window: - window_size = (sliding_window, sliding_window) - - if use_padding_free_transformer: - assert sliding_window is None - - if use_flash_attention_3: - attn_output, _ = flash_attention_3_varlen( - q=query, - k=key, - v=value, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attention_2_varlen( - q=query, - k=key, - v=value, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - if attention_mask is None: - if use_flash_attention_3: - attn_output, _ = flash_attention_3( - q=query, - k=key, - v=value, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - ) - else: - attn_output = flash_attention_2( - q=query, - k=key, - v=value, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - ) - else: - batch_size, query_length, num_heads, head_dim = query.size() - - query, key, value, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = unpad_input( - query, key, value, attention_mask, query_length - ) - - if use_flash_attention_3: - attn_output, _ = flash_attention_3_varlen( - q=query, - k=key, - v=value, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - ) - else: - attn_output = flash_attention_2_varlen( - q=query, - k=key, - v=value, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - ) - - attn_output = unpack_sequence( - inputs=attn_output, - cu_seqlens=cu_seqlens_q, - output_shape=(batch_size, query_length, num_heads, head_dim), - ) - - return attn_output diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/packing.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/packing.py deleted file mode 100644 index 1b93e03bb..000000000 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/packing.py +++ /dev/null @@ -1,54 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch -import torch.nn.functional as F - -from .....utils import is_xma_available - - -if is_xma_available(): - from xma import pack_sequence as _pack_sequence - from xma import unpack_sequence as _unpack_sequence - - -def compute_cu_seqlens_and_max_seqlen_from_attention_mask( - attention_mask: torch.Tensor, -) -> tuple[torch.Tensor, int]: - seqlens = attention_mask.sum(dim=-1, dtype=torch.int32) - cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) - max_seqlen = seqlens.max().item() - return cu_seqlens, max_seqlen - - -def pack_sequence( - inputs: torch.Tensor | list[torch.Tensor], cu_seqlens: torch.Tensor -) -> torch.Tensor | list[torch.Tensor]: - is_tensor = isinstance(inputs, torch.Tensor) - if is_tensor: - inputs = [inputs] - - inputs = _pack_sequence(inputs=inputs, cu_seqlens=cu_seqlens, total_tokens=cu_seqlens[-1].item()) - - if is_tensor: - inputs = inputs[0] - - return inputs - - -def unpack_sequence( - inputs: torch.Tensor | list[torch.Tensor], cu_seqlens: torch.Tensor, output_shape: tuple[int] -) -> torch.Tensor | list[torch.Tensor]: - is_tensor = isinstance(inputs, torch.Tensor) - if is_tensor: - inputs = [inputs] - - inputs = _unpack_sequence( - inputs=inputs, cu_seqlens=cu_seqlens, batch_size=output_shape[0], sequence_length=output_shape[1] - ) - - if is_tensor: - inputs = inputs[0] - - return inputs diff --git a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py index 85d030be6..be482c4a7 100644 --- a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py @@ -14,6 +14,7 @@ from ....kernels import is_kernel_allowed, wait_for_ACT from ....utils import ProcessGroupManager, divide_if_divisible from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...modeling_utils import Attention, apply_rotary_pos_emb, flash_attention from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear from ..dropout import Dropout_TP @@ -127,11 +128,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) @@ -189,23 +188,20 @@ def forward( value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) hidden_states = flash_attention( - query=query, - key=key, - value=value, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - attention_mask=attention_mask, - use_padding_free_transformer=self.use_padding_free_transformer, + q=query, + k=key, + v=value, + attention_mask_info=attention_mask_info, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, ) - del query, key, value - hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) hidden_states = hidden_states.view(*output_shape) else: + attention_mask = attention_mask_info.get_attention_mask() + hidden_states = F.scaled_dot_product_attention( query, key, @@ -217,8 +213,6 @@ def forward( enable_gqa=True, ) - del query, key, value - batch_size = hidden_states.shape[0] hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) diff --git a/lm_engine/hf_models/models/gpt_crosslayer/base.py b/lm_engine/hf_models/models/gpt_crosslayer/base.py index ba5e827f6..c6cdc919b 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/base.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/base.py @@ -7,6 +7,7 @@ import torch from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...mixins import BaseModelMixin, BaseModelOutputWithPast, PreTrainedModelMixin from .config import GPTCrossLayerConfig from .layer import GPTCrossLayerBlock @@ -25,49 +26,29 @@ def __init__(self, config: GPTCrossLayerConfig, *args, **kwargs) -> GPTCrossLaye class GPTCrossLayerModel(GPTCrossLayerPreTrainedModel, BaseModelMixin): def forward( self, - input_ids: torch.Tensor | None = None, + input_ids: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, - use_cache: bool | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> BaseModelOutputWithPast: - ( - use_cache, - hidden_states, - attention_mask, - position_ids, - rope_cos_sin, - past_key_values, - ) = self._prepare_a_bunch_of_stuff( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + hidden_states = self._get_initial_hidden_state(input_ids, position_ids) + rope_cos_sin = self._get_rope_cos_sin( + attention_mask_info.get_max_seqlen(), position_ids, dtype=hidden_states.dtype ) - past_key_values = GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values - key = None value = None for block in self.h: hidden_states, key, value = block( - hidden_states, + hidden_states=hidden_states, key=key, value=value, + attention_mask_info=attention_mask_info, past_key_values=past_key_values, - attention_mask=attention_mask, rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) - del key, value hidden_states = self.ln_f(hidden_states) return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values) diff --git a/lm_engine/hf_models/models/gpt_crosslayer/layer.py b/lm_engine/hf_models/models/gpt_crosslayer/layer.py index 3413911ff..f2fb79856 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/layer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/layer.py @@ -7,19 +7,16 @@ import torch import torch.nn as nn -from ....enums import Kernel -from ....kernels import is_kernel_allowed from ....utils import divide_if_divisible from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...modeling_utils import apply_rotary_pos_emb, get_mlp_block, get_normalization_function from .config import GPTCrossLayerConfig -from .sequence_mixers import KeyValueProjection, get_sequence_mixer +from .sequence_mixer import KeyValueProjection, get_sequence_mixer class GPTCrossLayerBlock(nn.Module): - def __init__( - self, config: GPTCrossLayerConfig, use_padding_free_transformer: bool, layer_idx: int - ) -> GPTCrossLayerBlock: + def __init__(self, config: GPTCrossLayerConfig, layer_idx: int) -> GPTCrossLayerBlock: super().__init__() hidden_size = config.hidden_size @@ -30,8 +27,6 @@ def __init__( self.head_dim = divide_if_divisible(hidden_size, self.num_heads, "") self.num_key_value_heads = config.sequence_mixer_blocks[layer_idx].num_key_value_heads - self.use_padding_free_transformer = use_padding_free_transformer - self.kv_proj = None if config.sharing_pattern[layer_idx] == layer_idx: self.kv_proj = KeyValueProjection( @@ -42,30 +37,25 @@ def __init__( initializer_range=config.initializer_range, normalization_function=config.normalization_function, layer_norm_epsilon=config.layer_norm_epsilon, - use_padding_free_transformer=use_padding_free_transformer, ) self.ln_1 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) - self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) + self.sequence_mixer = get_sequence_mixer(config, True, layer_idx) self.ln_2 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) - self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx - ) + self.mlp_block = get_mlp_block(config, layer_idx=layer_idx) def forward( self, hidden_states: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: if self.kv_proj is not None: key, value = self.kv_proj(hidden_states) @@ -76,11 +66,6 @@ def forward( if past_key_values is not None: key, value = past_key_values.update(key_states=key, value_states=value, layer_idx=self.layer_idx) - if is_kernel_allowed(Kernel.flash_attention_3) or is_kernel_allowed(Kernel.flash_attention_2): - if not self.use_padding_free_transformer: - key = key.transpose(1, 2) - value = value.transpose(1, 2) - residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -88,10 +73,8 @@ def forward( hidden_states, key=key, value=value, - attention_mask=attention_mask, + attention_mask_info=attention_mask_info, rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) if self.m_residual is not None: diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py similarity index 54% rename from lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py rename to lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py index 4e99b775c..d588eb4fb 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py @@ -10,10 +10,12 @@ import torch.nn as nn import torch.nn.functional as F -from .....enums import Kernel -from .....kernels import is_kernel_allowed -from .....utils import divide_if_divisible -from ....modeling_utils import ParameterizedLinear, apply_rotary_pos_emb, flash_attention, get_normalization_function +from ....enums import Kernel +from ....kernels import is_kernel_allowed +from ....utils import divide_if_divisible +from ...mask import AttentionMaskInfo +from ...modeling_utils import ParameterizedLinear, apply_rotary_pos_emb, flash_attention, get_normalization_function +from .config import GPTCrossLayerConfig class CrossLayerAttention(nn.Module): @@ -31,7 +33,6 @@ def __init__( num_layers: int, causal: bool, layer_idx: int, - use_padding_free_transformer: bool, ) -> CrossLayerAttention: super().__init__() @@ -41,7 +42,6 @@ def __init__( self.num_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.add_bias = add_bias - self.use_padding_free_transformer = use_padding_free_transformer assert ( self.hidden_size % self.num_heads == 0 @@ -81,75 +81,33 @@ def forward( hidden_states: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: torch.Tensor | None = None, + attention_mask_info: AttentionMaskInfo, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: if is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3): - if self.use_padding_free_transformer: - total_q = hidden_states.shape[0] - - query = self.q_attn(hidden_states) - query = query.view(total_q, self.num_heads, -1) - - if self.position_embedding_type == "rope": - query = apply_rotary_pos_emb(query, rope_cos_sin) - - hidden_states = flash_attention( - query=query, - key=key, - value=value, - attention_mask=attention_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - use_padding_free_transformer=self.use_padding_free_transformer, - causal=self.causal, - dropout=self.softmax_dropout_p if self.training else 0, - softmax_scale=self.attention_multiplier, - ) - - del query, key, value - - hidden_states = hidden_states.view(-1, self.hidden_size) - else: - batch_size, query_length = hidden_states.shape[:2] - - query = self.q_attn(hidden_states) - query = query.view(batch_size, query_length, self.num_heads, -1) - - if self.position_embedding_type == "rope": - # TODO avoid this extra transpose - query = query.transpose(1, 2) - query = apply_rotary_pos_emb(query, rope_cos_sin) - query = query.transpose(1, 2) - - hidden_states = flash_attention( - query=query, - key=key, - value=value, - attention_mask=attention_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - use_padding_free_transformer=self.use_padding_free_transformer, - causal=self.causal, - dropout=self.softmax_dropout_p if self.training else 0, - softmax_scale=self.attention_multiplier, - ) - - del query, key, value - - hidden_states = hidden_states.view(batch_size, query_length, -1) - else: - batch_size, query_length = hidden_states.shape[:2] - query = self.q_attn(hidden_states) - query = query.view(batch_size, query_length, self.num_heads, -1) + query = query.view(*hidden_states.size()[:-1], self.num_heads, -1) + + if self.position_embedding_type == "rope": + query = apply_rotary_pos_emb(query, rope_cos_sin) + + hidden_states = flash_attention( + q=query, + k=key, + v=value, + attention_mask_info=attention_mask_info, + causal=self.causal, + dropout=self.softmax_dropout_p if self.training else 0, + softmax_scale=self.attention_multiplier, + ) + else: query = query.transpose(1, 2) if self.position_embedding_type == "rope": query = apply_rotary_pos_emb(query, rope_cos_sin) + attention_mask = attention_mask_info.get_attention_mask() + hidden_states = F.scaled_dot_product_attention( query, key, @@ -161,11 +119,9 @@ def forward( enable_gqa=True, ) - del query, key, value - hidden_states = hidden_states.transpose(1, 2) - hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.flatten(-2, -1) hidden_states = self.c_proj(hidden_states) hidden_states = self.dropout(hidden_states) @@ -182,7 +138,6 @@ def __init__( initializer_range: float, normalization_function: str, layer_norm_epsilon: float, - use_padding_free_transformer: bool, ) -> KeyValueProjection: super().__init__() @@ -197,28 +152,31 @@ def __init__( std=initializer_range, ) - self.use_padding_free_transformer = use_padding_free_transformer - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states = self.ln(hidden_states) hidden_states = self.kv_attn(hidden_states) - if self.use_padding_free_transformer: - total_q = hidden_states.shape[0] - - if self.num_key_value_heads == 1: - hidden_states = hidden_states.unsqueeze(1) - else: - hidden_states = hidden_states.view(total_q, self.num_key_value_heads, -1) - else: - batch_size, query_length = hidden_states.shape[:2] - - if self.num_key_value_heads == 1: - hidden_states = hidden_states.unsqueeze(1) - else: - hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1) - hidden_states = hidden_states.transpose(1, 2) - + hidden_states = hidden_states.view(*hidden_states.size()[:-1], self.num_key_value_heads, -1) key, value = hidden_states.chunk(2, -1) return key, value + + +def get_sequence_mixer(config: GPTCrossLayerConfig, causal: bool, layer_idx: int) -> CrossLayerAttention: + block = config.sequence_mixer_blocks[layer_idx] + assert block.sequence_mixer_type == "softmax_attention" + + return CrossLayerAttention( + hidden_size=config.hidden_size, + num_attention_heads=block.num_attention_heads, + num_key_value_heads=block.num_key_value_heads, + attention_multiplier=block.attention_multiplier, + position_embedding_type=config.position_embedding_type, + add_bias=block.add_bias, + softmax_dropout=block.softmax_dropout, + dropout=block.dropout, + initializer_range=config.initializer_range, + num_layers=config.num_layers, + causal=causal, + layer_idx=layer_idx, + ) diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py deleted file mode 100644 index 1fafdc67a..000000000 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from ..config import GPTCrossLayerConfig -from .base import CrossLayerAttention, KeyValueProjection - - -def get_sequence_mixer( - config: GPTCrossLayerConfig, causal: bool, use_padding_free_transformer: bool, layer_idx: int -) -> CrossLayerAttention: - block = config.sequence_mixer_blocks[layer_idx] - assert block.sequence_mixer_type == "softmax_attention" - - return CrossLayerAttention( - hidden_size=config.hidden_size, - num_attention_heads=block.num_attention_heads, - num_key_value_heads=block.num_key_value_heads, - attention_multiplier=block.attention_multiplier, - position_embedding_type=config.position_embedding_type, - add_bias=block.add_bias, - softmax_dropout=block.softmax_dropout, - dropout=block.dropout, - initializer_range=config.initializer_range, - num_layers=config.num_layers, - causal=causal, - layer_idx=layer_idx, - use_padding_free_transformer=use_padding_free_transformer, - ) diff --git a/lm_engine/hf_models/models/ladder_residual/base.py b/lm_engine/hf_models/models/ladder_residual/base.py index e952b064b..aad260331 100644 --- a/lm_engine/hf_models/models/ladder_residual/base.py +++ b/lm_engine/hf_models/models/ladder_residual/base.py @@ -4,9 +4,8 @@ import torch -from ...cache import GenerationCache +from ...cache import GenerationCache, is_generation_cache_enabled from ...mixins import BaseModelMixin, BaseModelOutputWithPast, PreTrainedModelMixin -from ...utils import is_generation_cache_enabled from .config import LadderResidualConfig from .layer import LadderResidualBlock diff --git a/lm_engine/hf_models/models/palm/layer.py b/lm_engine/hf_models/models/palm/layer.py index 35dd9541a..0623086f4 100644 --- a/lm_engine/hf_models/models/palm/layer.py +++ b/lm_engine/hf_models/models/palm/layer.py @@ -13,9 +13,7 @@ class PaLMBlock(nn.Module): - def __init__( - self, config: PaLMConfig, use_padding_free_transformer: bool, layer_idx: int | None = None - ) -> PaLMBlock: + def __init__(self, config: PaLMConfig, layer_idx: int | None = None) -> PaLMBlock: super().__init__() self.m_residual = config.m_residual @@ -23,10 +21,8 @@ def __init__( self.ln = get_normalization_function( config.normalization_function, config.hidden_size, eps=config.layer_norm_epsilon ) - self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) - self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx - ) + self.sequence_mixer = get_sequence_mixer(config, True, layer_idx) + self.mlp_block = get_mlp_block(config, layer_idx=layer_idx) def forward( self, diff --git a/lm_engine/hf_models/utils.py b/lm_engine/hf_models/utils.py deleted file mode 100644 index c16f3a17b..000000000 --- a/lm_engine/hf_models/utils.py +++ /dev/null @@ -1,72 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from typing import Any - -import torch - - -def convert_padding_free_lists_to_tensors( - input_ids: list[list[int]] | None = None, - position_ids: list[list[int]] | None = None, - labels: list[list[int]] | None = None, - device: torch.device = None, -) -> tuple[torch.Tensor | int]: - - # check input types are correct - error_message = "{variable} should be of type List[List[{dtype}]]" - _check_list_type(input_ids, error_message.format(variable="input_ids", dtype="int")) - _check_list_type(position_ids, error_message.format(variable="position_ids", dtype="int")) - _check_list_type(labels, error_message.format(variable="labels", dtype="int")) - - # prepare inputs for the model - seqlens = torch.tensor([0] + [len(x) for x in input_ids], device=device) - cu_seqlens = seqlens.cumsum(dim=-1).to(torch.int32) - max_seqlen = seqlens.max().item() - - if position_ids is None: - position_ids = [list(range(len(x))) for x in input_ids] - position_ids = _flatten_and_convert_to_tensors(position_ids, device) - - input_ids = _flatten_and_convert_to_tensors(input_ids, device) - - if labels is not None: - labels = _flatten_and_convert_to_tensors(labels, device) - - return input_ids, position_ids, labels, cu_seqlens, max_seqlen - - -def _check_list_type(list_of_list: list[list[int | float]] | None, error_message: str) -> None: - if list_of_list is None: - return - - assert isinstance(list_of_list, list), error_message - assert isinstance(list_of_list[0], list), error_message - - -def _flatten_and_convert_to_tensors(x: list[int], device: torch.device) -> torch.Tensor: - y = [] - for sequence in x: - y.extend(sequence) - - return torch.tensor(y, device=device) - - -_IS_GENERATION_CACHE_ENABLED: bool = True - - -class disable_generation_cache: - def __enter__(self) -> Any: - global _IS_GENERATION_CACHE_ENABLED - self.original = _IS_GENERATION_CACHE_ENABLED - - _IS_GENERATION_CACHE_ENABLED = False - - def __exit__(self, exception_type, exception_value, exception_traceback) -> Any: - global _IS_GENERATION_CACHE_ENABLED - _IS_GENERATION_CACHE_ENABLED = self.original - - -def is_generation_cache_enabled() -> bool: - return _IS_GENERATION_CACHE_ENABLED diff --git a/lm_engine/model_wrapper/__init__.py b/lm_engine/model_wrapper/__init__.py index 12615bdde..57601f803 100644 --- a/lm_engine/model_wrapper/__init__.py +++ b/lm_engine/model_wrapper/__init__.py @@ -39,7 +39,6 @@ def get_model_container( "model_class": args.model_args.model_class, "dtype": args.mixed_precision_args.dtype, "efficient_initialization": efficient_initialization, - "use_padding_free_transformer": args.model_args.use_padding_free_transformer, "sequence_parallel": args.distributed_args.sequence_parallel, "num_pipeline_stages": num_pipeline_stages, "trust_remote_code": args.model_args.trust_remote_code, diff --git a/lm_engine/model_wrapper/base.py b/lm_engine/model_wrapper/base.py index bfb21ffac..47f8da9ac 100644 --- a/lm_engine/model_wrapper/base.py +++ b/lm_engine/model_wrapper/base.py @@ -28,7 +28,6 @@ def __init__( model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM, dtype: torch.dtype, efficient_initialization: bool, - use_padding_free_transformer: bool, sequence_parallel: bool, num_pipeline_stages: int, pipeline_stage_id: int, @@ -45,7 +44,6 @@ def __init__( model_class (AutoModelForCausalLM | AutoModelForSeq2SeqLM): HF model class to use for model loading dtype (torch.dtype): dtype for the model efficient_initialization (bool): whether to use efficient initialization for the model initialization, saves CPU memory - use_padding_free_transformer (bool): whether to use padding free transformer sequence_parallel (bool): whether to use sequence parallel num_pipeline_stages (int): number of stages for the pipeline pipeline_stage_id (int): current pipeline stage id @@ -62,7 +60,6 @@ def __init__( self.model_class = model_class self.efficient_initialization = efficient_initialization self.dtype = dtype - self.use_padding_free_transformer = use_padding_free_transformer self.sequence_parallel = sequence_parallel self.tokenizer_name = self.model_name if tokenizer_name is None else tokenizer_name self.trust_remote_code = trust_remote_code @@ -88,9 +85,6 @@ def __init__( self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() self.model_class = get_model_parallel_class(self.config.model_type) - if self.use_padding_free_transformer: - assert self.is_custom_model, "padding free transformer is not supported with the specified model" - self._setup_tokenizer() self._setup_model() @@ -148,8 +142,6 @@ def _get_model_kwargs(self) -> dict: "flash_attention_2" if is_kernel_allowed(Kernel.flash_attention_2) else "sdpa" ) - if self.use_padding_free_transformer: - model_kwargs["use_padding_free_transformer"] = True if self.sequence_parallel: model_kwargs["sequence_parallel"] = True if self.trust_remote_code: diff --git a/lm_engine/model_wrapper/distillation.py b/lm_engine/model_wrapper/distillation.py index 8394b1477..14e52d5a4 100644 --- a/lm_engine/model_wrapper/distillation.py +++ b/lm_engine/model_wrapper/distillation.py @@ -31,7 +31,6 @@ def __init__( model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM, dtype: torch.dtype, efficient_initialization: bool, - use_padding_free_transformer: bool, sequence_parallel: bool, micro_batch_size: int, sequence_length: int, @@ -57,7 +56,6 @@ def __init__( model_class (AutoModelForCausalLM | AutoModelForSeq2SeqLM): HF model class to use for model loading dtype (torch.dtype): dtype for the model efficient_initialization (bool): whether to use efficient initialization for the model initialization, saves CPU memory - use_padding_free_transformer (bool): whether to use padding free transformer sequence_parallel (bool): whether to use sequence parallel micro_batch_size (int): micro batch size for pretraining sequence_length (int): sequence length for pretraining @@ -83,7 +81,6 @@ def __init__( model_class=model_class, dtype=dtype, efficient_initialization=efficient_initialization, - use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, micro_batch_size=micro_batch_size, sequence_length=sequence_length, @@ -129,6 +126,8 @@ def forward( student_logits = output.logits del output + assert False + # TODO modify this when TP support is added lm_loss = get_autoregressive_language_modeling_loss( lm_logits=student_logits, @@ -136,7 +135,7 @@ def forward( hidden_states=None, vocab_weight=None, cu_seqlens=None, - use_padding_free_transformer=self.use_padding_free_transformer, + use_padding_free_transformer=True, reduction="sum", shift_logits_and_labels=False, tensor_parallel_enabled=False, diff --git a/lm_engine/model_wrapper/finetuning.py b/lm_engine/model_wrapper/finetuning.py index f05581e44..59021a8a6 100644 --- a/lm_engine/model_wrapper/finetuning.py +++ b/lm_engine/model_wrapper/finetuning.py @@ -53,13 +53,15 @@ def get_loss( tensor_parallel_enabled = ProcessGroupManager.is_tensor_parallel_enabled() use_fused_linear_cross_entropy_kernel = is_kernel_allowed(Kernel.fused_linear_cross_entropy) + assert False + lm_loss = get_autoregressive_language_modeling_loss( lm_logits=None if use_fused_linear_cross_entropy_kernel else model_outputs.logits, labels=labels, hidden_states=model_outputs.last_hidden_state if use_fused_linear_cross_entropy_kernel else None, vocab_weight=self.model.get_output_embeddings().weight if use_fused_linear_cross_entropy_kernel else None, cu_seqlens=cu_seqlens, - use_padding_free_transformer=self.use_padding_free_transformer, + use_padding_free_transformer=True, reduction="sum", shift_logits_and_labels=True, tensor_parallel_enabled=tensor_parallel_enabled, @@ -87,7 +89,7 @@ def _broadcast_inputs_for_tensor_parallel(self, batch: dict) -> dict: tp_source_rank = ProcessGroupManager.get_tensor_parallel_first_rank() tp_group = ProcessGroupManager.get_tensor_parallel_group() - if self.use_padding_free_transformer: + if self.is_custom_model: keys = ["input_ids", "position_ids", "labels", "cu_seqlens", "max_seqlen"] if is_tp_first_rank: diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index b849432a3..47fda399f 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -11,6 +11,7 @@ from ..dtensors import tensor_to_dtensor from ..enums import Kernel from ..hf_models import ( + AttentionMaskInfo, CausalLMOutputWithPast, PipelineParallelInput, PipelineParallelOutput, @@ -31,7 +32,6 @@ def __init__( model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM, dtype: torch.dtype, efficient_initialization: bool, - use_padding_free_transformer: bool, sequence_parallel: bool, micro_batch_size: int, sequence_length: int, @@ -52,7 +52,6 @@ def __init__( model_class (AutoModelForCausalLM | AutoModelForSeq2SeqLM): HF model class to use for model loading dtype (torch.dtype): dtype for the model efficient_initialization (bool): whether to use efficient initialization for the model initialization, saves CPU memory - use_padding_free_transformer (bool): whether to use padding free transformer sequence_parallel (bool): whether to use sequence parallel micro_batch_size (int): micro batch size for pretraining sequence_length (int): sequence length for pretraining @@ -77,7 +76,6 @@ def __init__( model_class=model_class, dtype=dtype, efficient_initialization=efficient_initialization, - use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, num_pipeline_stages=num_pipeline_stages, pipeline_stage_id=pipeline_stage_id, @@ -126,6 +124,8 @@ def forward( batch = self._prepare_model_inputs(batch) labels = batch.pop("labels") + attention_mask_info = batch["attention_mask_info"] + output: CausalLMOutputWithPast | PipelineParallelOutput = self.model(**batch, return_dict=True) if self.is_pipeline_parallel_enabled: @@ -146,12 +146,18 @@ def forward( if use_aux_loss: output = (output, aux_loss) else: - output = self.get_loss(output, labels, lm_loss_multiplier=lm_loss_multiplier) + output = self.get_loss( + output, labels, attention_mask_info=attention_mask_info, lm_loss_multiplier=lm_loss_multiplier + ) return output def get_loss( - self, model_outputs: CausalLMOutputWithPast, labels: torch.Tensor, lm_loss_multiplier: float = 1 + self, + model_outputs: CausalLMOutputWithPast, + labels: torch.Tensor, + attention_mask_info: AttentionMaskInfo, + lm_loss_multiplier: float = 1, ) -> torch.Tensor | dict: tensor_parallel_enabled = ProcessGroupManager.is_tensor_parallel_enabled() use_fused_linear_cross_entropy_kernel = is_kernel_allowed(Kernel.fused_linear_cross_entropy) @@ -159,10 +165,9 @@ def get_loss( lm_loss = get_autoregressive_language_modeling_loss( lm_logits=None if use_fused_linear_cross_entropy_kernel else model_outputs.logits, labels=labels, + attention_mask_info=attention_mask_info, hidden_states=model_outputs.last_hidden_state if use_fused_linear_cross_entropy_kernel else None, vocab_weight=self.model.get_output_embeddings().weight if use_fused_linear_cross_entropy_kernel else None, - cu_seqlens=None, - use_padding_free_transformer=self.use_padding_free_transformer, reduction="sum", shift_logits_and_labels=False, tensor_parallel_enabled=tensor_parallel_enabled, @@ -225,40 +230,39 @@ def _prepare_model_inputs(self, batch: dict) -> dict: input_ids = tokens[:, :-1] batch = {"labels": tokens[:, 1:]} - if self.use_padding_free_transformer: - batch_size, sequence_length = input_ids.shape - input_ids = input_ids.reshape(-1) - - if self.reset_attention_mask: - num_tokens_in_batch = batch_size * sequence_length - - document_end_positions = input_ids == self.eos_token_id - for i in range(sequence_length - 1, num_tokens_in_batch, sequence_length): - document_end_positions[i] = 1 - cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 - cu_seqlens = torch.cat([torch.tensor([0], device=input_ids.device), cu_seqlens]) - cu_seqlens = cu_seqlens.to(torch.int32) - - seqlen = cu_seqlens[1:] - cu_seqlens[:-1] - # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers - max_seqlen = seqlen.max().item() - - if self.reset_position_ids: - position_ids = torch.cat( - [torch.arange(0, i, 1, dtype=torch.int32, device=input_ids.device) for i in seqlen] - ) - else: - position_ids = self.position_ids + batch_size, sequence_length = input_ids.shape + + if self.reset_attention_mask: + input_ids = input_ids.flatten() + num_tokens_in_batch = batch_size * sequence_length + + document_end_positions = input_ids == self.eos_token_id + for i in range(sequence_length - 1, num_tokens_in_batch, sequence_length): + document_end_positions[i] = 1 + cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 + cu_seqlens = torch.cat([torch.tensor([0], device=input_ids.device), cu_seqlens]) + cu_seqlens = cu_seqlens.to(torch.int32) + + seqlen = cu_seqlens[1:] - cu_seqlens[:-1] + # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers + max_seqlen = seqlen.max().item() + + if self.reset_position_ids: + position_ids = torch.cat( + [torch.arange(0, i, 1, dtype=torch.int32, device=input_ids.device) for i in seqlen] + ) else: - cu_seqlens = self.cu_seqlens - max_seqlen = self.sequence_length position_ids = self.position_ids - - batch["cu_seqlens"] = cu_seqlens - batch["max_seqlen"] = max_seqlen - batch["position_ids"] = position_ids + else: + cu_seqlens = None + max_seqlen = None + position_ids = self.position_ids batch["input_ids"] = input_ids + batch["attention_mask_info"] = self._get_attention_mask_info( + x=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + batch["position_ids"] = position_ids if ProcessGroupManager.is_tensor_parallel_enabled(): batch["output_parallel_lm_logits"] = True @@ -270,37 +274,43 @@ def _setup_model(self) -> None: self.reset_parameters() def reset_parameters(self) -> None: - if self.use_padding_free_transformer: - if not self.reset_attention_mask: - self.register_buffer( - "cu_seqlens", - torch.arange( - 0, - self.micro_batch_size * self.sequence_length + 1, - self.sequence_length, - dtype=torch.int32, - device=Accelerator.get_current_device(), - ), - persistent=False, - ) - - if self.reset_position_ids: - assert self.reset_attention_mask, "reset_attention_mask should be specified with reset_position_ids" - else: - self.register_buffer( - "position_ids", - torch.arange(0, self.sequence_length, 1, device=Accelerator.get_current_device()).repeat( - self.micro_batch_size - ), - persistent=False, - ) + # if not self.reset_attention_mask: + # self.register_buffer( + # "cu_seqlens", + # torch.arange( + # 0, + # self.micro_batch_size * self.sequence_length + 1, + # self.sequence_length, + # dtype=torch.int32, + # device=Accelerator.get_current_device(), + # ), + # persistent=False, + # ) + + if self.reset_position_ids: + assert self.reset_attention_mask, "reset_attention_mask should be specified with reset_position_ids" else: - assert ( - not self.reset_attention_mask - ), "currently reset_attention_mask is only implemented for padding free transformer" - assert ( - not self.reset_position_ids - ), "currently reset_position_ids is only implemented for padding free transformer" + self.register_buffer( + "position_ids", + torch.arange(0, self.sequence_length, 1, device=Accelerator.get_current_device()).repeat( + self.micro_batch_size + ), + persistent=False, + ) + + def _get_attention_mask_info( + self, x: torch.Tensor, cu_seqlens: torch.Tensor | None, max_seqlen: torch.Tensor + ) -> AttentionMaskInfo: + kwargs = {} + if cu_seqlens is None: + kwargs["batch_size"] = x.size(0) + kwargs["max_seqlen"] = x.size(1) + kwargs["device"] = x.device + else: + kwargs["cu_seqlens"] = cu_seqlens + kwargs["max_seqlen"] = max_seqlen + + return AttentionMaskInfo(**kwargs) class _F(torch.autograd.Function): diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py index fce40795e..ba7dfb450 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py @@ -22,7 +22,6 @@ parser.add_argument("--attention-implementation", type=str) parser.add_argument("--dtype", type=str) parser.add_argument("--tmp-path", type=str) -parser.add_argument("--use-padding-free-transformer", action="store_true") parser.add_argument("--sequence-parallel", action="store_true") args = parser.parse_args() @@ -81,9 +80,7 @@ # try sharding vocab matrices if really struggling for memory model_tp = get_model_parallel_class(config.model_type)._from_config( - config, - use_padding_free_transformer=args.use_padding_free_transformer, - sequence_parallel=args.sequence_parallel, + config, sequence_parallel=args.sequence_parallel ) # copy to device without copying storage @@ -109,21 +106,18 @@ 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False ) -if args.use_padding_free_transformer: - cu_seqlens = torch.arange( - 0, input_ids.numel() + 1, sequence_length, dtype=torch.int32, device=torch.cuda.current_device() - ) - position_ids = torch.arange(0, sequence_length, 1, device=torch.cuda.current_device()).repeat(batch_size) - - output_tp = model_tp( - input_ids=input_ids.view(-1), - labels=labels.view(-1), - cu_seqlens=cu_seqlens, - max_seqlen=sequence_length, - position_ids=position_ids, - ) -else: - output_tp = model_tp(input_ids=input_ids, labels=labels) +cu_seqlens = torch.arange( + 0, input_ids.numel() + 1, sequence_length, dtype=torch.int32, device=torch.cuda.current_device() +) +position_ids = torch.arange(0, sequence_length, 1, device=torch.cuda.current_device()).repeat(batch_size) + +output_tp = model_tp( + input_ids=input_ids.view(-1), + labels=labels.view(-1), + cu_seqlens=cu_seqlens, + max_seqlen=sequence_length, + position_ids=position_ids, +) loss_tp = output_tp.loss logits_tp = output_tp.logits[..., : config.vocab_size] @@ -135,9 +129,7 @@ loss = output.loss logits = output.logits - - if args.use_padding_free_transformer: - logits_tp = logits_tp.reshape(batch_size, sequence_length, -1) + logits_tp = logits_tp.reshape(batch_size, sequence_length, -1) error = (logits - logits_tp).abs().max() assert error < 5e-4, f"logits don't match for normal and tensor parallel model, error is ({error})" diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py index 184f94af5..acf3c899d 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py @@ -20,7 +20,6 @@ class TensorParallelTest(TestCommons): TestCommons.get_attention_implementations(), TestCommons.get_dtypes(), [False, True], - [False, True], ) ) @TestCommons.slow_test @@ -29,7 +28,6 @@ def test_tensor_parallel_forward( position_embedding_type: str, attention_implementation: str, dtype: torch.dtype, - use_padding_free_transformer: bool, sequence_parallel: bool, ) -> None: self.skip_test_if_device_unavailable(torch.device("cuda")) @@ -40,7 +38,7 @@ def test_tensor_parallel_forward( ]: self.skipTest("skipping test since running all takes too long") - if use_padding_free_transformer and attention_implementation != "flash_attention_2": + if attention_implementation != "flash_attention_2": self.skipTest("skipping test since flash attention is needed for padding free transformer") gpus_per_node = torch.cuda.device_count() @@ -62,9 +60,6 @@ def test_tensor_parallel_forward( tmp_path, ] - if use_padding_free_transformer: - command.append("--use-padding-free-transformer") - if sequence_parallel: command.append("--sequence-parallel") diff --git a/tests/hf_models/single_gpu/gpt_base_test.py b/tests/hf_models/single_gpu/gpt_base_test.py index 579437443..aa5956eec 100644 --- a/tests/hf_models/single_gpu/gpt_base_test.py +++ b/tests/hf_models/single_gpu/gpt_base_test.py @@ -2,8 +2,6 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -import itertools - import torch from parameterized import parameterized from transformers import set_seed @@ -20,63 +18,23 @@ class GPTBaseAttentionTest(TestCommons): @parameterized.expand( TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] - ) - ) - def test_sdpa_padding_free_transformer_equivalence( - self, device: torch.device, position_embedding_type: str, dtype: torch.dtype - ) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(position_embedding_type, num_layers=1) - - sdpa_model = self.from_config(config, dtype=dtype).to(device) - flash_model = self.from_config(config, dtype=dtype, use_padding_free_transformer=True).to(device) - - sdpa_model.eval() - flash_model.eval() - - flash_model.load_state_dict(sdpa_model.state_dict()) - - input_ids, attention_mask, labels = self.get_dummy_inputs(device) - sdpa_output = sdpa_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - attention_mask = attention_mask.to(torch.bool) - sdpa_logits = sdpa_output.logits - sdpa_logits = torch.cat([sdpa_logits[i, ex, :] for i, ex in enumerate(attention_mask)]) - sdpa_loss = sdpa_output.loss - - with enable_kernels([Kernel.flash_attention_2]): - input_ids, attention_mask, labels = self.get_dummy_inputs(device, return_list=True) - flash_output = flash_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - flash_logits = flash_output.logits - flash_loss = flash_output.loss - - self.assert_equal_tensors( - sdpa_logits, - flash_logits, - False, - rtol_float16=1e-3, - atol_float16=3e-4, - rtol_bfloat16=5e-3, - atol_bfloat16=5e-3, - ) - self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=1.2e-4, rtol_float32=0) - - @parameterized.expand( - TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] + [torch.device("cuda")], + TestCommons.get_position_embedding_types(), + [torch.float16, torch.bfloat16], + [False, True], ) ) def test_sdpa_flash_attention_equivalence( - self, device: torch.device, position_embedding_type: str, dtype: torch.dtype + self, device: torch.device, position_embedding_type: str, dtype: torch.dtype, has_attention_mask: bool ) -> None: self.skip_test_if_device_unavailable(device) set_seed(SEED) input_ids, attention_mask, labels = self.get_dummy_inputs(device) + if not has_attention_mask: + attention_mask = None + config = self.get_dense_test_config(position_embedding_type, num_layers=1) model = self.from_config(config, dtype=dtype).to(device) @@ -106,88 +64,6 @@ def test_sdpa_flash_attention_equivalence( ) self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=1.2e-4, rtol_float32=0) - @parameterized.expand( - TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] - ) - ) - def test_padding_free_transformer_with_list_and_tensor( - self, device: torch.device, position_embedding_type: str, dtype: torch.dtype - ) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(position_embedding_type, num_layers=1) - - model = self.from_config(config, dtype=dtype, use_padding_free_transformer=True).to(device) - model.eval() - - with enable_kernels([Kernel.flash_attention_2]): - input_ids, attention_mask, labels = self.get_dummy_inputs(device, return_list=True) - list_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - list_logits = list_output.logits - list_loss = list_output.loss - - seqlens = torch.tensor([0] + [len(i) for i in input_ids]) - cu_seqlens = seqlens.cumsum(dim=-1).to(device, torch.int32) - max_seqlen = seqlens.max().item() - position_ids = torch.tensor( - list(itertools.chain(*[list(range(len(i))) for i in input_ids])), device=device - ) - input_ids = torch.tensor(list(itertools.chain(*input_ids)), device=device) - labels = torch.tensor(list(itertools.chain(*labels)), device=device) - tensor_output = model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - tensor_logits = tensor_output.logits - tensor_loss = tensor_output.loss - - self.assert_equal_tensors(list_logits, tensor_logits, True) - self.assert_equal_tensors(list_loss, tensor_loss, True) - - @parameterized.expand( - TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] - ) - ) - def test_sdpa_flash_enabled(self, device: torch.device, position_embedding_type: str, dtype: torch.dtype) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(position_embedding_type, num_layers=1) - - model = self.from_config(config, dtype=dtype).to(device) - model.eval() - - input_ids, _, labels = self.get_dummy_inputs(device) - attention_mask = torch.ones_like(input_ids, dtype=torch.int, device=device) - - sdpa_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - sdpa_logits = sdpa_output.logits - sdpa_loss = sdpa_output.loss - - flash_output = model(input_ids=input_ids, labels=labels) - flash_logits = flash_output.logits - flash_loss = flash_output.loss - - self.assert_equal_tensors( - sdpa_logits, - flash_logits, - False, - rtol_float16=1e-3, - atol_float16=3e-4, - rtol_bfloat16=5e-3, - atol_bfloat16=5e-3, - ) - self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=3.8e-4, rtol_float32=0) - @parameterized.expand( TestCommons.make_args_matrix( [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] diff --git a/tests/hf_models/single_gpu/multihead_latent_attention_test.py b/tests/hf_models/single_gpu/multihead_latent_attention_test.py deleted file mode 100644 index 4accd572b..000000000 --- a/tests/hf_models/single_gpu/multihead_latent_attention_test.py +++ /dev/null @@ -1,137 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch -from parameterized import parameterized -from transformers import set_seed - -from lm_engine.enums import Kernel -from lm_engine.hf_models import GPTBaseConfig -from lm_engine.kernels import enable_kernels - -from ..test_common import TestCommons - - -SEED = 1234 - - -class MultiHeadLatentAttentionTest(TestCommons): - @parameterized.expand(TestCommons.make_args_matrix([torch.device("cuda")], [torch.float16, torch.bfloat16])) - def test_sdpa_padding_free_transformer_equivalence(self, device: torch.device, dtype: torch.dtype) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(num_layers=1) - - sdpa_model = self.from_config(config, dtype=dtype).to(device) - flash_model = self.from_config(config, dtype=dtype, use_padding_free_transformer=True).to(device) - - sdpa_model.eval() - flash_model.eval() - - flash_model.load_state_dict(sdpa_model.state_dict()) - - input_ids, attention_mask, labels = self.get_dummy_inputs(device) - sdpa_output = sdpa_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - attention_mask = attention_mask.to(torch.bool) - sdpa_logits = sdpa_output.logits - sdpa_logits = torch.cat([sdpa_logits[i, ex, :] for i, ex in enumerate(attention_mask)]) - sdpa_loss = sdpa_output.loss - - with enable_kernels([Kernel.flash_attention_2]): - input_ids, attention_mask, labels = self.get_dummy_inputs(device, return_list=True) - flash_output = flash_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - flash_logits = flash_output.logits - flash_loss = flash_output.loss - - self.assert_equal_tensors( - sdpa_logits, - flash_logits, - False, - rtol_float16=1e-3, - atol_float16=3e-4, - rtol_bfloat16=5e-3, - atol_bfloat16=5e-3, - ) - self.assert_equal_tensors(sdpa_loss, flash_loss, False) - - @parameterized.expand(TestCommons.make_args_matrix([torch.device("cuda")], [torch.float16, torch.bfloat16])) - def test_sdpa_flash_attention_equivalence(self, device: torch.device, dtype: torch.dtype) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - input_ids, attention_mask, labels = self.get_dummy_inputs(device) - config = self.get_dense_test_config(num_layers=1) - - model = self.from_config(config, dtype=dtype).to(device) - model.eval() - - sdpa_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - sdpa_logits = sdpa_output.logits - sdpa_loss = sdpa_output.loss - - with enable_kernels([Kernel.flash_attention_2]): - flash_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - flash_logits = flash_output.logits - flash_loss = flash_output.loss - - # we don't care about what happens on masked values (they don't match btw) - sdpa_logits[attention_mask == 0] = 0 - flash_logits[attention_mask == 0] = 0 - - self.assert_equal_tensors( - sdpa_logits[attention_mask], - flash_logits[attention_mask], - False, - rtol_float16=1e-3, - atol_float16=3e-4, - rtol_bfloat16=5e-3, - atol_bfloat16=5e-3, - ) - self.assert_equal_tensors(sdpa_loss, flash_loss, False) - - @staticmethod - def get_dense_test_config( - num_layers: int = 8, - add_bias: bool = True, - activation_function: str = "gelu_pytorch_tanh", - normalization_function: str = "layernorm", - m_emb: float = None, - m_width: float = None, - m_residual: float = None, - attention_multiplier: float = None, - ) -> GPTBaseConfig: - return GPTBaseConfig( - vocab_size=2048, - max_position_embeddings=1024, - hidden_size=32, - num_layers=num_layers, - position_embedding_type="nope", - normalization_function=normalization_function, - tie_word_embeddings=False, - bos_token_id=0, - eos_token_id=1, - pad_token_id=2, - m_emb=m_emb, - m_width=m_width, - m_residual=m_residual, - sequence_mixer_blocks=[ - { - "sequence_mixer_type": "multihead_latent_attention", - "add_bias": add_bias, - "attention_multiplier": attention_multiplier, - "num_attention_heads": 4, - "query_compression_size": 12, - "key_value_compression_size": 8, - "head_dim": 8, - } - for _ in range(num_layers) - ], - mlp_blocks=[ - {"mlp_type": "MLP", "activation_function": activation_function, "add_bias": add_bias} - for _ in range(num_layers) - ], - ) diff --git a/tests/hf_models/single_gpu/typecheck_test.py b/tests/hf_models/single_gpu/typecheck_test.py deleted file mode 100644 index ded25bb6a..000000000 --- a/tests/hf_models/single_gpu/typecheck_test.py +++ /dev/null @@ -1,29 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch -from parameterized import parameterized - -from lm_engine.enums import Kernel -from lm_engine.kernels import enable_kernels - -from ..test_common import TestCommons - - -class TypeCheckTest(TestCommons): - @parameterized.expand(TestCommons.make_args_matrix([torch.device("cuda")])) - def test_no_attention_mask_flash_attention(self, device: torch.device) -> None: - self.skip_test_if_device_unavailable(device) - - config = self.get_dense_test_config( - position_embedding_type="learned_absolute", num_layers=8, num_attention_heads=32 - ) - model = self.from_config(config, use_padding_free_transformer=True).to(device) - model.eval() - - input_ids, _, labels = self.get_dummy_inputs(device, return_list=True) - attention_mask = [[1] * len(i) for i in input_ids] - - with enable_kernels([Kernel.flash_attention_2]): - self.assertRaises(AssertionError, model, input_ids=input_ids, attention_mask=attention_mask, labels=labels) diff --git a/tests/hf_models/test_common.py b/tests/hf_models/test_common.py index cbc4152b1..f2a7cb36e 100644 --- a/tests/hf_models/test_common.py +++ b/tests/hf_models/test_common.py @@ -269,19 +269,8 @@ def compare_saved_models(path1: str, path2: str) -> bool: return False def from_config(self, config: AutoConfig, **kwargs) -> AutoModelForCausalLM: - use_padding_free_transformer = kwargs.pop("use_padding_free_transformer", False) - - model = AutoModelForCausalLM.from_config( - config, - use_padding_free_transformer=use_padding_free_transformer, - dtype=kwargs.pop("dtype", None), - ) - - if use_padding_free_transformer: - assert model.use_padding_free_transformer - + model = AutoModelForCausalLM.from_config(config, dtype=kwargs.pop("dtype", None)) assert len(kwargs) == 0 - return model def assert_equal_tensors( diff --git a/tests/training/dataloader_test.py b/tests/training/dataloader_test.py index a3884073f..b98c50ca7 100644 --- a/tests/training/dataloader_test.py +++ b/tests/training/dataloader_test.py @@ -53,7 +53,6 @@ def test_dataloader_has_correct_order(self) -> None: use_output=True, loss_mask=args.training_parameters.loss_mask, eos_token_id=tokenizer.eos_token_id, - use_padding_free_transformer=args.model_args.use_padding_free_transformer, device="cpu", ), )