From 31c40e22c599ceea6f8be60d5021da73bf5f88c1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 02:57:09 -0700 Subject: [PATCH 001/177] cleanup configs Signed-off-by: Mayank Mishra --- configs/distillation-example.yml | 1 - configs/finetuning-example.yml | 1 - configs/pretraining-examples/dense/pretrain-1.yml | 1 - configs/pretraining-examples/dense/pretrain-2.yml | 1 - configs/pretraining-examples/dense/pretrain-3.yml | 1 - configs/pretraining-examples/dense/pretrain-4.yml | 1 - configs/pretraining-examples/dense/pretrain-fsdp.yml | 1 - configs/research/cross-layer-attention/base.yml | 1 - configs/research/cross-layer-attention/cla.yml | 1 - configs/research/ladder-residual/1b-base.yml | 1 - configs/research/ladder-residual/1b-ladder.yml | 1 - configs/research/ladder-residual/1b-parallel.yml | 1 - configs/research/ladder-residual/3b-base.yml | 1 - configs/research/ladder-residual/3b-ladder.yml | 1 - configs/research/ladder-residual/3b-parallel.yml | 1 - 15 files changed, 15 deletions(-) diff --git a/configs/distillation-example.yml b/configs/distillation-example.yml index 05a697969..ed80a4c70 100644 --- a/configs/distillation-example.yml +++ b/configs/distillation-example.yml @@ -23,7 +23,6 @@ model_args: model_class: AutoModelForCausalLM model_name: ibm/PowerLM-3b efficient_initialization: false - use_padding_free_transformer: false teacher_args: model_class: AutoModelForCausalLM diff --git a/configs/finetuning-example.yml b/configs/finetuning-example.yml index 2c99e9005..74bf993c9 100644 --- a/configs/finetuning-example.yml +++ b/configs/finetuning-example.yml @@ -25,7 +25,6 @@ model_args: # padding free transformer needs a gpt_base model. # To convert granite models to this class and convert back after training, # take a look at the readme of this repo - use_padding_free_transformer: false random_args: # for replication of experiment (however, flash attention is non-deterministic so replication generally won't work) diff --git a/configs/pretraining-examples/dense/pretrain-1.yml b/configs/pretraining-examples/dense/pretrain-1.yml index 17d127514..7bb861374 100644 --- a/configs/pretraining-examples/dense/pretrain-1.yml +++ b/configs/pretraining-examples/dense/pretrain-1.yml @@ -120,7 +120,6 @@ model_args: intermediate_size: 3072 add_bias: true position_embedding_type: learned_absolute - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/pretraining-examples/dense/pretrain-2.yml b/configs/pretraining-examples/dense/pretrain-2.yml index 371efc271..63b8dd9d3 100644 --- a/configs/pretraining-examples/dense/pretrain-2.yml +++ b/configs/pretraining-examples/dense/pretrain-2.yml @@ -125,7 +125,6 @@ model_args: intermediate_size: 3072 add_bias: true position_embedding_type: learned_absolute - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/pretraining-examples/dense/pretrain-3.yml b/configs/pretraining-examples/dense/pretrain-3.yml index ff138a2cb..2a10695c0 100644 --- a/configs/pretraining-examples/dense/pretrain-3.yml +++ b/configs/pretraining-examples/dense/pretrain-3.yml @@ -138,7 +138,6 @@ model_args: intermediate_size: 3072 add_bias: true position_embedding_type: learned_absolute - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/pretraining-examples/dense/pretrain-4.yml b/configs/pretraining-examples/dense/pretrain-4.yml index 629956dcf..5c8b64c75 100644 --- a/configs/pretraining-examples/dense/pretrain-4.yml +++ b/configs/pretraining-examples/dense/pretrain-4.yml @@ -177,7 +177,6 @@ model_args: intermediate_size: 3072 add_bias: true position_embedding_type: learned_absolute - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/pretraining-examples/dense/pretrain-fsdp.yml b/configs/pretraining-examples/dense/pretrain-fsdp.yml index ab75e0908..1e429747b 100755 --- a/configs/pretraining-examples/dense/pretrain-fsdp.yml +++ b/configs/pretraining-examples/dense/pretrain-fsdp.yml @@ -132,7 +132,6 @@ model_args: intermediate_size: 3072 add_bias: true position_embedding_type: learned_absolute - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/research/cross-layer-attention/base.yml b/configs/research/cross-layer-attention/base.yml index 86dd56916..2688a968c 100644 --- a/configs/research/cross-layer-attention/base.yml +++ b/configs/research/cross-layer-attention/base.yml @@ -249,7 +249,6 @@ model_args: activation_function: swiglu intermediate_size: 8192 efficient_initialization: false - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/research/cross-layer-attention/cla.yml b/configs/research/cross-layer-attention/cla.yml index 867e7c232..a5ee1f63c 100644 --- a/configs/research/cross-layer-attention/cla.yml +++ b/configs/research/cross-layer-attention/cla.yml @@ -282,7 +282,6 @@ model_args: activation_function: swiglu intermediate_size: 8192 efficient_initialization: false - use_padding_free_transformer: true tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/1b-base.yml b/configs/research/ladder-residual/1b-base.yml index 29cee68f1..9ddee7075 100644 --- a/configs/research/ladder-residual/1b-base.yml +++ b/configs/research/ladder-residual/1b-base.yml @@ -278,7 +278,6 @@ model_args: activation_function: swiglu intermediate_size: 4096 efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/1b-ladder.yml b/configs/research/ladder-residual/1b-ladder.yml index f3fb1f56a..66f787d69 100644 --- a/configs/research/ladder-residual/1b-ladder.yml +++ b/configs/research/ladder-residual/1b-ladder.yml @@ -278,7 +278,6 @@ model_args: activation_function: swiglu intermediate_size: 4096 efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/1b-parallel.yml b/configs/research/ladder-residual/1b-parallel.yml index 4959b0378..c843e57fe 100644 --- a/configs/research/ladder-residual/1b-parallel.yml +++ b/configs/research/ladder-residual/1b-parallel.yml @@ -278,7 +278,6 @@ model_args: activation_function: swiglu intermediate_size: 4096 efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/3b-base.yml b/configs/research/ladder-residual/3b-base.yml index be1bccb68..50ddee10c 100644 --- a/configs/research/ladder-residual/3b-base.yml +++ b/configs/research/ladder-residual/3b-base.yml @@ -238,7 +238,6 @@ model_args: - mlp_type: MLP activation_function: swiglu efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/3b-ladder.yml b/configs/research/ladder-residual/3b-ladder.yml index 81b485e73..10c2fac95 100644 --- a/configs/research/ladder-residual/3b-ladder.yml +++ b/configs/research/ladder-residual/3b-ladder.yml @@ -238,7 +238,6 @@ model_args: - mlp_type: MLP activation_function: swiglu efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining diff --git a/configs/research/ladder-residual/3b-parallel.yml b/configs/research/ladder-residual/3b-parallel.yml index fb14fc5b8..9db78544d 100644 --- a/configs/research/ladder-residual/3b-parallel.yml +++ b/configs/research/ladder-residual/3b-parallel.yml @@ -238,7 +238,6 @@ model_args: - mlp_type: MLP activation_function: swiglu efficient_initialization: false - use_padding_free_transformer: false tuning_args: tuning_method: pretraining From 6191d014b7f4f02012fbc7594f8c868bce7afcbc Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 03:00:53 -0700 Subject: [PATCH 002/177] cleanup configs Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 4 +--- .../hf_models/modeling_utils/mlp_blocks/__init__.py | 3 +-- lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py | 9 ++------- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index a0c3541cd..6af3cd439 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -29,9 +29,7 @@ def __init__( self.ln_2 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) - self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx - ) + self.mlp_block = get_mlp_block(config, layer_idx=layer_idx) def forward( self, diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index 802f431d1..4cb078c48 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -7,7 +7,7 @@ from .moe import MoE, ParameterizedExperts -def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int) -> MLP | MoE: +def get_mlp_block(config: CommonConfig, layer_idx: int) -> MLP | MoE: block = config.mlp_blocks[layer_idx] mlp_type = block.mlp_type @@ -33,7 +33,6 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye normalized_topk=block.normalized_topk, num_experts=block.num_experts, num_experts_per_tok=block.num_experts_per_tok, - use_padding_free_transformer=use_padding_free_transformer, ) else: raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})") diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index c624b9467..b5136befa 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -167,13 +167,11 @@ def __init__( initializer_range: float, m_width: float, num_layers: int, - use_padding_free_transformer: bool, ) -> MoE: super().__init__() self.num_experts = num_experts self.top_k = num_experts_per_tok - self.use_padding_free_transformer = use_padding_free_transformer self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.shared_intermediate_size = shared_intermediate_size @@ -247,8 +245,7 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_proj_shared.weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if not self.use_padding_free_transformer: - batch_size, sequence_length, _ = hidden_states.shape + original_shape = hidden_states.size() hidden_states = hidden_states.view(-1, self.hidden_size) @@ -263,9 +260,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: del moe_output - if not self.use_padding_free_transformer: - hidden_states = hidden_states.reshape(batch_size, sequence_length, self.hidden_size) - + hidden_states = hidden_states.reshape(*original_shape) hidden_states = self.dropout(hidden_states) aux_loss = ( From 1b153ea8ab9bf4750c3790712e04126393c6e858 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 03:13:48 -0700 Subject: [PATCH 003/177] cleanup TP test Signed-off-by: Mayank Mishra --- .../modeling_utils/mlp_blocks/moe.py | 1 - .../modeling_utils/position_embedding/rope.py | 7 +--- .../hf_models/models/gpt_crosslayer/layer.py | 4 +-- .../tensor_parallel_forward.py | 36 ++++++++----------- .../tensor_parallel_forward_test.py | 8 ++--- 5 files changed, 18 insertions(+), 38 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index b5136befa..7a3ae211b 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -246,7 +246,6 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: original_shape = hidden_states.size() - hidden_states = hidden_states.view(-1, self.hidden_size) router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states) diff --git a/lm_engine/hf_models/modeling_utils/position_embedding/rope.py b/lm_engine/hf_models/modeling_utils/position_embedding/rope.py index 3aa8b21f3..86426193a 100644 --- a/lm_engine/hf_models/modeling_utils/position_embedding/rope.py +++ b/lm_engine/hf_models/modeling_utils/position_embedding/rope.py @@ -13,12 +13,7 @@ class RoPE(nn.Module): - def __init__( - self, - head_dim: int, - max_position_embeddings: int = 2048, - base: int = 10000, - ) -> RoPE: + def __init__(self, head_dim: int, max_position_embeddings: int = 2048, base: int = 10000) -> RoPE: super().__init__() self.head_dim = head_dim diff --git a/lm_engine/hf_models/models/gpt_crosslayer/layer.py b/lm_engine/hf_models/models/gpt_crosslayer/layer.py index 3413911ff..4c8a663b9 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/layer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/layer.py @@ -52,9 +52,7 @@ def __init__( self.ln_2 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) - self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx - ) + self.mlp_block = get_mlp_block(config, layer_idx=layer_idx) def forward( self, diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py index 6a8e919c8..27d703a5e 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py @@ -22,7 +22,6 @@ parser.add_argument("--attention-implementation", type=str) parser.add_argument("--dtype", type=str) parser.add_argument("--tmp-path", type=str) -parser.add_argument("--use-padding-free-transformer", action="store_true") parser.add_argument("--sequence-parallel", action="store_true") args = parser.parse_args() @@ -81,9 +80,7 @@ # try sharding vocab matrices if really struggling for memory model_tp = get_model_parallel_class(config.model_type)._from_config( - config, - use_padding_free_transformer=args.use_padding_free_transformer, - sequence_parallel=args.sequence_parallel, + config, use_padding_free_transformer=True, sequence_parallel=args.sequence_parallel ) # copy to device without copying storage @@ -109,21 +106,18 @@ 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False ) -if args.use_padding_free_transformer: - cu_seqlens = torch.arange( - 0, input_ids.numel() + 1, sequence_length, dtype=torch.int32, device=torch.cuda.current_device() - ) - position_ids = torch.arange(0, sequence_length, 1, device=torch.cuda.current_device()).repeat(batch_size) - - output_tp = model_tp( - input_ids=input_ids.view(-1), - labels=labels.view(-1), - cu_seqlens=cu_seqlens, - max_seqlen=sequence_length, - position_ids=position_ids, - ) -else: - output_tp = model_tp(input_ids=input_ids, labels=labels) +cu_seqlens = torch.arange( + 0, input_ids.numel() + 1, sequence_length, dtype=torch.int32, device=torch.cuda.current_device() +) +position_ids = torch.arange(0, sequence_length, 1, device=torch.cuda.current_device()).repeat(batch_size) + +output_tp = model_tp( + input_ids=input_ids.view(-1), + labels=labels.view(-1), + cu_seqlens=cu_seqlens, + max_seqlen=sequence_length, + position_ids=position_ids, +) loss_tp = output_tp.loss logits_tp = output_tp.logits[..., : config.vocab_size] @@ -135,9 +129,7 @@ loss = output.loss logits = output.logits - - if args.use_padding_free_transformer: - logits_tp = logits_tp.reshape(batch_size, sequence_length, -1) + logits_tp = logits_tp.reshape(batch_size, sequence_length, -1) error = (logits - logits_tp).abs().max() assert error < 5e-4, f"logits don't match for normal and tensor parallel model, error is ({error})" diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py index 184f94af5..3a3423f1f 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py @@ -20,7 +20,6 @@ class TensorParallelTest(TestCommons): TestCommons.get_attention_implementations(), TestCommons.get_dtypes(), [False, True], - [False, True], ) ) @TestCommons.slow_test @@ -29,7 +28,6 @@ def test_tensor_parallel_forward( position_embedding_type: str, attention_implementation: str, dtype: torch.dtype, - use_padding_free_transformer: bool, sequence_parallel: bool, ) -> None: self.skip_test_if_device_unavailable(torch.device("cuda")) @@ -40,7 +38,7 @@ def test_tensor_parallel_forward( ]: self.skipTest("skipping test since running all takes too long") - if use_padding_free_transformer and attention_implementation != "flash_attention_2": + if attention_implementation != "flash_attention_2": self.skipTest("skipping test since flash attention is needed for padding free transformer") gpus_per_node = torch.cuda.device_count() @@ -60,11 +58,9 @@ def test_tensor_parallel_forward( attention_implementation, "--tmp-path", tmp_path, + "--use-padding-free-transformer", ] - if use_padding_free_transformer: - command.append("--use-padding-free-transformer") - if sequence_parallel: command.append("--sequence-parallel") From 94f4020fc9db5ef885a6abcf92e2290d381f3620 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 03:16:17 -0700 Subject: [PATCH 004/177] add SWA Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 26 ++++++------------------ 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 3be5b9568..238f1598a 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -59,26 +59,12 @@ def prepare_inputs_for_model( use_cache: bool, ) -> tuple[torch.Tensor]: if self.use_padding_free_transformer: - if isinstance(input_ids, list): - # this is managed internally - error_message = ( - "{variable} should not be passed for flash attention when using List[List[int]] " - "input types attention mask logic is handled internally" - ) - assert cu_seqlens is None, error_message.format(variable="cu_seqlens") - assert max_seqlen is None, error_message.format(variable="max_seqlen") - assert attention_mask is None, error_message.format(variable="attention_mask") - - input_ids, position_ids, labels, cu_seqlens, max_seqlen = convert_padding_free_lists_to_tensors( - input_ids=input_ids, position_ids=position_ids, labels=labels, device=torch.cuda.current_device() - ) - else: - assert ( - cu_seqlens is not None - ), "cu_seqlens needs to be specified when using tensor inputs with padding_free transformer" - assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert max_seqlen is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" + assert ( + cu_seqlens is not None + ), "cu_seqlens needs to be specified when using tensor inputs with padding_free transformer" + assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" + assert max_seqlen is not None, "max_seqlen needs to be specified when specifying cu_seqlens" + assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" if use_cache or past_key_values is not None: raise NotImplementedError("KV caching is not supported with padding_free transformer") From bd46de7a464dfdb3b453b19f5b5f03c601300206 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 03:18:11 -0700 Subject: [PATCH 005/177] add SWA Signed-off-by: Mayank Mishra --- README.md | 2 +- .../multi_gpu/tensor_parallel/tensor_parallel_forward_test.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 7adcc76b0..173e6a0bd 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ labels = [[-100, -100, -100, 4, 5, 0], [-100, -100, 8, 0]] # this will throw a warning saying that the model is of gpt_bigcode class # ignore the warning -model = GPTBaseForCausalLM.from_pretrained(, use_padding_free_transformer=True).cuda() +model = GPTBaseForCausalLM.from_pretrained().cuda() with enable_kernels([Kernel.flash_attention_2]): loss = model(input_ids=input_ids, labels=labels).loss diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py index 3a3423f1f..acf3c899d 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py @@ -58,7 +58,6 @@ def test_tensor_parallel_forward( attention_implementation, "--tmp-path", tmp_path, - "--use-padding-free-transformer", ] if sequence_parallel: From 615e105bfcd40c3f59f3307896fba3fc7a14888e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 03:22:51 -0700 Subject: [PATCH 006/177] dim in FA Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/attention.py | 1 - .../multihead_latent_attention.py | 1 - .../utils/flash_attention_utils.py | 64 +++++++++---------- .../sequence_mixer_blocks/attention.py | 1 - .../gpt_crosslayer/sequence_mixers/base.py | 2 - 5 files changed, 32 insertions(+), 37 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 0d784e2c3..ef74c4637 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -200,7 +200,6 @@ def forward( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask, - use_padding_free_transformer=self.use_padding_free_transformer, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py index b9a9f414e..9bec13d46 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py @@ -176,7 +176,6 @@ def forward( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask, - use_padding_free_transformer=self.use_padding_free_transformer, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py index b9f43b198..a04353d13 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py @@ -58,7 +58,6 @@ def flash_attention( attention_mask: torch.Tensor | None, cu_seqlens: torch.Tensor | None, max_seqlen: int | None, - use_padding_free_transformer: bool, causal: bool, dropout: float = 0, softmax_scale: float | None = None, @@ -73,42 +72,14 @@ def flash_attention( assert use_flash_attention_3 or use_flash_attention_2, "enable flash_attention_2 or flash_attention_3" - if use_padding_free_transformer: - assert use_flash_attention_3 or use_flash_attention_2 - window_size = (-1, -1) if sliding_window is not None and key.size(1) > sliding_window: window_size = (sliding_window, sliding_window) - if use_padding_free_transformer: - assert sliding_window is None + if cu_seqlens is None: + assert max_seqlen is None + assert query.dim() == 4 - if use_flash_attention_3: - attn_output, _ = flash_attention_3_varlen( - q=query, - k=key, - v=value, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attention_2_varlen( - q=query, - k=key, - v=value, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: if attention_mask is None: if use_flash_attention_3: attn_output, _ = flash_attention_3( @@ -173,5 +144,34 @@ def flash_attention( cu_seqlens=cu_seqlens_q, output_shape=(batch_size, query_length, num_heads, head_dim), ) + else: + assert sliding_window is None + assert query.dim() == 3 + + if use_flash_attention_3: + attn_output, _ = flash_attention_3_varlen( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attention_2_varlen( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) return attn_output diff --git a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py index 85d030be6..131d02d42 100644 --- a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py @@ -195,7 +195,6 @@ def forward( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask, - use_padding_free_transformer=self.use_padding_free_transformer, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py index 4e99b775c..df65ced1d 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py @@ -103,7 +103,6 @@ def forward( attention_mask=attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - use_padding_free_transformer=self.use_padding_free_transformer, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, @@ -131,7 +130,6 @@ def forward( attention_mask=attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - use_padding_free_transformer=self.use_padding_free_transformer, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, From 8d9f7d496f7d678778bc62238718617984cd5b26 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 03:24:40 -0700 Subject: [PATCH 007/177] dim in FA Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 26 ++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 238f1598a..3be5b9568 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -59,12 +59,26 @@ def prepare_inputs_for_model( use_cache: bool, ) -> tuple[torch.Tensor]: if self.use_padding_free_transformer: - assert ( - cu_seqlens is not None - ), "cu_seqlens needs to be specified when using tensor inputs with padding_free transformer" - assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert max_seqlen is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" + if isinstance(input_ids, list): + # this is managed internally + error_message = ( + "{variable} should not be passed for flash attention when using List[List[int]] " + "input types attention mask logic is handled internally" + ) + assert cu_seqlens is None, error_message.format(variable="cu_seqlens") + assert max_seqlen is None, error_message.format(variable="max_seqlen") + assert attention_mask is None, error_message.format(variable="attention_mask") + + input_ids, position_ids, labels, cu_seqlens, max_seqlen = convert_padding_free_lists_to_tensors( + input_ids=input_ids, position_ids=position_ids, labels=labels, device=torch.cuda.current_device() + ) + else: + assert ( + cu_seqlens is not None + ), "cu_seqlens needs to be specified when using tensor inputs with padding_free transformer" + assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" + assert max_seqlen is not None, "max_seqlen needs to be specified when specifying cu_seqlens" + assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" if use_cache or past_key_values is not None: raise NotImplementedError("KV caching is not supported with padding_free transformer") From 6b4d1dc222066b4976a3abf8cb07460e1c7e5180 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 03:28:31 -0700 Subject: [PATCH 008/177] drop SBA Signed-off-by: Mayank Mishra --- lm_engine/hf_models/cache/__init__.py | 1 - lm_engine/hf_models/config/__init__.py | 3 - lm_engine/hf_models/config/sequence_mixer.py | 12 - lm_engine/hf_models/mixins/dense/layer.py | 2 +- .../sequence_mixer_blocks/__init__.py | 6 - .../stickbreaking_attention.py | 235 ------------------ lm_engine/train_utils.py | 2 +- lm_engine/utils/__init__.py | 1 - lm_engine/utils/packages.py | 20 -- 9 files changed, 2 insertions(+), 280 deletions(-) delete mode 100644 lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/stickbreaking_attention.py diff --git a/lm_engine/hf_models/cache/__init__.py b/lm_engine/hf_models/cache/__init__.py index 2166867d7..7bbb5a15e 100644 --- a/lm_engine/hf_models/cache/__init__.py +++ b/lm_engine/hf_models/cache/__init__.py @@ -21,7 +21,6 @@ "multihead_latent_attention": _SoftmaxAttentionCache, "rnn": _RNNCache, "softmax_attention": _SoftmaxAttentionCache, - "stickbreaking_attention": _SoftmaxAttentionCache, } CACHE_TYPE = torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None diff --git a/lm_engine/hf_models/config/__init__.py b/lm_engine/hf_models/config/__init__.py index 608e0ab80..1e2b26767 100644 --- a/lm_engine/hf_models/config/__init__.py +++ b/lm_engine/hf_models/config/__init__.py @@ -18,7 +18,6 @@ _MultiHeadLatentAttentionArgs, _RNNArgs, _SoftmaxAttentionArgs, - _StickbreakingAttentionArgs, ) @@ -71,7 +70,6 @@ def _update_with_key_value(block: dict, kwargs: dict, key: str) -> None: "mamba2": _Mamba2Args, "multihead_latent_attention": _MultiHeadLatentAttentionArgs, "rnn": _RNNArgs, - "stickbreaking_attention": _StickbreakingAttentionArgs, "softmax_attention": _SoftmaxAttentionArgs, } @@ -214,7 +212,6 @@ def _set_sequence_mixer_blocks(self) -> None: | _MultiHeadLatentAttentionArgs | _RNNArgs | _SoftmaxAttentionArgs - | _StickbreakingAttentionArgs ] = [] for i in range(self.num_layers): sequence_mixer_block = deepcopy(self.sequence_mixer_blocks[i]) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 08ef1ca6e..a3708a71c 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -49,18 +49,6 @@ def model_post_init(self, __context: Any) -> None: assert self.head_dim is not None -class _StickbreakingAttentionArgs(BaseArgs): - sequence_mixer_type: str = "stickbreaking_attention" - num_attention_heads: int = 12 - num_key_value_heads: int = 1 - dropout: float = 0 - add_bias: bool = False - attention_multiplier: float | None = None - - def model_post_init(self, __context: Any) -> None: - assert self.sequence_mixer_type == "stickbreaking_attention" - - class _Mamba2Args(BaseArgs): sequence_mixer_type: str = "mamba2" state_size: int = 128 diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 6af3cd439..3cf420889 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -78,7 +78,7 @@ def _sequence_mixer_forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, ) -> torch.Tensor: - if self.sequence_mixer_type in ["softmax_attention", "stickbreaking_attention", "multihead_latent_attention"]: + if self.sequence_mixer_type in ["softmax_attention", "multihead_latent_attention"]: hidden_states = self.sequence_mixer( hidden_states, past_key_values=past_key_values, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 1001e2450..de75175ef 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -13,7 +13,6 @@ from .mamba2 import Mamba2 from .multihead_latent_attention import MultiHeadLatentAttention from .rnn import RNN -from .stickbreaking_attention import PaddingFreeSBAttention, SBAttention from .utils import flash_attention @@ -140,10 +139,5 @@ def get_sequence_mixer( softmax_dropout=block.softmax_dropout, use_padding_free_transformer=use_padding_free_transformer, ) - elif sequence_mixer_type == "stickbreaking_attention": - if use_padding_free_transformer: - return PaddingFreeSBAttention(**sequence_mixer_kwargs) - else: - return SBAttention(**sequence_mixer_kwargs) else: raise ValueError(f"unexpected sequence_mixer_type ({sequence_mixer_type})") diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/stickbreaking_attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/stickbreaking_attention.py deleted file mode 100644 index 6eaab0699..000000000 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/stickbreaking_attention.py +++ /dev/null @@ -1,235 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import math - -import torch -import torch.nn -import torch.nn.functional as F - -from ....utils import is_stickbreaking_available -from ...cache import GenerationCache -from .attention import Attention - - -if is_stickbreaking_available(): - from stickbreaking_attention import sb_attn, sb_attn_varlen - - -def decoding_stickbreaking(q, k, v, scale=None): - """ - Stick-breaking attention weights. - """ - if scale is None: - scale = 1 / math.sqrt(q.shape[-1]) - # logits = q @ k[..., :-1, :].transpose(-1, -2) * scale - - assert q.size(2) == 1 - original_dtype = q.dtype - q = q.float() - k = k.float() - logits = q @ k[..., :-1, :].transpose(-1, -2) * scale - log_z = F.logsigmoid(logits).to(original_dtype) - log_beta = F.logsigmoid(-logits).to(original_dtype) - re_cum_log_beta = log_beta.flip(-1).cumsum(dim=-1).flip(-1) - log_beta - log_att = log_z + re_cum_log_beta - att: torch.Tensor = log_att.exp() - v = v[..., :-1, :] - out = torch.einsum("bhij,bhjd->bhid", att, v) - return out, 1 - att.sum(dim=-1) - - -class SBAttention(Attention): - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - num_key_value_heads: int, - attention_multiplier: float, - position_embedding_type: str, - add_bias: bool, - dropout: float, - init_method: str, - initializer_range: float, - m_width: float, - num_layers: int, - causal: bool, - layer_idx: int, - use_padding_free_transformer: bool = False, - ) -> SBAttention: - super().__init__( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - attention_multiplier=attention_multiplier, - position_embedding_type=position_embedding_type, - add_bias=add_bias, - softmax_dropout=0, - dropout=dropout, - init_method=init_method, - initializer_range=initializer_range, - m_width=m_width, - num_layers=num_layers, - causal=causal, - layer_idx=layer_idx, - use_padding_free_transformer=use_padding_free_transformer, - ) - - self.head_bias = torch.nn.Parameter(torch.zeros(self.hidden_size // self.head_dim, self.head_dim)) - self.norm = torch.nn.GroupNorm(self.num_heads, self.hidden_size) - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, - rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - sb_metadata=None, - ) -> torch.Tensor: - # assert past_key_values is None - - query, key, value = self._prepare_qkv_for_forward(hidden_states) - bsz_, _, length_, _ = query.size() - - if query.size(2) == key.size(2): - hidden_states, rem = sb_attn(q=query, k=key, v=value, inv_temp=self.attention_multiplier) - else: - hidden_states, rem = decoding_stickbreaking(q=query, k=key, v=value, scale=self.attention_multiplier) - - hidden_states = hidden_states + rem[..., None] * self.head_bias[None, :, None, :] - - hidden_states = hidden_states.permute(0, 2, 1, 3) - hidden_states = hidden_states.view(bsz_ * length_, self.hidden_size) - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states.view(bsz_, length_, self.hidden_size) - - hidden_states = self.c_proj(hidden_states) - hidden_states = self.dropout(hidden_states) - - return hidden_states - - def _prepare_qkv_for_forward_gqa( - self, hidden_states: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - batch_size, query_length = hidden_states.shape[:-1] - - hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1) - - query, key, value = hidden_states.split( - ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 - ) - - # this needs to be a reshape instead of view sadly - query = query.reshape(batch_size, query_length, -1, self.head_dim) - - key = key.repeat(1, 1, self.num_heads // self.num_key_value_heads, 1) - value = value.repeat(1, 1, self.num_heads // self.num_key_value_heads, 1) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - return query, key, value - - -class PaddingFreeSBAttention(SBAttention): - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - num_key_value_heads: int, - attention_multiplier: float, - position_embedding_type: str, - add_bias: bool, - dropout: float, - init_method: str, - initializer_range: float, - m_width: float, - num_layers: int, - causal: bool, - layer_idx: int, - ) -> PaddingFreeSBAttention: - super().__init__( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - attention_multiplier=attention_multiplier, - position_embedding_type=position_embedding_type, - add_bias=add_bias, - dropout=dropout, - init_method=init_method, - initializer_range=initializer_range, - m_width=m_width, - num_layers=num_layers, - causal=causal, - layer_idx=layer_idx, - use_padding_free_transformer=True, - ) - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, - rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - sb_metadata=None, - ) -> torch.Tensor: - assert past_key_values is None - query, key, value = self._prepare_qkv_for_forward(hidden_states) - - value = value.permute(1, 0, 2) - hidden_states, rem = sb_attn_varlen( - q=query.permute(1, 0, 2), - k=key.permute(1, 0, 2), - v=value, - inv_temp=self.attention_multiplier, - cu_seqlens=cu_seqlens, - max_seqlens=max_seqlen, - ) - hidden_states = hidden_states + rem[..., None] * self.head_bias[:, None, :] - hidden_states = hidden_states.permute(1, 0, 2) - - hidden_states = hidden_states.view(-1, self.hidden_size) - hidden_states = self.norm(hidden_states) - - hidden_states = self.c_proj(hidden_states) - hidden_states = self.dropout(hidden_states) - - return hidden_states - - def _prepare_qkv_for_forward_mha( - self, hidden_states: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - total_q = hidden_states.shape[0] - - hidden_states = hidden_states.view(total_q, self.num_key_value_heads, -1) - query, key, value = hidden_states.chunk(3, dim=-1) - - return query, key, value - - def _prepare_qkv_for_forward_gqa( - self, hidden_states: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - total_q = hidden_states.shape[0] - - hidden_states = hidden_states.view(total_q, self.num_key_value_heads, -1) - - query, key, value = hidden_states.split( - ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 - ) - - # this needs to be a reshape instead of view sadly - query = query.reshape(total_q, -1, self.head_dim) - # key = key.repeat(1, self.num_heads // self.num_key_value_heads, 1) - # value = value.repeat(1, self.num_heads // self.num_key_value_heads, 1) - group_size = self.num_heads // self.num_key_value_heads - key = key.repeat_interleave(repeats=group_size, dim=1) - value = value.repeat_interleave(repeats=group_size, dim=1) - return query, key, value diff --git a/lm_engine/train_utils.py b/lm_engine/train_utils.py index eb7ab961e..3cae0045a 100644 --- a/lm_engine/train_utils.py +++ b/lm_engine/train_utils.py @@ -129,7 +129,7 @@ def get_model_tflops( sequence_mixer_flops += _get_linear_flops( b * s, block.out_channels, h, gradient_checkpointing=gradient_checkpointing_enabled ) - elif sequence_mixer_type in ["softmax_attention", "stickbreaking_attention"]: + elif sequence_mixer_type == "softmax_attention": # QKV projection FLOPs sequence_mixer_flops = _get_linear_flops( b * s, diff --git a/lm_engine/utils/__init__.py b/lm_engine/utils/__init__.py index 50a672fd1..2419d976e 100644 --- a/lm_engine/utils/__init__.py +++ b/lm_engine/utils/__init__.py @@ -17,7 +17,6 @@ is_flash_attention_3_available, is_fma_available, is_mamba_2_ssm_available, - is_stickbreaking_available, is_torchao_available, is_triton_available, is_zstandard_available, diff --git a/lm_engine/utils/packages.py b/lm_engine/utils/packages.py index 04e53e4fc..1d3a33f40 100644 --- a/lm_engine/utils/packages.py +++ b/lm_engine/utils/packages.py @@ -163,22 +163,6 @@ def is_torchao_available() -> bool: return _IS_TORCHAO_AVAILABLE -try: - import stickbreaking_attention - - _IS_STICKBREAKING_AVAILABLE = True -except ImportError: - _IS_STICKBREAKING_AVAILABLE = False - - warn_rank_0( - "stickbreaking-attention is not available, install from https://github.com/shawntan/stickbreaking-attention" - ) - - -def is_stickbreaking_available(): - return _IS_STICKBREAKING_AVAILABLE - - try: import zstandard @@ -189,10 +173,6 @@ def is_stickbreaking_available(): warn_rank_0("zstandard is not available") -def is_zstandard_available(): - return _IS_STICKBREAKING_AVAILABLE - - @run_rank_n def log_environment() -> None: packages = sorted(["{}=={}".format(d.metadata["Name"], d.version) for d in distributions()]) From 3e1aef2a3005b25aae1ab9ee44922186863c0490 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 03:33:33 -0700 Subject: [PATCH 009/177] drop SBA Signed-off-by: Mayank Mishra --- lm_engine/utils/packages.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lm_engine/utils/packages.py b/lm_engine/utils/packages.py index 1d3a33f40..b1fee3221 100644 --- a/lm_engine/utils/packages.py +++ b/lm_engine/utils/packages.py @@ -173,6 +173,10 @@ def is_torchao_available() -> bool: warn_rank_0("zstandard is not available") +def is_zstandard_available(): + return _IS_ZSTANDARD_AVAILABLE + + @run_rank_n def log_environment() -> None: packages = sorted(["{}=={}".format(d.metadata["Name"], d.version) for d in distributions()]) From bd123a88a04d1c27e425c1faf229c0ec977a1273 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 03:36:29 -0700 Subject: [PATCH 010/177] drop SBA Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/__init__.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index de75175ef..f15066f56 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -16,16 +16,7 @@ from .utils import flash_attention -SEQUENCE_MIXER_TYPE = ( - Attention - | CausalConvolution - | GRU - | Mamba2 - | MultiHeadLatentAttention - | RNN - | SBAttention - | PaddingFreeSBAttention -) +SEQUENCE_MIXER_TYPE = Attention | CausalConvolution | GRU | Mamba2 | MultiHeadLatentAttention | RNN def get_sequence_mixer( From aec861b6546dfa9d5694863721f6513db3697727 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 04:16:08 -0700 Subject: [PATCH 011/177] drop SBA Signed-off-by: Mayank Mishra --- a.py | 17 ++++++++ lm_engine/hf_models/tensor.py | 77 +++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 a.py create mode 100644 lm_engine/hf_models/tensor.py diff --git a/a.py b/a.py new file mode 100644 index 000000000..b3d274cdc --- /dev/null +++ b/a.py @@ -0,0 +1,17 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import torch + +from lm_engine.hf_models.tensor import PackedTensor + + +y = torch.randn(5, 4, requires_grad=True) +# Example usage +x = PackedTensor.from_unpacked_tensor(y, batch_size=5) +x.sum().backward() + +print(x) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py new file mode 100644 index 000000000..2840ed3d7 --- /dev/null +++ b/lm_engine/hf_models/tensor.py @@ -0,0 +1,77 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import torch +from fma import pack_sequence, unpack_sequence + + +class PackedTensor(torch.Tensor): + def __new__( + cls, + packed_tensor: torch.Tensor, + original_shape: tuple[int], + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + batch_size: int | None = None, + ) -> PackedTensor: + self = torch.as_tensor(packed_tensor).as_subclass(cls) + + assert batch_size or cu_seqlens + + if batch_size is not None: + assert packed_tensor.size(0) % batch_size == 0 + + self._packed_tensor = packed_tensor + self._original_shape = original_shape + self._cu_seqlens = cu_seqlens + self._max_seqlen = max_seqlen + self._batch_size = cu_seqlens.size(0) - 1 if batch_size is None else batch_size + + return self + + @staticmethod + def from_unpacked_tensor( + unpacked_tensor: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + batch_size: int | None = None, + ) -> PackedTensor: + assert batch_size or cu_seqlens + + if batch_size is None: + batch_size = cu_seqlens.size(0) - 1 + packed_tensor = pack_sequence(inputs=unpacked_tensor, cu_seqlens=cu_seqlens) + else: + assert unpacked_tensor.size(0) % batch_size == 0 + packed_tensor = unpacked_tensor.flatten(0, 1) + + packed_tensor = PackedTensor( + packed_tensor=packed_tensor, + original_shape=unpacked_tensor.size(), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + ) + + return packed_tensor + + def to_unpacked_tensor(self) -> torch.Tensor: + if self._batch_size is None: + unpacked_tensor = unpack_sequence( + inputs=self._packed_tensor, cu_seqlens=self._cu_seqlens, output_shape=self._original_shape + ) + else: + unpacked_tensor = self.view(self._batch_size, -1, *self.size()[1:]) + + return unpacked_tensor + + +y = torch.randn(5, 4, requires_grad=True) +# Example usage +x = PackedTensor.from_unpacked_tensor(y, batch_size=5) +x.sum().backward() + +print(x) From c45e89fcf103d86fc33647144b074b1ddd61b73e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 04:16:38 -0700 Subject: [PATCH 012/177] drop SBA Signed-off-by: Mayank Mishra --- a.py => b.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename a.py => b.py (100%) diff --git a/a.py b/b.py similarity index 100% rename from a.py rename to b.py From 97284d2fb663d8a3fd4949da0443fcf935b0fe12 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 04:17:21 -0700 Subject: [PATCH 013/177] drop SBA Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 2840ed3d7..d4236f41c 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -67,11 +67,3 @@ def to_unpacked_tensor(self) -> torch.Tensor: unpacked_tensor = self.view(self._batch_size, -1, *self.size()[1:]) return unpacked_tensor - - -y = torch.randn(5, 4, requires_grad=True) -# Example usage -x = PackedTensor.from_unpacked_tensor(y, batch_size=5) -x.sum().backward() - -print(x) From 3a94928bbe06934c052dc8098c0c0c0aa5f4033e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 04:22:26 -0700 Subject: [PATCH 014/177] drop SBA Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index d4236f41c..d85941906 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -19,7 +19,7 @@ def __new__( ) -> PackedTensor: self = torch.as_tensor(packed_tensor).as_subclass(cls) - assert batch_size or cu_seqlens + assert batch_size is not None or cu_seqlens is not None if batch_size is not None: assert packed_tensor.size(0) % batch_size == 0 @@ -39,7 +39,7 @@ def from_unpacked_tensor( max_seqlen: int | None = None, batch_size: int | None = None, ) -> PackedTensor: - assert batch_size or cu_seqlens + assert batch_size is not None or cu_seqlens is not None if batch_size is None: batch_size = cu_seqlens.size(0) - 1 From 0b0af130850f78597d17fe263a3333bfb040747a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 04:31:58 -0700 Subject: [PATCH 015/177] drop SBA Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index d85941906..ef6d8b1cc 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -59,11 +59,11 @@ def from_unpacked_tensor( return packed_tensor def to_unpacked_tensor(self) -> torch.Tensor: - if self._batch_size is None: + if self._cu_seqlens is None: + unpacked_tensor = self.view(self._batch_size, -1, *self.size()[1:]) + else: unpacked_tensor = unpack_sequence( inputs=self._packed_tensor, cu_seqlens=self._cu_seqlens, output_shape=self._original_shape ) - else: - unpacked_tensor = self.view(self._batch_size, -1, *self.size()[1:]) return unpacked_tensor From fb01683a449945346d030956bb3eabf892775027 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 04:42:24 -0700 Subject: [PATCH 016/177] drop SBA Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index ef6d8b1cc..e5258c542 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -19,11 +19,6 @@ def __new__( ) -> PackedTensor: self = torch.as_tensor(packed_tensor).as_subclass(cls) - assert batch_size is not None or cu_seqlens is not None - - if batch_size is not None: - assert packed_tensor.size(0) % batch_size == 0 - self._packed_tensor = packed_tensor self._original_shape = original_shape self._cu_seqlens = cu_seqlens From a061be4cff4ba9921e179cc782e9b3ec56ef46e0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 04:59:44 -0700 Subject: [PATCH 017/177] add packed tensor Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 48 ++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index e5258c542..0dd15e0c4 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -13,6 +13,7 @@ def __new__( cls, packed_tensor: torch.Tensor, original_shape: tuple[int], + assume_ragged: bool, cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, batch_size: int | None = None, @@ -21,9 +22,10 @@ def __new__( self._packed_tensor = packed_tensor self._original_shape = original_shape + self._assume_ragged = assume_ragged self._cu_seqlens = cu_seqlens self._max_seqlen = max_seqlen - self._batch_size = cu_seqlens.size(0) - 1 if batch_size is None else batch_size + self._batch_size = batch_size return self @@ -35,17 +37,31 @@ def from_unpacked_tensor( batch_size: int | None = None, ) -> PackedTensor: assert batch_size is not None or cu_seqlens is not None + assume_ragged = cu_seqlens is not None + + if assume_ragged: + assert max_seqlen is not None + + if batch_size is None: + batch_size = cu_seqlens.size(0) - 1 + + assert cu_seqlens.size(0) - 1 == batch_size - if batch_size is None: - batch_size = cu_seqlens.size(0) - 1 packed_tensor = pack_sequence(inputs=unpacked_tensor, cu_seqlens=cu_seqlens) else: assert unpacked_tensor.size(0) % batch_size == 0 + + if max_seqlen is None: + max_seqlen = unpacked_tensor.size(1) + + assert unpacked_tensor.size(1) == max_seqlen + packed_tensor = unpacked_tensor.flatten(0, 1) packed_tensor = PackedTensor( packed_tensor=packed_tensor, original_shape=unpacked_tensor.size(), + assume_ragged=assume_ragged, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, @@ -54,11 +70,31 @@ def from_unpacked_tensor( return packed_tensor def to_unpacked_tensor(self) -> torch.Tensor: - if self._cu_seqlens is None: - unpacked_tensor = self.view(self._batch_size, -1, *self.size()[1:]) - else: + if self._assume_ragged: unpacked_tensor = unpack_sequence( inputs=self._packed_tensor, cu_seqlens=self._cu_seqlens, output_shape=self._original_shape ) + else: + unpacked_tensor = self.view(self._batch_size, -1, *self.size()[1:]) return unpacked_tensor + + def is_ragged_tensor(self) -> bool: + return self._assume_ragged + + def get_batch_size(self) -> int: + return self._batch_size + + def get_max_seqlen(self) -> int: + return self._max_seqlen + + def get_cu_seqlens(self, force_compute: bool = False) -> torch.Tensor: + if force_compute: + if self._cu_seqlens is None: + self._cu_seqlens = torch.arange( + 0, self._batch_size * self._max_seqlen + 1, self._max_seqlen, device=self.device + ) + else: + raise NotImplementedError("code is not supposed to reach here") + + return self._cu_seqlens From a1c8d55ec3f9c9ffa24ac84ff7fd67091d456b77 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 05:36:05 -0700 Subject: [PATCH 018/177] add packed tensor Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 0dd15e0c4..dfc125572 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -98,3 +98,17 @@ def get_cu_seqlens(self, force_compute: bool = False) -> torch.Tensor: raise NotImplementedError("code is not supposed to reach here") return self._cu_seqlens + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + # Example: disallow reductions on dim 0 or dim 1 + if "dim" in kwargs: + dim = kwargs["dim"] + if isinstance(dim, int) and dim in (0, 1): + raise RuntimeError(f"{func.__name__} on dim {dim} is not valid for PackedTensor") + + # Fallback: run base implementation + return super().__torch_function__(func, types, args, kwargs) From 137618caebc3de17891d630e69a853b94008873d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 05:42:32 -0700 Subject: [PATCH 019/177] add packed tensor Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index dfc125572..b7203164c 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -101,14 +101,4 @@ def get_cu_seqlens(self, force_compute: bool = False) -> torch.Tensor: @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - - # Example: disallow reductions on dim 0 or dim 1 - if "dim" in kwargs: - dim = kwargs["dim"] - if isinstance(dim, int) and dim in (0, 1): - raise RuntimeError(f"{func.__name__} on dim {dim} is not valid for PackedTensor") - - # Fallback: run base implementation - return super().__torch_function__(func, types, args, kwargs) + raise NotImplementedError("unpack the tensor to run ops on it") From 18c9a93a4c3994f450774f24a9a0149b9db4be66 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 05:46:01 -0700 Subject: [PATCH 020/177] add packed tensor Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index b7203164c..78d61f0b6 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -4,11 +4,15 @@ from __future__ import annotations +from contextlib import contextmanager + import torch from fma import pack_sequence, unpack_sequence class PackedTensor(torch.Tensor): + _is_safe = False + def __new__( cls, packed_tensor: torch.Tensor, @@ -99,6 +103,16 @@ def get_cu_seqlens(self, force_compute: bool = False) -> torch.Tensor: return self._cu_seqlens + @contextmanager + @classmethod + def safe_mode(cls): + cls._is_safe = True + yield + cls._is_safe = False + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): + if cls._is_safe: + return super().__torch_function__(func, types, args, kwargs) + raise NotImplementedError("unpack the tensor to run ops on it") From e77686eefa296fabfa669dcf608c6666143c901b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 14:17:58 -0700 Subject: [PATCH 021/177] add packed tensor Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 10 ++- .../sequence_mixer_blocks/__init__.py | 1 - .../sequence_mixer_blocks/rnn.py | 89 +++++++------------ lm_engine/hf_models/tensor.py | 65 ++++++++++---- 4 files changed, 90 insertions(+), 75 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 3cf420889..c1995ac07 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -10,6 +10,7 @@ from ...cache import GenerationCache from ...config import CommonConfig from ...modeling_utils import get_mlp_block, get_normalization_function, get_sequence_mixer +from ...tensor import PackedTensor class Block(nn.Module): @@ -92,13 +93,16 @@ def _sequence_mixer_forward( hidden_states, cache_params=past_key_values, attention_mask=attention_mask ) elif self.sequence_mixer_type in ["gru", "rnn"]: - hidden_states = self.sequence_mixer( + hidden_states = PackedTensor.from_unpacked_tensor( hidden_states, - cache_params=past_key_values, - attention_mask=attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + batch_size=hidden_states.size(0) if cu_seqlens is None else None, ) + + hidden_states = self.sequence_mixer(hidden_states, cache_params=past_key_values) + + hidden_states = hidden_states.get_raw_data() else: raise ValueError(f"unexpected sequence_mixer_type ({self.sequence_mixer_type})") diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index f15066f56..1bf463fe6 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -59,7 +59,6 @@ def get_sequence_mixer( scaling_factor=block.scaling_factor, num_layers=config.num_layers, layer_idx=layer_idx, - use_padding_free_transformer=use_padding_free_transformer, ) elif sequence_mixer_type == "mamba2": return Mamba2( diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 00c520f05..a52faf5c7 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -15,9 +15,10 @@ from ....utils import divide_if_divisible, is_fma_available from ...cache import GenerationCache from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay +from ...tensor import PackedTensor from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence +from .utils import unpack_sequence if is_fma_available(): @@ -28,7 +29,7 @@ class RNN(nn.Module): def __init__( self, - input_size: int, + x_size: int, state_size: int, output_size: int, num_heads: int, @@ -41,17 +42,15 @@ def __init__( scaling_factor: float | None, num_layers: int, layer_idx: int, - use_padding_free_transformer: bool, ) -> RNN: super().__init__() - self.input_size = input_size + self.x_size = x_size self.state_size = state_size self.output_size = output_size self.num_heads = num_heads self.gradient_clipping = gradient_clipping self.layer_idx = layer_idx - self.use_padding_free_transformer = use_padding_free_transformer self.state_head_dim = divide_if_divisible(self.state_size, self.num_heads, "") std = initializer_range @@ -59,7 +58,7 @@ def __init__( std /= math.sqrt(m_width) self.state_weight_std = std - self.input_projection = ParameterizedLinear(self.input_size, 2 * self.state_size, bias=add_bias, std=std) + self.x_projection = ParameterizedLinear(self.x_size, 2 * self.state_size, bias=add_bias, std=std) self.state_weight = nn.Parameter(torch.empty(self.num_heads, self.state_head_dim, self.state_head_dim)) std = initializer_range / math.sqrt(2 * num_layers) @@ -72,72 +71,52 @@ def __init__( self.scaling_factor = scaling_factor self.reset_parameters() - mark_parameter_as_mup_learning_rate(self.input_projection.weight) + mark_parameter_as_mup_learning_rate(self.x_projection.weight) mark_parameter_as_mup_learning_rate(self.state_weight) mark_parameter_as_mup_learning_rate(self.output_projection.weight) mark_parameter_as_no_weight_decay(self.state_weight) - def forward( - self, - input: torch.Tensor, - cache_params: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - ) -> torch.Tensor: - if self.use_padding_free_transformer: - assert cache_params is None - assert attention_mask is None - else: - assert cu_seqlens is None - assert max_seqlen is None - - batch_size, sequence_length = input.size()[:2] - - if attention_mask is not None: - cu_seqlens, max_seqlen = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask) - input = pack_sequence(inputs=input, cu_seqlens=cu_seqlens) + def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) -> PackedTensor: + state = None if cache_params is None else cache_params.get_cache(self.layer_idx) + T = x.get_num_tokens() - input_state = None if cache_params is None else cache_params.get_cache(self.layer_idx) + with x.safe_mode(): + x = self.x_projection(x) + x, gate = x.chunk(2, dim=-1) + x = x.view(T, self.num_heads, self.state_head_dim) - input = self.input_projection(input) - input, gate = input.chunk(2, dim=-1) - - input = input.view(*input.size()[:-1], self.num_heads, self.state_head_dim) + if self.scaling_factor != 1: + x = x * self.scaling_factor weight = self.state_weight - if self.scaling_factor != 1: - input = input * self.scaling_factor weight = weight * self.scaling_factor - input = rnn( - input=input, - weight=weight, - input_state=input_state, - gradient_clipping=self.gradient_clipping, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.rnn) else KernelBackend.torch, - ) - - if not self.use_padding_free_transformer and attention_mask is not None: - input = unpack_sequence( - inputs=input, cu_seqlens=cu_seqlens, output_shape=(batch_size, sequence_length, *input.size()[1:]) + x = x.with_new_data( + rnn( + input=x.get_raw_data(), + weight=weight, + input_state=state, + gradient_clipping=self.gradient_clipping, + cu_seqlens=x.get_cu_seqlens(), + max_seqlen=x.get_max_seqlen(), + kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.rnn) else KernelBackend.torch, ) + ) if cache_params is not None: - cache_params.update(state=input[:, -1], num_tokens_added=input.size(1), layer_idx=self.layer_idx) - - input = input.view(*input.size()[:-2], -1) - - input = input * F.silu(gate) - input = self.norm(input) + cache_params.update( + state=x.get_last_element_along_sequence(), num_tokens_added=x.size(1), layer_idx=self.layer_idx + ) - input = self.output_projection(input) + with x.safe_mode(): + x = x.view(T, -1) + x = x * F.silu(gate) + x = self.norm(x) + x = self.output_projection(x) - return input + return x @torch.no_grad() def reset_parameters(self) -> None: diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 78d61f0b6..38e40bc0f 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -15,16 +15,16 @@ class PackedTensor(torch.Tensor): def __new__( cls, - packed_tensor: torch.Tensor, + tensor: torch.Tensor, original_shape: tuple[int], assume_ragged: bool, cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, batch_size: int | None = None, ) -> PackedTensor: - self = torch.as_tensor(packed_tensor).as_subclass(cls) + self = torch.as_tensor(tensor).as_subclass(cls) - self._packed_tensor = packed_tensor + self._tensor = tensor self._original_shape = original_shape self._assume_ragged = assume_ragged self._cu_seqlens = cu_seqlens @@ -60,7 +60,7 @@ def from_unpacked_tensor( assert unpacked_tensor.size(1) == max_seqlen - packed_tensor = unpacked_tensor.flatten(0, 1) + packed_tensor = unpacked_tensor packed_tensor = PackedTensor( packed_tensor=packed_tensor, @@ -73,33 +73,66 @@ def from_unpacked_tensor( return packed_tensor + def get_num_tokens(self) -> int: + T = self.get_raw_data().size(0) + if not self.is_ragged_tensor(): + T *= self.get_raw_data().size(1) + + return T + + def with_new_data(self, tensor: torch.Tensor) -> PackedTensor: + return PackedTensor( + tensor=tensor, + original_shape=self._original_shape, + assume_ragged=self._assume_ragged, + cu_seqlens=self._cu_seqlens, + max_seqlen=self._max_seqlen, + batch_size=self._batch_size, + ) + def to_unpacked_tensor(self) -> torch.Tensor: - if self._assume_ragged: + if self.is_ragged_tensor(): unpacked_tensor = unpack_sequence( - inputs=self._packed_tensor, cu_seqlens=self._cu_seqlens, output_shape=self._original_shape + inputs=self.get_raw_data(), cu_seqlens=self._cu_seqlens, output_shape=self._original_shape ) else: - unpacked_tensor = self.view(self._batch_size, -1, *self.size()[1:]) + unpacked_tensor = self.get_raw_data() return unpacked_tensor + def get_raw_data(self) -> torch.Tensor: + return self._tensor + + def get_last_element_along_sequence(self) -> torch.Tensor: + output = self.get_raw_data() + + if self.is_ragged_tensor(): + output = output[self.get_cu_seqlens()[1:] - 1] + else: + output = output[:, -1] + + return output + def is_ragged_tensor(self) -> bool: return self._assume_ragged def get_batch_size(self) -> int: return self._batch_size - def get_max_seqlen(self) -> int: + def get_max_seqlen(self, return_none_allowed: bool = True) -> int: + if return_none_allowed and not self.is_ragged_tensor(): + return None + return self._max_seqlen - def get_cu_seqlens(self, force_compute: bool = False) -> torch.Tensor: - if force_compute: - if self._cu_seqlens is None: - self._cu_seqlens = torch.arange( - 0, self._batch_size * self._max_seqlen + 1, self._max_seqlen, device=self.device - ) - else: - raise NotImplementedError("code is not supposed to reach here") + def get_cu_seqlens(self, return_none_allowed: bool = False) -> torch.Tensor: + if return_none_allowed and not self.is_ragged_tensor(): + return None + + if self._cu_seqlens is None: + self._cu_seqlens = torch.arange( + 0, self._batch_size * self._max_seqlen + 1, self._max_seqlen, device=self.device + ) return self._cu_seqlens From 5232d26b87d94f6abfe112f09e764ba03dbbd818 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 14:24:35 -0700 Subject: [PATCH 022/177] add packed tensor Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/rnn.py | 8 +++++--- lm_engine/hf_models/tensor.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index a52faf5c7..2f7af417e 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -83,7 +83,7 @@ def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) with x.safe_mode(): x = self.x_projection(x) - x, gate = x.chunk(2, dim=-1) + x, g = x.chunk(2, dim=-1) x = x.view(T, self.num_heads, self.state_head_dim) if self.scaling_factor != 1: @@ -107,12 +107,14 @@ def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) if cache_params is not None: cache_params.update( - state=x.get_last_element_along_sequence(), num_tokens_added=x.size(1), layer_idx=self.layer_idx + state=x.get_last_element_along_sequence(), + num_tokens_added=x.get_cu_seqlens(False), + layer_idx=self.layer_idx, ) with x.safe_mode(): x = x.view(T, -1) - x = x * F.silu(gate) + x = x * F.silu(g) x = self.norm(x) x = self.output_projection(x) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 38e40bc0f..3c42a48b5 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -125,7 +125,7 @@ def get_max_seqlen(self, return_none_allowed: bool = True) -> int: return self._max_seqlen - def get_cu_seqlens(self, return_none_allowed: bool = False) -> torch.Tensor: + def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor: if return_none_allowed and not self.is_ragged_tensor(): return None From 0ba23bfac70891d1bd558604d65b1b22d20eca98 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 15:11:41 -0700 Subject: [PATCH 023/177] add packed tensor Signed-off-by: Mayank Mishra --- lm_engine/data/__init__.py | 1 - lm_engine/data/utils.py | 91 +++++------- tests/hf_models/single_gpu/gpt_base_test.py | 128 ---------------- .../multihead_latent_attention_test.py | 137 ------------------ tests/hf_models/single_gpu/typecheck_test.py | 29 ---- tests/training/dataloader_test.py | 1 - 6 files changed, 33 insertions(+), 354 deletions(-) delete mode 100644 tests/hf_models/single_gpu/multihead_latent_attention_test.py delete mode 100644 tests/hf_models/single_gpu/typecheck_test.py diff --git a/lm_engine/data/__init__.py b/lm_engine/data/__init__.py index 8b3b81a99..8cf38bf85 100644 --- a/lm_engine/data/__init__.py +++ b/lm_engine/data/__init__.py @@ -135,7 +135,6 @@ def get_finetuning_dataloader( use_output=use_output, loss_mask=args.training_parameters.loss_mask, eos_token_id=tokenizer.eos_token_id, - use_padding_free_transformer=args.model_args.use_padding_free_transformer, pad_to_multiple_of=ProcessGroupManager.get_tensor_parallel_world_size(), ), ) diff --git a/lm_engine/data/utils.py b/lm_engine/data/utils.py index f686acd32..63351c82b 100644 --- a/lm_engine/data/utils.py +++ b/lm_engine/data/utils.py @@ -16,7 +16,6 @@ def collate_fn( use_output: bool, loss_mask: LossMask, eos_token_id: int, - use_padding_free_transformer: bool, labels_mask_value: int = -100, pad_to_multiple_of: int = 1, device: torch.device = None, @@ -38,64 +37,40 @@ def collate_fn( device = torch.cuda.current_device() if device is None else device - if use_padding_free_transformer: - input_ids = inputs - attention_mask = None - - if loss_mask == LossMask.output_only: - labels = [ - [labels_mask_value] * (len(array_in) - len(array_out)) + array_out - for array_in, array_out in zip(inputs, outputs) - ] - elif loss_mask == LossMask.no_mask: - labels = inputs - else: - raise ValueError(f"unexpected loss_mask ({loss_mask})") - - tokens_to_add = 0 - if pad_to_multiple_of > 1: - total_tokens = sum([len(array) for array in input_ids]) - tokens_to_add = (math.ceil(total_tokens / pad_to_multiple_of) * pad_to_multiple_of) - total_tokens - - # we pad the last example in the batch on the right - # NOTE this can be done since the attention is causal - input_ids[-1].extend([eos_token_id] * tokens_to_add) - labels[-1].extend([labels_mask_value] * tokens_to_add) - - input_ids, position_ids, _, labels, cu_seqlens, max_seqlen = convert_padding_free_lists_to_tensors( - input_ids=input_ids, labels=labels, device=device - ) - - result = { - "input_ids": input_ids, - "position_ids": position_ids, - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - } - if labels is not None: - result["labels"] = labels + input_ids = inputs + + if loss_mask == LossMask.output_only: + labels = [ + [labels_mask_value] * (len(array_in) - len(array_out)) + array_out + for array_in, array_out in zip(inputs, outputs) + ] + elif loss_mask == LossMask.no_mask: + labels = inputs else: - max_length = max(list(map(len, inputs))) - if pad_to_multiple_of > 1: - max_length = math.ceil(max_length / pad_to_multiple_of) * pad_to_multiple_of - - input_ids = [[eos_token_id] * (max_length - len(array)) + array for array in inputs] - attention_mask = [[0] * (max_length - len(array)) + [1] * len(array) for array in inputs] - - if outputs is not None: - if loss_mask == LossMask.output_only: - labels = [[labels_mask_value] * (max_length - len(array)) + array for array in outputs] - elif loss_mask == LossMask.no_mask: - labels = inputs - else: - raise ValueError(f"unexpected loss_mask ({loss_mask})") - - result = { - "input_ids": torch.tensor(input_ids, device=device), - "attention_mask": torch.tensor(attention_mask, device=device), - } - if labels is not None: - result["labels"] = torch.tensor(labels, device=device) + raise ValueError(f"unexpected loss_mask ({loss_mask})") + + tokens_to_add = 0 + if pad_to_multiple_of > 1: + total_tokens = sum([len(array) for array in input_ids]) + tokens_to_add = (math.ceil(total_tokens / pad_to_multiple_of) * pad_to_multiple_of) - total_tokens + + # we pad the last example in the batch on the right + # NOTE this can be done since the attention is causal + input_ids[-1].extend([eos_token_id] * tokens_to_add) + labels[-1].extend([labels_mask_value] * tokens_to_add) + + input_ids, position_ids, _, labels, cu_seqlens, max_seqlen = convert_padding_free_lists_to_tensors( + input_ids=input_ids, labels=labels, device=device + ) + + result = { + "input_ids": input_ids, + "position_ids": position_ids, + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + } + if labels is not None: + result["labels"] = labels return result diff --git a/tests/hf_models/single_gpu/gpt_base_test.py b/tests/hf_models/single_gpu/gpt_base_test.py index 579437443..540fdd929 100644 --- a/tests/hf_models/single_gpu/gpt_base_test.py +++ b/tests/hf_models/single_gpu/gpt_base_test.py @@ -18,52 +18,6 @@ class GPTBaseAttentionTest(TestCommons): - @parameterized.expand( - TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] - ) - ) - def test_sdpa_padding_free_transformer_equivalence( - self, device: torch.device, position_embedding_type: str, dtype: torch.dtype - ) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(position_embedding_type, num_layers=1) - - sdpa_model = self.from_config(config, dtype=dtype).to(device) - flash_model = self.from_config(config, dtype=dtype, use_padding_free_transformer=True).to(device) - - sdpa_model.eval() - flash_model.eval() - - flash_model.load_state_dict(sdpa_model.state_dict()) - - input_ids, attention_mask, labels = self.get_dummy_inputs(device) - sdpa_output = sdpa_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - attention_mask = attention_mask.to(torch.bool) - sdpa_logits = sdpa_output.logits - sdpa_logits = torch.cat([sdpa_logits[i, ex, :] for i, ex in enumerate(attention_mask)]) - sdpa_loss = sdpa_output.loss - - with enable_kernels([Kernel.flash_attention_2]): - input_ids, attention_mask, labels = self.get_dummy_inputs(device, return_list=True) - flash_output = flash_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - flash_logits = flash_output.logits - flash_loss = flash_output.loss - - self.assert_equal_tensors( - sdpa_logits, - flash_logits, - False, - rtol_float16=1e-3, - atol_float16=3e-4, - rtol_bfloat16=5e-3, - atol_bfloat16=5e-3, - ) - self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=1.2e-4, rtol_float32=0) - @parameterized.expand( TestCommons.make_args_matrix( [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] @@ -106,88 +60,6 @@ def test_sdpa_flash_attention_equivalence( ) self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=1.2e-4, rtol_float32=0) - @parameterized.expand( - TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] - ) - ) - def test_padding_free_transformer_with_list_and_tensor( - self, device: torch.device, position_embedding_type: str, dtype: torch.dtype - ) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(position_embedding_type, num_layers=1) - - model = self.from_config(config, dtype=dtype, use_padding_free_transformer=True).to(device) - model.eval() - - with enable_kernels([Kernel.flash_attention_2]): - input_ids, attention_mask, labels = self.get_dummy_inputs(device, return_list=True) - list_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - list_logits = list_output.logits - list_loss = list_output.loss - - seqlens = torch.tensor([0] + [len(i) for i in input_ids]) - cu_seqlens = seqlens.cumsum(dim=-1).to(device, torch.int32) - max_seqlen = seqlens.max().item() - position_ids = torch.tensor( - list(itertools.chain(*[list(range(len(i))) for i in input_ids])), device=device - ) - input_ids = torch.tensor(list(itertools.chain(*input_ids)), device=device) - labels = torch.tensor(list(itertools.chain(*labels)), device=device) - tensor_output = model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - tensor_logits = tensor_output.logits - tensor_loss = tensor_output.loss - - self.assert_equal_tensors(list_logits, tensor_logits, True) - self.assert_equal_tensors(list_loss, tensor_loss, True) - - @parameterized.expand( - TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] - ) - ) - def test_sdpa_flash_enabled(self, device: torch.device, position_embedding_type: str, dtype: torch.dtype) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(position_embedding_type, num_layers=1) - - model = self.from_config(config, dtype=dtype).to(device) - model.eval() - - input_ids, _, labels = self.get_dummy_inputs(device) - attention_mask = torch.ones_like(input_ids, dtype=torch.int, device=device) - - sdpa_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - sdpa_logits = sdpa_output.logits - sdpa_loss = sdpa_output.loss - - flash_output = model(input_ids=input_ids, labels=labels) - flash_logits = flash_output.logits - flash_loss = flash_output.loss - - self.assert_equal_tensors( - sdpa_logits, - flash_logits, - False, - rtol_float16=1e-3, - atol_float16=3e-4, - rtol_bfloat16=5e-3, - atol_bfloat16=5e-3, - ) - self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=3.8e-4, rtol_float32=0) - @parameterized.expand( TestCommons.make_args_matrix( [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] diff --git a/tests/hf_models/single_gpu/multihead_latent_attention_test.py b/tests/hf_models/single_gpu/multihead_latent_attention_test.py deleted file mode 100644 index 4accd572b..000000000 --- a/tests/hf_models/single_gpu/multihead_latent_attention_test.py +++ /dev/null @@ -1,137 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch -from parameterized import parameterized -from transformers import set_seed - -from lm_engine.enums import Kernel -from lm_engine.hf_models import GPTBaseConfig -from lm_engine.kernels import enable_kernels - -from ..test_common import TestCommons - - -SEED = 1234 - - -class MultiHeadLatentAttentionTest(TestCommons): - @parameterized.expand(TestCommons.make_args_matrix([torch.device("cuda")], [torch.float16, torch.bfloat16])) - def test_sdpa_padding_free_transformer_equivalence(self, device: torch.device, dtype: torch.dtype) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(num_layers=1) - - sdpa_model = self.from_config(config, dtype=dtype).to(device) - flash_model = self.from_config(config, dtype=dtype, use_padding_free_transformer=True).to(device) - - sdpa_model.eval() - flash_model.eval() - - flash_model.load_state_dict(sdpa_model.state_dict()) - - input_ids, attention_mask, labels = self.get_dummy_inputs(device) - sdpa_output = sdpa_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - attention_mask = attention_mask.to(torch.bool) - sdpa_logits = sdpa_output.logits - sdpa_logits = torch.cat([sdpa_logits[i, ex, :] for i, ex in enumerate(attention_mask)]) - sdpa_loss = sdpa_output.loss - - with enable_kernels([Kernel.flash_attention_2]): - input_ids, attention_mask, labels = self.get_dummy_inputs(device, return_list=True) - flash_output = flash_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - flash_logits = flash_output.logits - flash_loss = flash_output.loss - - self.assert_equal_tensors( - sdpa_logits, - flash_logits, - False, - rtol_float16=1e-3, - atol_float16=3e-4, - rtol_bfloat16=5e-3, - atol_bfloat16=5e-3, - ) - self.assert_equal_tensors(sdpa_loss, flash_loss, False) - - @parameterized.expand(TestCommons.make_args_matrix([torch.device("cuda")], [torch.float16, torch.bfloat16])) - def test_sdpa_flash_attention_equivalence(self, device: torch.device, dtype: torch.dtype) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - input_ids, attention_mask, labels = self.get_dummy_inputs(device) - config = self.get_dense_test_config(num_layers=1) - - model = self.from_config(config, dtype=dtype).to(device) - model.eval() - - sdpa_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - sdpa_logits = sdpa_output.logits - sdpa_loss = sdpa_output.loss - - with enable_kernels([Kernel.flash_attention_2]): - flash_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - flash_logits = flash_output.logits - flash_loss = flash_output.loss - - # we don't care about what happens on masked values (they don't match btw) - sdpa_logits[attention_mask == 0] = 0 - flash_logits[attention_mask == 0] = 0 - - self.assert_equal_tensors( - sdpa_logits[attention_mask], - flash_logits[attention_mask], - False, - rtol_float16=1e-3, - atol_float16=3e-4, - rtol_bfloat16=5e-3, - atol_bfloat16=5e-3, - ) - self.assert_equal_tensors(sdpa_loss, flash_loss, False) - - @staticmethod - def get_dense_test_config( - num_layers: int = 8, - add_bias: bool = True, - activation_function: str = "gelu_pytorch_tanh", - normalization_function: str = "layernorm", - m_emb: float = None, - m_width: float = None, - m_residual: float = None, - attention_multiplier: float = None, - ) -> GPTBaseConfig: - return GPTBaseConfig( - vocab_size=2048, - max_position_embeddings=1024, - hidden_size=32, - num_layers=num_layers, - position_embedding_type="nope", - normalization_function=normalization_function, - tie_word_embeddings=False, - bos_token_id=0, - eos_token_id=1, - pad_token_id=2, - m_emb=m_emb, - m_width=m_width, - m_residual=m_residual, - sequence_mixer_blocks=[ - { - "sequence_mixer_type": "multihead_latent_attention", - "add_bias": add_bias, - "attention_multiplier": attention_multiplier, - "num_attention_heads": 4, - "query_compression_size": 12, - "key_value_compression_size": 8, - "head_dim": 8, - } - for _ in range(num_layers) - ], - mlp_blocks=[ - {"mlp_type": "MLP", "activation_function": activation_function, "add_bias": add_bias} - for _ in range(num_layers) - ], - ) diff --git a/tests/hf_models/single_gpu/typecheck_test.py b/tests/hf_models/single_gpu/typecheck_test.py deleted file mode 100644 index ded25bb6a..000000000 --- a/tests/hf_models/single_gpu/typecheck_test.py +++ /dev/null @@ -1,29 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch -from parameterized import parameterized - -from lm_engine.enums import Kernel -from lm_engine.kernels import enable_kernels - -from ..test_common import TestCommons - - -class TypeCheckTest(TestCommons): - @parameterized.expand(TestCommons.make_args_matrix([torch.device("cuda")])) - def test_no_attention_mask_flash_attention(self, device: torch.device) -> None: - self.skip_test_if_device_unavailable(device) - - config = self.get_dense_test_config( - position_embedding_type="learned_absolute", num_layers=8, num_attention_heads=32 - ) - model = self.from_config(config, use_padding_free_transformer=True).to(device) - model.eval() - - input_ids, _, labels = self.get_dummy_inputs(device, return_list=True) - attention_mask = [[1] * len(i) for i in input_ids] - - with enable_kernels([Kernel.flash_attention_2]): - self.assertRaises(AssertionError, model, input_ids=input_ids, attention_mask=attention_mask, labels=labels) diff --git a/tests/training/dataloader_test.py b/tests/training/dataloader_test.py index a3884073f..b98c50ca7 100644 --- a/tests/training/dataloader_test.py +++ b/tests/training/dataloader_test.py @@ -53,7 +53,6 @@ def test_dataloader_has_correct_order(self) -> None: use_output=True, loss_mask=args.training_parameters.loss_mask, eos_token_id=tokenizer.eos_token_id, - use_padding_free_transformer=args.model_args.use_padding_free_transformer, device="cpu", ), ) From 21f7521f6469abc57491284560c4afbd5e657e17 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 17:20:39 -0700 Subject: [PATCH 024/177] add packed tensor Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/rnn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 2f7af417e..2d0573cc4 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -29,7 +29,7 @@ class RNN(nn.Module): def __init__( self, - x_size: int, + input_size: int, state_size: int, output_size: int, num_heads: int, @@ -45,7 +45,7 @@ def __init__( ) -> RNN: super().__init__() - self.x_size = x_size + self.input_size = input_size self.state_size = state_size self.output_size = output_size self.num_heads = num_heads @@ -58,7 +58,7 @@ def __init__( std /= math.sqrt(m_width) self.state_weight_std = std - self.x_projection = ParameterizedLinear(self.x_size, 2 * self.state_size, bias=add_bias, std=std) + self.x_projection = ParameterizedLinear(self.input_size, 2 * self.state_size, bias=add_bias, std=std) self.state_weight = nn.Parameter(torch.empty(self.num_heads, self.state_head_dim, self.state_head_dim)) std = initializer_range / math.sqrt(2 * num_layers) From 0405210985705519cd97137ad99c922c700338f6 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 17:28:48 -0700 Subject: [PATCH 025/177] add packed tensor Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/rnn.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 2d0573cc4..7de4ebf89 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -18,7 +18,6 @@ from ...tensor import PackedTensor from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .utils import unpack_sequence if is_fma_available(): @@ -58,7 +57,7 @@ def __init__( std /= math.sqrt(m_width) self.state_weight_std = std - self.x_projection = ParameterizedLinear(self.input_size, 2 * self.state_size, bias=add_bias, std=std) + self.input_projection = ParameterizedLinear(self.input_size, 2 * self.state_size, bias=add_bias, std=std) self.state_weight = nn.Parameter(torch.empty(self.num_heads, self.state_head_dim, self.state_head_dim)) std = initializer_range / math.sqrt(2 * num_layers) @@ -71,7 +70,7 @@ def __init__( self.scaling_factor = scaling_factor self.reset_parameters() - mark_parameter_as_mup_learning_rate(self.x_projection.weight) + mark_parameter_as_mup_learning_rate(self.input_projection.weight) mark_parameter_as_mup_learning_rate(self.state_weight) mark_parameter_as_mup_learning_rate(self.output_projection.weight) @@ -82,7 +81,7 @@ def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) T = x.get_num_tokens() with x.safe_mode(): - x = self.x_projection(x) + x = self.input_projection(x) x, g = x.chunk(2, dim=-1) x = x.view(T, self.num_heads, self.state_head_dim) From 156db5dd35dc7c3da334ac115c3875f5b053b08b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 17:29:04 -0700 Subject: [PATCH 026/177] add packed tensor Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/gru.py | 107 +++++++----------- 1 file changed, 40 insertions(+), 67 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index 5e519c4fd..925ff216a 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -15,9 +15,9 @@ from ....utils import divide_if_divisible, is_fma_available from ...cache import GenerationCache from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay +from ...tensor import PackedTensor from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence if is_fma_available(): @@ -41,7 +41,6 @@ def __init__( scaling_factor: float | None, num_layers: int, layer_idx: int, - use_padding_free_transformer: bool, ) -> GRU: super().__init__() @@ -51,7 +50,6 @@ def __init__( self.num_heads = num_heads self.gradient_clipping = gradient_clipping self.layer_idx = layer_idx - self.use_padding_free_transformer = use_padding_free_transformer self.state_head_dim = divide_if_divisible(self.state_size, self.num_heads, "") std = initializer_range @@ -59,13 +57,7 @@ def __init__( std /= math.sqrt(m_width) self.state_weight_std = std - self.input_projection = ParameterizedLinear( - self.input_size, - 4 * self.state_size, - bias=add_bias, - std=std, - ) - + self.input_projection = ParameterizedLinear(self.input_size, 4 * self.state_size, bias=add_bias, std=std) self.state_weight = nn.Parameter(torch.empty(3 * self.num_heads, self.state_head_dim, self.state_head_dim)) std = initializer_range / math.sqrt(2 * num_layers) @@ -84,75 +76,56 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) - def forward( - self, - input: torch.Tensor, - cache_params: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - ) -> torch.Tensor: - if self.use_padding_free_transformer: - assert cache_params is None - assert attention_mask is None - else: - assert cu_seqlens is None - assert max_seqlen is None - - batch_size, sequence_length = input.size()[:2] - - if attention_mask is not None: - cu_seqlens, max_seqlen = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask) - input = pack_sequence(inputs=input, cu_seqlens=cu_seqlens) - - input_state = None if cache_params is None else cache_params.get_cache(self.layer_idx) - - input = self.input_projection(input) - input, gate = input.split((3 * self.state_size, self.state_size), dim=-1) + def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) -> PackedTensor: + state = None if cache_params is None else cache_params.get_cache(self.layer_idx) + T = x.get_num_tokens() - weight = self.state_weight + with x.safe_mode(): + x = self.input_projection(x) + x, g = x.split((3 * self.state_size, self.state_size), dim=-1) + if self.scaling_factor != 1: + x = x * self.scaling_factor + + x, x_forget, x_reset = x.chunk(3, dim=-1) + x, x_forget, x_reset = [i.view(T, self.num_heads, self.state_head_dim) for i in (x, x_forget, x_reset)] + + weight = self.state_weight if self.scaling_factor != 1: - input = input * self.scaling_factor weight = weight * self.scaling_factor - input, forget_input, reset_input = input.chunk(3, dim=-1) weight, forget_weight, reset_weight = weight.chunk(3, dim=0) - input, forget_input, reset_input = [ - i.view(*input.size()[:-1], self.num_heads, self.state_head_dim) for i in (input, forget_input, reset_input) - ] - - input = gru( - input=input, - weight=weight, - forget_input=forget_input, - forget_weight=forget_weight, - reset_input=reset_input, - reset_weight=reset_weight, - input_state=input_state, - gradient_clipping=self.gradient_clipping, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.gru) else KernelBackend.torch, - ) - - if not self.use_padding_free_transformer and attention_mask is not None: - input = unpack_sequence( - inputs=input, cu_seqlens=cu_seqlens, output_shape=(batch_size, sequence_length, *input.size()[1:]) + x = x.with_new_data( + gru( + input=x, + weight=weight, + forget_input=x_forget, + forget_weight=forget_weight, + reset_input=x_reset, + reset_weight=reset_weight, + input_state=state, + gradient_clipping=self.gradient_clipping, + cu_seqlens=x.get_cu_seqlens(), + max_seqlen=x.get_max_seqlen(), + kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.gru) else KernelBackend.torch, ) + ) if cache_params is not None: - cache_params.update(state=input[:, -1], num_tokens_added=input.size(1), layer_idx=self.layer_idx) - - input = input.view(*input.size()[:-2], -1) - - input = input * F.silu(gate) - input = self.norm(input) + cache_params.update( + state=x.get_last_element_along_sequence(), + num_tokens_added=x.get_cu_seqlens(False), + layer_idx=self.layer_idx, + ) - input = self.output_projection(input) + with x.safe_mode(): + x = x.view(T, -1) + x = x * F.silu(g) + x = self.norm(x) + x = self.output_projection(x) - return input + return x @torch.no_grad() def reset_parameters(self) -> None: From bb6e1b94c0532f84baa66dfd0d66d845b5d9c7d6 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 17:31:47 -0700 Subject: [PATCH 027/177] add packed tensor Signed-off-by: Mayank Mishra --- lm_engine/hf_models/models/palm/layer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lm_engine/hf_models/models/palm/layer.py b/lm_engine/hf_models/models/palm/layer.py index 35dd9541a..16e8f7c93 100644 --- a/lm_engine/hf_models/models/palm/layer.py +++ b/lm_engine/hf_models/models/palm/layer.py @@ -24,9 +24,7 @@ def __init__( config.normalization_function, config.hidden_size, eps=config.layer_norm_epsilon ) self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) - self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx - ) + self.mlp_block = get_mlp_block(config, layer_idx=layer_idx) def forward( self, From 6e688e5afd841f064f685d10bbacdf66b5b1e918 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 17:34:11 -0700 Subject: [PATCH 028/177] add packed tensor Signed-off-by: Mayank Mishra --- tests/hf_models/test_common.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/hf_models/test_common.py b/tests/hf_models/test_common.py index 0bbbcb924..6becb41e2 100644 --- a/tests/hf_models/test_common.py +++ b/tests/hf_models/test_common.py @@ -269,19 +269,8 @@ def compare_saved_models(path1: str, path2: str) -> bool: return False def from_config(self, config: AutoConfig, **kwargs) -> AutoModelForCausalLM: - use_padding_free_transformer = kwargs.pop("use_padding_free_transformer", False) - - model = AutoModelForCausalLM.from_config( - config, - use_padding_free_transformer=use_padding_free_transformer, - dtype=kwargs.pop("dtype", None), - ) - - if use_padding_free_transformer: - assert model.use_padding_free_transformer - + model = AutoModelForCausalLM.from_config(config, dtype=kwargs.pop("dtype", None)) assert len(kwargs) == 0 - return model def assert_equal_tensors( From 84cc4e1233ef2a19f37b4d5337982f2e2080255d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 18:22:39 -0700 Subject: [PATCH 029/177] drop SBA temporarily Signed-off-by: Mayank Mishra --- b.py | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 b.py diff --git a/b.py b/b.py deleted file mode 100644 index b3d274cdc..000000000 --- a/b.py +++ /dev/null @@ -1,17 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import torch - -from lm_engine.hf_models.tensor import PackedTensor - - -y = torch.randn(5, 4, requires_grad=True) -# Example usage -x = PackedTensor.from_unpacked_tensor(y, batch_size=5) -x.sum().backward() - -print(x) From 796b1da1ffc84b5622b6fe53a1a5a08c303318ae Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 18:28:36 -0700 Subject: [PATCH 030/177] drop SBA temporarily Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/attention.py | 6 +-- .../multihead_latent_attention.py | 6 +-- .../utils/flash_attention_utils.py | 52 +++++++++---------- .../sequence_mixer_blocks/attention.py | 6 +-- .../gpt_crosslayer/sequence_mixers/base.py | 12 ++--- 5 files changed, 41 insertions(+), 41 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index ef74c4637..898ac18e5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -194,9 +194,9 @@ def forward( value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) hidden_states = flash_attention( - query=query, - key=key, - value=value, + q=query, + k=key, + v=value, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py index 9bec13d46..6afba02bb 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py @@ -170,9 +170,9 @@ def forward( value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) hidden_states = flash_attention( - query=query, - key=key, - value=value, + q=query, + k=key, + v=value, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py index a04353d13..2ab1b1f6d 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py @@ -52,9 +52,9 @@ def unpad_input( def flash_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, attention_mask: torch.Tensor | None, cu_seqlens: torch.Tensor | None, max_seqlen: int | None, @@ -78,14 +78,14 @@ def flash_attention( if cu_seqlens is None: assert max_seqlen is None - assert query.dim() == 4 + assert q.dim() == 4 if attention_mask is None: if use_flash_attention_3: attn_output, _ = flash_attention_3( - q=query, - k=key, - v=value, + q=q, + k=k, + v=v, softmax_scale=softmax_scale, causal=causal, window_size=window_size, @@ -93,9 +93,9 @@ def flash_attention( ) else: attn_output = flash_attention_2( - q=query, - k=key, - v=value, + q=q, + k=k, + v=v, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, @@ -103,17 +103,17 @@ def flash_attention( softcap=softcap, ) else: - batch_size, query_length, num_heads, head_dim = query.size() + batch_size, query_length, num_heads, head_dim = q.size() - query, key, value, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = unpad_input( - query, key, value, attention_mask, query_length + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = unpad_input( + q, k, v, attention_mask, query_length ) if use_flash_attention_3: attn_output, _ = flash_attention_3_varlen( - q=query, - k=key, - v=value, + q=q, + k=k, + v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, @@ -125,9 +125,9 @@ def flash_attention( ) else: attn_output = flash_attention_2_varlen( - q=query, - k=key, - v=value, + q=q, + k=k, + v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, @@ -146,13 +146,13 @@ def flash_attention( ) else: assert sliding_window is None - assert query.dim() == 3 + assert q.dim() == 3 if use_flash_attention_3: attn_output, _ = flash_attention_3_varlen( - q=query, - k=key, - v=value, + q=q, + k=k, + v=v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, @@ -162,9 +162,9 @@ def flash_attention( ) else: attn_output = flash_attention_2_varlen( - q=query, - k=key, - v=value, + q=q, + k=k, + v=v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, diff --git a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py index 131d02d42..27b4b887e 100644 --- a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py @@ -189,9 +189,9 @@ def forward( value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) hidden_states = flash_attention( - query=query, - key=key, - value=value, + q=query, + k=key, + v=value, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask, diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py index df65ced1d..3bc88a560 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py @@ -97,9 +97,9 @@ def forward( query = apply_rotary_pos_emb(query, rope_cos_sin) hidden_states = flash_attention( - query=query, - key=key, - value=value, + q=query, + k=key, + v=value, attention_mask=attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -124,9 +124,9 @@ def forward( query = query.transpose(1, 2) hidden_states = flash_attention( - query=query, - key=key, - value=value, + q=query, + k=key, + v=value, attention_mask=attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, From 766bcded9033649e32a4ed630549d6e5e132ac76 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 19:23:34 -0700 Subject: [PATCH 031/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/model_wrapper/__init__.py | 1 - lm_engine/model_wrapper/base.py | 9 +- lm_engine/model_wrapper/distillation.py | 7 +- lm_engine/model_wrapper/finetuning.py | 6 +- lm_engine/model_wrapper/pretraining.py | 114 +++++++++++------------- 5 files changed, 59 insertions(+), 78 deletions(-) diff --git a/lm_engine/model_wrapper/__init__.py b/lm_engine/model_wrapper/__init__.py index 12615bdde..57601f803 100644 --- a/lm_engine/model_wrapper/__init__.py +++ b/lm_engine/model_wrapper/__init__.py @@ -39,7 +39,6 @@ def get_model_container( "model_class": args.model_args.model_class, "dtype": args.mixed_precision_args.dtype, "efficient_initialization": efficient_initialization, - "use_padding_free_transformer": args.model_args.use_padding_free_transformer, "sequence_parallel": args.distributed_args.sequence_parallel, "num_pipeline_stages": num_pipeline_stages, "trust_remote_code": args.model_args.trust_remote_code, diff --git a/lm_engine/model_wrapper/base.py b/lm_engine/model_wrapper/base.py index bfb21ffac..6f80d6315 100644 --- a/lm_engine/model_wrapper/base.py +++ b/lm_engine/model_wrapper/base.py @@ -28,7 +28,6 @@ def __init__( model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM, dtype: torch.dtype, efficient_initialization: bool, - use_padding_free_transformer: bool, sequence_parallel: bool, num_pipeline_stages: int, pipeline_stage_id: int, @@ -45,7 +44,6 @@ def __init__( model_class (AutoModelForCausalLM | AutoModelForSeq2SeqLM): HF model class to use for model loading dtype (torch.dtype): dtype for the model efficient_initialization (bool): whether to use efficient initialization for the model initialization, saves CPU memory - use_padding_free_transformer (bool): whether to use padding free transformer sequence_parallel (bool): whether to use sequence parallel num_pipeline_stages (int): number of stages for the pipeline pipeline_stage_id (int): current pipeline stage id @@ -62,7 +60,6 @@ def __init__( self.model_class = model_class self.efficient_initialization = efficient_initialization self.dtype = dtype - self.use_padding_free_transformer = use_padding_free_transformer self.sequence_parallel = sequence_parallel self.tokenizer_name = self.model_name if tokenizer_name is None else tokenizer_name self.trust_remote_code = trust_remote_code @@ -88,9 +85,6 @@ def __init__( self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() self.model_class = get_model_parallel_class(self.config.model_type) - if self.use_padding_free_transformer: - assert self.is_custom_model, "padding free transformer is not supported with the specified model" - self._setup_tokenizer() self._setup_model() @@ -148,8 +142,7 @@ def _get_model_kwargs(self) -> dict: "flash_attention_2" if is_kernel_allowed(Kernel.flash_attention_2) else "sdpa" ) - if self.use_padding_free_transformer: - model_kwargs["use_padding_free_transformer"] = True + model_kwargs["use_padding_free_transformer"] = True if self.sequence_parallel: model_kwargs["sequence_parallel"] = True if self.trust_remote_code: diff --git a/lm_engine/model_wrapper/distillation.py b/lm_engine/model_wrapper/distillation.py index 8394b1477..14e52d5a4 100644 --- a/lm_engine/model_wrapper/distillation.py +++ b/lm_engine/model_wrapper/distillation.py @@ -31,7 +31,6 @@ def __init__( model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM, dtype: torch.dtype, efficient_initialization: bool, - use_padding_free_transformer: bool, sequence_parallel: bool, micro_batch_size: int, sequence_length: int, @@ -57,7 +56,6 @@ def __init__( model_class (AutoModelForCausalLM | AutoModelForSeq2SeqLM): HF model class to use for model loading dtype (torch.dtype): dtype for the model efficient_initialization (bool): whether to use efficient initialization for the model initialization, saves CPU memory - use_padding_free_transformer (bool): whether to use padding free transformer sequence_parallel (bool): whether to use sequence parallel micro_batch_size (int): micro batch size for pretraining sequence_length (int): sequence length for pretraining @@ -83,7 +81,6 @@ def __init__( model_class=model_class, dtype=dtype, efficient_initialization=efficient_initialization, - use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, micro_batch_size=micro_batch_size, sequence_length=sequence_length, @@ -129,6 +126,8 @@ def forward( student_logits = output.logits del output + assert False + # TODO modify this when TP support is added lm_loss = get_autoregressive_language_modeling_loss( lm_logits=student_logits, @@ -136,7 +135,7 @@ def forward( hidden_states=None, vocab_weight=None, cu_seqlens=None, - use_padding_free_transformer=self.use_padding_free_transformer, + use_padding_free_transformer=True, reduction="sum", shift_logits_and_labels=False, tensor_parallel_enabled=False, diff --git a/lm_engine/model_wrapper/finetuning.py b/lm_engine/model_wrapper/finetuning.py index fdcdb4f64..86e2ba289 100644 --- a/lm_engine/model_wrapper/finetuning.py +++ b/lm_engine/model_wrapper/finetuning.py @@ -54,13 +54,15 @@ def get_loss( tensor_parallel_enabled = ProcessGroupManager.is_tensor_parallel_enabled() use_fused_linear_cross_entropy_kernel = is_kernel_allowed(Kernel.fused_linear_cross_entropy) + assert False + lm_loss = get_autoregressive_language_modeling_loss( lm_logits=None if use_fused_linear_cross_entropy_kernel else model_outputs.logits, labels=labels, hidden_states=model_outputs.last_hidden_state if use_fused_linear_cross_entropy_kernel else None, vocab_weight=self.model.get_output_embeddings().weight if use_fused_linear_cross_entropy_kernel else None, cu_seqlens=cu_seqlens, - use_padding_free_transformer=self.use_padding_free_transformer, + use_padding_free_transformer=True, reduction="sum", shift_logits_and_labels=True, tensor_parallel_enabled=tensor_parallel_enabled, @@ -88,7 +90,7 @@ def _broadcast_inputs_for_tensor_parallel(self, batch: dict) -> dict: tp_source_rank = ProcessGroupManager.get_tensor_parallel_first_rank() tp_group = ProcessGroupManager.get_tensor_parallel_group() - if self.use_padding_free_transformer: + if self.is_custom_model: keys = ["input_ids", "position_ids", "labels", "cu_seqlens", "max_seqlen"] if is_tp_first_rank: diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index 65658129a..f86c4aa8e 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -31,7 +31,6 @@ def __init__( model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM, dtype: torch.dtype, efficient_initialization: bool, - use_padding_free_transformer: bool, sequence_parallel: bool, micro_batch_size: int, sequence_length: int, @@ -52,7 +51,6 @@ def __init__( model_class (AutoModelForCausalLM | AutoModelForSeq2SeqLM): HF model class to use for model loading dtype (torch.dtype): dtype for the model efficient_initialization (bool): whether to use efficient initialization for the model initialization, saves CPU memory - use_padding_free_transformer (bool): whether to use padding free transformer sequence_parallel (bool): whether to use sequence parallel micro_batch_size (int): micro batch size for pretraining sequence_length (int): sequence length for pretraining @@ -77,7 +75,6 @@ def __init__( model_class=model_class, dtype=dtype, efficient_initialization=efficient_initialization, - use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, num_pipeline_stages=num_pipeline_stages, pipeline_stage_id=pipeline_stage_id, @@ -162,7 +159,7 @@ def get_loss( hidden_states=model_outputs.last_hidden_state if use_fused_linear_cross_entropy_kernel else None, vocab_weight=self.model.get_output_embeddings().weight if use_fused_linear_cross_entropy_kernel else None, cu_seqlens=None, - use_padding_free_transformer=self.use_padding_free_transformer, + use_padding_free_transformer=True, reduction="sum", shift_logits_and_labels=False, tensor_parallel_enabled=tensor_parallel_enabled, @@ -225,38 +222,37 @@ def _prepare_model_inputs(self, batch: dict) -> dict: input_ids = tokens[:, :-1] batch = {"labels": tokens[:, 1:]} - if self.use_padding_free_transformer: - batch_size, sequence_length = input_ids.shape - input_ids = input_ids.reshape(-1) - - if self.reset_attention_mask: - num_tokens_in_batch = batch_size * sequence_length - - document_end_positions = input_ids == self.eos_token_id - for i in range(sequence_length - 1, num_tokens_in_batch, sequence_length): - document_end_positions[i] = 1 - cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 - cu_seqlens = torch.cat([torch.tensor([0], device=input_ids.device), cu_seqlens]) - cu_seqlens = cu_seqlens.to(torch.int32) - - seqlen = cu_seqlens[1:] - cu_seqlens[:-1] - # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers - max_seqlen = seqlen.max().item() - - if self.reset_position_ids: - position_ids = torch.cat( - [torch.arange(0, i, 1, dtype=torch.int32, device=input_ids.device) for i in seqlen] - ) - else: - position_ids = self.position_ids + batch_size, sequence_length = input_ids.shape + input_ids = input_ids.reshape(-1) + + if self.reset_attention_mask: + num_tokens_in_batch = batch_size * sequence_length + + document_end_positions = input_ids == self.eos_token_id + for i in range(sequence_length - 1, num_tokens_in_batch, sequence_length): + document_end_positions[i] = 1 + cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 + cu_seqlens = torch.cat([torch.tensor([0], device=input_ids.device), cu_seqlens]) + cu_seqlens = cu_seqlens.to(torch.int32) + + seqlen = cu_seqlens[1:] - cu_seqlens[:-1] + # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers + max_seqlen = seqlen.max().item() + + if self.reset_position_ids: + position_ids = torch.cat( + [torch.arange(0, i, 1, dtype=torch.int32, device=input_ids.device) for i in seqlen] + ) else: - cu_seqlens = self.cu_seqlens - max_seqlen = self.sequence_length position_ids = self.position_ids + else: + cu_seqlens = self.cu_seqlens + max_seqlen = self.sequence_length + position_ids = self.position_ids - batch["cu_seqlens"] = cu_seqlens - batch["max_seqlen"] = max_seqlen - batch["position_ids"] = position_ids + batch["cu_seqlens"] = cu_seqlens + batch["max_seqlen"] = max_seqlen + batch["position_ids"] = position_ids batch["input_ids"] = input_ids @@ -270,37 +266,29 @@ def _setup_model(self) -> None: self.reset_parameters() def reset_parameters(self) -> None: - if self.use_padding_free_transformer: - if not self.reset_attention_mask: - self.register_buffer( - "cu_seqlens", - torch.arange( - 0, - self.micro_batch_size * self.sequence_length + 1, - self.sequence_length, - dtype=torch.int32, - device=torch.cuda.current_device(), - ), - persistent=False, - ) - - if self.reset_position_ids: - assert self.reset_attention_mask, "reset_attention_mask should be specified with reset_position_ids" - else: - self.register_buffer( - "position_ids", - torch.arange(0, self.sequence_length, 1, device=torch.cuda.current_device()).repeat( - self.micro_batch_size - ), - persistent=False, - ) + if not self.reset_attention_mask: + self.register_buffer( + "cu_seqlens", + torch.arange( + 0, + self.micro_batch_size * self.sequence_length + 1, + self.sequence_length, + dtype=torch.int32, + device=torch.cuda.current_device(), + ), + persistent=False, + ) + + if self.reset_position_ids: + assert self.reset_attention_mask, "reset_attention_mask should be specified with reset_position_ids" else: - assert ( - not self.reset_attention_mask - ), "currently reset_attention_mask is only implemented for padding free transformer" - assert ( - not self.reset_position_ids - ), "currently reset_position_ids is only implemented for padding free transformer" + self.register_buffer( + "position_ids", + torch.arange(0, self.sequence_length, 1, device=torch.cuda.current_device()).repeat( + self.micro_batch_size + ), + persistent=False, + ) class _F(torch.autograd.Function): From 3dd0bd41b5cf95bd2eea13e8cedcd11736d39a49 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 28 Sep 2025 19:31:04 -0700 Subject: [PATCH 032/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/__init__.py | 1 + lm_engine/model_wrapper/pretraining.py | 35 +++++++++++++------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index 72a0e2133..c30f5b462 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -29,6 +29,7 @@ mark_parameter_as_no_weight_decay, ) from .register_hf import get_model_parallel_class, is_custom_model, register_model_classes +from .tensor import PackedTensor from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts from .utils import convert_padding_free_lists_to_tensors, disable_generation_cache diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index f86c4aa8e..fb72c175d 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -12,6 +12,7 @@ from ..enums import Kernel from ..hf_models import ( CausalLMOutputWithPast, + PackedTensor, PipelineParallelInput, PipelineParallelOutput, get_autoregressive_language_modeling_loss, @@ -246,15 +247,15 @@ def _prepare_model_inputs(self, batch: dict) -> dict: else: position_ids = self.position_ids else: - cu_seqlens = self.cu_seqlens + cu_seqlens = None max_seqlen = self.sequence_length position_ids = self.position_ids - batch["cu_seqlens"] = cu_seqlens - batch["max_seqlen"] = max_seqlen - batch["position_ids"] = position_ids + batch["input_ids"] = PackedTensor.from_unpacked_tensor( + unpacked_tensor=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) - batch["input_ids"] = input_ids + batch["position_ids"] = position_ids if ProcessGroupManager.is_tensor_parallel_enabled(): batch["output_parallel_lm_logits"] = True @@ -266,18 +267,18 @@ def _setup_model(self) -> None: self.reset_parameters() def reset_parameters(self) -> None: - if not self.reset_attention_mask: - self.register_buffer( - "cu_seqlens", - torch.arange( - 0, - self.micro_batch_size * self.sequence_length + 1, - self.sequence_length, - dtype=torch.int32, - device=torch.cuda.current_device(), - ), - persistent=False, - ) + # if not self.reset_attention_mask: + # self.register_buffer( + # "cu_seqlens", + # torch.arange( + # 0, + # self.micro_batch_size * self.sequence_length + 1, + # self.sequence_length, + # dtype=torch.int32, + # device=torch.cuda.current_device(), + # ), + # persistent=False, + # ) if self.reset_position_ids: assert self.reset_attention_mask, "reset_attention_mask should be specified with reset_position_ids" From 1e19a5ca7da7c7fbc847578b4fe2c83c45089794 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:08:53 -0700 Subject: [PATCH 033/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 2 +- lm_engine/hf_models/tensor.py | 37 ++++++++++++----------- lm_engine/model_wrapper/pretraining.py | 2 +- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index c1995ac07..7baf014df 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -93,7 +93,7 @@ def _sequence_mixer_forward( hidden_states, cache_params=past_key_values, attention_mask=attention_mask ) elif self.sequence_mixer_type in ["gru", "rnn"]: - hidden_states = PackedTensor.from_unpacked_tensor( + hidden_states = PackedTensor.from_torch_tensor( hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 3c42a48b5..7397bed47 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -34,11 +34,12 @@ def __new__( return self @staticmethod - def from_unpacked_tensor( - unpacked_tensor: torch.Tensor, + def from_torch_tensor( + tensor: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, batch_size: int | None = None, + is_packed: bool = False, ) -> PackedTensor: assert batch_size is not None or cu_seqlens is not None assume_ragged = cu_seqlens is not None @@ -51,20 +52,20 @@ def from_unpacked_tensor( assert cu_seqlens.size(0) - 1 == batch_size - packed_tensor = pack_sequence(inputs=unpacked_tensor, cu_seqlens=cu_seqlens) + packed_tensor = pack_sequence(inputs=tensor, cu_seqlens=cu_seqlens) if is_packed else tensor else: - assert unpacked_tensor.size(0) % batch_size == 0 + assert tensor.size(0) == batch_size if max_seqlen is None: - max_seqlen = unpacked_tensor.size(1) + max_seqlen = tensor.size(1) - assert unpacked_tensor.size(1) == max_seqlen + assert tensor.size(1) == max_seqlen - packed_tensor = unpacked_tensor + packed_tensor = tensor packed_tensor = PackedTensor( packed_tensor=packed_tensor, - original_shape=unpacked_tensor.size(), + original_shape=tensor.size(), assume_ragged=assume_ragged, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -73,6 +74,16 @@ def from_unpacked_tensor( return packed_tensor + def to_torch_tensor(self) -> torch.Tensor: + if self.is_ragged_tensor(): + tensor = unpack_sequence( + inputs=self.get_raw_data(), cu_seqlens=self._cu_seqlens, output_shape=self._original_shape + ) + else: + tensor = self.get_raw_data() + + return tensor + def get_num_tokens(self) -> int: T = self.get_raw_data().size(0) if not self.is_ragged_tensor(): @@ -90,16 +101,6 @@ def with_new_data(self, tensor: torch.Tensor) -> PackedTensor: batch_size=self._batch_size, ) - def to_unpacked_tensor(self) -> torch.Tensor: - if self.is_ragged_tensor(): - unpacked_tensor = unpack_sequence( - inputs=self.get_raw_data(), cu_seqlens=self._cu_seqlens, output_shape=self._original_shape - ) - else: - unpacked_tensor = self.get_raw_data() - - return unpacked_tensor - def get_raw_data(self) -> torch.Tensor: return self._tensor diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index fb72c175d..2d3941035 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -251,7 +251,7 @@ def _prepare_model_inputs(self, batch: dict) -> dict: max_seqlen = self.sequence_length position_ids = self.position_ids - batch["input_ids"] = PackedTensor.from_unpacked_tensor( + batch["input_ids"] = PackedTensor.from_torch_tensor( unpacked_tensor=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen ) From fa6345289ac782bf658f4ff09cdb0e40e706fc63 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:14:27 -0700 Subject: [PATCH 034/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 2 +- lm_engine/model_wrapper/pretraining.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 7397bed47..e136d4b9c 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -64,7 +64,7 @@ def from_torch_tensor( packed_tensor = tensor packed_tensor = PackedTensor( - packed_tensor=packed_tensor, + tensor=packed_tensor, original_shape=tensor.size(), assume_ragged=assume_ragged, cu_seqlens=cu_seqlens, diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index 2d3941035..86b381f43 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -224,9 +224,9 @@ def _prepare_model_inputs(self, batch: dict) -> dict: batch = {"labels": tokens[:, 1:]} batch_size, sequence_length = input_ids.shape - input_ids = input_ids.reshape(-1) if self.reset_attention_mask: + input_ids = input_ids.flatten() num_tokens_in_batch = batch_size * sequence_length document_end_positions = input_ids == self.eos_token_id @@ -252,7 +252,7 @@ def _prepare_model_inputs(self, batch: dict) -> dict: position_ids = self.position_ids batch["input_ids"] = PackedTensor.from_torch_tensor( - unpacked_tensor=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + tensor=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size ) batch["position_ids"] = position_ids From c80e8736b96de781b9ae032e97363c198f08712c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:18:57 -0700 Subject: [PATCH 035/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 39 ------------------------ lm_engine/hf_models/mixins/dense/main.py | 25 ++++++--------- 2 files changed, 9 insertions(+), 55 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 3be5b9568..d75b67963 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -46,45 +46,6 @@ def _init_weights(self, module: nn.Module) -> None: if hasattr(module, "reset_parameters"): module.reset_parameters() - # FIXME typing - def prepare_inputs_for_model( - self, - input_ids: torch.Tensor | list[list[int]] | None, - position_ids: torch.Tensor | list[list[int]] | None, - labels: torch.Tensor | list[list[int]] | None, - cu_seqlens: torch.Tensor | None, - max_seqlen: int | None, - past_key_values: tuple[tuple[torch.Tensor]], - attention_mask: torch.Tensor | None, - use_cache: bool, - ) -> tuple[torch.Tensor]: - if self.use_padding_free_transformer: - if isinstance(input_ids, list): - # this is managed internally - error_message = ( - "{variable} should not be passed for flash attention when using List[List[int]] " - "input types attention mask logic is handled internally" - ) - assert cu_seqlens is None, error_message.format(variable="cu_seqlens") - assert max_seqlen is None, error_message.format(variable="max_seqlen") - assert attention_mask is None, error_message.format(variable="attention_mask") - - input_ids, position_ids, labels, cu_seqlens, max_seqlen = convert_padding_free_lists_to_tensors( - input_ids=input_ids, position_ids=position_ids, labels=labels, device=torch.cuda.current_device() - ) - else: - assert ( - cu_seqlens is not None - ), "cu_seqlens needs to be specified when using tensor inputs with padding_free transformer" - assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert max_seqlen is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" - - if use_cache or past_key_values is not None: - raise NotImplementedError("KV caching is not supported with padding_free transformer") - - return input_ids, position_ids, labels, cu_seqlens, max_seqlen - class BaseModelMixin(PreTrainedModelMixin): mask_value = None diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index ad4883430..6d4510e99 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -14,6 +14,7 @@ from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear +from ...tensor import PackedTensor from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .base import PreTrainedModelMixin @@ -56,31 +57,23 @@ def set_output_embeddings(self, new_embeddings: ParameterizedLinear) -> None: def forward( self, - input_ids: torch.Tensor | list[list[int]] | None = None, + input_ids: PackedTensor | torch.Tensor | None = None, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | list[list[int]] | None = None, - inputs_embeds: torch.Tensor | list[list[float]] | None = None, - labels: torch.Tensor | list[list[int]] | None = None, + position_ids: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + labels: torch.Tensor | None = None, use_cache: bool | None = None, return_dict: bool = True, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, reduction: str = "mean", ) -> CausalLMOutputWithPast: assert return_dict assert inputs_embeds is None + assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" + assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" - input_ids, position_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - position_ids=position_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - ) + if use_cache or past_key_values is not None: + raise NotImplementedError("KV caching is not supported with padding_free transformer") clear_aux_loss() From 3a40c9b2c22ac6b3dca2aa2f6aa82ae1f5644f22 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:26:56 -0700 Subject: [PATCH 036/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 31 +++++------------------- lm_engine/hf_models/mixins/dense/main.py | 2 -- 2 files changed, 6 insertions(+), 27 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index d75b67963..51c35b80e 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -13,7 +13,8 @@ from ...cache import GenerationCache from ...config import CommonConfig from ...modeling_utils import ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function -from ...utils import convert_padding_free_lists_to_tensors, is_generation_cache_enabled +from ...tensor import PackedTensor +from ...utils import is_generation_cache_enabled from ..modeling_outputs import BaseModelOutputWithPast from .layer import Block @@ -87,13 +88,11 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: def forward( self, - input_ids: torch.Tensor | None = None, + input_ids: PackedTensor | None = None, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, use_cache: bool | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> BaseModelOutputWithPast: ( use_cache, @@ -108,8 +107,6 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) if is_generation_cache_enabled(): @@ -220,26 +217,16 @@ def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch def _prepare_a_bunch_of_stuff( self, - input_ids: torch.Tensor | None = None, + input_ids: PackedTensor | None = None, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, use_cache: bool | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> tuple[bool, torch.Tensor, torch.Tensor, torch.Tensor | None, GenerationCache | None]: if use_cache is None: use_cache = False if self.use_padding_free_transformer else self.config.use_cache - input_shape = input_ids.size() - - # special handling for padding free transformer with list inputs - if self.use_padding_free_transformer: - # for flash attention, there is no padding and we do packing - # so, input_ids is of shape (s1 + s2 + ... + sb) - batch_size = cu_seqlens.shape[0] - 1 - else: - batch_size = input_shape[0] + batch_size = input_ids.get_batch_size() if self.use_padding_free_transformer: assert position_ids is not None, ( @@ -249,13 +236,7 @@ def _prepare_a_bunch_of_stuff( past_length = None query_length = None - key_length = None - if self.use_padding_free_transformer: - key_length = max_seqlen.item() if isinstance(max_seqlen, torch.Tensor) else max_seqlen - else: - past_length = 0 if past_key_values is None else past_key_values.get_seq_length() - query_length = input_shape[-1] - key_length = past_length + query_length + key_length = input_ids.get_max_seqlen() if position_ids is None: position_ids = self._get_position_ids( diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 6d4510e99..73dd33d1a 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -83,8 +83,6 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) hidden_states = transformer_outputs.last_hidden_state From 82a9c8898425d2c9f5ff7dd2900cf19f85552edd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:28:45 -0700 Subject: [PATCH 037/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 51c35b80e..510561275 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -243,7 +243,8 @@ def _prepare_a_bunch_of_stuff( attention_mask, past_length, query_length, key_length, input_ids.device ) - hidden_states = self._get_initial_hidden_state(input_ids, position_ids) + with input_ids.safe_mode(): + hidden_states = self._get_initial_hidden_state(input_ids, position_ids) rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) From 3df92f5b67d4869b384619b89adcaf0ba08cfea0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:32:08 -0700 Subject: [PATCH 038/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index e136d4b9c..6c0d82f6d 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -137,8 +137,8 @@ def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor: return self._cu_seqlens - @contextmanager @classmethod + @contextmanager def safe_mode(cls): cls._is_safe = True yield From 8085a1743cfc53c572b54cb3af1453f5deff765c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:36:53 -0700 Subject: [PATCH 039/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 6c0d82f6d..c3c18a64a 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -137,6 +137,14 @@ def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor: return self._cu_seqlens + def get_dtype(self) -> torch.dtype: + with self.safe_mode(): + return self.dtype + + def get_device(self) -> torch.device: + with self.safe_mode(): + return super().get_device() + @classmethod @contextmanager def safe_mode(cls): From 06b35fedf6f3b683ba5ccdae289b15da87e5b918 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:38:34 -0700 Subject: [PATCH 040/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 510561275..a0ff722b9 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -246,10 +246,10 @@ def _prepare_a_bunch_of_stuff( with input_ids.safe_mode(): hidden_states = self._get_initial_hidden_state(input_ids, position_ids) - rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) + rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=input_ids.get_dtype()) attention_mask = self._get_maybe_causal_mask( - attention_mask, batch_size, query_length, key_length, hidden_states.dtype, input_ids.device + attention_mask, batch_size, query_length, key_length, input_ids.get_dtype(), input_ids.get_device() ) return ( From b04c4ed20380496ab32c048f050121e46bfd82f8 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:40:09 -0700 Subject: [PATCH 041/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 2 -- lm_engine/hf_models/mixins/dense/layer.py | 6 +----- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index a0ff722b9..dce2ce91a 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -129,8 +129,6 @@ def forward( past_key_values=past_key_values, attention_mask=mamba_mask if is_linear_layer else causal_mask, rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) hidden_states = self.ln_f(hidden_states) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 7baf014df..93e63e430 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -34,12 +34,10 @@ def __init__( def forward( self, - hidden_states: torch.Tensor, + hidden_states: PackedTensor, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -49,8 +47,6 @@ def forward( past_key_values=past_key_values, attention_mask=attention_mask, rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) if self.m_residual is not None: From 2ba48cd19eb067d6e32f1d7a802c14557509751a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:43:24 -0700 Subject: [PATCH 042/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index c3c18a64a..ecd447eb3 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -152,6 +152,10 @@ def safe_mode(cls): yield cls._is_safe = False + @classmethod + def set_safe_mode(cls, enable: bool = False) -> None: + cls._is_safe = enable + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if cls._is_safe: From f6beeade0afe6c7e9ba1d3a924d752afb32794a0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:47:08 -0700 Subject: [PATCH 043/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index ecd447eb3..034769c81 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -145,16 +145,14 @@ def get_device(self) -> torch.device: with self.safe_mode(): return super().get_device() - @classmethod @contextmanager - def safe_mode(cls): - cls._is_safe = True + def safe_mode(self): + self._is_safe = True yield - cls._is_safe = False + self._is_safe = False - @classmethod - def set_safe_mode(cls, enable: bool = False) -> None: - cls._is_safe = enable + def set_safe_mode(self, enable: bool = False) -> None: + self._is_safe = enable @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): From 695c85a20bba3b5af422acfa587704cd796d2155 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:49:46 -0700 Subject: [PATCH 044/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 034769c81..ecd447eb3 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -145,14 +145,16 @@ def get_device(self) -> torch.device: with self.safe_mode(): return super().get_device() + @classmethod @contextmanager - def safe_mode(self): - self._is_safe = True + def safe_mode(cls): + cls._is_safe = True yield - self._is_safe = False + cls._is_safe = False - def set_safe_mode(self, enable: bool = False) -> None: - self._is_safe = enable + @classmethod + def set_safe_mode(cls, enable: bool = False) -> None: + cls._is_safe = enable @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): From 5f3d79087c525a0cab4ac479fb84099be87c8ed2 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 10:52:06 -0700 Subject: [PATCH 045/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 93e63e430..abdd26c6a 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -39,8 +39,12 @@ def forward( attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, ) -> torch.Tensor: + PackedTensor.set_safe_mode(False) + residual = hidden_states - hidden_states = self.ln_1(hidden_states) + + with hidden_states.safe_mode(): + hidden_states = self.ln_1(hidden_states) hidden_states = self._sequence_mixer_forward( hidden_states=hidden_states, @@ -68,7 +72,7 @@ def forward( def _sequence_mixer_forward( self, - hidden_states: torch.Tensor, + hidden_states: PackedTensor, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, @@ -89,16 +93,7 @@ def _sequence_mixer_forward( hidden_states, cache_params=past_key_values, attention_mask=attention_mask ) elif self.sequence_mixer_type in ["gru", "rnn"]: - hidden_states = PackedTensor.from_torch_tensor( - hidden_states, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - batch_size=hidden_states.size(0) if cu_seqlens is None else None, - ) - hidden_states = self.sequence_mixer(hidden_states, cache_params=past_key_values) - - hidden_states = hidden_states.get_raw_data() else: raise ValueError(f"unexpected sequence_mixer_type ({self.sequence_mixer_type})") From f526bb119d01a13ffeea222c9fba19bf28bade2c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 11:00:42 -0700 Subject: [PATCH 046/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index ecd447eb3..c846c0e9f 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -33,6 +33,22 @@ def __new__( return self + def __init__( + self, + tensor: torch.Tensor, + original_shape: tuple[int], + assume_ragged: bool, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + batch_size: int | None = None, + ) -> PackedTensor: + self._tensor = tensor + self._original_shape = original_shape + self._assume_ragged = assume_ragged + self._cu_seqlens = cu_seqlens + self._max_seqlen = max_seqlen + self._batch_size = batch_size + @staticmethod def from_torch_tensor( tensor: torch.Tensor, From 07aa99e367cb77de7f23495db443b143cf666e5f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 11:12:27 -0700 Subject: [PATCH 047/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 43 ++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index c846c0e9f..ab40305c6 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -24,7 +24,6 @@ def __new__( ) -> PackedTensor: self = torch.as_tensor(tensor).as_subclass(cls) - self._tensor = tensor self._original_shape = original_shape self._assume_ragged = assume_ragged self._cu_seqlens = cu_seqlens @@ -33,22 +32,6 @@ def __new__( return self - def __init__( - self, - tensor: torch.Tensor, - original_shape: tuple[int], - assume_ragged: bool, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - batch_size: int | None = None, - ) -> PackedTensor: - self._tensor = tensor - self._original_shape = original_shape - self._assume_ragged = assume_ragged - self._cu_seqlens = cu_seqlens - self._max_seqlen = max_seqlen - self._batch_size = batch_size - @staticmethod def from_torch_tensor( tensor: torch.Tensor, @@ -118,7 +101,7 @@ def with_new_data(self, tensor: torch.Tensor) -> PackedTensor: ) def get_raw_data(self) -> torch.Tensor: - return self._tensor + return self.as_subclass(torch.Tensor) def get_last_element_along_sequence(self) -> torch.Tensor: output = self.get_raw_data() @@ -178,3 +161,27 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return super().__torch_function__(func, types, args, kwargs) raise NotImplementedError("unpack the tensor to run ops on it") + + def __tensor_flatten__(self): + ctx = { + "_tensor": self._tensor, + "_original_shape": self._original_shape, + "_assume_ragged": self._assume_ragged, + "_cu_seqlens": self._cu_seqlens, + "_max_seqlen": self._max_seqlen, + "_batch_size": self._batch_size, + } + + return ["data"], ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors: dict, metadata, outer_size, outer_stride): + assert len(inner_tensors) == 2 + return PackedTensor( + tensor=inner_tensors["data"], + original_shape=metadata["_original_shape"], + assume_ragged=metadata["_assume_ragged"], + cu_seqlens=metadata["_cu_seqlens"], + max_seqlen=metadata["_max_seqlen"], + batch_size=metadata["_batch_size"], + ) From bfc8eb1d6aad0e12ebd0951b7e63f587500555cf Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 11:16:02 -0700 Subject: [PATCH 048/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index ab40305c6..5f7853212 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -32,6 +32,22 @@ def __new__( return self + def __init__( + self, + tensor: torch.Tensor, + original_shape: tuple[int], + assume_ragged: bool, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + batch_size: int | None = None, + ) -> PackedTensor: + self._tensor = tensor + self._original_shape = original_shape + self._assume_ragged = assume_ragged + self._cu_seqlens = cu_seqlens + self._max_seqlen = max_seqlen + self._batch_size = batch_size + @staticmethod def from_torch_tensor( tensor: torch.Tensor, From f98a168ea4dd74001f98952a2c32862d64dffa73 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 11:24:25 -0700 Subject: [PATCH 049/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 5f7853212..996ae5b6a 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -172,9 +172,9 @@ def set_safe_mode(cls, enable: bool = False) -> None: cls._is_safe = enable @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if cls._is_safe: - return super().__torch_function__(func, types, args, kwargs) + return super().__torch_dispatch__(func, types, args, kwargs) raise NotImplementedError("unpack the tensor to run ops on it") From 14e54b6808c388524d35db5ee8099c31e983ad2b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 11:43:41 -0700 Subject: [PATCH 050/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 34 ++-------------------------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 996ae5b6a..64c55c44d 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -16,7 +16,6 @@ class PackedTensor(torch.Tensor): def __new__( cls, tensor: torch.Tensor, - original_shape: tuple[int], assume_ragged: bool, cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, @@ -24,7 +23,6 @@ def __new__( ) -> PackedTensor: self = torch.as_tensor(tensor).as_subclass(cls) - self._original_shape = original_shape self._assume_ragged = assume_ragged self._cu_seqlens = cu_seqlens self._max_seqlen = max_seqlen @@ -35,14 +33,12 @@ def __new__( def __init__( self, tensor: torch.Tensor, - original_shape: tuple[int], assume_ragged: bool, cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, batch_size: int | None = None, ) -> PackedTensor: self._tensor = tensor - self._original_shape = original_shape self._assume_ragged = assume_ragged self._cu_seqlens = cu_seqlens self._max_seqlen = max_seqlen @@ -80,7 +76,6 @@ def from_torch_tensor( packed_tensor = PackedTensor( tensor=packed_tensor, - original_shape=tensor.size(), assume_ragged=assume_ragged, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -89,10 +84,10 @@ def from_torch_tensor( return packed_tensor - def to_torch_tensor(self) -> torch.Tensor: + def to_torch_tensor(self, output_shape: tuple[int]) -> torch.Tensor: if self.is_ragged_tensor(): tensor = unpack_sequence( - inputs=self.get_raw_data(), cu_seqlens=self._cu_seqlens, output_shape=self._original_shape + inputs=self.get_raw_data(), cu_seqlens=self._cu_seqlens, output_shape=output_shape ) else: tensor = self.get_raw_data() @@ -109,7 +104,6 @@ def get_num_tokens(self) -> int: def with_new_data(self, tensor: torch.Tensor) -> PackedTensor: return PackedTensor( tensor=tensor, - original_shape=self._original_shape, assume_ragged=self._assume_ragged, cu_seqlens=self._cu_seqlens, max_seqlen=self._max_seqlen, @@ -177,27 +171,3 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return super().__torch_dispatch__(func, types, args, kwargs) raise NotImplementedError("unpack the tensor to run ops on it") - - def __tensor_flatten__(self): - ctx = { - "_tensor": self._tensor, - "_original_shape": self._original_shape, - "_assume_ragged": self._assume_ragged, - "_cu_seqlens": self._cu_seqlens, - "_max_seqlen": self._max_seqlen, - "_batch_size": self._batch_size, - } - - return ["data"], ctx - - @staticmethod - def __tensor_unflatten__(inner_tensors: dict, metadata, outer_size, outer_stride): - assert len(inner_tensors) == 2 - return PackedTensor( - tensor=inner_tensors["data"], - original_shape=metadata["_original_shape"], - assume_ragged=metadata["_assume_ragged"], - cu_seqlens=metadata["_cu_seqlens"], - max_seqlen=metadata["_max_seqlen"], - batch_size=metadata["_batch_size"], - ) From 88b0595858fddbd689e5d38b8e1aea32c6aca8c3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 20:57:24 -0700 Subject: [PATCH 051/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 64c55c44d..1034f2f8c 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -168,6 +168,39 @@ def set_safe_mode(cls, enable: bool = False) -> None: @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if cls._is_safe: - return super().__torch_dispatch__(func, types, args, kwargs) + output = super().__torch_dispatch__(func, types, args, kwargs) + + if isinstance(output, PackedTensor): + return output + else: + arg_packed = None + + for arg in args: + if isinstance(arg, PackedTensor): + arg_packed = arg + break + + if arg_packed is None: + for arg in kwargs.values(): + if isinstance(arg, PackedTensor): + arg_packed = arg + break + + if isinstance(output, torch.Tensor): + return arg_packed.with_new_data(output) + elif isinstance(output, (tuple, list)): + output_packed = [] + for i in output: + if isinstance(i, PackedTensor): + output_packed.append(i) + elif isinstance(i, torch.Tensor): + i = arg_packed.with_new_data(i) + output_packed.append(i) + else: + output_packed.append(i) + + return output_packed + else: + raise ValueError("unexpected output type") raise NotImplementedError("unpack the tensor to run ops on it") From b4a2b5ebde793b4f803e4a5d602dc9c75257ec5a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 23:53:06 -0700 Subject: [PATCH 052/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 152 +++++++--------------------------- 1 file changed, 28 insertions(+), 124 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 1034f2f8c..7dd34595c 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -4,45 +4,19 @@ from __future__ import annotations -from contextlib import contextmanager +from dataclasses import dataclass import torch from fma import pack_sequence, unpack_sequence -class PackedTensor(torch.Tensor): - _is_safe = False - - def __new__( - cls, - tensor: torch.Tensor, - assume_ragged: bool, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - batch_size: int | None = None, - ) -> PackedTensor: - self = torch.as_tensor(tensor).as_subclass(cls) - - self._assume_ragged = assume_ragged - self._cu_seqlens = cu_seqlens - self._max_seqlen = max_seqlen - self._batch_size = batch_size - - return self - - def __init__( - self, - tensor: torch.Tensor, - assume_ragged: bool, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - batch_size: int | None = None, - ) -> PackedTensor: - self._tensor = tensor - self._assume_ragged = assume_ragged - self._cu_seqlens = cu_seqlens - self._max_seqlen = max_seqlen - self._batch_size = batch_size +@dataclass +class PackedTensor: + tensor: torch.Tensor + assume_ragged: bool + cu_seqlens: torch.Tensor | None = None + max_seqlen: int | None = None + batch_size: int | None = None @staticmethod def from_torch_tensor( @@ -85,122 +59,52 @@ def from_torch_tensor( return packed_tensor def to_torch_tensor(self, output_shape: tuple[int]) -> torch.Tensor: - if self.is_ragged_tensor(): - tensor = unpack_sequence( - inputs=self.get_raw_data(), cu_seqlens=self._cu_seqlens, output_shape=output_shape - ) + if self.assume_ragged: + tensor = unpack_sequence(inputs=self.tensor, cu_seqlens=self.cu_seqlens, output_shape=output_shape) else: - tensor = self.get_raw_data() + tensor = self.tensor return tensor def get_num_tokens(self) -> int: - T = self.get_raw_data().size(0) - if not self.is_ragged_tensor(): - T *= self.get_raw_data().size(1) + T = self.tensor.size(0) + if not self.assume_ragged: + T *= self.tensor.size(1) return T def with_new_data(self, tensor: torch.Tensor) -> PackedTensor: return PackedTensor( tensor=tensor, - assume_ragged=self._assume_ragged, - cu_seqlens=self._cu_seqlens, - max_seqlen=self._max_seqlen, - batch_size=self._batch_size, + assume_ragged=self.assume_ragged, + cu_seqlens=self.cu_seqlens, + max_seqlen=self.max_seqlen, + batch_size=self.batch_size, ) - def get_raw_data(self) -> torch.Tensor: - return self.as_subclass(torch.Tensor) - def get_last_element_along_sequence(self) -> torch.Tensor: - output = self.get_raw_data() + output = self.tensor - if self.is_ragged_tensor(): - output = output[self.get_cu_seqlens()[1:] - 1] + if self.assume_ragged: + output = output[self.getcu_seqlens()[1:] - 1] else: output = output[:, -1] return output - def is_ragged_tensor(self) -> bool: - return self._assume_ragged - - def get_batch_size(self) -> int: - return self._batch_size - def get_max_seqlen(self, return_none_allowed: bool = True) -> int: - if return_none_allowed and not self.is_ragged_tensor(): + if return_none_allowed and not self.assume_ragged: return None - return self._max_seqlen + return self.max_seqlen def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor: - if return_none_allowed and not self.is_ragged_tensor(): + if return_none_allowed and not self.assume_ragged: return None - if self._cu_seqlens is None: - self._cu_seqlens = torch.arange( - 0, self._batch_size * self._max_seqlen + 1, self._max_seqlen, device=self.device + if self.cu_seqlens is None: + self.cu_seqlens = torch.arange( + 0, self.batch_size * self.max_seqlen + 1, self.max_seqlen, device=self.tensor.device ) - return self._cu_seqlens - - def get_dtype(self) -> torch.dtype: - with self.safe_mode(): - return self.dtype - - def get_device(self) -> torch.device: - with self.safe_mode(): - return super().get_device() - - @classmethod - @contextmanager - def safe_mode(cls): - cls._is_safe = True - yield - cls._is_safe = False - - @classmethod - def set_safe_mode(cls, enable: bool = False) -> None: - cls._is_safe = enable - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - if cls._is_safe: - output = super().__torch_dispatch__(func, types, args, kwargs) - - if isinstance(output, PackedTensor): - return output - else: - arg_packed = None - - for arg in args: - if isinstance(arg, PackedTensor): - arg_packed = arg - break - - if arg_packed is None: - for arg in kwargs.values(): - if isinstance(arg, PackedTensor): - arg_packed = arg - break - - if isinstance(output, torch.Tensor): - return arg_packed.with_new_data(output) - elif isinstance(output, (tuple, list)): - output_packed = [] - for i in output: - if isinstance(i, PackedTensor): - output_packed.append(i) - elif isinstance(i, torch.Tensor): - i = arg_packed.with_new_data(i) - output_packed.append(i) - else: - output_packed.append(i) - - return output_packed - else: - raise ValueError("unexpected output type") - - raise NotImplementedError("unpack the tensor to run ops on it") + return self.cu_seqlens From 0f8b0aaa19d620b3344968ddd86e9b07b310dc19 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 23:54:42 -0700 Subject: [PATCH 053/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index dce2ce91a..f2e08ac7e 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -224,7 +224,7 @@ def _prepare_a_bunch_of_stuff( if use_cache is None: use_cache = False if self.use_padding_free_transformer else self.config.use_cache - batch_size = input_ids.get_batch_size() + B = input_ids.batch_size if self.use_padding_free_transformer: assert position_ids is not None, ( @@ -247,7 +247,7 @@ def _prepare_a_bunch_of_stuff( rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=input_ids.get_dtype()) attention_mask = self._get_maybe_causal_mask( - attention_mask, batch_size, query_length, key_length, input_ids.get_dtype(), input_ids.get_device() + attention_mask, B, query_length, key_length, input_ids.tensor.dtype, input_ids.tensor.device ) return ( From 7239f5b02aff9967933e83bfc4e6bb06973e03db Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 23:56:17 -0700 Subject: [PATCH 054/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index f2e08ac7e..e7adcdb6a 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -241,13 +241,14 @@ def _prepare_a_bunch_of_stuff( attention_mask, past_length, query_length, key_length, input_ids.device ) - with input_ids.safe_mode(): - hidden_states = self._get_initial_hidden_state(input_ids, position_ids) + hidden_states = self._get_initial_hidden_state(input_ids.tensor, position_ids) + hidden_states = input_ids.with_new_data(hidden_states) + del input_ids - rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=input_ids.get_dtype()) + rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.tensor.dtype) attention_mask = self._get_maybe_causal_mask( - attention_mask, B, query_length, key_length, input_ids.tensor.dtype, input_ids.tensor.device + attention_mask, B, query_length, key_length, hidden_states.tensor.dtype, hidden_states.tensor.device ) return ( From bf191e208be6d931f90b3f4444fe7c2c39c1149a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 29 Sep 2025 23:58:47 -0700 Subject: [PATCH 055/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index abdd26c6a..5be92512a 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -39,12 +39,8 @@ def forward( attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, ) -> torch.Tensor: - PackedTensor.set_safe_mode(False) - - residual = hidden_states - - with hidden_states.safe_mode(): - hidden_states = self.ln_1(hidden_states) + residual = hidden_states.tensor + hidden_states.tensor = self.ln_1(hidden_states.tensor) hidden_states = self._sequence_mixer_forward( hidden_states=hidden_states, @@ -54,19 +50,19 @@ def forward( ) if self.m_residual is not None: - hidden_states = hidden_states * self.m_residual + hidden_states.tensor = hidden_states.tensor * self.m_residual - hidden_states = hidden_states + residual + hidden_states.tensor = hidden_states.tensor + residual - residual = hidden_states - hidden_states = self.ln_2(hidden_states) + residual = hidden_states.tensor + hidden_states.tensor = self.ln_2(hidden_states.tensor) - hidden_states = self.mlp_block(hidden_states) + hidden_states.tensor = self.mlp_block(hidden_states.tensor) if self.m_residual is not None: - hidden_states = hidden_states * self.m_residual + hidden_states.tensor = hidden_states.tensor * self.m_residual - hidden_states = hidden_states + residual + hidden_states.tensor = hidden_states.tensor + residual return hidden_states From 42c3efb313b19b2596b993d0335d5d43ef37dceb Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:06:46 -0700 Subject: [PATCH 056/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/rnn.py | 43 +++++++++---------- lm_engine/hf_models/tensor.py | 18 +++++--- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 7de4ebf89..970ebcfad 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -77,31 +77,29 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) -> PackedTensor: - state = None if cache_params is None else cache_params.get_cache(self.layer_idx) - T = x.get_num_tokens() + cu_seqlens = x.cu_seqlens + max_seqlen = x.max_seqlen + x: torch.Tensor = x.tensor - with x.safe_mode(): - x = self.input_projection(x) - x, g = x.chunk(2, dim=-1) - x = x.view(T, self.num_heads, self.state_head_dim) + x = self.input_projection(x) + x, g = x.tensor.chunk(2, dim=-1) + x = x.view(*x.size()[:-1], self.num_heads, self.state_head_dim) - if self.scaling_factor != 1: - x = x * self.scaling_factor + if self.scaling_factor != 1: + x = x * self.scaling_factor weight = self.state_weight if self.scaling_factor != 1: weight = weight * self.scaling_factor - x = x.with_new_data( - rnn( - input=x.get_raw_data(), - weight=weight, - input_state=state, - gradient_clipping=self.gradient_clipping, - cu_seqlens=x.get_cu_seqlens(), - max_seqlen=x.get_max_seqlen(), - kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.rnn) else KernelBackend.torch, - ) + x = rnn( + input=x, + weight=weight, + input_state=None if cache_params is None else cache_params.get_cache(self.layer_idx), + gradient_clipping=self.gradient_clipping, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.rnn) else KernelBackend.torch, ) if cache_params is not None: @@ -111,11 +109,10 @@ def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) layer_idx=self.layer_idx, ) - with x.safe_mode(): - x = x.view(T, -1) - x = x * F.silu(g) - x = self.norm(x) - x = self.output_projection(x) + x = x.flatten(-2, -1) + x = x * F.silu(g) + x = self.norm(x) + x = self.output_projection(x) return x diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 7dd34595c..93c36a98c 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -82,15 +82,23 @@ def with_new_data(self, tensor: torch.Tensor) -> PackedTensor: batch_size=self.batch_size, ) - def get_last_element_along_sequence(self) -> torch.Tensor: - output = self.tensor + def get_last_element_along_sequence( + self, tensor: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None + ) -> torch.Tensor: + if tensor is None: + assert cu_seqlens is None + + tensor = self.tensor + cu_seqlens = self.cu_seqlens + else: + assert cu_seqlens is not None if self.assume_ragged: - output = output[self.getcu_seqlens()[1:] - 1] + tensor = tensor[cu_seqlens[1:] - 1] else: - output = output[:, -1] + tensor = tensor[:, -1] - return output + return tensor def get_max_seqlen(self, return_none_allowed: bool = True) -> int: if return_none_allowed and not self.assume_ragged: From f04ae6bb0b650e80f6dad03dcb186d0a4cce95af Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:07:25 -0700 Subject: [PATCH 057/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 970ebcfad..40d87cc9c 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -82,7 +82,7 @@ def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) x: torch.Tensor = x.tensor x = self.input_projection(x) - x, g = x.tensor.chunk(2, dim=-1) + x, g = x.chunk(2, dim=-1) x = x.view(*x.size()[:-1], self.num_heads, self.state_head_dim) if self.scaling_factor != 1: From 7f7f0a2b14f48b6080c29046ee1ec49dd6e3dc88 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:10:02 -0700 Subject: [PATCH 058/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 93c36a98c..94cc61ae7 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -14,10 +14,11 @@ class PackedTensor: tensor: torch.Tensor assume_ragged: bool - cu_seqlens: torch.Tensor | None = None - max_seqlen: int | None = None batch_size: int | None = None + _cu_seqlens: torch.Tensor | None = None + _max_seqlen: int | None = None + @staticmethod def from_torch_tensor( tensor: torch.Tensor, @@ -60,7 +61,7 @@ def from_torch_tensor( def to_torch_tensor(self, output_shape: tuple[int]) -> torch.Tensor: if self.assume_ragged: - tensor = unpack_sequence(inputs=self.tensor, cu_seqlens=self.cu_seqlens, output_shape=output_shape) + tensor = unpack_sequence(inputs=self.tensor, cu_seqlens=self._cu_seqlens, output_shape=output_shape) else: tensor = self.tensor @@ -77,8 +78,8 @@ def with_new_data(self, tensor: torch.Tensor) -> PackedTensor: return PackedTensor( tensor=tensor, assume_ragged=self.assume_ragged, - cu_seqlens=self.cu_seqlens, - max_seqlen=self.max_seqlen, + cu_seqlens=self._cu_seqlens, + max_seqlen=self._max_seqlen, batch_size=self.batch_size, ) @@ -89,7 +90,7 @@ def get_last_element_along_sequence( assert cu_seqlens is None tensor = self.tensor - cu_seqlens = self.cu_seqlens + cu_seqlens = self._cu_seqlens else: assert cu_seqlens is not None @@ -104,15 +105,15 @@ def get_max_seqlen(self, return_none_allowed: bool = True) -> int: if return_none_allowed and not self.assume_ragged: return None - return self.max_seqlen + return self._max_seqlen def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor: if return_none_allowed and not self.assume_ragged: return None - if self.cu_seqlens is None: - self.cu_seqlens = torch.arange( - 0, self.batch_size * self.max_seqlen + 1, self.max_seqlen, device=self.tensor.device + if self._cu_seqlens is None: + self._cu_seqlens = torch.arange( + 0, self.batch_size * self._max_seqlen + 1, self._max_seqlen, device=self.tensor.device ) - return self.cu_seqlens + return self._cu_seqlens From 42147551c662ee4944a52802928ba6a4e36908b4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:10:27 -0700 Subject: [PATCH 059/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/rnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 40d87cc9c..97b06390b 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -77,8 +77,8 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) -> PackedTensor: - cu_seqlens = x.cu_seqlens - max_seqlen = x.max_seqlen + cu_seqlens = x.get_cu_seqlens() + max_seqlen = x.get_max_seqlen() x: torch.Tensor = x.tensor x = self.input_projection(x) From 3e970873c5ddf1b2a1bc976bb172781b6ad8be2b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:12:18 -0700 Subject: [PATCH 060/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 94cc61ae7..977ddcb31 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -13,9 +13,9 @@ @dataclass class PackedTensor: tensor: torch.Tensor - assume_ragged: bool batch_size: int | None = None + _assume_ragged: bool _cu_seqlens: torch.Tensor | None = None _max_seqlen: int | None = None @@ -51,16 +51,16 @@ def from_torch_tensor( packed_tensor = PackedTensor( tensor=packed_tensor, - assume_ragged=assume_ragged, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, batch_size=batch_size, + _assume_ragged=assume_ragged, + _cu_seqlens=cu_seqlens, + _max_seqlen=max_seqlen, ) return packed_tensor def to_torch_tensor(self, output_shape: tuple[int]) -> torch.Tensor: - if self.assume_ragged: + if self._assume_ragged: tensor = unpack_sequence(inputs=self.tensor, cu_seqlens=self._cu_seqlens, output_shape=output_shape) else: tensor = self.tensor @@ -69,7 +69,7 @@ def to_torch_tensor(self, output_shape: tuple[int]) -> torch.Tensor: def get_num_tokens(self) -> int: T = self.tensor.size(0) - if not self.assume_ragged: + if not self._assume_ragged: T *= self.tensor.size(1) return T @@ -77,10 +77,10 @@ def get_num_tokens(self) -> int: def with_new_data(self, tensor: torch.Tensor) -> PackedTensor: return PackedTensor( tensor=tensor, - assume_ragged=self.assume_ragged, - cu_seqlens=self._cu_seqlens, - max_seqlen=self._max_seqlen, batch_size=self.batch_size, + _assume_ragged=self._assume_ragged, + _cu_seqlens=self._cu_seqlens, + _max_seqlen=self._max_seqlen, ) def get_last_element_along_sequence( @@ -94,7 +94,7 @@ def get_last_element_along_sequence( else: assert cu_seqlens is not None - if self.assume_ragged: + if self._assume_ragged: tensor = tensor[cu_seqlens[1:] - 1] else: tensor = tensor[:, -1] @@ -102,13 +102,13 @@ def get_last_element_along_sequence( return tensor def get_max_seqlen(self, return_none_allowed: bool = True) -> int: - if return_none_allowed and not self.assume_ragged: + if return_none_allowed and not self._assume_ragged: return None return self._max_seqlen def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor: - if return_none_allowed and not self.assume_ragged: + if return_none_allowed and not self._assume_ragged: return None if self._cu_seqlens is None: From 83e406311107594365c524ee22f20797fe7a3fb4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:13:47 -0700 Subject: [PATCH 061/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index 977ddcb31..be7b45249 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -15,7 +15,7 @@ class PackedTensor: tensor: torch.Tensor batch_size: int | None = None - _assume_ragged: bool + _assume_ragged: bool | None = None _cu_seqlens: torch.Tensor | None = None _max_seqlen: int | None = None From b1dff71ec28319a6df4b48fadd27793fa9b7cb66 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:19:26 -0700 Subject: [PATCH 062/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/rnn.py | 12 +++++++----- lm_engine/hf_models/tensor.py | 8 ++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 97b06390b..cf7cbd90f 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -76,10 +76,10 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) - def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) -> PackedTensor: - cu_seqlens = x.get_cu_seqlens() - max_seqlen = x.get_max_seqlen() - x: torch.Tensor = x.tensor + def forward(self, x_packed: PackedTensor, cache_params: GenerationCache | None = None) -> PackedTensor: + cu_seqlens = x_packed.get_cu_seqlens() + max_seqlen = x_packed.get_max_seqlen() + x: torch.Tensor = x_packed.get_underlying_tensor(True) x = self.input_projection(x) x, g = x.chunk(2, dim=-1) @@ -114,7 +114,9 @@ def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) x = self.norm(x) x = self.output_projection(x) - return x + x_packed = x_packed.with_new_data(x) + + return x_packed @torch.no_grad() def reset_parameters(self) -> None: diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py index be7b45249..58b35e35c 100644 --- a/lm_engine/hf_models/tensor.py +++ b/lm_engine/hf_models/tensor.py @@ -101,6 +101,14 @@ def get_last_element_along_sequence( return tensor + def get_underlying_tensor(self, set_to_none: bool = False) -> torch.Tensor: + tensor = self.tensor + + if set_to_none: + self.tensor = None + + return tensor + def get_max_seqlen(self, return_none_allowed: bool = True) -> int: if return_none_allowed and not self._assume_ragged: return None From d6dffd3962b75d541cedf9786bd53e632da750d7 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:26:34 -0700 Subject: [PATCH 063/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 4 ++-- lm_engine/hf_models/mixins/dense/main.py | 12 ++++++------ lm_engine/hf_models/mixins/modeling_outputs.py | 4 +++- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index e7adcdb6a..ed741691e 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -124,14 +124,14 @@ def forward( mamba_mask = self._get_mamba_mask(attention_mask, past_key_values) mamba_mask_computed = True - hidden_states = block( + hidden_states: PackedTensor = block( hidden_states, past_key_values=past_key_values, attention_mask=mamba_mask if is_linear_layer else causal_mask, rope_cos_sin=rope_cos_sin, ) - hidden_states = self.ln_f(hidden_states) + hidden_states.tensor = self.ln_f(hidden_states.tensor) return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 73dd33d1a..38451cf08 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -95,26 +95,26 @@ def forward( if labels is None: if is_kernel_allowed(Kernel.fused_linear_cross_entropy): if self.m_width is not None: - hidden_states = hidden_states / self.m_width + hidden_states.tensor = hidden_states.tensor / self.m_width else: - lm_logits = self.get_lm_logits(hidden_states) + lm_logits = hidden_states.with_new_data(self.get_lm_logits(hidden_states.tensor)) if self.m_width is not None: - lm_logits = lm_logits / self.m_width + lm_logits.tensor = lm_logits.tensor / self.m_width else: assert not is_kernel_allowed(Kernel.fused_linear_cross_entropy) - lm_logits = self.get_lm_logits(hidden_states) + lm_logits = hidden_states.with_new_data(self.get_lm_logits(hidden_states.tensor)) if self.m_width is not None: - lm_logits = lm_logits / self.m_width + lm_logits.tensor = lm_logits.tensor / self.m_width loss = get_autoregressive_language_modeling_loss( lm_logits=lm_logits, labels=labels, hidden_states=None, vocab_weight=None, - cu_seqlens=cu_seqlens, + cu_seqlens=lm_logits.get_cu_seqlens(), use_padding_free_transformer=self.use_padding_free_transformer, reduction=reduction, shift_logits_and_labels=True, diff --git a/lm_engine/hf_models/mixins/modeling_outputs.py b/lm_engine/hf_models/mixins/modeling_outputs.py index e63688aa8..9ecbc8d13 100644 --- a/lm_engine/hf_models/mixins/modeling_outputs.py +++ b/lm_engine/hf_models/mixins/modeling_outputs.py @@ -7,10 +7,12 @@ import torch from transformers.modeling_outputs import ModelOutput +from ..tensor import PackedTensor + @dataclass class BaseModelOutputWithPast(ModelOutput): - last_hidden_state: torch.Tensor | None = None + last_hidden_state: PackedTensor | None = None past_key_values: tuple[tuple[torch.Tensor]] | None = None From f838cead1dcaab80a88a73c77cb813548f5e250b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:29:13 -0700 Subject: [PATCH 064/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 9 ++++++++- .../modeling_utils/sequence_mixer_blocks/rnn.py | 17 ++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 5be92512a..0336ec7eb 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -89,7 +89,14 @@ def _sequence_mixer_forward( hidden_states, cache_params=past_key_values, attention_mask=attention_mask ) elif self.sequence_mixer_type in ["gru", "rnn"]: - hidden_states = self.sequence_mixer(hidden_states, cache_params=past_key_values) + hidden_states = hidden_states.with_new_data( + self.sequence_mixer( + x=hidden_states.tensor, + cu_seqlens=hidden_states.get_cu_seqlens(), + max_seqlen=hidden_states.get_max_seqlen(), + cache_params=past_key_values, + ) + ) else: raise ValueError(f"unexpected sequence_mixer_type ({self.sequence_mixer_type})") diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index cf7cbd90f..17cf1692e 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -15,7 +15,6 @@ from ....utils import divide_if_divisible, is_fma_available from ...cache import GenerationCache from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay -from ...tensor import PackedTensor from ..linear import ParameterizedLinear from ..normalization import get_normalization_function @@ -76,11 +75,13 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) - def forward(self, x_packed: PackedTensor, cache_params: GenerationCache | None = None) -> PackedTensor: - cu_seqlens = x_packed.get_cu_seqlens() - max_seqlen = x_packed.get_max_seqlen() - x: torch.Tensor = x_packed.get_underlying_tensor(True) - + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + cache_params: GenerationCache | None = None, + ) -> torch.Tensor: x = self.input_projection(x) x, g = x.chunk(2, dim=-1) x = x.view(*x.size()[:-1], self.num_heads, self.state_head_dim) @@ -114,9 +115,7 @@ def forward(self, x_packed: PackedTensor, cache_params: GenerationCache | None = x = self.norm(x) x = self.output_projection(x) - x_packed = x_packed.with_new_data(x) - - return x_packed + return x @torch.no_grad() def reset_parameters(self) -> None: From 15c0824e2106568e10291b30330879471d9788a3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:31:09 -0700 Subject: [PATCH 065/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 0336ec7eb..959cbb712 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -47,6 +47,8 @@ def forward( past_key_values=past_key_values, attention_mask=attention_mask, rope_cos_sin=rope_cos_sin, + cu_seqlens=hidden_states.get_cu_seqlens(), + max_seqlen=hidden_states.get_max_seqlen(), ) if self.m_residual is not None: @@ -68,7 +70,7 @@ def forward( def _sequence_mixer_forward( self, - hidden_states: PackedTensor, + hidden_states: torch.Tensor, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, @@ -89,13 +91,8 @@ def _sequence_mixer_forward( hidden_states, cache_params=past_key_values, attention_mask=attention_mask ) elif self.sequence_mixer_type in ["gru", "rnn"]: - hidden_states = hidden_states.with_new_data( - self.sequence_mixer( - x=hidden_states.tensor, - cu_seqlens=hidden_states.get_cu_seqlens(), - max_seqlen=hidden_states.get_max_seqlen(), - cache_params=past_key_values, - ) + hidden_states = self.sequence_mixer( + x=hidden_states.tensor, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, cache_params=past_key_values ) else: raise ValueError(f"unexpected sequence_mixer_type ({self.sequence_mixer_type})") From 3aed85901ac2cb539f62907c7b2ba71ec63fac5e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:34:07 -0700 Subject: [PATCH 066/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 29 ++++++++++++----------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 959cbb712..4df7aeb88 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -10,7 +10,6 @@ from ...cache import GenerationCache from ...config import CommonConfig from ...modeling_utils import get_mlp_block, get_normalization_function, get_sequence_mixer -from ...tensor import PackedTensor class Block(nn.Module): @@ -34,37 +33,39 @@ def __init__( def forward( self, - hidden_states: PackedTensor, + hidden_states: torch.Tensor, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, ) -> torch.Tensor: - residual = hidden_states.tensor - hidden_states.tensor = self.ln_1(hidden_states.tensor) + residual = hidden_states + hidden_states = self.ln_1(hidden_states) hidden_states = self._sequence_mixer_forward( hidden_states=hidden_states, past_key_values=past_key_values, attention_mask=attention_mask, rope_cos_sin=rope_cos_sin, - cu_seqlens=hidden_states.get_cu_seqlens(), - max_seqlen=hidden_states.get_max_seqlen(), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, ) if self.m_residual is not None: - hidden_states.tensor = hidden_states.tensor * self.m_residual + hidden_states = hidden_states * self.m_residual - hidden_states.tensor = hidden_states.tensor + residual + hidden_states = hidden_states + residual - residual = hidden_states.tensor - hidden_states.tensor = self.ln_2(hidden_states.tensor) + residual = hidden_states + hidden_states = self.ln_2(hidden_states) - hidden_states.tensor = self.mlp_block(hidden_states.tensor) + hidden_states = self.mlp_block(hidden_states) if self.m_residual is not None: - hidden_states.tensor = hidden_states.tensor * self.m_residual + hidden_states = hidden_states * self.m_residual - hidden_states.tensor = hidden_states.tensor + residual + hidden_states = hidden_states + residual return hidden_states @@ -92,7 +93,7 @@ def _sequence_mixer_forward( ) elif self.sequence_mixer_type in ["gru", "rnn"]: hidden_states = self.sequence_mixer( - x=hidden_states.tensor, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, cache_params=past_key_values + x=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, cache_params=past_key_values ) else: raise ValueError(f"unexpected sequence_mixer_type ({self.sequence_mixer_type})") From 012e42e7a6f31aff5b1b1c7b217901e6600edd0a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:35:03 -0700 Subject: [PATCH 067/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index ed741691e..d573cc8a3 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -13,7 +13,6 @@ from ...cache import GenerationCache from ...config import CommonConfig from ...modeling_utils import ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function -from ...tensor import PackedTensor from ...utils import is_generation_cache_enabled from ..modeling_outputs import BaseModelOutputWithPast from .layer import Block @@ -88,7 +87,7 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: def forward( self, - input_ids: PackedTensor | None = None, + input_ids: torch.Tensor | None = None, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, @@ -124,7 +123,7 @@ def forward( mamba_mask = self._get_mamba_mask(attention_mask, past_key_values) mamba_mask_computed = True - hidden_states: PackedTensor = block( + hidden_states: torch.Tensor = block( hidden_states, past_key_values=past_key_values, attention_mask=mamba_mask if is_linear_layer else causal_mask, @@ -215,7 +214,7 @@ def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch def _prepare_a_bunch_of_stuff( self, - input_ids: PackedTensor | None = None, + input_ids: torch.Tensor | None = None, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, From 14b9ff71ef5941f04a8f8a391e0f65c9f46e2b09 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:36:26 -0700 Subject: [PATCH 068/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/tensor.py | 127 ---------------------------------- 1 file changed, 127 deletions(-) delete mode 100644 lm_engine/hf_models/tensor.py diff --git a/lm_engine/hf_models/tensor.py b/lm_engine/hf_models/tensor.py deleted file mode 100644 index 58b35e35c..000000000 --- a/lm_engine/hf_models/tensor.py +++ /dev/null @@ -1,127 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -from dataclasses import dataclass - -import torch -from fma import pack_sequence, unpack_sequence - - -@dataclass -class PackedTensor: - tensor: torch.Tensor - batch_size: int | None = None - - _assume_ragged: bool | None = None - _cu_seqlens: torch.Tensor | None = None - _max_seqlen: int | None = None - - @staticmethod - def from_torch_tensor( - tensor: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - batch_size: int | None = None, - is_packed: bool = False, - ) -> PackedTensor: - assert batch_size is not None or cu_seqlens is not None - assume_ragged = cu_seqlens is not None - - if assume_ragged: - assert max_seqlen is not None - - if batch_size is None: - batch_size = cu_seqlens.size(0) - 1 - - assert cu_seqlens.size(0) - 1 == batch_size - - packed_tensor = pack_sequence(inputs=tensor, cu_seqlens=cu_seqlens) if is_packed else tensor - else: - assert tensor.size(0) == batch_size - - if max_seqlen is None: - max_seqlen = tensor.size(1) - - assert tensor.size(1) == max_seqlen - - packed_tensor = tensor - - packed_tensor = PackedTensor( - tensor=packed_tensor, - batch_size=batch_size, - _assume_ragged=assume_ragged, - _cu_seqlens=cu_seqlens, - _max_seqlen=max_seqlen, - ) - - return packed_tensor - - def to_torch_tensor(self, output_shape: tuple[int]) -> torch.Tensor: - if self._assume_ragged: - tensor = unpack_sequence(inputs=self.tensor, cu_seqlens=self._cu_seqlens, output_shape=output_shape) - else: - tensor = self.tensor - - return tensor - - def get_num_tokens(self) -> int: - T = self.tensor.size(0) - if not self._assume_ragged: - T *= self.tensor.size(1) - - return T - - def with_new_data(self, tensor: torch.Tensor) -> PackedTensor: - return PackedTensor( - tensor=tensor, - batch_size=self.batch_size, - _assume_ragged=self._assume_ragged, - _cu_seqlens=self._cu_seqlens, - _max_seqlen=self._max_seqlen, - ) - - def get_last_element_along_sequence( - self, tensor: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None - ) -> torch.Tensor: - if tensor is None: - assert cu_seqlens is None - - tensor = self.tensor - cu_seqlens = self._cu_seqlens - else: - assert cu_seqlens is not None - - if self._assume_ragged: - tensor = tensor[cu_seqlens[1:] - 1] - else: - tensor = tensor[:, -1] - - return tensor - - def get_underlying_tensor(self, set_to_none: bool = False) -> torch.Tensor: - tensor = self.tensor - - if set_to_none: - self.tensor = None - - return tensor - - def get_max_seqlen(self, return_none_allowed: bool = True) -> int: - if return_none_allowed and not self._assume_ragged: - return None - - return self._max_seqlen - - def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor: - if return_none_allowed and not self._assume_ragged: - return None - - if self._cu_seqlens is None: - self._cu_seqlens = torch.arange( - 0, self.batch_size * self._max_seqlen + 1, self._max_seqlen, device=self.tensor.device - ) - - return self._cu_seqlens From 8183244ad977d4a7a6c455d607b21d030d7e4918 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:37:22 -0700 Subject: [PATCH 069/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/__init__.py | 1 - lm_engine/hf_models/mixins/dense/main.py | 3 +-- lm_engine/hf_models/mixins/modeling_outputs.py | 4 +--- .../hf_models/modeling_utils/sequence_mixer_blocks/gru.py | 3 +-- lm_engine/model_wrapper/pretraining.py | 1 - 5 files changed, 3 insertions(+), 9 deletions(-) diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index c30f5b462..72a0e2133 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -29,7 +29,6 @@ mark_parameter_as_no_weight_decay, ) from .register_hf import get_model_parallel_class, is_custom_model, register_model_classes -from .tensor import PackedTensor from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts from .utils import convert_padding_free_lists_to_tensors, disable_generation_cache diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 38451cf08..9a85c4f46 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -14,7 +14,6 @@ from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear -from ...tensor import PackedTensor from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .base import PreTrainedModelMixin @@ -57,7 +56,7 @@ def set_output_embeddings(self, new_embeddings: ParameterizedLinear) -> None: def forward( self, - input_ids: PackedTensor | torch.Tensor | None = None, + input_ids: torch.Tensor | None = None, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, diff --git a/lm_engine/hf_models/mixins/modeling_outputs.py b/lm_engine/hf_models/mixins/modeling_outputs.py index 9ecbc8d13..e63688aa8 100644 --- a/lm_engine/hf_models/mixins/modeling_outputs.py +++ b/lm_engine/hf_models/mixins/modeling_outputs.py @@ -7,12 +7,10 @@ import torch from transformers.modeling_outputs import ModelOutput -from ..tensor import PackedTensor - @dataclass class BaseModelOutputWithPast(ModelOutput): - last_hidden_state: PackedTensor | None = None + last_hidden_state: torch.Tensor | None = None past_key_values: tuple[tuple[torch.Tensor]] | None = None diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index 925ff216a..8df5ffa05 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -15,7 +15,6 @@ from ....utils import divide_if_divisible, is_fma_available from ...cache import GenerationCache from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay -from ...tensor import PackedTensor from ..linear import ParameterizedLinear from ..normalization import get_normalization_function @@ -76,7 +75,7 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) - def forward(self, x: PackedTensor, cache_params: GenerationCache | None = None) -> PackedTensor: + def forward(self, x: torch.Tensor, cache_params: GenerationCache | None = None) -> torch.Tensor: state = None if cache_params is None else cache_params.get_cache(self.layer_idx) T = x.get_num_tokens() diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index 86b381f43..e04ad9bd6 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -12,7 +12,6 @@ from ..enums import Kernel from ..hf_models import ( CausalLMOutputWithPast, - PackedTensor, PipelineParallelInput, PipelineParallelOutput, get_autoregressive_language_modeling_loss, From 4dfb0c7504747c5ed5c2743036c5ad1188a2fc46 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:49:33 -0700 Subject: [PATCH 070/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/model_wrapper/pretraining.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index e04ad9bd6..926319436 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -250,10 +250,9 @@ def _prepare_model_inputs(self, batch: dict) -> dict: max_seqlen = self.sequence_length position_ids = self.position_ids - batch["input_ids"] = PackedTensor.from_torch_tensor( - tensor=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size - ) - + batch["input_ids"] = input_ids + batch["cu_seqlens"] = cu_seqlens + batch["max_seqlen"] = max_seqlen batch["position_ids"] = position_ids if ProcessGroupManager.is_tensor_parallel_enabled(): From 8438c1b0907b313fcae3ada2e96414c32d06c0f3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 00:51:45 -0700 Subject: [PATCH 071/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 4 +++- lm_engine/hf_models/mixins/dense/main.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index d573cc8a3..ad4d07bcc 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -88,8 +88,10 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: def forward( self, input_ids: torch.Tensor | None = None, - past_key_values: GenerationCache | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, attention_mask: torch.Tensor | None = None, + past_key_values: GenerationCache | None = None, position_ids: torch.Tensor | None = None, use_cache: bool | None = None, ) -> BaseModelOutputWithPast: diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 9a85c4f46..4aa563206 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -64,6 +64,8 @@ def forward( labels: torch.Tensor | None = None, use_cache: bool | None = None, return_dict: bool = True, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, reduction: str = "mean", ) -> CausalLMOutputWithPast: assert return_dict @@ -78,8 +80,10 @@ def forward( transformer_outputs: BaseModelOutputWithPast = self.transformer( input_ids, - past_key_values=past_key_values, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, attention_mask=attention_mask, + past_key_values=past_key_values, position_ids=position_ids, use_cache=use_cache, ) From 60ac022317db64e55df145d492de975b29e2fff3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 01:00:17 -0700 Subject: [PATCH 072/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 7 ++++++- lm_engine/model_wrapper/pretraining.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index ad4d07bcc..febcd66c0 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -217,6 +217,7 @@ def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch def _prepare_a_bunch_of_stuff( self, input_ids: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, @@ -225,7 +226,11 @@ def _prepare_a_bunch_of_stuff( if use_cache is None: use_cache = False if self.use_padding_free_transformer else self.config.use_cache - B = input_ids.batch_size + if cu_seqlens is None: + assert input_ids.dim() == 2 + B = input_ids.size(0) + else: + B = cu_seqlens.size(0) - 1 if self.use_padding_free_transformer: assert position_ids is not None, ( diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index 926319436..dddb66a14 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -247,7 +247,7 @@ def _prepare_model_inputs(self, batch: dict) -> dict: position_ids = self.position_ids else: cu_seqlens = None - max_seqlen = self.sequence_length + max_seqlen = None position_ids = self.position_ids batch["input_ids"] = input_ids From 47ac63800e9d8a3195eff3dc1645e42611ba5817 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 01:03:26 -0700 Subject: [PATCH 073/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index febcd66c0..6dbcd7237 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -104,6 +104,8 @@ def forward( past_key_values, ) = self._prepare_a_bunch_of_stuff( input_ids=input_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, @@ -218,6 +220,7 @@ def _prepare_a_bunch_of_stuff( self, input_ids: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, @@ -240,7 +243,7 @@ def _prepare_a_bunch_of_stuff( past_length = None query_length = None - key_length = input_ids.get_max_seqlen() + key_length = max_seqlen if position_ids is None: position_ids = self._get_position_ids( From 71c8585a1ab87e6dfca22179de9c97b06db5a46e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 01:04:56 -0700 Subject: [PATCH 074/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 6dbcd7237..66d6f8c6c 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -134,7 +134,7 @@ def forward( rope_cos_sin=rope_cos_sin, ) - hidden_states.tensor = self.ln_f(hidden_states.tensor) + hidden_states = self.ln_f(hidden_states) return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values) @@ -250,14 +250,11 @@ def _prepare_a_bunch_of_stuff( attention_mask, past_length, query_length, key_length, input_ids.device ) - hidden_states = self._get_initial_hidden_state(input_ids.tensor, position_ids) - hidden_states = input_ids.with_new_data(hidden_states) - del input_ids - - rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.tensor.dtype) + hidden_states = self._get_initial_hidden_state(input_ids, position_ids) + rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) attention_mask = self._get_maybe_causal_mask( - attention_mask, B, query_length, key_length, hidden_states.tensor.dtype, hidden_states.tensor.device + attention_mask, B, query_length, key_length, hidden_states.dtype, hidden_states.device ) return ( From 37a81421b4f9104dce29539132a8e1d181e1ff50 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 01:14:00 -0700 Subject: [PATCH 075/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 4aa563206..1bfc2d62e 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -98,26 +98,26 @@ def forward( if labels is None: if is_kernel_allowed(Kernel.fused_linear_cross_entropy): if self.m_width is not None: - hidden_states.tensor = hidden_states.tensor / self.m_width + hidden_states = hidden_states / self.m_width else: - lm_logits = hidden_states.with_new_data(self.get_lm_logits(hidden_states.tensor)) + lm_logits = self.get_lm_logits(hidden_states) if self.m_width is not None: - lm_logits.tensor = lm_logits.tensor / self.m_width + lm_logits = lm_logits / self.m_width else: assert not is_kernel_allowed(Kernel.fused_linear_cross_entropy) - lm_logits = hidden_states.with_new_data(self.get_lm_logits(hidden_states.tensor)) + lm_logits = self.get_lm_logits(hidden_states) if self.m_width is not None: - lm_logits.tensor = lm_logits.tensor / self.m_width + lm_logits = lm_logits / self.m_width loss = get_autoregressive_language_modeling_loss( lm_logits=lm_logits, labels=labels, hidden_states=None, vocab_weight=None, - cu_seqlens=lm_logits.get_cu_seqlens(), + cu_seqlens=cu_seqlens, use_padding_free_transformer=self.use_padding_free_transformer, reduction=reduction, shift_logits_and_labels=True, From 8041f6caed46ffa0ddc72b51aaa78206e429bdc6 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 01:22:23 -0700 Subject: [PATCH 076/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/models/palm/layer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/models/palm/layer.py b/lm_engine/hf_models/models/palm/layer.py index 16e8f7c93..0623086f4 100644 --- a/lm_engine/hf_models/models/palm/layer.py +++ b/lm_engine/hf_models/models/palm/layer.py @@ -13,9 +13,7 @@ class PaLMBlock(nn.Module): - def __init__( - self, config: PaLMConfig, use_padding_free_transformer: bool, layer_idx: int | None = None - ) -> PaLMBlock: + def __init__(self, config: PaLMConfig, layer_idx: int | None = None) -> PaLMBlock: super().__init__() self.m_residual = config.m_residual @@ -23,7 +21,7 @@ def __init__( self.ln = get_normalization_function( config.normalization_function, config.hidden_size, eps=config.layer_norm_epsilon ) - self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) + self.sequence_mixer = get_sequence_mixer(config, True, layer_idx) self.mlp_block = get_mlp_block(config, layer_idx=layer_idx) def forward( From 636987e975369ff8704a791730584201eae3f0e9 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 01:26:09 -0700 Subject: [PATCH 077/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/gru.py | 59 +++++++++---------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index 8df5ffa05..b283a47fa 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -75,19 +75,21 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) - def forward(self, x: torch.Tensor, cache_params: GenerationCache | None = None) -> torch.Tensor: - state = None if cache_params is None else cache_params.get_cache(self.layer_idx) - T = x.get_num_tokens() - - with x.safe_mode(): - x = self.input_projection(x) - x, g = x.split((3 * self.state_size, self.state_size), dim=-1) + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + cache_params: GenerationCache | None = None, + ) -> torch.Tensor: + x = self.input_projection(x) + x, g = x.split((3 * self.state_size, self.state_size), dim=-1) - if self.scaling_factor != 1: - x = x * self.scaling_factor + if self.scaling_factor != 1: + x = x * self.scaling_factor - x, x_forget, x_reset = x.chunk(3, dim=-1) - x, x_forget, x_reset = [i.view(T, self.num_heads, self.state_head_dim) for i in (x, x_forget, x_reset)] + x, x_forget, x_reset = x.chunk(3, dim=-1) + x, x_forget, x_reset = [i.view(T, self.num_heads, self.state_head_dim) for i in (x, x_forget, x_reset)] weight = self.state_weight if self.scaling_factor != 1: @@ -95,20 +97,18 @@ def forward(self, x: torch.Tensor, cache_params: GenerationCache | None = None) weight, forget_weight, reset_weight = weight.chunk(3, dim=0) - x = x.with_new_data( - gru( - input=x, - weight=weight, - forget_input=x_forget, - forget_weight=forget_weight, - reset_input=x_reset, - reset_weight=reset_weight, - input_state=state, - gradient_clipping=self.gradient_clipping, - cu_seqlens=x.get_cu_seqlens(), - max_seqlen=x.get_max_seqlen(), - kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.gru) else KernelBackend.torch, - ) + x = gru( + input=x, + weight=weight, + forget_input=x_forget, + forget_weight=forget_weight, + reset_input=x_reset, + reset_weight=reset_weight, + input_state=None if cache_params is None else cache_params.get_cache(self.layer_idx), + gradient_clipping=self.gradient_clipping, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.gru) else KernelBackend.torch, ) if cache_params is not None: @@ -118,11 +118,10 @@ def forward(self, x: torch.Tensor, cache_params: GenerationCache | None = None) layer_idx=self.layer_idx, ) - with x.safe_mode(): - x = x.view(T, -1) - x = x * F.silu(g) - x = self.norm(x) - x = self.output_projection(x) + x = x.flatten(-2, -1) + x = x * F.silu(g) + x = self.norm(x) + x = self.output_projection(x) return x From 424dab1f2beddcb360da02a1daf877f1e908f02c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 01:30:11 -0700 Subject: [PATCH 078/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/gru.py | 11 ++++++----- .../modeling_utils/sequence_mixer_blocks/rnn.py | 11 ++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index b283a47fa..41eb9fb7d 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -112,11 +112,12 @@ def forward( ) if cache_params is not None: - cache_params.update( - state=x.get_last_element_along_sequence(), - num_tokens_added=x.get_cu_seqlens(False), - layer_idx=self.layer_idx, - ) + if cu_seqlens is None: + cache_params.update(state=x[:, -1], num_tokens_added=input.size(1), layer_idx=self.layer_idx) + else: + cache_params.update( + state=x[cu_seqlens[1:] - 1], num_tokens_added=cu_seqlens[1:], layer_idx=self.layer_idx + ) x = x.flatten(-2, -1) x = x * F.silu(g) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 17cf1692e..31c198a80 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -104,11 +104,12 @@ def forward( ) if cache_params is not None: - cache_params.update( - state=x.get_last_element_along_sequence(), - num_tokens_added=x.get_cu_seqlens(False), - layer_idx=self.layer_idx, - ) + if cu_seqlens is None: + cache_params.update(state=x[:, -1], num_tokens_added=input.size(1), layer_idx=self.layer_idx) + else: + cache_params.update( + state=x[cu_seqlens[1:] - 1], num_tokens_added=cu_seqlens[1:], layer_idx=self.layer_idx + ) x = x.flatten(-2, -1) x = x * F.silu(g) From ab95037659c54505020020ae7cbc2d862330c942 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 01:52:27 -0700 Subject: [PATCH 079/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 6 ++---- .../sequence_mixer_blocks/__init__.py | 16 ++-------------- .../gpt_crosslayer/sequence_mixers/__init__.py | 5 +---- 3 files changed, 5 insertions(+), 22 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 4df7aeb88..a4aafcbe4 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -13,9 +13,7 @@ class Block(nn.Module): - def __init__( - self, config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int | None = None - ) -> Block: + def __init__(self, config: CommonConfig, layer_idx: int | None = None) -> Block: super().__init__() hidden_size = config.hidden_size @@ -25,7 +23,7 @@ def __init__( self.ln_1 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) - self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) + self.sequence_mixer = get_sequence_mixer(config, True, layer_idx) self.ln_2 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 1bf463fe6..acd03fdf7 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -19,12 +19,7 @@ SEQUENCE_MIXER_TYPE = Attention | CausalConvolution | GRU | Mamba2 | MultiHeadLatentAttention | RNN -def get_sequence_mixer( - config: CommonConfig, - causal: bool, - use_padding_free_transformer: bool, - layer_idx: int, -) -> SEQUENCE_MIXER_TYPE: +def get_sequence_mixer(config: CommonConfig, causal: bool, layer_idx: int) -> SEQUENCE_MIXER_TYPE: block = config.sequence_mixer_blocks[layer_idx] sequence_mixer_type = block.sequence_mixer_type @@ -42,7 +37,6 @@ def get_sequence_mixer( init_method=config.init_method, num_layers=config.num_layers, layer_idx=layer_idx, - use_padding_free_transformer=use_padding_free_transformer, ) elif sequence_mixer_type in ["rnn", "gru"]: return (GRU if sequence_mixer_type == "gru" else RNN)( @@ -100,7 +94,6 @@ def get_sequence_mixer( num_layers=config.num_layers, causal=True, layer_idx=layer_idx, - use_padding_free_transformer=use_padding_free_transformer, normalization_function=block.normalization_function, layer_norm_epsilon=config.layer_norm_epsilon, ) @@ -123,11 +116,6 @@ def get_sequence_mixer( ) if sequence_mixer_type == "softmax_attention": - return Attention( - **sequence_mixer_kwargs, - qkv_bias=block.qkv_bias, - softmax_dropout=block.softmax_dropout, - use_padding_free_transformer=use_padding_free_transformer, - ) + return Attention(**sequence_mixer_kwargs, qkv_bias=block.qkv_bias, softmax_dropout=block.softmax_dropout) else: raise ValueError(f"unexpected sequence_mixer_type ({sequence_mixer_type})") diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py index 1fafdc67a..a28161366 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py @@ -6,9 +6,7 @@ from .base import CrossLayerAttention, KeyValueProjection -def get_sequence_mixer( - config: GPTCrossLayerConfig, causal: bool, use_padding_free_transformer: bool, layer_idx: int -) -> CrossLayerAttention: +def get_sequence_mixer(config: GPTCrossLayerConfig, causal: bool, layer_idx: int) -> CrossLayerAttention: block = config.sequence_mixer_blocks[layer_idx] assert block.sequence_mixer_type == "softmax_attention" @@ -25,5 +23,4 @@ def get_sequence_mixer( num_layers=config.num_layers, causal=causal, layer_idx=layer_idx, - use_padding_free_transformer=use_padding_free_transformer, ) From 21e6416d069e01f154a76cde9c60c74ddc6ac1f0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 08:19:35 -0700 Subject: [PATCH 080/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- .../gpt_crosslayer/sequence_mixers/base.py | 30 ++++--------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py index 3bc88a560..47be7f08d 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py @@ -31,7 +31,6 @@ def __init__( num_layers: int, causal: bool, layer_idx: int, - use_padding_free_transformer: bool, ) -> CrossLayerAttention: super().__init__() @@ -41,7 +40,6 @@ def __init__( self.num_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.add_bias = add_bias - self.use_padding_free_transformer = use_padding_free_transformer assert ( self.hidden_size % self.num_heads == 0 @@ -87,12 +85,10 @@ def forward( max_seqlen: int | None = None, ) -> torch.Tensor: if is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3): - if self.use_padding_free_transformer: - total_q = hidden_states.shape[0] - - query = self.q_attn(hidden_states) - query = query.view(total_q, self.num_heads, -1) + query = self.q_attn(hidden_states) + query = query.view(*hidden_states.size()[:-1], self.num_heads, -1) + if self.use_padding_free_transformer: if self.position_embedding_type == "rope": query = apply_rotary_pos_emb(query, rope_cos_sin) @@ -107,16 +103,7 @@ def forward( dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, ) - - del query, key, value - - hidden_states = hidden_states.view(-1, self.hidden_size) else: - batch_size, query_length = hidden_states.shape[:2] - - query = self.q_attn(hidden_states) - query = query.view(batch_size, query_length, self.num_heads, -1) - if self.position_embedding_type == "rope": # TODO avoid this extra transpose query = query.transpose(1, 2) @@ -134,15 +121,8 @@ def forward( dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, ) - - del query, key, value - - hidden_states = hidden_states.view(batch_size, query_length, -1) else: - batch_size, query_length = hidden_states.shape[:2] - - query = self.q_attn(hidden_states) - query = query.view(batch_size, query_length, self.num_heads, -1) + query = query.view(*query.size()[:-1], self.num_heads, -1) query = query.transpose(1, 2) if self.position_embedding_type == "rope": @@ -162,8 +142,8 @@ def forward( del query, key, value hidden_states = hidden_states.transpose(1, 2) - hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.flatten(-2, -1) hidden_states = self.c_proj(hidden_states) hidden_states = self.dropout(hidden_states) From b0edda7f9ecb2f769a7333a856d317e204ecc030 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 08:53:33 -0700 Subject: [PATCH 081/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mask.py | 160 +++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 lm_engine/hf_models/modeling_utils/mask.py diff --git a/lm_engine/hf_models/modeling_utils/mask.py b/lm_engine/hf_models/modeling_utils/mask.py new file mode 100644 index 000000000..800acdf9c --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/mask.py @@ -0,0 +1,160 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from ...enums import Kernel +from ...kernels import is_kernel_allowed +from ...utils import is_fma_available + + +if is_fma_available(): + from fma import KernelBackend + from fma import pack_sequence as _pack_sequence + from fma import unpack_sequence as _unpack_sequence + + +def pack_sequence( + inputs: torch.Tensor | list[torch.Tensor], cu_seqlens: torch.Tensor +) -> torch.Tensor | list[torch.Tensor]: + kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch + + inputs = _pack_sequence( + inputs=inputs, + cu_seqlens=cu_seqlens, + kernel_backend_forward=kernel_backend, + kernel_backend_backward=kernel_backend, + ) + + return inputs + + +def unpack_sequence( + inputs: torch.Tensor | list[torch.Tensor], cu_seqlens: torch.Tensor, output_shape: tuple[int] +) -> torch.Tensor | list[torch.Tensor]: + kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch + + inputs = _unpack_sequence( + inputs=inputs, + cu_seqlens=cu_seqlens, + output_shape=output_shape, + kernel_backend_forward=kernel_backend, + kernel_backend_backward=kernel_backend, + ) + + return inputs + + +# NOTE using dataclass here since pydantic doesn't work with torch.compile +@dataclass +class AttentionMaskInfo: + batch_size: int | None = None + cu_seqlens: torch.Tensor | None = None + max_seqlen: int | None = None + attention_mask: torch.Tensor | None = None + causal_mask: torch.Tensor | None = None + + def get_batch_size(self) -> int: + if self.batch_size is not None: + return self.batch_size + + if self.cu_seqlens is not None: + self.batch_size = self.cu_seqlens.size(0) - 1 + elif self.attention_mask is not None: + self.batch_size = self.attention_mask.size(0) + else: + raise NotImplementedError("code is not supposed to reach here") + + return self.batch_size + + def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor | None: + if return_none_allowed: + return self.cu_seqlens + + if self.cu_seqlens is None: + seqlens = self.attention_mask.sum(dim=-1, dtype=torch.int32) + self.cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) + self.max_seqlen = seqlens.max().item() + + return self.cu_seqlens + + def get_max_seqlen(self, return_none_allowed: bool = True) -> int | None: + if return_none_allowed: + return self.max_seqlen + + if self.max_seqlen is None: + # this will cache the max_seqlen + self.get_cu_seqlens(False) + + return self.max_seqlen + + def get_attention_mask( + self, return_none_allowed: bool = True, device: torch.device | None = None + ) -> torch.Tensor | None: + if return_none_allowed: + return self.attention_mask + + if self.attention_mask is None: + cu_seqlens = self.get_cu_seqlens() + batch_size = self.get_batch_size() + max_seqlen = self.get_max_seqlen() + assert max_seqlen is not None + + if cu_seqlens is None: + self.attention_mask = torch.ones(batch_size, max_seqlen, device=device, dtype=torch.int32) + else: + attention_mask_flat = torch.ones_like(cu_seqlens, device=device, dtype=torch.int32) + self.attention_mask = unpack_sequence( + inputs=attention_mask_flat, cu_seqlens=cu_seqlens, output_shape=(batch_size, max_seqlen) + ) + + return self.attention_mask + + def get_causal_mask( + self, return_none_allowed: bool = True, dtype: torch.dtype | None = None + ) -> torch.Tensor | None: + attention_mask = self.get_attention_mask(return_none_allowed) + + if attention_mask is not None: + _, Q, K = attention_mask.size() + L = K - Q + + if Q > 1: + device = attention_mask.device + + causal_mask = torch.empty((Q, K), dtype=torch.bool, device=device) + causal_mask[:, L:] = torch.tril(torch.ones(Q, K, dtype=torch.bool, device=device)) + + if L > 0: + causal_mask[:, :L] = True + + causal_mask = causal_mask[None, ...] + causal_mask = causal_mask & attention_mask[:, None, ...].to(torch.bool) + elif Q == 1: + causal_mask = attention_mask[:, None, ...].to(dtype=torch.bool, device=device) + else: + raise NotImplementedError("code is not expected to reach here") + + causal_mask = causal_mask[:, None, ...] + causal_mask = torch.where(causal_mask, ~causal_mask, AttentionMaskInfo._get_mask_value(device, dtype)) + + # this is needed to prevent NaN since SDPA + # see issue: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = causal_mask * ~torch.all( + causal_mask == AttentionMaskInfo._get_mask_value(device, dtype), dim=-1, keepdim=True + ) + + return attention_mask + + @classmethod + def _get_mask_value(cls, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if cls.mask_value is None or cls.mask_value.dtype != dtype or cls.mask_value.device != device: + cls.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + return cls.mask_value From 6bedf74504439b7e9647d0c36802da2ae3bd32b1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 08:54:00 -0700 Subject: [PATCH 082/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mask.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/mask.py b/lm_engine/hf_models/modeling_utils/mask.py index 800acdf9c..19ca18ac4 100644 --- a/lm_engine/hf_models/modeling_utils/mask.py +++ b/lm_engine/hf_models/modeling_utils/mask.py @@ -125,8 +125,9 @@ def get_causal_mask( _, Q, K = attention_mask.size() L = K - Q + device = attention_mask.device + if Q > 1: - device = attention_mask.device causal_mask = torch.empty((Q, K), dtype=torch.bool, device=device) causal_mask[:, L:] = torch.tril(torch.ones(Q, K, dtype=torch.bool, device=device)) From 2b879628a24ed0c420d5949a756854c3a8dfd88e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 08:54:09 -0700 Subject: [PATCH 083/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/mask.py b/lm_engine/hf_models/modeling_utils/mask.py index 19ca18ac4..339175660 100644 --- a/lm_engine/hf_models/modeling_utils/mask.py +++ b/lm_engine/hf_models/modeling_utils/mask.py @@ -128,7 +128,6 @@ def get_causal_mask( device = attention_mask.device if Q > 1: - causal_mask = torch.empty((Q, K), dtype=torch.bool, device=device) causal_mask[:, L:] = torch.tril(torch.ones(Q, K, dtype=torch.bool, device=device)) From e75d48f03c3d47fa0c1df275d1d870ec9e8e4d04 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 08:55:36 -0700 Subject: [PATCH 084/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mask.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mask.py b/lm_engine/hf_models/modeling_utils/mask.py index 339175660..3aab99baf 100644 --- a/lm_engine/hf_models/modeling_utils/mask.py +++ b/lm_engine/hf_models/modeling_utils/mask.py @@ -51,6 +51,9 @@ def unpack_sequence( return inputs +_ERROR_MESSAGE = "code is not supposed to reach here" + + # NOTE using dataclass here since pydantic doesn't work with torch.compile @dataclass class AttentionMaskInfo: @@ -69,7 +72,7 @@ def get_batch_size(self) -> int: elif self.attention_mask is not None: self.batch_size = self.attention_mask.size(0) else: - raise NotImplementedError("code is not supposed to reach here") + raise NotImplementedError(_ERROR_MESSAGE) return self.batch_size @@ -139,7 +142,7 @@ def get_causal_mask( elif Q == 1: causal_mask = attention_mask[:, None, ...].to(dtype=torch.bool, device=device) else: - raise NotImplementedError("code is not expected to reach here") + raise NotImplementedError(_ERROR_MESSAGE) causal_mask = causal_mask[:, None, ...] causal_mask = torch.where(causal_mask, ~causal_mask, AttentionMaskInfo._get_mask_value(device, dtype)) From 9b349c6538b5dd23c493f5f62c3dd25760fc7550 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:03:06 -0700 Subject: [PATCH 085/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/{modeling_utils => }/mask.py | 6 +++--- lm_engine/hf_models/mixins/dense/layer.py | 5 ++++- .../modeling_utils/sequence_mixer_blocks/gru.py | 11 +++++------ .../modeling_utils/sequence_mixer_blocks/rnn.py | 11 +++++------ 4 files changed, 17 insertions(+), 16 deletions(-) rename lm_engine/hf_models/{modeling_utils => }/mask.py (98%) diff --git a/lm_engine/hf_models/modeling_utils/mask.py b/lm_engine/hf_models/mask.py similarity index 98% rename from lm_engine/hf_models/modeling_utils/mask.py rename to lm_engine/hf_models/mask.py index 3aab99baf..29c9785a6 100644 --- a/lm_engine/hf_models/modeling_utils/mask.py +++ b/lm_engine/hf_models/mask.py @@ -9,9 +9,9 @@ import torch import torch.nn.functional as F -from ...enums import Kernel -from ...kernels import is_kernel_allowed -from ...utils import is_fma_available +from ..enums import Kernel +from ..kernels import is_kernel_allowed +from ..utils import is_fma_available if is_fma_available(): diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index a4aafcbe4..aa89d82cf 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -9,6 +9,7 @@ from ...cache import GenerationCache from ...config import CommonConfig +from ...mask import AttentionMaskInfo from ...modeling_utils import get_mlp_block, get_normalization_function, get_sequence_mixer @@ -91,7 +92,9 @@ def _sequence_mixer_forward( ) elif self.sequence_mixer_type in ["gru", "rnn"]: hidden_states = self.sequence_mixer( - x=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, cache_params=past_key_values + x=hidden_states, + attention_mask_info=AttentionMaskInfo(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen), + cache_params=past_key_values, ) else: raise ValueError(f"unexpected sequence_mixer_type ({self.sequence_mixer_type})") diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index 41eb9fb7d..363fde57f 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -14,6 +14,7 @@ from ....kernels import is_kernel_allowed from ....utils import divide_if_divisible, is_fma_available from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay from ..linear import ParameterizedLinear from ..normalization import get_normalization_function @@ -76,11 +77,7 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - cache_params: GenerationCache | None = None, + self, x: torch.Tensor, attention_mask_info: AttentionMaskInfo, cache_params: GenerationCache | None = None ) -> torch.Tensor: x = self.input_projection(x) x, g = x.split((3 * self.state_size, self.state_size), dim=-1) @@ -97,6 +94,8 @@ def forward( weight, forget_weight, reset_weight = weight.chunk(3, dim=0) + cu_seqlens = attention_mask_info.get_cu_seqlens() + x = gru( input=x, weight=weight, @@ -107,7 +106,7 @@ def forward( input_state=None if cache_params is None else cache_params.get_cache(self.layer_idx), gradient_clipping=self.gradient_clipping, cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + max_seqlen=attention_mask_info.get_max_seqlen(), kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.gru) else KernelBackend.torch, ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 31c198a80..5528d96a7 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -14,6 +14,7 @@ from ....kernels import is_kernel_allowed from ....utils import divide_if_divisible, is_fma_available from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay from ..linear import ParameterizedLinear from ..normalization import get_normalization_function @@ -76,11 +77,7 @@ def __init__( mark_parameter_as_no_weight_decay(self.state_weight) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - cache_params: GenerationCache | None = None, + self, x: torch.Tensor, attention_mask_info: AttentionMaskInfo, cache_params: GenerationCache | None = None ) -> torch.Tensor: x = self.input_projection(x) x, g = x.chunk(2, dim=-1) @@ -93,13 +90,15 @@ def forward( if self.scaling_factor != 1: weight = weight * self.scaling_factor + cu_seqlens = attention_mask_info.get_cu_seqlens() + x = rnn( input=x, weight=weight, input_state=None if cache_params is None else cache_params.get_cache(self.layer_idx), gradient_clipping=self.gradient_clipping, cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + max_seqlen=attention_mask_info.get_max_seqlen(), kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.rnn) else KernelBackend.torch, ) From df27f3250e6328cb6f983aef2d91006b1f1e9ee8 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:07:30 -0700 Subject: [PATCH 086/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 98 ++--------------------- lm_engine/hf_models/mixins/dense/layer.py | 16 +--- 2 files changed, 12 insertions(+), 102 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 66d6f8c6c..c11ebe94f 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -8,10 +8,9 @@ import torch.nn as nn from transformers import GenerationConfig, PreTrainedModel -from ....enums import Kernel -from ....kernels import is_kernel_allowed from ...cache import GenerationCache from ...config import CommonConfig +from ...mask import AttentionMaskInfo from ...modeling_utils import ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function from ...utils import is_generation_cache_enabled from ..modeling_outputs import BaseModelOutputWithPast @@ -98,7 +97,6 @@ def forward( ( use_cache, hidden_states, - causal_mask, position_ids, rope_cos_sin, past_key_values, @@ -117,20 +115,22 @@ def forward( GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values ) - mamba_mask = None mamba_mask_computed = False for sequence_mixer_type, block in zip(self.sequence_mixer_block_types, self.h): is_linear_layer = sequence_mixer_type in ["mamba2", "rnn", "gru"] if is_linear_layer and not mamba_mask_computed: - mamba_mask = self._get_mamba_mask(attention_mask, past_key_values) + self._get_mamba_mask(attention_mask, past_key_values) mamba_mask_computed = True hidden_states: torch.Tensor = block( hidden_states, + attention_mask_info=AttentionMaskInfo( + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask + ), past_key_values=past_key_values, - attention_mask=mamba_mask if is_linear_layer else causal_mask, + # attention_mask=mamba_mask if is_linear_layer else causal_mask, rope_cos_sin=rope_cos_sin, ) @@ -162,47 +162,6 @@ def _get_rope_cos_sin( sin = sin[position_ids].unsqueeze(1) return cos, sin - def _prepare_causal_attention_mask( - self, - attention_mask: torch.Tensor | None, - batch_size: int, - query_length: int, - key_length: int, - device: torch.device, - ) -> torch.Tensor: - past_length = key_length - query_length - - if query_length > 1: - # (query_length, key_length) - causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device) - causal_mask[:, past_length:] = torch.tril( - torch.ones(query_length, query_length, dtype=torch.bool, device=device) - ) - - if past_length > 0: - causal_mask[:, :past_length] = True - - # (query_length, key_length) -> (1, query_length, key_length) - causal_mask = causal_mask.unsqueeze(0) - - if attention_mask is None: - # (1, query_length, key_length) -> (batch_size, query_length, key_length) - causal_mask = causal_mask.expand(batch_size, -1, -1) - else: - # (1, query_length, key_length) & (batch_size, 1, key_length) -> (batch_size, query_length, key_length) - causal_mask = causal_mask & attention_mask.unsqueeze(1).to(torch.bool) - else: - if attention_mask is None: - # (batch_size, query_length, key_length) - causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device) - else: - # (batch_size, query_length, key_length) - causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device) - - causal_mask = causal_mask.unsqueeze(1) - - return causal_mask - def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor: hidden_state = self.wte(input_ids) @@ -231,9 +190,9 @@ def _prepare_a_bunch_of_stuff( if cu_seqlens is None: assert input_ids.dim() == 2 - B = input_ids.size(0) + input_ids.size(0) else: - B = cu_seqlens.size(0) - 1 + cu_seqlens.size(0) - 1 if self.use_padding_free_transformer: assert position_ids is not None, ( @@ -253,14 +212,9 @@ def _prepare_a_bunch_of_stuff( hidden_states = self._get_initial_hidden_state(input_ids, position_ids) rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) - attention_mask = self._get_maybe_causal_mask( - attention_mask, B, query_length, key_length, hidden_states.dtype, hidden_states.device - ) - return ( use_cache, hidden_states, - attention_mask, position_ids, rope_cos_sin, past_key_values, @@ -291,42 +245,6 @@ def _setup_positional_encoding(self) -> None: else: raise NotImplementedError() - def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: - # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: - self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) - return self.mask_value - - def _get_maybe_causal_mask( - self, - attention_mask: torch.Tensor | None, - batch_size: int, - query_length: int, - key_length: int, - dtype: torch.dtype, - device: torch.device, - ) -> torch.Tensor: - if not (is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3)): - # we use the causal/non-causal argument of SDPA for attention in this case - if attention_mask is not None: - attention_mask = self._prepare_causal_attention_mask( - attention_mask, batch_size, query_length, key_length, device - ) - - attention_mask = torch.where( - attention_mask, - ~attention_mask, - self._get_mask_value(attention_mask.device, dtype), - ) - - # this is needed to prevent NaN since SDPA - # see issue: https://github.com/pytorch/pytorch/issues/110213 - attention_mask = attention_mask * ~torch.all( - attention_mask == self._get_mask_value(attention_mask.device, dtype), dim=-1, keepdim=True - ) - - return attention_mask - def _get_mamba_mask( self, attention_mask: torch.Tensor | None, past_key_values: GenerationCache ) -> torch.Tensor | None: diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index aa89d82cf..6c5ba6a04 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -33,11 +33,9 @@ def __init__(self, config: CommonConfig, layer_idx: int | None = None) -> Block: def forward( self, hidden_states: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -45,10 +43,8 @@ def forward( hidden_states = self._sequence_mixer_forward( hidden_states=hidden_states, past_key_values=past_key_values, - attention_mask=attention_mask, + attention_mask_info=attention_mask_info, rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) if self.m_residual is not None: @@ -71,11 +67,9 @@ def forward( def _sequence_mixer_forward( self, hidden_states: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: if self.sequence_mixer_type in ["softmax_attention", "multihead_latent_attention"]: hidden_states = self.sequence_mixer( @@ -92,9 +86,7 @@ def _sequence_mixer_forward( ) elif self.sequence_mixer_type in ["gru", "rnn"]: hidden_states = self.sequence_mixer( - x=hidden_states, - attention_mask_info=AttentionMaskInfo(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen), - cache_params=past_key_values, + x=hidden_states, attention_mask_info=attention_mask_info, cache_params=past_key_values ) else: raise ValueError(f"unexpected sequence_mixer_type ({self.sequence_mixer_type})") From 21e0ba9aeb33eaa94bd71dd84e5157c81fd24f09 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:07:56 -0700 Subject: [PATCH 087/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index c11ebe94f..7cfbe36c6 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -116,6 +116,9 @@ def forward( ) mamba_mask_computed = False + attention_mask_info = AttentionMaskInfo( + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask + ) for sequence_mixer_type, block in zip(self.sequence_mixer_block_types, self.h): is_linear_layer = sequence_mixer_type in ["mamba2", "rnn", "gru"] @@ -126,9 +129,7 @@ def forward( hidden_states: torch.Tensor = block( hidden_states, - attention_mask_info=AttentionMaskInfo( - cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask - ), + attention_mask_info=attention_mask_info, past_key_values=past_key_values, # attention_mask=mamba_mask if is_linear_layer else causal_mask, rope_cos_sin=rope_cos_sin, From 976a9ba62edf432d2058a466fa3cee9c6f7ce780 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:09:02 -0700 Subject: [PATCH 088/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 7cfbe36c6..37e95a731 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -36,9 +36,7 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixi assert self.config_class is not None self.generation_config = GenerationConfig.from_model_config(self.config) - self.use_padding_free_transformer = kwargs.get("use_padding_free_transformer", False) self._tied_word_embeddings = config.tie_word_embeddings - self._has_mamba2 = any([block.sequence_mixer_type == "mamba2" for block in self.config.sequence_mixer_blocks]) def _init_weights(self, module: nn.Module) -> None: @@ -66,12 +64,7 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.embedding_dropout = ( nn.Identity() if config.embedding_dropout == 0 else nn.Dropout(config.embedding_dropout) ) - self.h = nn.ModuleList( - [ - self.layer_class(config, use_padding_free_transformer=self.use_padding_free_transformer, layer_idx=i) - for i in range(config.num_layers) - ] - ) + self.h = nn.ModuleList([self.layer_class(config, layer_idx=i) for i in range(config.num_layers)]) self.ln_f = get_normalization_function( config.normalization_function, self.embed_dim, eps=config.layer_norm_epsilon ) @@ -186,20 +179,15 @@ def _prepare_a_bunch_of_stuff( position_ids: torch.Tensor | None = None, use_cache: bool | None = None, ) -> tuple[bool, torch.Tensor, torch.Tensor, torch.Tensor | None, GenerationCache | None]: - if use_cache is None: - use_cache = False if self.use_padding_free_transformer else self.config.use_cache - if cu_seqlens is None: assert input_ids.dim() == 2 input_ids.size(0) else: cu_seqlens.size(0) - 1 - if self.use_padding_free_transformer: - assert position_ids is not None, ( - "GPTBaseModel needs position_ids from outside when using flash attention with List[List[int]] " - "inputs" - ) + assert position_ids is not None, ( + "GPTBaseModel needs position_ids from outside when using flash attention with List[List[int]] " "inputs" + ) past_length = None query_length = None From 1d949d5b04f4aa8fdeca64ced485c094fb7158c2 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:10:40 -0700 Subject: [PATCH 089/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 37e95a731..a3f285bed 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -95,7 +95,6 @@ def forward( past_key_values, ) = self._prepare_a_bunch_of_stuff( input_ids=input_ids, - cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, past_key_values=past_key_values, attention_mask=attention_mask, @@ -172,19 +171,12 @@ def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch def _prepare_a_bunch_of_stuff( self, input_ids: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, use_cache: bool | None = None, ) -> tuple[bool, torch.Tensor, torch.Tensor, torch.Tensor | None, GenerationCache | None]: - if cu_seqlens is None: - assert input_ids.dim() == 2 - input_ids.size(0) - else: - cu_seqlens.size(0) - 1 - assert position_ids is not None, ( "GPTBaseModel needs position_ids from outside when using flash attention with List[List[int]] " "inputs" ) From 24ee79a14056d27fa6480813df8efec30d4c4932 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:12:39 -0700 Subject: [PATCH 090/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index a3f285bed..b2f4f82c9 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -87,16 +87,9 @@ def forward( position_ids: torch.Tensor | None = None, use_cache: bool | None = None, ) -> BaseModelOutputWithPast: - ( - use_cache, - hidden_states, - position_ids, - rope_cos_sin, - past_key_values, - ) = self._prepare_a_bunch_of_stuff( + use_cache, hidden_states, position_ids, rope_cos_sin = self._prepare_a_bunch_of_stuff( input_ids=input_ids, max_seqlen=max_seqlen, - past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, @@ -172,11 +165,10 @@ def _prepare_a_bunch_of_stuff( self, input_ids: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, - past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, use_cache: bool | None = None, - ) -> tuple[bool, torch.Tensor, torch.Tensor, torch.Tensor | None, GenerationCache | None]: + ) -> tuple[bool, torch.Tensor, torch.Tensor, torch.Tensor | None]: assert position_ids is not None, ( "GPTBaseModel needs position_ids from outside when using flash attention with List[List[int]] " "inputs" ) @@ -193,13 +185,7 @@ def _prepare_a_bunch_of_stuff( hidden_states = self._get_initial_hidden_state(input_ids, position_ids) rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) - return ( - use_cache, - hidden_states, - position_ids, - rope_cos_sin, - past_key_values, - ) + return use_cache, hidden_states, position_ids, rope_cos_sin def _setup_positional_encoding(self) -> None: max_position_embeddings = self.config.max_position_embeddings From be7104e2e2110b1dfdacd903e23760c0fd0d8d36 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:13:28 -0700 Subject: [PATCH 091/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index b2f4f82c9..ced3e8480 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -100,23 +100,15 @@ def forward( GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values ) - mamba_mask_computed = False attention_mask_info = AttentionMaskInfo( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask ) - for sequence_mixer_type, block in zip(self.sequence_mixer_block_types, self.h): - is_linear_layer = sequence_mixer_type in ["mamba2", "rnn", "gru"] - - if is_linear_layer and not mamba_mask_computed: - self._get_mamba_mask(attention_mask, past_key_values) - mamba_mask_computed = True - + for block in self.h: hidden_states: torch.Tensor = block( hidden_states, attention_mask_info=attention_mask_info, past_key_values=past_key_values, - # attention_mask=mamba_mask if is_linear_layer else causal_mask, rope_cos_sin=rope_cos_sin, ) From 09ea1ef50c733b9be012caf4d5844b0fac4addf4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:16:32 -0700 Subject: [PATCH 092/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 21 +++------------------ lm_engine/hf_models/mixins/dense/main.py | 11 +++++++---- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index ced3e8480..546ce9679 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -79,18 +79,15 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: def forward( self, - input_ids: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - attention_mask: torch.Tensor | None = None, + input_ids: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, position_ids: torch.Tensor | None = None, use_cache: bool | None = None, ) -> BaseModelOutputWithPast: use_cache, hidden_states, position_ids, rope_cos_sin = self._prepare_a_bunch_of_stuff( input_ids=input_ids, - max_seqlen=max_seqlen, - attention_mask=attention_mask, + max_seqlen=attention_mask_info.get_max_seqlen(), position_ids=position_ids, use_cache=use_cache, ) @@ -100,10 +97,6 @@ def forward( GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values ) - attention_mask_info = AttentionMaskInfo( - cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask - ) - for block in self.h: hidden_states: torch.Tensor = block( hidden_states, @@ -157,7 +150,6 @@ def _prepare_a_bunch_of_stuff( self, input_ids: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, use_cache: bool | None = None, ) -> tuple[bool, torch.Tensor, torch.Tensor, torch.Tensor | None]: @@ -165,15 +157,8 @@ def _prepare_a_bunch_of_stuff( "GPTBaseModel needs position_ids from outside when using flash attention with List[List[int]] " "inputs" ) - past_length = None - query_length = None key_length = max_seqlen - if position_ids is None: - position_ids = self._get_position_ids( - attention_mask, past_length, query_length, key_length, input_ids.device - ) - hidden_states = self._get_initial_hidden_state(input_ids, position_ids) rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 1bfc2d62e..b34fe064f 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -13,6 +13,7 @@ from ...cache import GenerationCache from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero +from ...mask import AttentionMaskInfo from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .base import PreTrainedModelMixin @@ -78,11 +79,13 @@ def forward( clear_aux_loss() + attention_mask_info = AttentionMaskInfo( + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask + ) + transformer_outputs: BaseModelOutputWithPast = self.transformer( - input_ids, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - attention_mask=attention_mask, + input_ids=input_ids, + attention_mask_info=attention_mask_info, past_key_values=past_key_values, position_ids=position_ids, use_cache=use_cache, From 836900fa2e0e5b12fc907b27c5ea391edf176b7c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:19:03 -0700 Subject: [PATCH 093/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 33 +++--------------------- lm_engine/hf_models/mixins/dense/main.py | 7 ++++- 2 files changed, 9 insertions(+), 31 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 546ce9679..4e1aa22e6 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -12,7 +12,6 @@ from ...config import CommonConfig from ...mask import AttentionMaskInfo from ...modeling_utils import ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function -from ...utils import is_generation_cache_enabled from ..modeling_outputs import BaseModelOutputWithPast from .layer import Block @@ -83,20 +82,12 @@ def forward( attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, position_ids: torch.Tensor | None = None, - use_cache: bool | None = None, ) -> BaseModelOutputWithPast: - use_cache, hidden_states, position_ids, rope_cos_sin = self._prepare_a_bunch_of_stuff( - input_ids=input_ids, - max_seqlen=attention_mask_info.get_max_seqlen(), - position_ids=position_ids, - use_cache=use_cache, + hidden_states = self._get_initial_hidden_state(input_ids, position_ids) + rope_cos_sin = self._get_rope_cos_sin( + attention_mask_info.get_max_seqlen(), position_ids, dtype=hidden_states.dtype ) - if is_generation_cache_enabled(): - past_key_values = ( - GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values - ) - for block in self.h: hidden_states: torch.Tensor = block( hidden_states, @@ -146,24 +137,6 @@ def _get_initial_hidden_state(self, input_ids: torch.Tensor, position_ids: torch return hidden_state - def _prepare_a_bunch_of_stuff( - self, - input_ids: torch.Tensor | None = None, - max_seqlen: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - use_cache: bool | None = None, - ) -> tuple[bool, torch.Tensor, torch.Tensor, torch.Tensor | None]: - assert position_ids is not None, ( - "GPTBaseModel needs position_ids from outside when using flash attention with List[List[int]] " "inputs" - ) - - key_length = max_seqlen - - hidden_states = self._get_initial_hidden_state(input_ids, position_ids) - rope_cos_sin = self._get_rope_cos_sin(key_length, position_ids, dtype=hidden_states.dtype) - - return use_cache, hidden_states, position_ids, rope_cos_sin - def _setup_positional_encoding(self) -> None: max_position_embeddings = self.config.max_position_embeddings diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index b34fe064f..6123b3202 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -15,6 +15,7 @@ from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero from ...mask import AttentionMaskInfo from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear +from ...utils import is_generation_cache_enabled from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .base import PreTrainedModelMixin @@ -83,12 +84,16 @@ def forward( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask ) + if is_generation_cache_enabled(): + past_key_values = ( + GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values + ) + transformer_outputs: BaseModelOutputWithPast = self.transformer( input_ids=input_ids, attention_mask_info=attention_mask_info, past_key_values=past_key_values, position_ids=position_ids, - use_cache=use_cache, ) hidden_states = transformer_outputs.last_hidden_state From 6ca483d680fc518cd1153d44fb059e708b82320d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:27:33 -0700 Subject: [PATCH 094/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/arguments.py | 2 -- lm_engine/model_wrapper/base.py | 1 - 2 files changed, 3 deletions(-) diff --git a/lm_engine/arguments.py b/lm_engine/arguments.py index e9ac14641..549e9c66b 100644 --- a/lm_engine/arguments.py +++ b/lm_engine/arguments.py @@ -49,8 +49,6 @@ class ModelArgs(BaseArgs): model_class: str = None # trust remote code for models that are not directly supported by HuggingFace yet trust_remote_code: bool = False - # whether to use padding free transformer: https://huggingface.co/blog/mayank-mishra/padding-free-transformer - use_padding_free_transformer: bool = False # use lower memory to initialize model efficient_initialization: bool = False # whether to reset attention masks for pretraining diff --git a/lm_engine/model_wrapper/base.py b/lm_engine/model_wrapper/base.py index 6f80d6315..47f8da9ac 100644 --- a/lm_engine/model_wrapper/base.py +++ b/lm_engine/model_wrapper/base.py @@ -142,7 +142,6 @@ def _get_model_kwargs(self) -> dict: "flash_attention_2" if is_kernel_allowed(Kernel.flash_attention_2) else "sdpa" ) - model_kwargs["use_padding_free_transformer"] = True if self.sequence_parallel: model_kwargs["sequence_parallel"] = True if self.trust_remote_code: From 4ebee0d09e78911f28629d5350673fb828265146 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 09:29:13 -0700 Subject: [PATCH 095/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/models/gpt_crosslayer/layer.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/lm_engine/hf_models/models/gpt_crosslayer/layer.py b/lm_engine/hf_models/models/gpt_crosslayer/layer.py index 4c8a663b9..3e284ae14 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/layer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/layer.py @@ -17,9 +17,7 @@ class GPTCrossLayerBlock(nn.Module): - def __init__( - self, config: GPTCrossLayerConfig, use_padding_free_transformer: bool, layer_idx: int - ) -> GPTCrossLayerBlock: + def __init__(self, config: GPTCrossLayerConfig, layer_idx: int) -> GPTCrossLayerBlock: super().__init__() hidden_size = config.hidden_size @@ -30,8 +28,6 @@ def __init__( self.head_dim = divide_if_divisible(hidden_size, self.num_heads, "") self.num_key_value_heads = config.sequence_mixer_blocks[layer_idx].num_key_value_heads - self.use_padding_free_transformer = use_padding_free_transformer - self.kv_proj = None if config.sharing_pattern[layer_idx] == layer_idx: self.kv_proj = KeyValueProjection( @@ -42,13 +38,12 @@ def __init__( initializer_range=config.initializer_range, normalization_function=config.normalization_function, layer_norm_epsilon=config.layer_norm_epsilon, - use_padding_free_transformer=use_padding_free_transformer, ) self.ln_1 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) - self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) + self.sequence_mixer = get_sequence_mixer(config, True, layer_idx) self.ln_2 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) @@ -74,11 +69,6 @@ def forward( if past_key_values is not None: key, value = past_key_values.update(key_states=key, value_states=value, layer_idx=self.layer_idx) - if is_kernel_allowed(Kernel.flash_attention_3) or is_kernel_allowed(Kernel.flash_attention_2): - if not self.use_padding_free_transformer: - key = key.transpose(1, 2) - value = value.transpose(1, 2) - residual = hidden_states hidden_states = self.ln_1(hidden_states) From 753ade19bf38881002f8892b1839d3ea3cb4a1bb Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 10:36:13 -0700 Subject: [PATCH 096/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/attention.py | 8 +- .../sequence_mixer_blocks/utils.py | 99 ++++++++++ .../sequence_mixer_blocks/utils/__init__.py | 6 - .../utils/flash_attention_utils.py | 177 ------------------ .../sequence_mixer_blocks/utils/packing.py | 56 ------ .../sequence_mixer_blocks/attention.py | 11 +- .../gpt_crosslayer/sequence_mixers/base.py | 15 +- 7 files changed, 113 insertions(+), 259 deletions(-) create mode 100644 lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py delete mode 100644 lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/__init__.py delete mode 100644 lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py delete mode 100644 lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/packing.py diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 898ac18e5..8f98df6a2 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -14,6 +14,7 @@ from ....kernels import is_kernel_allowed, wait_for_ACT from ....utils import divide_if_divisible from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...parameter import mark_parameter_as_mup_learning_rate from ..linear import ParameterizedLinear from ..position_embedding import apply_rotary_pos_emb @@ -135,11 +136,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) @@ -197,9 +197,7 @@ def forward( q=query, k=key, v=value, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - attention_mask=attention_mask, + attention_mask_info=attention_mask_info, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py new file mode 100644 index 000000000..84de2418a --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py @@ -0,0 +1,99 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +import torch + +from ....enums import Kernel +from ....kernels import is_kernel_allowed +from ....utils import is_flash_attention_2_available, is_flash_attention_3_available +from ...mask import AttentionMaskInfo + + +if is_flash_attention_2_available(): + from flash_attn import flash_attn_func as flash_attention_2 + from flash_attn import flash_attn_varlen_func as flash_attention_2_varlen + +if is_flash_attention_3_available(): + from flash_attn_interface import flash_attn_func as flash_attention_3 + from flash_attn_interface import flash_attn_varlen_func as flash_attention_3_varlen + + +def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask_info: AttentionMaskInfo, + causal: bool, + dropout: float = 0, + softmax_scale: float | None = None, + sliding_window: int | None = None, + softcap: float = 0, +) -> torch.Tensor: + use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) + + if use_flash_attention_3: + assert dropout == 0 + + window_size = (-1, -1) + if sliding_window is not None and k.size(1) > sliding_window: + window_size = (sliding_window, sliding_window) + + cu_seqlens = attention_mask_info.get_cu_seqlens() + max_seqlen = attention_mask_info.get_max_seqlen() + + if cu_seqlens is None: + assert q.dim() == 4 + + if use_flash_attention_3: + attn_output, _ = flash_attention_3( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + else: + attn_output = flash_attention_2( + q=q, + k=k, + v=v, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + else: + assert sliding_window is None + assert q.dim() == 3 + + if use_flash_attention_3: + attn_output, _ = flash_attention_3_varlen( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attention_2_varlen( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/__init__.py deleted file mode 100644 index 3b7c41a83..000000000 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from .flash_attention_utils import flash_attention -from .packing import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py deleted file mode 100644 index 2ab1b1f6d..000000000 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py +++ /dev/null @@ -1,177 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch - -from .....enums import Kernel -from .....kernels import is_kernel_allowed -from .....utils import is_flash_attention_2_available, is_flash_attention_3_available -from .packing import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence - - -if is_flash_attention_2_available(): - from flash_attn import flash_attn_func as flash_attention_2 - from flash_attn import flash_attn_varlen_func as flash_attention_2_varlen - -if is_flash_attention_3_available(): - from flash_attn_interface import flash_attn_func as flash_attention_3 - from flash_attn_interface import flash_attn_varlen_func as flash_attention_3_varlen - - -def unpad_input( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - cu_seqlens_k, max_seqlen_k = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask) - batch_size, kv_seq_len = key.size()[:2] - - if query_length == kv_seq_len: - query, key, value = pack_sequence(inputs=(query, key, value), cu_seqlens=cu_seqlens_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_q = max_seqlen_k - else: - key, value = pack_sequence(inputs=(key, value), cu_seqlens=cu_seqlens_k) - - if query_length == 1: - # There is a memcpy here, that is very bad. - cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query.device) - query = query.squeeze(1) - key, value = pack_sequence(inputs=(key, value), cu_seqlens=cu_seqlens_k) - max_seqlen_q = 1 - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - cu_seqlens_q, max_seqlen_q = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask) - query = pack_sequence(inputs=query, cu_seqlens=cu_seqlens_q) - - return query, key, value, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k - - -def flash_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attention_mask: torch.Tensor | None, - cu_seqlens: torch.Tensor | None, - max_seqlen: int | None, - causal: bool, - dropout: float = 0, - softmax_scale: float | None = None, - sliding_window: int | None = None, - softcap: float = 0, -) -> torch.Tensor: - use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) - use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) - - if use_flash_attention_3: - assert dropout == 0 - - assert use_flash_attention_3 or use_flash_attention_2, "enable flash_attention_2 or flash_attention_3" - - window_size = (-1, -1) - if sliding_window is not None and key.size(1) > sliding_window: - window_size = (sliding_window, sliding_window) - - if cu_seqlens is None: - assert max_seqlen is None - assert q.dim() == 4 - - if attention_mask is None: - if use_flash_attention_3: - attn_output, _ = flash_attention_3( - q=q, - k=k, - v=v, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - ) - else: - attn_output = flash_attention_2( - q=q, - k=k, - v=v, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - ) - else: - batch_size, query_length, num_heads, head_dim = q.size() - - q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = unpad_input( - q, k, v, attention_mask, query_length - ) - - if use_flash_attention_3: - attn_output, _ = flash_attention_3_varlen( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - ) - else: - attn_output = flash_attention_2_varlen( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - ) - - attn_output = unpack_sequence( - inputs=attn_output, - cu_seqlens=cu_seqlens_q, - output_shape=(batch_size, query_length, num_heads, head_dim), - ) - else: - assert sliding_window is None - assert q.dim() == 3 - - if use_flash_attention_3: - attn_output, _ = flash_attention_3_varlen( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attention_2_varlen( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - return attn_output diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/packing.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/packing.py deleted file mode 100644 index 1ca34b5d9..000000000 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/packing.py +++ /dev/null @@ -1,56 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch -import torch.nn.functional as F - -from .....enums import Kernel -from .....kernels import is_kernel_allowed -from .....utils import is_fma_available - - -if is_fma_available(): - from fma import KernelBackend - from fma import pack_sequence as _pack_sequence - from fma import unpack_sequence as _unpack_sequence - - -def compute_cu_seqlens_and_max_seqlen_from_attention_mask( - attention_mask: torch.Tensor, -) -> tuple[torch.Tensor, int]: - seqlens = attention_mask.sum(dim=-1, dtype=torch.int32) - cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) - max_seqlen = seqlens.max().item() - return cu_seqlens, max_seqlen - - -def pack_sequence( - inputs: torch.Tensor | list[torch.Tensor], cu_seqlens: torch.Tensor -) -> torch.Tensor | list[torch.Tensor]: - kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch - - inputs = _pack_sequence( - inputs=inputs, - cu_seqlens=cu_seqlens, - kernel_backend_forward=kernel_backend, - kernel_backend_backward=kernel_backend, - ) - - return inputs - - -def unpack_sequence( - inputs: torch.Tensor | list[torch.Tensor], cu_seqlens: torch.Tensor, output_shape: tuple[int] -) -> torch.Tensor | list[torch.Tensor]: - kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch - - inputs = _unpack_sequence( - inputs=inputs, - cu_seqlens=cu_seqlens, - output_shape=output_shape, - kernel_backend_forward=kernel_backend, - kernel_backend_backward=kernel_backend, - ) - - return inputs diff --git a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py index 27b4b887e..67bcb0cad 100644 --- a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py @@ -14,6 +14,7 @@ from ....kernels import is_kernel_allowed, wait_for_ACT from ....utils import ProcessGroupManager, divide_if_divisible from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...modeling_utils import Attention, apply_rotary_pos_emb, flash_attention from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear from ..dropout import Dropout_TP @@ -127,11 +128,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) @@ -192,9 +191,7 @@ def forward( q=query, k=key, v=value, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - attention_mask=attention_mask, + attention_mask_info=attention_mask_info, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, @@ -205,6 +202,8 @@ def forward( hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) hidden_states = hidden_states.view(*output_shape) else: + attention_mask = attention_mask_info.get_attention_mask() + hidden_states = F.scaled_dot_product_attention( query, key, diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py index 47be7f08d..c95299761 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py @@ -13,6 +13,7 @@ from .....enums import Kernel from .....kernels import is_kernel_allowed from .....utils import divide_if_divisible +from ....mask import AttentionMaskInfo from ....modeling_utils import ParameterizedLinear, apply_rotary_pos_emb, flash_attention, get_normalization_function @@ -77,12 +78,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, + attention_mask_info: AttentionMaskInfo, key: torch.Tensor, value: torch.Tensor, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: if is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3): query = self.q_attn(hidden_states) @@ -96,9 +95,7 @@ def forward( q=query, k=key, v=value, - attention_mask=attention_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + attention_mask_info=attention_mask_info, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, @@ -114,9 +111,7 @@ def forward( q=query, k=key, v=value, - attention_mask=attention_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + attention_mask_info=attention_mask_info, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, softmax_scale=self.attention_multiplier, @@ -128,6 +123,8 @@ def forward( if self.position_embedding_type == "rope": query = apply_rotary_pos_emb(query, rope_cos_sin) + attention_mask = attention_mask_info.get_attention_mask() + hidden_states = F.scaled_dot_product_attention( query, key, From e682a14c4d92d5988a32c0d9104ee89a4bce9dab Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 10:42:06 -0700 Subject: [PATCH 097/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- .../hf_models/models/gpt_crosslayer/layer.py | 11 ++-- .../base.py => sequence_mixer.py} | 51 ++++++++++--------- .../sequence_mixers/__init__.py | 26 ---------- 3 files changed, 30 insertions(+), 58 deletions(-) rename lm_engine/hf_models/models/gpt_crosslayer/{sequence_mixers/base.py => sequence_mixer.py} (83%) delete mode 100644 lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py diff --git a/lm_engine/hf_models/models/gpt_crosslayer/layer.py b/lm_engine/hf_models/models/gpt_crosslayer/layer.py index 3e284ae14..be8c672ec 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/layer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/layer.py @@ -7,10 +7,9 @@ import torch import torch.nn as nn -from ....enums import Kernel -from ....kernels import is_kernel_allowed from ....utils import divide_if_divisible from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...modeling_utils import apply_rotary_pos_emb, get_mlp_block, get_normalization_function from .config import GPTCrossLayerConfig from .sequence_mixers import KeyValueProjection, get_sequence_mixer @@ -54,11 +53,9 @@ def forward( hidden_states: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> torch.Tensor: if self.kv_proj is not None: key, value = self.kv_proj(hidden_states) @@ -76,10 +73,8 @@ def forward( hidden_states, key=key, value=value, - attention_mask=attention_mask, + attention_mask_info=attention_mask_info, rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) if self.m_residual is not None: diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py similarity index 83% rename from lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py rename to lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py index c95299761..91518910f 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/base.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py @@ -10,11 +10,12 @@ import torch.nn as nn import torch.nn.functional as F -from .....enums import Kernel -from .....kernels import is_kernel_allowed -from .....utils import divide_if_divisible -from ....mask import AttentionMaskInfo -from ....modeling_utils import ParameterizedLinear, apply_rotary_pos_emb, flash_attention, get_normalization_function +from ....enums import Kernel +from ....kernels import is_kernel_allowed +from ....utils import divide_if_divisible +from ...mask import AttentionMaskInfo +from ...modeling_utils import ParameterizedLinear, apply_rotary_pos_emb, flash_attention, get_normalization_function +from .config import GPTCrossLayerConfig class CrossLayerAttention(nn.Module): @@ -157,7 +158,6 @@ def __init__( initializer_range: float, normalization_function: str, layer_norm_epsilon: float, - use_padding_free_transformer: bool, ) -> KeyValueProjection: super().__init__() @@ -172,28 +172,31 @@ def __init__( std=initializer_range, ) - self.use_padding_free_transformer = use_padding_free_transformer - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states = self.ln(hidden_states) hidden_states = self.kv_attn(hidden_states) - if self.use_padding_free_transformer: - total_q = hidden_states.shape[0] - - if self.num_key_value_heads == 1: - hidden_states = hidden_states.unsqueeze(1) - else: - hidden_states = hidden_states.view(total_q, self.num_key_value_heads, -1) - else: - batch_size, query_length = hidden_states.shape[:2] - - if self.num_key_value_heads == 1: - hidden_states = hidden_states.unsqueeze(1) - else: - hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1) - hidden_states = hidden_states.transpose(1, 2) - + hidden_states = hidden_states.view(*hidden_states.size()[:-1], self.num_key_value_heads, -1) key, value = hidden_states.chunk(2, -1) return key, value + + +def get_sequence_mixer(config: GPTCrossLayerConfig, causal: bool, layer_idx: int) -> CrossLayerAttention: + block = config.sequence_mixer_blocks[layer_idx] + assert block.sequence_mixer_type == "softmax_attention" + + return CrossLayerAttention( + hidden_size=config.hidden_size, + num_attention_heads=block.num_attention_heads, + num_key_value_heads=block.num_key_value_heads, + attention_multiplier=block.attention_multiplier, + position_embedding_type=config.position_embedding_type, + add_bias=block.add_bias, + softmax_dropout=block.softmax_dropout, + dropout=block.dropout, + initializer_range=config.initializer_range, + num_layers=config.num_layers, + causal=causal, + layer_idx=layer_idx, + ) diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py deleted file mode 100644 index a28161366..000000000 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixers/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from ..config import GPTCrossLayerConfig -from .base import CrossLayerAttention, KeyValueProjection - - -def get_sequence_mixer(config: GPTCrossLayerConfig, causal: bool, layer_idx: int) -> CrossLayerAttention: - block = config.sequence_mixer_blocks[layer_idx] - assert block.sequence_mixer_type == "softmax_attention" - - return CrossLayerAttention( - hidden_size=config.hidden_size, - num_attention_heads=block.num_attention_heads, - num_key_value_heads=block.num_key_value_heads, - attention_multiplier=block.attention_multiplier, - position_embedding_type=config.position_embedding_type, - add_bias=block.add_bias, - softmax_dropout=block.softmax_dropout, - dropout=block.dropout, - initializer_range=config.initializer_range, - num_layers=config.num_layers, - causal=causal, - layer_idx=layer_idx, - ) From c5825baba8e3083e4c13701638aa6f52de75c3b0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 10:46:24 -0700 Subject: [PATCH 098/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- .../hf_models/models/gpt_crosslayer/base.py | 35 +++++-------------- .../models/gpt_crosslayer/sequence_mixer.py | 2 +- 2 files changed, 9 insertions(+), 28 deletions(-) diff --git a/lm_engine/hf_models/models/gpt_crosslayer/base.py b/lm_engine/hf_models/models/gpt_crosslayer/base.py index ba5e827f6..c6cdc919b 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/base.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/base.py @@ -7,6 +7,7 @@ import torch from ...cache import GenerationCache +from ...mask import AttentionMaskInfo from ...mixins import BaseModelMixin, BaseModelOutputWithPast, PreTrainedModelMixin from .config import GPTCrossLayerConfig from .layer import GPTCrossLayerBlock @@ -25,49 +26,29 @@ def __init__(self, config: GPTCrossLayerConfig, *args, **kwargs) -> GPTCrossLaye class GPTCrossLayerModel(GPTCrossLayerPreTrainedModel, BaseModelMixin): def forward( self, - input_ids: torch.Tensor | None = None, + input_ids: torch.Tensor, + attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, - use_cache: bool | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, ) -> BaseModelOutputWithPast: - ( - use_cache, - hidden_states, - attention_mask, - position_ids, - rope_cos_sin, - past_key_values, - ) = self._prepare_a_bunch_of_stuff( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + hidden_states = self._get_initial_hidden_state(input_ids, position_ids) + rope_cos_sin = self._get_rope_cos_sin( + attention_mask_info.get_max_seqlen(), position_ids, dtype=hidden_states.dtype ) - past_key_values = GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values - key = None value = None for block in self.h: hidden_states, key, value = block( - hidden_states, + hidden_states=hidden_states, key=key, value=value, + attention_mask_info=attention_mask_info, past_key_values=past_key_values, - attention_mask=attention_mask, rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) - del key, value hidden_states = self.ln_f(hidden_states) return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values) diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py index 91518910f..d8aed0019 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py @@ -79,9 +79,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - attention_mask_info: AttentionMaskInfo, key: torch.Tensor, value: torch.Tensor, + attention_mask_info: AttentionMaskInfo, rope_cos_sin: torch.Tensor | None = None, ) -> torch.Tensor: if is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3): From 9af4136320a03780d3f06a151bb6b30f6ccc4ef5 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 10:47:47 -0700 Subject: [PATCH 099/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- .gitignore | 1 + lm_engine/hf_models/models/gpt_crosslayer/layer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 9ad896e97..e9155ff15 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__ /appwrapper.yaml *.egg-info/ build/ +*.log diff --git a/lm_engine/hf_models/models/gpt_crosslayer/layer.py b/lm_engine/hf_models/models/gpt_crosslayer/layer.py index be8c672ec..f2fb79856 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/layer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/layer.py @@ -12,7 +12,7 @@ from ...mask import AttentionMaskInfo from ...modeling_utils import apply_rotary_pos_emb, get_mlp_block, get_normalization_function from .config import GPTCrossLayerConfig -from .sequence_mixers import KeyValueProjection, get_sequence_mixer +from .sequence_mixer import KeyValueProjection, get_sequence_mixer class GPTCrossLayerBlock(nn.Module): From 08d3bfe905c5e2f420c333299ae874ff13077f91 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 14:25:46 -0700 Subject: [PATCH 100/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 6123b3202..cde8fa4e2 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -80,8 +80,8 @@ def forward( clear_aux_loss() - attention_mask_info = AttentionMaskInfo( - cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask + attention_mask_info = self._get_attention_mask_info( + x=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask ) if is_generation_cache_enabled(): @@ -260,3 +260,19 @@ def generate( ) return generated_tokens + + def _get_attention_mask_info( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor | None, + max_seqlen: torch.Tensor, + attention_mask: torch.Tensor | None, + ) -> AttentionMaskInfo: + if cu_seqlens is None: + attention_mask_info = AttentionMaskInfo( + batch_size=x.size(0), max_seqlen=x.size(1), attention_mask=attention_mask + ) + else: + attention_mask_info = AttentionMaskInfo(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + + return attention_mask_info From ee3862deeaa3bc602f791f05cf0c851c382e3ed5 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 17:08:39 -0700 Subject: [PATCH 101/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 54 +++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 29c9785a6..2312f1e3f 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -61,39 +61,59 @@ class AttentionMaskInfo: cu_seqlens: torch.Tensor | None = None max_seqlen: int | None = None attention_mask: torch.Tensor | None = None - causal_mask: torch.Tensor | None = None + _causal_mask: torch.Tensor | None = None + device: torch.device | None = None + + def __post_init__(self) -> None: + self._is_ragged = self.cu_seqlens is not None + + if self.batch_size is None: + assert self.max_seqlen is None + assert self.cu_seqlens is not None or self.attention_mask is not None def get_batch_size(self) -> int: - if self.batch_size is not None: - return self.batch_size - - if self.cu_seqlens is not None: - self.batch_size = self.cu_seqlens.size(0) - 1 - elif self.attention_mask is not None: - self.batch_size = self.attention_mask.size(0) - else: - raise NotImplementedError(_ERROR_MESSAGE) + if self.batch_size is None: + if self._is_ragged: + self.batch_size = self.cu_seqlens.size(0) - 1 + elif self.attention_mask is not None: + self.batch_size = self.attention_mask.size(0) + else: + raise NotImplementedError(_ERROR_MESSAGE) return self.batch_size def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor | None: - if return_none_allowed: + if self._is_ragged: return self.cu_seqlens + if return_none_allowed: + return None + if self.cu_seqlens is None: - seqlens = self.attention_mask.sum(dim=-1, dtype=torch.int32) - self.cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) - self.max_seqlen = seqlens.max().item() + if self.attention_mask is None: + B = self.get_batch_size() + S = self.get_max_seqlen(False) + + self.cu_seqlens = torch.arange(0, B * S, S, dtype=torch.int32, device=self.device) + else: + seqlens = self.attention_mask.sum(dim=-1, dtype=torch.int32) + self.cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) + self.max_seqlen = seqlens.max().item() return self.cu_seqlens def get_max_seqlen(self, return_none_allowed: bool = True) -> int | None: - if return_none_allowed: + if self._is_ragged: return self.max_seqlen + if return_none_allowed: + return None + + # this will cache the max_seqlen + self.get_cu_seqlens(False) + if self.max_seqlen is None: - # this will cache the max_seqlen - self.get_cu_seqlens(False) + raise NotImplementedError(_ERROR_MESSAGE) return self.max_seqlen From 21a395898b2cfe4a47db658b0360c92f27a0b8d9 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 17:17:38 -0700 Subject: [PATCH 102/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 40 +++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 2312f1e3f..af96f16a7 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -67,9 +67,23 @@ class AttentionMaskInfo: def __post_init__(self) -> None: self._is_ragged = self.cu_seqlens is not None - if self.batch_size is None: - assert self.max_seqlen is None - assert self.cu_seqlens is not None or self.attention_mask is not None + if self.batch_size is not None: + assert self.max_seqlen is not None + assert self.cu_seqlens is None + assert self.attention_mask is None + assert self.device is None + elif self.cu_seqlens is not None: + assert self.batch_size is None + assert self.max_seqlen is not None + assert self.attention_mask is None + self.device = self.cu_seqlens.device + elif self.attention_mask is not None: + assert self.batch_size is None + assert self.cu_seqlens is None + assert self.max_seqlen is not None + self.device = self.attention_mask.device + + assert self.device is not None def get_batch_size(self) -> int: if self.batch_size is None: @@ -117,9 +131,7 @@ def get_max_seqlen(self, return_none_allowed: bool = True) -> int | None: return self.max_seqlen - def get_attention_mask( - self, return_none_allowed: bool = True, device: torch.device | None = None - ) -> torch.Tensor | None: + def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | None: if return_none_allowed: return self.attention_mask @@ -130,9 +142,9 @@ def get_attention_mask( assert max_seqlen is not None if cu_seqlens is None: - self.attention_mask = torch.ones(batch_size, max_seqlen, device=device, dtype=torch.int32) + self.attention_mask = torch.ones(batch_size, max_seqlen, device=self.device, dtype=torch.int32) else: - attention_mask_flat = torch.ones_like(cu_seqlens, device=device, dtype=torch.int32) + attention_mask_flat = torch.ones_like(cu_seqlens, device=self.device, dtype=torch.int32) self.attention_mask = unpack_sequence( inputs=attention_mask_flat, cu_seqlens=cu_seqlens, output_shape=(batch_size, max_seqlen) ) @@ -148,11 +160,9 @@ def get_causal_mask( _, Q, K = attention_mask.size() L = K - Q - device = attention_mask.device - if Q > 1: - causal_mask = torch.empty((Q, K), dtype=torch.bool, device=device) - causal_mask[:, L:] = torch.tril(torch.ones(Q, K, dtype=torch.bool, device=device)) + causal_mask = torch.empty((Q, K), dtype=torch.bool, device=self.device) + causal_mask[:, L:] = torch.tril(torch.ones(Q, K, dtype=torch.bool, device=self.device)) if L > 0: causal_mask[:, :L] = True @@ -160,17 +170,17 @@ def get_causal_mask( causal_mask = causal_mask[None, ...] causal_mask = causal_mask & attention_mask[:, None, ...].to(torch.bool) elif Q == 1: - causal_mask = attention_mask[:, None, ...].to(dtype=torch.bool, device=device) + causal_mask = attention_mask[:, None, ...].to(dtype=torch.bool, device=self.device) else: raise NotImplementedError(_ERROR_MESSAGE) causal_mask = causal_mask[:, None, ...] - causal_mask = torch.where(causal_mask, ~causal_mask, AttentionMaskInfo._get_mask_value(device, dtype)) + causal_mask = torch.where(causal_mask, ~causal_mask, AttentionMaskInfo._get_mask_value(self.device, dtype)) # this is needed to prevent NaN since SDPA # see issue: https://github.com/pytorch/pytorch/issues/110213 causal_mask = causal_mask * ~torch.all( - causal_mask == AttentionMaskInfo._get_mask_value(device, dtype), dim=-1, keepdim=True + causal_mask == AttentionMaskInfo._get_mask_value(self.device, dtype), dim=-1, keepdim=True ) return attention_mask From 55bb8f3bf3db8971f704488598d6e8e4e8a082d9 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 17:26:28 -0700 Subject: [PATCH 103/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index cde8fa4e2..147a20988 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -268,11 +268,16 @@ def _get_attention_mask_info( max_seqlen: torch.Tensor, attention_mask: torch.Tensor | None, ) -> AttentionMaskInfo: + kwargs = {} if cu_seqlens is None: - attention_mask_info = AttentionMaskInfo( - batch_size=x.size(0), max_seqlen=x.size(1), attention_mask=attention_mask - ) + if attention_mask is None: + kwargs["batch_size"] = x.size(0) + kwargs["max_seqlen"] = x.size(1) + kwargs["device"] = x.device + else: + kwargs["attention_mask"] = attention_mask else: - attention_mask_info = AttentionMaskInfo(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + kwargs["cu_seqlens"] = cu_seqlens + kwargs["max_seqlen"] = max_seqlen - return attention_mask_info + return AttentionMaskInfo(**kwargs) From 09805683db838a658cc28b5579f7964be938b560 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 17:27:51 -0700 Subject: [PATCH 104/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index af96f16a7..798c81d20 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -71,7 +71,6 @@ def __post_init__(self) -> None: assert self.max_seqlen is not None assert self.cu_seqlens is None assert self.attention_mask is None - assert self.device is None elif self.cu_seqlens is not None: assert self.batch_size is None assert self.max_seqlen is not None From a90184fc8ceda801a0c394b2d00a0218130127bc Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 30 Sep 2025 17:34:55 -0700 Subject: [PATCH 105/177] cleanup model_wrappers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 5 ++--- lm_engine/hf_models/mixins/dense/base.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 798c81d20..1a9435d16 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -122,10 +122,9 @@ def get_max_seqlen(self, return_none_allowed: bool = True) -> int | None: if return_none_allowed: return None - # this will cache the max_seqlen - self.get_cu_seqlens(False) - if self.max_seqlen is None: + # this will cache the max_seqlen + self.get_cu_seqlens(False) raise NotImplementedError(_ERROR_MESSAGE) return self.max_seqlen diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 4e1aa22e6..807ad0a49 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -85,7 +85,7 @@ def forward( ) -> BaseModelOutputWithPast: hidden_states = self._get_initial_hidden_state(input_ids, position_ids) rope_cos_sin = self._get_rope_cos_sin( - attention_mask_info.get_max_seqlen(), position_ids, dtype=hidden_states.dtype + attention_mask_info.get_max_seqlen(False), position_ids, dtype=hidden_states.dtype ) for block in self.h: From 06fb51ab6f1a425187b1600c3706f1fccdc7ae4f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 1 Oct 2025 00:49:00 -0700 Subject: [PATCH 106/177] rsa torch Signed-off-by: Mayank Mishra --- lm_engine/hf_models/loss.py | 2 +- lm_engine/hf_models/mixins/dense/main.py | 2 +- .../sequence_mixer_blocks/attention.py | 2 -- .../causal_convolution.py | 4 --- .../multihead_latent_attention.py | 2 -- .../models/gpt_crosslayer/sequence_mixer.py | 36 +++++++------------ 6 files changed, 14 insertions(+), 34 deletions(-) diff --git a/lm_engine/hf_models/loss.py b/lm_engine/hf_models/loss.py index d2ead6e46..c2a609ddc 100644 --- a/lm_engine/hf_models/loss.py +++ b/lm_engine/hf_models/loss.py @@ -26,7 +26,7 @@ def get_autoregressive_language_modeling_loss( hidden_states: torch.Tensor | None = None, vocab_weight: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None, - use_padding_free_transformer: bool = False, + use_padding_free_transformer: bool = True, reduction: str = "mean", shift_logits_and_labels: bool = True, tensor_parallel_enabled: bool = False, diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 147a20988..c999503de 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -126,7 +126,7 @@ def forward( hidden_states=None, vocab_weight=None, cu_seqlens=cu_seqlens, - use_padding_free_transformer=self.use_padding_free_transformer, + use_padding_free_transformer=True, reduction=reduction, shift_logits_and_labels=True, tensor_parallel_enabled=False, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 8f98df6a2..8ca4ce6cf 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -81,7 +81,6 @@ def __init__( num_layers: int, causal: bool, layer_idx: int, - use_padding_free_transformer: bool, ) -> Attention: super().__init__() @@ -91,7 +90,6 @@ def __init__( self.num_key_value_heads = num_key_value_heads self.add_bias = add_bias self.qkv_bias = qkv_bias - self.use_padding_free_transformer = use_padding_free_transformer self.sliding_window = sliding_window self.head_dim = divide_if_divisible( diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py index 6b8de7b21..99dd87cf7 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py @@ -147,13 +147,9 @@ def __init__( init_method: str, num_layers: int, layer_idx: int, - use_padding_free_transformer: bool, ) -> CausalConvolution: super().__init__() - if use_padding_free_transformer: - raise NotImplementedError() - self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py index 6afba02bb..6dcb9e4d7 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py @@ -38,7 +38,6 @@ def __init__( num_layers: int, causal: bool, layer_idx: int, - use_padding_free_transformer: bool, normalization_function: str, layer_norm_epsilon: float = 1e-5, ) -> MultiHeadLatentAttention: @@ -49,7 +48,6 @@ def __init__( self.num_heads = num_attention_heads self.head_dim = head_dim self.add_bias = add_bias - self.use_padding_free_transformer = use_padding_free_transformer self.query_compression_size = query_compression_size self.key_value_compression_size = key_value_compression_size self.position_embedding_type = position_embedding_type diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py index d8aed0019..d2770e36a 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py @@ -88,37 +88,25 @@ def forward( query = self.q_attn(hidden_states) query = query.view(*hidden_states.size()[:-1], self.num_heads, -1) - if self.use_padding_free_transformer: - if self.position_embedding_type == "rope": + if self.position_embedding_type == "rope": + if self.use_padding_free_transformer: query = apply_rotary_pos_emb(query, rope_cos_sin) - - hidden_states = flash_attention( - q=query, - k=key, - v=value, - attention_mask_info=attention_mask_info, - causal=self.causal, - dropout=self.softmax_dropout_p if self.training else 0, - softmax_scale=self.attention_multiplier, - ) - else: - if self.position_embedding_type == "rope": + else: # TODO avoid this extra transpose query = query.transpose(1, 2) query = apply_rotary_pos_emb(query, rope_cos_sin) query = query.transpose(1, 2) - hidden_states = flash_attention( - q=query, - k=key, - v=value, - attention_mask_info=attention_mask_info, - causal=self.causal, - dropout=self.softmax_dropout_p if self.training else 0, - softmax_scale=self.attention_multiplier, - ) + hidden_states = flash_attention( + q=query, + k=key, + v=value, + attention_mask_info=attention_mask_info, + causal=self.causal, + dropout=self.softmax_dropout_p if self.training else 0, + softmax_scale=self.attention_multiplier, + ) else: - query = query.view(*query.size()[:-1], self.num_heads, -1) query = query.transpose(1, 2) if self.position_embedding_type == "rope": From ee5bb333ac9ebf7f593d64552b82280cf820f733 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 15:37:18 -0700 Subject: [PATCH 107/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 45 +++++++++---------------------------- 1 file changed, 10 insertions(+), 35 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 1a9435d16..7505d26be 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -16,39 +16,7 @@ if is_fma_available(): from fma import KernelBackend - from fma import pack_sequence as _pack_sequence - from fma import unpack_sequence as _unpack_sequence - - -def pack_sequence( - inputs: torch.Tensor | list[torch.Tensor], cu_seqlens: torch.Tensor -) -> torch.Tensor | list[torch.Tensor]: - kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch - - inputs = _pack_sequence( - inputs=inputs, - cu_seqlens=cu_seqlens, - kernel_backend_forward=kernel_backend, - kernel_backend_backward=kernel_backend, - ) - - return inputs - - -def unpack_sequence( - inputs: torch.Tensor | list[torch.Tensor], cu_seqlens: torch.Tensor, output_shape: tuple[int] -) -> torch.Tensor | list[torch.Tensor]: - kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch - - inputs = _unpack_sequence( - inputs=inputs, - cu_seqlens=cu_seqlens, - output_shape=output_shape, - kernel_backend_forward=kernel_backend, - kernel_backend_backward=kernel_backend, - ) - - return inputs + from fma import unpack_sequence as unpack_sequence _ERROR_MESSAGE = "code is not supposed to reach here" @@ -142,9 +110,16 @@ def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | if cu_seqlens is None: self.attention_mask = torch.ones(batch_size, max_seqlen, device=self.device, dtype=torch.int32) else: - attention_mask_flat = torch.ones_like(cu_seqlens, device=self.device, dtype=torch.int32) + kernel_backend = ( + KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch + ) + self.attention_mask = unpack_sequence( - inputs=attention_mask_flat, cu_seqlens=cu_seqlens, output_shape=(batch_size, max_seqlen) + inputs=torch.ones_like(cu_seqlens, device=self.device, dtype=torch.int32), + cu_seqlens=cu_seqlens, + output_shape=(batch_size, max_seqlen), + kernel_backend_forward=kernel_backend, + kernel_backend_backward=kernel_backend, ) return self.attention_mask From eb339e6c495361c2b98e77cf702bec368096e9e8 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 15:42:17 -0700 Subject: [PATCH 108/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py index d2770e36a..98291f2f4 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py @@ -125,8 +125,6 @@ def forward( enable_gqa=True, ) - del query, key, value - hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.flatten(-2, -1) From 1297653aedb9df101f2d39a5066bdc977f8978c8 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 15:42:48 -0700 Subject: [PATCH 109/177] merge Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/attention.py | 4 ---- .../sequence_mixer_blocks/multihead_latent_attention.py | 4 ---- .../modeling_utils_TP/sequence_mixer_blocks/attention.py | 4 ---- 3 files changed, 12 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 8ca4ce6cf..dd3adb2ae 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -202,8 +202,6 @@ def forward( sliding_window=self.sliding_window, ) - del query, key, value - hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) hidden_states = hidden_states.view(*output_shape) else: @@ -220,8 +218,6 @@ def forward( enable_gqa=True, ) - del query, key, value - batch_size = hidden_states.shape[0] hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py index 6dcb9e4d7..83933c29b 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py @@ -180,8 +180,6 @@ def forward( sliding_window=self.sliding_window, ) - del query, key, value - hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) hidden_states = hidden_states.view(*output_shape) else: @@ -205,8 +203,6 @@ def forward( enable_gqa=True, ) - del query, key, value - batch_size = hidden_states.shape[0] hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) diff --git a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py index 67bcb0cad..be482c4a7 100644 --- a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py @@ -197,8 +197,6 @@ def forward( softmax_scale=self.attention_multiplier, ) - del query, key, value - hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) hidden_states = hidden_states.view(*output_shape) else: @@ -215,8 +213,6 @@ def forward( enable_gqa=True, ) - del query, key, value - batch_size = hidden_states.shape[0] hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) From b18fd58826f769cb1f11465bc5730875a90f2132 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 15:43:53 -0700 Subject: [PATCH 110/177] merge Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/multihead_latent_attention.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py index 83933c29b..e79ba9d8d 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py @@ -117,10 +117,6 @@ def forward( use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) - if self.use_padding_free_transformer: - assert use_flash_attention_2 or use_flash_attention_3 - assert past_key_values is None - query = self.query_down_projection(hidden_states) query = self.query_ln(query) From 62406584ed44b66f0c10ad00a480c5063005d795 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 15:52:06 -0700 Subject: [PATCH 111/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 5 --- .../sequence_mixer_blocks/attention.py | 34 +++---------------- 2 files changed, 4 insertions(+), 35 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index c999503de..fd4ceb9c3 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -75,9 +75,6 @@ def forward( assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" - if use_cache or past_key_values is not None: - raise NotImplementedError("KV caching is not supported with padding_free transformer") - clear_aux_loss() attention_mask_info = self._get_attention_mask_info( @@ -163,8 +160,6 @@ def generate( top_p: float | None = None, **kwargs, ) -> torch.Tensor: - assert not self.use_padding_free_transformer - has_attention_mask = attention_mask is not None min_tokens_to_keep = 1 diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index dd3adb2ae..2d5a4ed3d 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -142,33 +142,16 @@ def forward( use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) - if self.use_padding_free_transformer: - assert use_flash_attention_2 or use_flash_attention_3 - assert past_key_values is None - - total_q = hidden_states.shape[0] - input_shape = (total_q, self.num_key_value_heads, -1) - output_shape = (total_q, -1, self.head_dim) - else: - batch_size, query_length = hidden_states.shape[:-1] - - input_shape = (batch_size, query_length, self.num_key_value_heads, -1) - output_shape = (batch_size, query_length, -1, self.head_dim) + T = hidden_states.size(0) hidden_states = self.c_attn(hidden_states) - - hidden_states = hidden_states.view(*input_shape) + hidden_states = hidden_states.view(T, self.num_key_value_heads, -1) query, key, value = hidden_states.split( ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 ) - query = query.reshape(*output_shape) - - if not self.use_padding_free_transformer: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) + query = query.reshape(T, -1, self.head_dim) if self.position_embedding_type == "rope": query = apply_rotary_pos_emb(query, rope_cos_sin) @@ -178,15 +161,6 @@ def forward( key, value = past_key_values.update(key_states=key, value_states=value, layer_idx=self.layer_idx) if use_flash_attention_2 or use_flash_attention_3: - if self.use_padding_free_transformer: - output_shape = (-1, self.hidden_size) - else: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - output_shape = (batch_size, query_length, -1) - query = wait_for_ACT(query, wait_in_forward=True, wait_in_backward=False) key = wait_for_ACT(key, wait_in_forward=True, wait_in_backward=False) value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) @@ -203,7 +177,7 @@ def forward( ) hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(-1, self.hidden_size) else: assert self.sliding_window is None From 8c8259136200e5167585ea43e890a2603784c67f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 16:26:57 -0700 Subject: [PATCH 112/177] merge Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/attention.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 2d5a4ed3d..5ce41e997 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -139,9 +139,6 @@ def forward( attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, ) -> torch.Tensor: - use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) - use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) - T = hidden_states.size(0) hidden_states = self.c_attn(hidden_states) @@ -160,7 +157,7 @@ def forward( if past_key_values is not None: key, value = past_key_values.update(key_states=key, value_states=value, layer_idx=self.layer_idx) - if use_flash_attention_2 or use_flash_attention_3: + if is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3): query = wait_for_ACT(query, wait_in_forward=True, wait_in_backward=False) key = wait_for_ACT(key, wait_in_forward=True, wait_in_backward=False) value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) From 47aa2c4c5392bf19cbba32e4d303487947d7af73 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 16:29:00 -0700 Subject: [PATCH 113/177] merge Signed-off-by: Mayank Mishra --- .../multihead_latent_attention.py | 27 +++++-------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py index e79ba9d8d..eabf4bb03 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/multihead_latent_attention.py @@ -114,9 +114,6 @@ def forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, ) -> torch.Tensor: - use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) - use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) - query = self.query_down_projection(hidden_states) query = self.query_ln(query) @@ -140,24 +137,12 @@ def forward( key = self.key_up_projection(key) value = self.value_up_projection(value) - if use_flash_attention_2 or use_flash_attention_3: - if self.use_padding_free_transformer: - total_q = query.shape[0] - - query = query.view(total_q, self.num_heads, -1) - key = key.view(total_q, self.num_heads, -1) - value = value.view(total_q, self.num_heads, -1) - - output_shape = (-1, self.hidden_size) - else: - batch_size, query_length = query.shape[:-1] - key_length = key.shape[1] - - query = query.view(batch_size, query_length, self.num_heads, -1) - key = key.view(batch_size, key_length, self.num_heads, -1) - value = value.view(batch_size, key_length, self.num_heads, -1) + if is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3): + T = query.size(0) - output_shape = (batch_size, query_length, -1) + query = query.view(T, self.num_heads, -1) + key = key.view(T, self.num_heads, -1) + value = value.view(T, self.num_heads, -1) query = wait_for_ACT(query, wait_in_forward=True, wait_in_backward=False) key = wait_for_ACT(key, wait_in_forward=True, wait_in_backward=False) @@ -177,7 +162,7 @@ def forward( ) hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(-1, self.hidden_size) else: assert self.sliding_window is None From efc4862c45dd33eae525f33135a2f245617f183d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 16:33:40 -0700 Subject: [PATCH 114/177] merge Signed-off-by: Mayank Mishra --- lm_engine/data/utils.py | 47 ++++++++++++++++- lm_engine/hf_models/__init__.py | 2 +- lm_engine/hf_models/cache/__init__.py | 21 +++++++- lm_engine/hf_models/utils.py | 72 --------------------------- 4 files changed, 67 insertions(+), 75 deletions(-) delete mode 100644 lm_engine/hf_models/utils.py diff --git a/lm_engine/data/utils.py b/lm_engine/data/utils.py index 63351c82b..1ec7f95de 100644 --- a/lm_engine/data/utils.py +++ b/lm_engine/data/utils.py @@ -8,7 +8,52 @@ import torch from ..enums import LossMask -from ..hf_models import convert_padding_free_lists_to_tensors + + +def _check_list_type(list_of_list: list[list[int | float]] | None, error_message: str) -> None: + if list_of_list is None: + return + + assert isinstance(list_of_list, list), error_message + assert isinstance(list_of_list[0], list), error_message + + +def _flatten_and_convert_to_tensors(x: list[int], device: torch.device) -> torch.Tensor: + y = [] + for sequence in x: + y.extend(sequence) + + return torch.tensor(y, device=device) + + +def convert_padding_free_lists_to_tensors( + input_ids: list[list[int]] | None = None, + position_ids: list[list[int]] | None = None, + labels: list[list[int]] | None = None, + device: torch.device = None, +) -> tuple[torch.Tensor | int]: + + # check input types are correct + error_message = "{variable} should be of type List[List[{dtype}]]" + _check_list_type(input_ids, error_message.format(variable="input_ids", dtype="int")) + _check_list_type(position_ids, error_message.format(variable="position_ids", dtype="int")) + _check_list_type(labels, error_message.format(variable="labels", dtype="int")) + + # prepare inputs for the model + seqlens = torch.tensor([0] + [len(x) for x in input_ids], device=device) + cu_seqlens = seqlens.cumsum(dim=-1).to(torch.int32) + max_seqlen = seqlens.max().item() + + if position_ids is None: + position_ids = [list(range(len(x))) for x in input_ids] + position_ids = _flatten_and_convert_to_tensors(position_ids, device) + + input_ids = _flatten_and_convert_to_tensors(input_ids, device) + + if labels is not None: + labels = _flatten_and_convert_to_tensors(labels, device) + + return input_ids, position_ids, labels, cu_seqlens, max_seqlen def collate_fn( diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index 72a0e2133..6155f3e9f 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -2,6 +2,7 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +from .cache import disable_generation_cache from .config import CommonConfig from .loss import get_autoregressive_language_modeling_loss, is_aux_loss_zero from .mixins import CausalLMOutputWithPast, PipelineParallelInput, PipelineParallelOutput @@ -30,7 +31,6 @@ ) from .register_hf import get_model_parallel_class, is_custom_model, register_model_classes from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts -from .utils import convert_padding_free_lists_to_tensors, disable_generation_cache register_model_classes() diff --git a/lm_engine/hf_models/cache/__init__.py b/lm_engine/hf_models/cache/__init__.py index 7bbb5a15e..4b3b9431f 100644 --- a/lm_engine/hf_models/cache/__init__.py +++ b/lm_engine/hf_models/cache/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Iterable +from typing import Any, Iterable import torch @@ -53,3 +53,22 @@ def get_seq_length(self, layer_idx: int = 0) -> int: def reorder_cache(self, beam_idx: torch.Tensor) -> None: for cache in self.cache: cache.reorder_cache(beam_idx) + + +_IS_GENERATION_CACHE_ENABLED: bool = True + + +class disable_generation_cache: + def __enter__(self) -> Any: + global _IS_GENERATION_CACHE_ENABLED + self.original = _IS_GENERATION_CACHE_ENABLED + + _IS_GENERATION_CACHE_ENABLED = False + + def __exit__(self, exception_type, exception_value, exception_traceback) -> Any: + global _IS_GENERATION_CACHE_ENABLED + _IS_GENERATION_CACHE_ENABLED = self.original + + +def is_generation_cache_enabled() -> bool: + return _IS_GENERATION_CACHE_ENABLED diff --git a/lm_engine/hf_models/utils.py b/lm_engine/hf_models/utils.py deleted file mode 100644 index c16f3a17b..000000000 --- a/lm_engine/hf_models/utils.py +++ /dev/null @@ -1,72 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from typing import Any - -import torch - - -def convert_padding_free_lists_to_tensors( - input_ids: list[list[int]] | None = None, - position_ids: list[list[int]] | None = None, - labels: list[list[int]] | None = None, - device: torch.device = None, -) -> tuple[torch.Tensor | int]: - - # check input types are correct - error_message = "{variable} should be of type List[List[{dtype}]]" - _check_list_type(input_ids, error_message.format(variable="input_ids", dtype="int")) - _check_list_type(position_ids, error_message.format(variable="position_ids", dtype="int")) - _check_list_type(labels, error_message.format(variable="labels", dtype="int")) - - # prepare inputs for the model - seqlens = torch.tensor([0] + [len(x) for x in input_ids], device=device) - cu_seqlens = seqlens.cumsum(dim=-1).to(torch.int32) - max_seqlen = seqlens.max().item() - - if position_ids is None: - position_ids = [list(range(len(x))) for x in input_ids] - position_ids = _flatten_and_convert_to_tensors(position_ids, device) - - input_ids = _flatten_and_convert_to_tensors(input_ids, device) - - if labels is not None: - labels = _flatten_and_convert_to_tensors(labels, device) - - return input_ids, position_ids, labels, cu_seqlens, max_seqlen - - -def _check_list_type(list_of_list: list[list[int | float]] | None, error_message: str) -> None: - if list_of_list is None: - return - - assert isinstance(list_of_list, list), error_message - assert isinstance(list_of_list[0], list), error_message - - -def _flatten_and_convert_to_tensors(x: list[int], device: torch.device) -> torch.Tensor: - y = [] - for sequence in x: - y.extend(sequence) - - return torch.tensor(y, device=device) - - -_IS_GENERATION_CACHE_ENABLED: bool = True - - -class disable_generation_cache: - def __enter__(self) -> Any: - global _IS_GENERATION_CACHE_ENABLED - self.original = _IS_GENERATION_CACHE_ENABLED - - _IS_GENERATION_CACHE_ENABLED = False - - def __exit__(self, exception_type, exception_value, exception_traceback) -> Any: - global _IS_GENERATION_CACHE_ENABLED - _IS_GENERATION_CACHE_ENABLED = self.original - - -def is_generation_cache_enabled() -> bool: - return _IS_GENERATION_CACHE_ENABLED From 0bd6eb3a9a1d76fee187977bd389592553e99d0f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 16:34:08 -0700 Subject: [PATCH 115/177] merge Signed-off-by: Mayank Mishra --- lm_engine/data/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/data/utils.py b/lm_engine/data/utils.py index 1ec7f95de..26d185964 100644 --- a/lm_engine/data/utils.py +++ b/lm_engine/data/utils.py @@ -26,7 +26,7 @@ def _flatten_and_convert_to_tensors(x: list[int], device: torch.device) -> torch return torch.tensor(y, device=device) -def convert_padding_free_lists_to_tensors( +def _convert_padding_free_lists_to_tensors( input_ids: list[list[int]] | None = None, position_ids: list[list[int]] | None = None, labels: list[list[int]] | None = None, @@ -104,7 +104,7 @@ def collate_fn( input_ids[-1].extend([eos_token_id] * tokens_to_add) labels[-1].extend([labels_mask_value] * tokens_to_add) - input_ids, position_ids, _, labels, cu_seqlens, max_seqlen = convert_padding_free_lists_to_tensors( + input_ids, position_ids, _, labels, cu_seqlens, max_seqlen = _convert_padding_free_lists_to_tensors( input_ids=input_ids, labels=labels, device=device ) From 02eb59eab293f255b7a3e5b06a4af5ff459be5da Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 16:35:32 -0700 Subject: [PATCH 116/177] merge Signed-off-by: Mayank Mishra --- .../hf_models/models/gpt_crosslayer/sequence_mixer.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py index 98291f2f4..d588eb4fb 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/sequence_mixer.py @@ -89,13 +89,7 @@ def forward( query = query.view(*hidden_states.size()[:-1], self.num_heads, -1) if self.position_embedding_type == "rope": - if self.use_padding_free_transformer: - query = apply_rotary_pos_emb(query, rope_cos_sin) - else: - # TODO avoid this extra transpose - query = query.transpose(1, 2) - query = apply_rotary_pos_emb(query, rope_cos_sin) - query = query.transpose(1, 2) + query = apply_rotary_pos_emb(query, rope_cos_sin) hidden_states = flash_attention( q=query, From 07fb5fd9abb109f76acf724797d7a62c39702a81 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:03:43 -0700 Subject: [PATCH 117/177] merge Signed-off-by: Mayank Mishra --- tests/hf_models/single_gpu/gpt_base_test.py | 128 ++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/tests/hf_models/single_gpu/gpt_base_test.py b/tests/hf_models/single_gpu/gpt_base_test.py index 540fdd929..579437443 100644 --- a/tests/hf_models/single_gpu/gpt_base_test.py +++ b/tests/hf_models/single_gpu/gpt_base_test.py @@ -18,6 +18,52 @@ class GPTBaseAttentionTest(TestCommons): + @parameterized.expand( + TestCommons.make_args_matrix( + [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] + ) + ) + def test_sdpa_padding_free_transformer_equivalence( + self, device: torch.device, position_embedding_type: str, dtype: torch.dtype + ) -> None: + self.skip_test_if_device_unavailable(device) + + set_seed(SEED) + + config = self.get_dense_test_config(position_embedding_type, num_layers=1) + + sdpa_model = self.from_config(config, dtype=dtype).to(device) + flash_model = self.from_config(config, dtype=dtype, use_padding_free_transformer=True).to(device) + + sdpa_model.eval() + flash_model.eval() + + flash_model.load_state_dict(sdpa_model.state_dict()) + + input_ids, attention_mask, labels = self.get_dummy_inputs(device) + sdpa_output = sdpa_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + attention_mask = attention_mask.to(torch.bool) + sdpa_logits = sdpa_output.logits + sdpa_logits = torch.cat([sdpa_logits[i, ex, :] for i, ex in enumerate(attention_mask)]) + sdpa_loss = sdpa_output.loss + + with enable_kernels([Kernel.flash_attention_2]): + input_ids, attention_mask, labels = self.get_dummy_inputs(device, return_list=True) + flash_output = flash_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + flash_logits = flash_output.logits + flash_loss = flash_output.loss + + self.assert_equal_tensors( + sdpa_logits, + flash_logits, + False, + rtol_float16=1e-3, + atol_float16=3e-4, + rtol_bfloat16=5e-3, + atol_bfloat16=5e-3, + ) + self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=1.2e-4, rtol_float32=0) + @parameterized.expand( TestCommons.make_args_matrix( [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] @@ -60,6 +106,88 @@ def test_sdpa_flash_attention_equivalence( ) self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=1.2e-4, rtol_float32=0) + @parameterized.expand( + TestCommons.make_args_matrix( + [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] + ) + ) + def test_padding_free_transformer_with_list_and_tensor( + self, device: torch.device, position_embedding_type: str, dtype: torch.dtype + ) -> None: + self.skip_test_if_device_unavailable(device) + + set_seed(SEED) + + config = self.get_dense_test_config(position_embedding_type, num_layers=1) + + model = self.from_config(config, dtype=dtype, use_padding_free_transformer=True).to(device) + model.eval() + + with enable_kernels([Kernel.flash_attention_2]): + input_ids, attention_mask, labels = self.get_dummy_inputs(device, return_list=True) + list_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + list_logits = list_output.logits + list_loss = list_output.loss + + seqlens = torch.tensor([0] + [len(i) for i in input_ids]) + cu_seqlens = seqlens.cumsum(dim=-1).to(device, torch.int32) + max_seqlen = seqlens.max().item() + position_ids = torch.tensor( + list(itertools.chain(*[list(range(len(i))) for i in input_ids])), device=device + ) + input_ids = torch.tensor(list(itertools.chain(*input_ids)), device=device) + labels = torch.tensor(list(itertools.chain(*labels)), device=device) + tensor_output = model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + tensor_logits = tensor_output.logits + tensor_loss = tensor_output.loss + + self.assert_equal_tensors(list_logits, tensor_logits, True) + self.assert_equal_tensors(list_loss, tensor_loss, True) + + @parameterized.expand( + TestCommons.make_args_matrix( + [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] + ) + ) + def test_sdpa_flash_enabled(self, device: torch.device, position_embedding_type: str, dtype: torch.dtype) -> None: + self.skip_test_if_device_unavailable(device) + + set_seed(SEED) + + config = self.get_dense_test_config(position_embedding_type, num_layers=1) + + model = self.from_config(config, dtype=dtype).to(device) + model.eval() + + input_ids, _, labels = self.get_dummy_inputs(device) + attention_mask = torch.ones_like(input_ids, dtype=torch.int, device=device) + + sdpa_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + sdpa_logits = sdpa_output.logits + sdpa_loss = sdpa_output.loss + + flash_output = model(input_ids=input_ids, labels=labels) + flash_logits = flash_output.logits + flash_loss = flash_output.loss + + self.assert_equal_tensors( + sdpa_logits, + flash_logits, + False, + rtol_float16=1e-3, + atol_float16=3e-4, + rtol_bfloat16=5e-3, + atol_bfloat16=5e-3, + ) + self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=3.8e-4, rtol_float32=0) + @parameterized.expand( TestCommons.make_args_matrix( [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] From 958f0819fb886979014a5d01c21c42fd38951890 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:05:22 -0700 Subject: [PATCH 118/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index fd4ceb9c3..a44b9ee68 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -10,12 +10,11 @@ from ....enums import Kernel from ....kernels import is_kernel_allowed -from ...cache import GenerationCache +from ...cache import GenerationCache, is_generation_cache_enabled from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero from ...mask import AttentionMaskInfo from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear -from ...utils import is_generation_cache_enabled from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .base import PreTrainedModelMixin From 38383ac0f4427edf772937ba9d1da7a2fb5267f4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:06:12 -0700 Subject: [PATCH 119/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense_TP/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense_TP/base.py b/lm_engine/hf_models/mixins/dense_TP/base.py index 8943f8f97..ca1e841d8 100644 --- a/lm_engine/hf_models/mixins/dense_TP/base.py +++ b/lm_engine/hf_models/mixins/dense_TP/base.py @@ -8,11 +8,10 @@ import torch.nn as nn from ....utils import ProcessGroupManager, divide_if_divisible -from ...cache import GenerationCache +from ...cache import GenerationCache, is_generation_cache_enabled from ...config import CommonConfig from ...modeling_utils import RoPE, YaRNScaledRoPE from ...modeling_utils_TP import Dropout_TP, Embedding_TP, get_normalization_function_TP -from ...utils import is_generation_cache_enabled from ..dense import BaseModelMixin, PreTrainedModelMixin from ..modeling_outputs import BaseModelOutputWithPast from .layer import Block_TP From c5feb37dfee53217b0de9124b943166a018f2b3e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:07:12 -0700 Subject: [PATCH 120/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/models/ladder_residual/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lm_engine/hf_models/models/ladder_residual/base.py b/lm_engine/hf_models/models/ladder_residual/base.py index e952b064b..aad260331 100644 --- a/lm_engine/hf_models/models/ladder_residual/base.py +++ b/lm_engine/hf_models/models/ladder_residual/base.py @@ -4,9 +4,8 @@ import torch -from ...cache import GenerationCache +from ...cache import GenerationCache, is_generation_cache_enabled from ...mixins import BaseModelMixin, BaseModelOutputWithPast, PreTrainedModelMixin -from ...utils import is_generation_cache_enabled from .config import LadderResidualConfig from .layer import LadderResidualBlock From c11141bb38199b052fc199bf40427f0577710fd0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:08:34 -0700 Subject: [PATCH 121/177] merge Signed-off-by: Mayank Mishra --- tests/hf_models/single_gpu/gpt_base_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/hf_models/single_gpu/gpt_base_test.py b/tests/hf_models/single_gpu/gpt_base_test.py index 579437443..ec1e83211 100644 --- a/tests/hf_models/single_gpu/gpt_base_test.py +++ b/tests/hf_models/single_gpu/gpt_base_test.py @@ -33,7 +33,7 @@ def test_sdpa_padding_free_transformer_equivalence( config = self.get_dense_test_config(position_embedding_type, num_layers=1) sdpa_model = self.from_config(config, dtype=dtype).to(device) - flash_model = self.from_config(config, dtype=dtype, use_padding_free_transformer=True).to(device) + flash_model = self.from_config(config, dtype=dtype).to(device) sdpa_model.eval() flash_model.eval() @@ -120,7 +120,7 @@ def test_padding_free_transformer_with_list_and_tensor( config = self.get_dense_test_config(position_embedding_type, num_layers=1) - model = self.from_config(config, dtype=dtype, use_padding_free_transformer=True).to(device) + model = self.from_config(config, dtype=dtype).to(device) model.eval() with enable_kernels([Kernel.flash_attention_2]): From fcf6761254c0ea83ad638a44e0921bd56ef481a1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:10:01 -0700 Subject: [PATCH 122/177] merge Signed-off-by: Mayank Mishra --- tests/hf_models/single_gpu/gpt_base_test.py | 37 --------------------- 1 file changed, 37 deletions(-) diff --git a/tests/hf_models/single_gpu/gpt_base_test.py b/tests/hf_models/single_gpu/gpt_base_test.py index ec1e83211..4a2388034 100644 --- a/tests/hf_models/single_gpu/gpt_base_test.py +++ b/tests/hf_models/single_gpu/gpt_base_test.py @@ -151,43 +151,6 @@ def test_padding_free_transformer_with_list_and_tensor( self.assert_equal_tensors(list_logits, tensor_logits, True) self.assert_equal_tensors(list_loss, tensor_loss, True) - @parameterized.expand( - TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] - ) - ) - def test_sdpa_flash_enabled(self, device: torch.device, position_embedding_type: str, dtype: torch.dtype) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(position_embedding_type, num_layers=1) - - model = self.from_config(config, dtype=dtype).to(device) - model.eval() - - input_ids, _, labels = self.get_dummy_inputs(device) - attention_mask = torch.ones_like(input_ids, dtype=torch.int, device=device) - - sdpa_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - sdpa_logits = sdpa_output.logits - sdpa_loss = sdpa_output.loss - - flash_output = model(input_ids=input_ids, labels=labels) - flash_logits = flash_output.logits - flash_loss = flash_output.loss - - self.assert_equal_tensors( - sdpa_logits, - flash_logits, - False, - rtol_float16=1e-3, - atol_float16=3e-4, - rtol_bfloat16=5e-3, - atol_bfloat16=5e-3, - ) - self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=3.8e-4, rtol_float32=0) - @parameterized.expand( TestCommons.make_args_matrix( [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] From d4ac5cb42fae40e1038c6d2897482b67fd7579c6 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:12:21 -0700 Subject: [PATCH 123/177] merge Signed-off-by: Mayank Mishra --- tests/hf_models/single_gpu/gpt_base_test.py | 47 --------------------- 1 file changed, 47 deletions(-) diff --git a/tests/hf_models/single_gpu/gpt_base_test.py b/tests/hf_models/single_gpu/gpt_base_test.py index 4a2388034..e575886f5 100644 --- a/tests/hf_models/single_gpu/gpt_base_test.py +++ b/tests/hf_models/single_gpu/gpt_base_test.py @@ -2,8 +2,6 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -import itertools - import torch from parameterized import parameterized from transformers import set_seed @@ -106,51 +104,6 @@ def test_sdpa_flash_attention_equivalence( ) self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=1.2e-4, rtol_float32=0) - @parameterized.expand( - TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] - ) - ) - def test_padding_free_transformer_with_list_and_tensor( - self, device: torch.device, position_embedding_type: str, dtype: torch.dtype - ) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(position_embedding_type, num_layers=1) - - model = self.from_config(config, dtype=dtype).to(device) - model.eval() - - with enable_kernels([Kernel.flash_attention_2]): - input_ids, attention_mask, labels = self.get_dummy_inputs(device, return_list=True) - list_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - list_logits = list_output.logits - list_loss = list_output.loss - - seqlens = torch.tensor([0] + [len(i) for i in input_ids]) - cu_seqlens = seqlens.cumsum(dim=-1).to(device, torch.int32) - max_seqlen = seqlens.max().item() - position_ids = torch.tensor( - list(itertools.chain(*[list(range(len(i))) for i in input_ids])), device=device - ) - input_ids = torch.tensor(list(itertools.chain(*input_ids)), device=device) - labels = torch.tensor(list(itertools.chain(*labels)), device=device) - tensor_output = model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - tensor_logits = tensor_output.logits - tensor_loss = tensor_output.loss - - self.assert_equal_tensors(list_logits, tensor_logits, True) - self.assert_equal_tensors(list_loss, tensor_loss, True) - @parameterized.expand( TestCommons.make_args_matrix( [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] From 4a8fcffb1f7056e27a2bdfd5ee79c4eee7618948 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:26:04 -0700 Subject: [PATCH 124/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 5 +++++ lm_engine/hf_models/mixins/dense/main.py | 1 - 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 7505d26be..19fc0d9f1 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -124,6 +124,11 @@ def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | return self.attention_mask + def get_position_ids(self) -> torch.Tensor: + attention_mask = self.get_attention_mask() + position_ids = attention_mask.sum(-1) + return position_ids + def get_causal_mask( self, return_none_allowed: bool = True, dtype: torch.dtype | None = None ) -> torch.Tensor | None: diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index a44b9ee68..0760398a8 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -72,7 +72,6 @@ def forward( assert return_dict assert inputs_embeds is None assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" clear_aux_loss() From cb63b7948acf9dc6684f45752da42aa294935182 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:28:22 -0700 Subject: [PATCH 125/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 19fc0d9f1..32416c360 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -125,7 +125,7 @@ def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | return self.attention_mask def get_position_ids(self) -> torch.Tensor: - attention_mask = self.get_attention_mask() + attention_mask = self.get_attention_mask(False) position_ids = attention_mask.sum(-1) return position_ids From 224ab27d376742b7ed8e2537ab5322439aea4263 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:29:29 -0700 Subject: [PATCH 126/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 0760398a8..22caa0573 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -71,7 +71,6 @@ def forward( ) -> CausalLMOutputWithPast: assert return_dict assert inputs_embeds is None - assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" clear_aux_loss() @@ -79,6 +78,9 @@ def forward( x=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask ) + if position_ids is None: + position_ids = attention_mask_info.get_position_ids() + if is_generation_cache_enabled(): past_key_values = ( GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values From ff5dc5503decdcb161bee7122663c109a66c535d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:33:02 -0700 Subject: [PATCH 127/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 32416c360..2420c558c 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -43,11 +43,13 @@ def __post_init__(self) -> None: assert self.batch_size is None assert self.max_seqlen is not None assert self.attention_mask is None + self.device = self.cu_seqlens.device elif self.attention_mask is not None: assert self.batch_size is None assert self.cu_seqlens is None - assert self.max_seqlen is not None + assert self.max_seqlen is None + self.device = self.attention_mask.device assert self.device is not None From f2c860db078001faa6153dd0aae95a3f8446f3e2 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:39:13 -0700 Subject: [PATCH 128/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 2420c558c..fc2bc7b90 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -128,7 +128,7 @@ def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | def get_position_ids(self) -> torch.Tensor: attention_mask = self.get_attention_mask(False) - position_ids = attention_mask.sum(-1) + position_ids = attention_mask.cumsum(-1) return position_ids def get_causal_mask( From 764c233fbd12461955081b0ccb4da78caaba204f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 17:45:25 -0700 Subject: [PATCH 129/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index fc2bc7b90..6b2432a9b 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -95,10 +95,10 @@ def get_max_seqlen(self, return_none_allowed: bool = True) -> int | None: if self.max_seqlen is None: # this will cache the max_seqlen self.get_cu_seqlens(False) + return self.max_seqlen + else: raise NotImplementedError(_ERROR_MESSAGE) - return self.max_seqlen - def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | None: if return_none_allowed: return self.attention_mask From d173775216997a511fd9551eed39ac57b93e847c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 8 Oct 2025 19:32:52 -0700 Subject: [PATCH 130/177] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 39 +++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 6b2432a9b..048b545fc 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -15,8 +15,7 @@ if is_fma_available(): - from fma import KernelBackend - from fma import unpack_sequence as unpack_sequence + from fma import KernelBackend, pack_sequence, unpack_sequence _ERROR_MESSAGE = "code is not supposed to reach here" @@ -112,16 +111,9 @@ def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | if cu_seqlens is None: self.attention_mask = torch.ones(batch_size, max_seqlen, device=self.device, dtype=torch.int32) else: - kernel_backend = ( - KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch - ) - - self.attention_mask = unpack_sequence( + self.attention_mask = self.unpack_sequence( inputs=torch.ones_like(cu_seqlens, device=self.device, dtype=torch.int32), - cu_seqlens=cu_seqlens, output_shape=(batch_size, max_seqlen), - kernel_backend_forward=kernel_backend, - kernel_backend_backward=kernel_backend, ) return self.attention_mask @@ -165,6 +157,33 @@ def get_causal_mask( return attention_mask + def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: + kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch + + inputs = pack_sequence( + inputs=inputs, + cu_seqlens=self.get_cu_seqlens(False), + kernel_backend_forward=kernel_backend, + kernel_backend_backward=kernel_backend, + ) + + return inputs + + def unpack_sequence( + self, inputs: torch.Tensor | list[torch.Tensor], output_shape: tuple[int] + ) -> torch.Tensor | list[torch.Tensor]: + kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch + + inputs = unpack_sequence( + inputs=inputs, + cu_seqlens=self.get_cu_seqlens(False), + output_shape=output_shape, + kernel_backend_forward=kernel_backend, + kernel_backend_backward=kernel_backend, + ) + + return inputs + @classmethod def _get_mask_value(cls, device: torch.device, dtype: torch.dtype) -> torch.Tensor: # torch.where expects a tensor. We use a cache to avoid recreating it every time. From 369a3b3f2ad7df9f1497f560f65998e69ac678c8 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 02:28:36 -0700 Subject: [PATCH 131/177] better Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 169 ++++++++++-------- .../sequence_mixer_blocks/attention.py | 64 +++---- 2 files changed, 124 insertions(+), 109 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 048b545fc..6636020bb 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -30,6 +30,7 @@ class AttentionMaskInfo: attention_mask: torch.Tensor | None = None _causal_mask: torch.Tensor | None = None device: torch.device | None = None + mask_value: torch.Tensor | None = None def __post_init__(self) -> None: self._is_ragged = self.cu_seqlens is not None @@ -53,9 +54,12 @@ def __post_init__(self) -> None: assert self.device is not None + def is_ragged(self) -> bool: + return self._is_ragged + def get_batch_size(self) -> int: if self.batch_size is None: - if self._is_ragged: + if self.is_ragged(): self.batch_size = self.cu_seqlens.size(0) - 1 elif self.attention_mask is not None: self.batch_size = self.attention_mask.size(0) @@ -64,17 +68,14 @@ def get_batch_size(self) -> int: return self.batch_size - def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor | None: - if self._is_ragged: + def get_cu_seqlens(self) -> torch.Tensor | None: + if self.is_ragged(): return self.cu_seqlens - if return_none_allowed: - return None - if self.cu_seqlens is None: if self.attention_mask is None: B = self.get_batch_size() - S = self.get_max_seqlen(False) + S = self.get_max_seqlen() self.cu_seqlens = torch.arange(0, B * S, S, dtype=torch.int32, device=self.device) else: @@ -84,103 +85,113 @@ def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor | Non return self.cu_seqlens - def get_max_seqlen(self, return_none_allowed: bool = True) -> int | None: - if self._is_ragged: + def get_max_seqlen(self) -> int | None: + if self.is_ragged(): return self.max_seqlen - if return_none_allowed: - return None - if self.max_seqlen is None: # this will cache the max_seqlen - self.get_cu_seqlens(False) - return self.max_seqlen - else: - raise NotImplementedError(_ERROR_MESSAGE) + self.get_cu_seqlens() - def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | None: - if return_none_allowed: - return self.attention_mask + if self.max_seqlen is None: + raise NotImplementedError(_ERROR_MESSAGE) - if self.attention_mask is None: - cu_seqlens = self.get_cu_seqlens() - batch_size = self.get_batch_size() - max_seqlen = self.get_max_seqlen() - assert max_seqlen is not None + return self.max_seqlen - if cu_seqlens is None: - self.attention_mask = torch.ones(batch_size, max_seqlen, device=self.device, dtype=torch.int32) - else: - self.attention_mask = self.unpack_sequence( - inputs=torch.ones_like(cu_seqlens, device=self.device, dtype=torch.int32), - output_shape=(batch_size, max_seqlen), - ) + def get_attention_mask(self) -> torch.Tensor | None: + if self.is_ragged() and self.attention_mask is None: + B = self.get_batch_size() + S = self.get_max_seqlen() + + self.attention_mask = self.unpack_sequence( + inputs=torch.ones_like(self.get_cu_seqlens(), device=self.device, dtype=torch.int32), + output_shape=(B, S), + ) return self.attention_mask def get_position_ids(self) -> torch.Tensor: - attention_mask = self.get_attention_mask(False) - position_ids = attention_mask.cumsum(-1) + if self.is_ragged(): + attention_mask = self.get_attention_mask() + position_ids = attention_mask.cumsum(-1) + else: + position_ids = torch.arange(0, self.get_max_seqlen(), device=self.device) + position_ids = position_ids[None, ...].repeat(self.get_batch_size(), 1) + return position_ids - def get_causal_mask( - self, return_none_allowed: bool = True, dtype: torch.dtype | None = None - ) -> torch.Tensor | None: - attention_mask = self.get_attention_mask(return_none_allowed) + def get_causal_mask(self, query_length: int, dtype: torch.dtype) -> torch.Tensor | None: + attention_mask = self.get_attention_mask() - if attention_mask is not None: - _, Q, K = attention_mask.size() - L = K - Q + if attention_mask is None: + return None - if Q > 1: - causal_mask = torch.empty((Q, K), dtype=torch.bool, device=self.device) - causal_mask[:, L:] = torch.tril(torch.ones(Q, K, dtype=torch.bool, device=self.device)) + Q = query_length + K = attention_mask.size(1) + L = K - Q - if L > 0: - causal_mask[:, :L] = True + if Q > 1: + causal_mask = torch.empty((Q, K), dtype=torch.bool, device=self.device) + causal_mask[:, L:] = torch.tril(torch.ones(Q, K, dtype=torch.bool, device=self.device)) - causal_mask = causal_mask[None, ...] - causal_mask = causal_mask & attention_mask[:, None, ...].to(torch.bool) - elif Q == 1: - causal_mask = attention_mask[:, None, ...].to(dtype=torch.bool, device=self.device) - else: - raise NotImplementedError(_ERROR_MESSAGE) + if L > 0: + causal_mask[:, :L] = True - causal_mask = causal_mask[:, None, ...] - causal_mask = torch.where(causal_mask, ~causal_mask, AttentionMaskInfo._get_mask_value(self.device, dtype)) + causal_mask = causal_mask[None, ...] + causal_mask = causal_mask & attention_mask[:, None, ...].to(torch.bool) + elif Q == 1: + causal_mask = attention_mask[:, None, ...].to(dtype=torch.bool, device=self.device) + else: + raise NotImplementedError(_ERROR_MESSAGE) - # this is needed to prevent NaN since SDPA - # see issue: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = causal_mask * ~torch.all( - causal_mask == AttentionMaskInfo._get_mask_value(self.device, dtype), dim=-1, keepdim=True - ) + causal_mask = causal_mask[:, None, ...] + causal_mask = torch.where(causal_mask, ~causal_mask, AttentionMaskInfo._get_mask_value(self.device, dtype)) + + # this is needed to prevent NaN since SDPA + # see issue: https://github.com/pytorch/pytorch/issues/110213 + self._causal_mask = causal_mask * ~torch.all( + causal_mask == AttentionMaskInfo._get_mask_value(self.device, dtype), dim=-1, keepdim=True + ) - return attention_mask + return self._causal_mask def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: - kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch - - inputs = pack_sequence( - inputs=inputs, - cu_seqlens=self.get_cu_seqlens(False), - kernel_backend_forward=kernel_backend, - kernel_backend_backward=kernel_backend, - ) + if self.is_ragged(): + kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch + inputs = pack_sequence( + inputs=inputs, + cu_seqlens=self.get_cu_seqlens(False), + kernel_backend_forward=kernel_backend, + kernel_backend_backward=kernel_backend, + ) + else: + if isinstance(inputs, torch.Tensor): + inputs = inputs.flatten(0, 1) + else: + inputs = [i.flatten(0, 1) for i in inputs] return inputs - def unpack_sequence( - self, inputs: torch.Tensor | list[torch.Tensor], output_shape: tuple[int] - ) -> torch.Tensor | list[torch.Tensor]: - kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch - - inputs = unpack_sequence( - inputs=inputs, - cu_seqlens=self.get_cu_seqlens(False), - output_shape=output_shape, - kernel_backend_forward=kernel_backend, - kernel_backend_backward=kernel_backend, - ) + def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: + B = self.get_batch_size() + S = self.get_max_seqlen() + + if self.is_ragged(): + kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch + other_shape = inputs.size()[1:] if isinstance(inputs, torch.Tensor) else inputs[0].size()[1:] + + inputs = unpack_sequence( + inputs=inputs, + cu_seqlens=self.get_cu_seqlens(False), + output_shape=(B, S, *other_shape), + kernel_backend_forward=kernel_backend, + kernel_backend_backward=kernel_backend, + ) + else: + if isinstance(inputs, torch.Tensor): + inputs = inputs.reshape(B, S, *inputs.size()[1:]) + else: + inputs = [i.reshape(B, S, *i.size()[1:]) for i in inputs] return inputs diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 5ce41e997..1b7f3c6c0 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -133,39 +133,38 @@ def __init__( def forward( self, - hidden_states: torch.Tensor, + x: torch.Tensor, attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, rope_cos_sin: torch.Tensor | None = None, ) -> torch.Tensor: - T = hidden_states.size(0) + T = x.size(0) - hidden_states = self.c_attn(hidden_states) - hidden_states = hidden_states.view(T, self.num_key_value_heads, -1) + x = self.c_attn(x) + x = x.view(T, self.num_key_value_heads, -1) - query, key, value = hidden_states.split( + q, k, v = x.split( ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 ) - query = query.reshape(T, -1, self.head_dim) + q = q.reshape(T, -1, self.head_dim) if self.position_embedding_type == "rope": - query = apply_rotary_pos_emb(query, rope_cos_sin) - key = apply_rotary_pos_emb(key, rope_cos_sin) + q = apply_rotary_pos_emb(q, rope_cos_sin) + k = apply_rotary_pos_emb(k, rope_cos_sin) if past_key_values is not None: - key, value = past_key_values.update(key_states=key, value_states=value, layer_idx=self.layer_idx) + k, v = past_key_values.update(key_states=k, value_states=v, layer_idx=self.layer_idx) if is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3): - query = wait_for_ACT(query, wait_in_forward=True, wait_in_backward=False) - key = wait_for_ACT(key, wait_in_forward=True, wait_in_backward=False) - value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) - - hidden_states = flash_attention( - q=query, - k=key, - v=value, + q = wait_for_ACT(q, wait_in_forward=True, wait_in_backward=False) + k = wait_for_ACT(k, wait_in_forward=True, wait_in_backward=False) + v = wait_for_ACT(v, wait_in_forward=True, wait_in_backward=False) + + x = flash_attention( + q=q, + k=k, + v=v, attention_mask_info=attention_mask_info, causal=self.causal, dropout=self.softmax_dropout_p if self.training else 0, @@ -173,15 +172,20 @@ def forward( sliding_window=self.sliding_window, ) - hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) - hidden_states = hidden_states.view(-1, self.hidden_size) + x = wait_for_ACT(x, wait_in_forward=False, wait_in_backward=True) else: assert self.sliding_window is None + q, k, v = attention_mask_info.unpack_sequence((q, k, v)) - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + attention_mask = attention_mask_info.get_causal_mask(query_length=q.size(-2), dtype=q.dtype) + + x = F.scaled_dot_product_attention( + q, + k, + v, attn_mask=attention_mask, dropout_p=self.softmax_dropout_p if self.training else 0, is_causal=self.causal if attention_mask is None else False, @@ -189,11 +193,11 @@ def forward( enable_gqa=True, ) - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.transpose(1, 2) - hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) + x = x.transpose(1, 2) + x = attention_mask_info.pack_sequence(x) - hidden_states = self.c_proj(hidden_states) - hidden_states = self.dropout(hidden_states) + x = x.flatten(-2, -1) + x = self.c_proj(x) + x = self.dropout(x) - return hidden_states + return x From 935af460548e6be72507f10957a572b07f19e1ae Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 02:32:00 -0700 Subject: [PATCH 132/177] better Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 6c5ba6a04..71761f2c1 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -75,10 +75,8 @@ def _sequence_mixer_forward( hidden_states = self.sequence_mixer( hidden_states, past_key_values=past_key_values, - attention_mask=attention_mask, + attention_mask_info=attention_mask_info, rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) elif self.sequence_mixer_type in ["causal_convolution", "mamba2"]: hidden_states = self.sequence_mixer( From 30f890b8226d731fd34745a72290f3b1cf8c9973 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 03:31:58 -0700 Subject: [PATCH 133/177] better Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 6636020bb..cfb9aabc3 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -77,7 +77,7 @@ def get_cu_seqlens(self) -> torch.Tensor | None: B = self.get_batch_size() S = self.get_max_seqlen() - self.cu_seqlens = torch.arange(0, B * S, S, dtype=torch.int32, device=self.device) + self.cu_seqlens = torch.arange(0, B * S + 1, S, dtype=torch.int32, device=self.device) else: seqlens = self.attention_mask.sum(dim=-1, dtype=torch.int32) self.cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) From 9d6b09520983b03c4cc374adbcab08ccd205119c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 03:36:03 -0700 Subject: [PATCH 134/177] better Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 2 +- .../modeling_utils/sequence_mixer_blocks/attention.py | 3 ++- .../hf_models/modeling_utils/sequence_mixer_blocks/gru.py | 7 ++++--- .../hf_models/modeling_utils/sequence_mixer_blocks/rnn.py | 7 ++++--- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 807ad0a49..4e1aa22e6 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -85,7 +85,7 @@ def forward( ) -> BaseModelOutputWithPast: hidden_states = self._get_initial_hidden_state(input_ids, position_ids) rope_cos_sin = self._get_rope_cos_sin( - attention_mask_info.get_max_seqlen(False), position_ids, dtype=hidden_states.dtype + attention_mask_info.get_max_seqlen(), position_ids, dtype=hidden_states.dtype ) for block in self.h: diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 1b7f3c6c0..2d1604514 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -175,11 +175,12 @@ def forward( x = wait_for_ACT(x, wait_in_forward=False, wait_in_backward=True) else: assert self.sliding_window is None - q, k, v = attention_mask_info.unpack_sequence((q, k, v)) + q, k, v = attention_mask_info.unpack_sequence((q, k, v)) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) + attention_mask = attention_mask_info.get_causal_mask(query_length=q.size(-2), dtype=q.dtype) x = F.scaled_dot_product_attention( diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index 363fde57f..b3da514e6 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -94,10 +94,11 @@ def forward( weight, forget_weight, reset_weight = weight.chunk(3, dim=0) - cu_seqlens = attention_mask_info.get_cu_seqlens() + cu_seqlens = None if attention_mask_info.is_ragged() else attention_mask_info.get_cu_seqlens() + max_seqlen = None if attention_mask_info.is_ragged() else attention_mask_info.get_max_seqlen() x = gru( - input=x, + input=attention_mask_info.unpack_sequence(x), weight=weight, forget_input=x_forget, forget_weight=forget_weight, @@ -106,7 +107,7 @@ def forward( input_state=None if cache_params is None else cache_params.get_cache(self.layer_idx), gradient_clipping=self.gradient_clipping, cu_seqlens=cu_seqlens, - max_seqlen=attention_mask_info.get_max_seqlen(), + max_seqlen=max_seqlen, kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.gru) else KernelBackend.torch, ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 5528d96a7..0207d0158 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -90,15 +90,16 @@ def forward( if self.scaling_factor != 1: weight = weight * self.scaling_factor - cu_seqlens = attention_mask_info.get_cu_seqlens() + cu_seqlens = None if attention_mask_info.is_ragged() else attention_mask_info.get_cu_seqlens() + max_seqlen = None if attention_mask_info.is_ragged() else attention_mask_info.get_max_seqlen() x = rnn( - input=x, + input=attention_mask_info.unpack_sequence(x), weight=weight, input_state=None if cache_params is None else cache_params.get_cache(self.layer_idx), gradient_clipping=self.gradient_clipping, cu_seqlens=cu_seqlens, - max_seqlen=attention_mask_info.get_max_seqlen(), + max_seqlen=max_seqlen, kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.rnn) else KernelBackend.torch, ) From f5d7175f4a3eb6cecf04623a99e46bf49ad6b448 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 03:37:56 -0700 Subject: [PATCH 135/177] better Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/utils.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py index 84de2418a..a672b82ea 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py @@ -39,61 +39,61 @@ def flash_attention( if sliding_window is not None and k.size(1) > sliding_window: window_size = (sliding_window, sliding_window) - cu_seqlens = attention_mask_info.get_cu_seqlens() - max_seqlen = attention_mask_info.get_max_seqlen() + if attention_mask_info.is_ragged(): + assert sliding_window is None + assert q.dim() == 3 - if cu_seqlens is None: - assert q.dim() == 4 + cu_seqlens = attention_mask_info.get_cu_seqlens() + max_seqlen = attention_mask_info.get_max_seqlen() if use_flash_attention_3: - attn_output, _ = flash_attention_3( + attn_output, _ = flash_attention_3_varlen( q=q, k=k, v=v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, softmax_scale=softmax_scale, causal=causal, - window_size=window_size, - softcap=softcap, ) else: - attn_output = flash_attention_2( + attn_output = flash_attention_2_varlen( q=q, k=k, v=v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, - window_size=window_size, - softcap=softcap, ) else: - assert sliding_window is None - assert q.dim() == 3 + assert q.dim() == 4 if use_flash_attention_3: - attn_output, _ = flash_attention_3_varlen( + attn_output, _ = flash_attention_3( q=q, k=k, v=v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, softmax_scale=softmax_scale, causal=causal, + window_size=window_size, + softcap=softcap, ) else: - attn_output = flash_attention_2_varlen( + attn_output = flash_attention_2( q=q, k=k, v=v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, + window_size=window_size, + softcap=softcap, ) return attn_output From 2407caf9d47e3994957a049bdbcee79f01cb3580 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 03:39:55 -0700 Subject: [PATCH 136/177] better Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/utils.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py index a672b82ea..22c1cc0f5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py @@ -39,15 +39,18 @@ def flash_attention( if sliding_window is not None and k.size(1) > sliding_window: window_size = (sliding_window, sliding_window) + assert q.dim() == 3 + assert k.dim() == 3 + assert v.dim() == 3 + if attention_mask_info.is_ragged(): assert sliding_window is None - assert q.dim() == 3 cu_seqlens = attention_mask_info.get_cu_seqlens() max_seqlen = attention_mask_info.get_max_seqlen() if use_flash_attention_3: - attn_output, _ = flash_attention_3_varlen( + x, _ = flash_attention_3_varlen( q=q, k=k, v=v, @@ -59,7 +62,7 @@ def flash_attention( causal=causal, ) else: - attn_output = flash_attention_2_varlen( + x = flash_attention_2_varlen( q=q, k=k, v=v, @@ -72,10 +75,10 @@ def flash_attention( causal=causal, ) else: - assert q.dim() == 4 + q, k, v = attention_mask_info.unpack_sequence(q, k, v) if use_flash_attention_3: - attn_output, _ = flash_attention_3( + x, _ = flash_attention_3( q=q, k=k, v=v, @@ -85,7 +88,7 @@ def flash_attention( softcap=softcap, ) else: - attn_output = flash_attention_2( + x = flash_attention_2( q=q, k=k, v=v, @@ -96,4 +99,6 @@ def flash_attention( softcap=softcap, ) - return attn_output + x = attention_mask_info.pack_sequence(x) + + return x From 205d6aba45c9cf1936e37b225cb6a5ff53dd7950 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 03:40:33 -0700 Subject: [PATCH 137/177] better Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py index 22c1cc0f5..0f0e477f3 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py @@ -75,7 +75,7 @@ def flash_attention( causal=causal, ) else: - q, k, v = attention_mask_info.unpack_sequence(q, k, v) + q, k, v = attention_mask_info.unpack_sequence((q, k, v)) if use_flash_attention_3: x, _ = flash_attention_3( From 656767ca9d8ef708e62f43c840cae10265a99c0f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 03:42:00 -0700 Subject: [PATCH 138/177] better Signed-off-by: Mayank Mishra --- tests/hf_models/single_gpu/gpt_base_test.py | 46 --------------------- 1 file changed, 46 deletions(-) diff --git a/tests/hf_models/single_gpu/gpt_base_test.py b/tests/hf_models/single_gpu/gpt_base_test.py index e575886f5..51e5b90ef 100644 --- a/tests/hf_models/single_gpu/gpt_base_test.py +++ b/tests/hf_models/single_gpu/gpt_base_test.py @@ -16,52 +16,6 @@ class GPTBaseAttentionTest(TestCommons): - @parameterized.expand( - TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] - ) - ) - def test_sdpa_padding_free_transformer_equivalence( - self, device: torch.device, position_embedding_type: str, dtype: torch.dtype - ) -> None: - self.skip_test_if_device_unavailable(device) - - set_seed(SEED) - - config = self.get_dense_test_config(position_embedding_type, num_layers=1) - - sdpa_model = self.from_config(config, dtype=dtype).to(device) - flash_model = self.from_config(config, dtype=dtype).to(device) - - sdpa_model.eval() - flash_model.eval() - - flash_model.load_state_dict(sdpa_model.state_dict()) - - input_ids, attention_mask, labels = self.get_dummy_inputs(device) - sdpa_output = sdpa_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - attention_mask = attention_mask.to(torch.bool) - sdpa_logits = sdpa_output.logits - sdpa_logits = torch.cat([sdpa_logits[i, ex, :] for i, ex in enumerate(attention_mask)]) - sdpa_loss = sdpa_output.loss - - with enable_kernels([Kernel.flash_attention_2]): - input_ids, attention_mask, labels = self.get_dummy_inputs(device, return_list=True) - flash_output = flash_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - flash_logits = flash_output.logits - flash_loss = flash_output.loss - - self.assert_equal_tensors( - sdpa_logits, - flash_logits, - False, - rtol_float16=1e-3, - atol_float16=3e-4, - rtol_bfloat16=5e-3, - atol_bfloat16=5e-3, - ) - self.assert_equal_tensors(sdpa_loss, flash_loss, False, atol_float32=1.2e-4, rtol_float32=0) - @parameterized.expand( TestCommons.make_args_matrix( [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] From 6a3623826d8f643c8da7a3776e90224a0151ecaa Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 05:06:05 -0700 Subject: [PATCH 139/177] better Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 2d1604514..8ba6fe4b6 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -184,9 +184,9 @@ def forward( attention_mask = attention_mask_info.get_causal_mask(query_length=q.size(-2), dtype=q.dtype) x = F.scaled_dot_product_attention( - q, - k, - v, + query=q, + key=k, + value=v, attn_mask=attention_mask, dropout_p=self.softmax_dropout_p if self.training else 0, is_causal=self.causal if attention_mask is None else False, From 4630666ed02efc4b8d6712f981ac89eee9280733 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 05:24:49 -0700 Subject: [PATCH 140/177] better Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index cfb9aabc3..e06d54b34 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -144,14 +144,7 @@ def get_causal_mask(self, query_length: int, dtype: torch.dtype) -> torch.Tensor else: raise NotImplementedError(_ERROR_MESSAGE) - causal_mask = causal_mask[:, None, ...] - causal_mask = torch.where(causal_mask, ~causal_mask, AttentionMaskInfo._get_mask_value(self.device, dtype)) - - # this is needed to prevent NaN since SDPA - # see issue: https://github.com/pytorch/pytorch/issues/110213 - self._causal_mask = causal_mask * ~torch.all( - causal_mask == AttentionMaskInfo._get_mask_value(self.device, dtype), dim=-1, keepdim=True - ) + self._causal_mask = causal_mask[:, None, ...] return self._causal_mask From 96a9f861dc72e8d50897176a2f60576d85006a0d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 11:59:01 -0700 Subject: [PATCH 141/177] better Signed-off-by: Mayank Mishra --- lm_engine/hf_models/loss.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/lm_engine/hf_models/loss.py b/lm_engine/hf_models/loss.py index c2a609ddc..0778be821 100644 --- a/lm_engine/hf_models/loss.py +++ b/lm_engine/hf_models/loss.py @@ -14,6 +14,7 @@ from ..enums import Kernel from ..kernels import is_kernel_allowed from ..utils import ProcessGroupManager, is_fma_available +from .mask import AttentionMaskInfo if is_fma_available(): @@ -23,10 +24,9 @@ def get_autoregressive_language_modeling_loss( lm_logits: torch.Tensor, labels: torch.Tensor, + attention_mask_info: AttentionMaskInfo, hidden_states: torch.Tensor | None = None, vocab_weight: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - use_padding_free_transformer: bool = True, reduction: str = "mean", shift_logits_and_labels: bool = True, tensor_parallel_enabled: bool = False, @@ -40,15 +40,13 @@ def get_autoregressive_language_modeling_loss( labels = labels[..., 1:] - if use_padding_free_transformer: - if shift_logits_and_labels: - assert cu_seqlens is not None + if shift_logits_and_labels: + cu_seqlens = attention_mask_info.get_cu_seqlens() + if cu_seqlens is not None: # this is needed so that the last token of current example doesn't predict first token of next example drop_loss_positions = cu_seqlens[1:-1] - 1 labels[drop_loss_positions] = -100 - else: - assert cu_seqlens is None if is_kernel_allowed(Kernel.fused_linear_cross_entropy): assert lm_logits is None From 5e60c0dc6059f321363bad943332a234ce1988ca Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 12:01:34 -0700 Subject: [PATCH 142/177] better Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 20 +++++++++---------- .../sequence_mixer_blocks/attention.py | 8 ++------ 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index e06d54b34..a007e43c1 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -33,7 +33,7 @@ class AttentionMaskInfo: mask_value: torch.Tensor | None = None def __post_init__(self) -> None: - self._is_ragged = self.cu_seqlens is not None + self._has_cu_seqlens = self.cu_seqlens is not None if self.batch_size is not None: assert self.max_seqlen is not None @@ -54,12 +54,12 @@ def __post_init__(self) -> None: assert self.device is not None - def is_ragged(self) -> bool: - return self._is_ragged + def has_cu_seqlens(self) -> bool: + return self._has_cu_seqlens def get_batch_size(self) -> int: if self.batch_size is None: - if self.is_ragged(): + if self.has_cu_seqlens(): self.batch_size = self.cu_seqlens.size(0) - 1 elif self.attention_mask is not None: self.batch_size = self.attention_mask.size(0) @@ -69,7 +69,7 @@ def get_batch_size(self) -> int: return self.batch_size def get_cu_seqlens(self) -> torch.Tensor | None: - if self.is_ragged(): + if self.has_cu_seqlens(): return self.cu_seqlens if self.cu_seqlens is None: @@ -86,7 +86,7 @@ def get_cu_seqlens(self) -> torch.Tensor | None: return self.cu_seqlens def get_max_seqlen(self) -> int | None: - if self.is_ragged(): + if self.has_cu_seqlens(): return self.max_seqlen if self.max_seqlen is None: @@ -99,7 +99,7 @@ def get_max_seqlen(self) -> int | None: return self.max_seqlen def get_attention_mask(self) -> torch.Tensor | None: - if self.is_ragged() and self.attention_mask is None: + if self.has_cu_seqlens() and self.attention_mask is None: B = self.get_batch_size() S = self.get_max_seqlen() @@ -111,7 +111,7 @@ def get_attention_mask(self) -> torch.Tensor | None: return self.attention_mask def get_position_ids(self) -> torch.Tensor: - if self.is_ragged(): + if self.has_cu_seqlens(): attention_mask = self.get_attention_mask() position_ids = attention_mask.cumsum(-1) else: @@ -149,7 +149,7 @@ def get_causal_mask(self, query_length: int, dtype: torch.dtype) -> torch.Tensor return self._causal_mask def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: - if self.is_ragged(): + if self.has_cu_seqlens(): kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch inputs = pack_sequence( inputs=inputs, @@ -169,7 +169,7 @@ def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Te B = self.get_batch_size() S = self.get_max_seqlen() - if self.is_ragged(): + if self.has_cu_seqlens(): kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch other_shape = inputs.size()[1:] if isinstance(inputs, torch.Tensor) else inputs[0].size()[1:] diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 8ba6fe4b6..83fd1383e 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -157,9 +157,7 @@ def forward( k, v = past_key_values.update(key_states=k, value_states=v, layer_idx=self.layer_idx) if is_kernel_allowed(Kernel.flash_attention_2) or is_kernel_allowed(Kernel.flash_attention_3): - q = wait_for_ACT(q, wait_in_forward=True, wait_in_backward=False) - k = wait_for_ACT(k, wait_in_forward=True, wait_in_backward=False) - v = wait_for_ACT(v, wait_in_forward=True, wait_in_backward=False) + q, k, v = [wait_for_ACT(i, wait_in_forward=True, wait_in_backward=False) for i in (q, k, v)] x = flash_attention( q=q, @@ -177,9 +175,7 @@ def forward( assert self.sliding_window is None q, k, v = attention_mask_info.unpack_sequence((q, k, v)) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) + q, k, v = [i.transpose(1, 2) for i in (q, k, v)] attention_mask = attention_mask_info.get_causal_mask(query_length=q.size(-2), dtype=q.dtype) From c93232aa3ef66c623acdd73b8ab92c79760de863 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 12:21:11 -0700 Subject: [PATCH 143/177] better Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 94 +++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 40 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index a007e43c1..a3860d3e2 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -28,26 +28,27 @@ class AttentionMaskInfo: cu_seqlens: torch.Tensor | None = None max_seqlen: int | None = None attention_mask: torch.Tensor | None = None - _causal_mask: torch.Tensor | None = None device: torch.device | None = None mask_value: torch.Tensor | None = None + causal_mask: torch.Tensor | None = None def __post_init__(self) -> None: self._has_cu_seqlens = self.cu_seqlens is not None + self._has_attention_mask = self.attention_mask is not None if self.batch_size is not None: assert self.max_seqlen is not None - assert self.cu_seqlens is None - assert self.attention_mask is None + assert not self.has_cu_seqlens() + assert not self.has_attention_mask() elif self.cu_seqlens is not None: assert self.batch_size is None assert self.max_seqlen is not None - assert self.attention_mask is None + assert not self.has_attention_mask() self.device = self.cu_seqlens.device - elif self.attention_mask is not None: + elif self.has_attention_mask(): assert self.batch_size is None - assert self.cu_seqlens is None + assert not self.has_cu_seqlens() assert self.max_seqlen is None self.device = self.attention_mask.device @@ -57,40 +58,49 @@ def __post_init__(self) -> None: def has_cu_seqlens(self) -> bool: return self._has_cu_seqlens + def has_attention_mask(self) -> bool: + return self._has_attention_mask + def get_batch_size(self) -> int: if self.batch_size is None: if self.has_cu_seqlens(): self.batch_size = self.cu_seqlens.size(0) - 1 - elif self.attention_mask is not None: + elif self.has_attention_mask(): self.batch_size = self.attention_mask.size(0) else: raise NotImplementedError(_ERROR_MESSAGE) return self.batch_size - def get_cu_seqlens(self) -> torch.Tensor | None: + def get_cu_seqlens(self, return_none_allowed: bool = True) -> torch.Tensor | None: if self.has_cu_seqlens(): return self.cu_seqlens - if self.cu_seqlens is None: - if self.attention_mask is None: - B = self.get_batch_size() - S = self.get_max_seqlen() + if return_none_allowed: + return None - self.cu_seqlens = torch.arange(0, B * S + 1, S, dtype=torch.int32, device=self.device) - else: - seqlens = self.attention_mask.sum(dim=-1, dtype=torch.int32) - self.cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) - self.max_seqlen = seqlens.max().item() + if self.has_attention_mask(): + seqlens = self.attention_mask.sum(dim=-1, dtype=torch.int32) + self.cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) + self.max_seqlen = seqlens.max().item() + else: + B = self.get_batch_size() + S = self.get_max_seqlen() + + self.cu_seqlens = torch.arange(0, B * S + 1, S, dtype=torch.int32, device=self.device) return self.cu_seqlens - def get_max_seqlen(self) -> int | None: + def get_max_seqlen(self, return_none_allowed: bool = True) -> int | None: if self.has_cu_seqlens(): + assert self.max_seqlen is not None return self.max_seqlen + if return_none_allowed: + return None + if self.max_seqlen is None: - # this will cache the max_seqlen + # this will cache the max_seqlen but causes synchronization with CPU self.get_cu_seqlens() if self.max_seqlen is None: @@ -98,32 +108,43 @@ def get_max_seqlen(self) -> int | None: return self.max_seqlen - def get_attention_mask(self) -> torch.Tensor | None: - if self.has_cu_seqlens() and self.attention_mask is None: - B = self.get_batch_size() - S = self.get_max_seqlen() + def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | None: + if self.has_attention_mask(): + return self.attention_mask + + if return_none_allowed: + return None + + B = self.get_batch_size() + S = self.get_max_seqlen() + if self.has_cu_seqlens(): self.attention_mask = self.unpack_sequence( inputs=torch.ones_like(self.get_cu_seqlens(), device=self.device, dtype=torch.int32), output_shape=(B, S), ) + else: + self.attention_mask = torch.ones(B, S, device=self.device, dtype=torch.int32) return self.attention_mask def get_position_ids(self) -> torch.Tensor: - if self.has_cu_seqlens(): - attention_mask = self.get_attention_mask() + if self.has_cu_seqlens() or self.has_attention_mask(): + attention_mask = self.get_attention_mask(False) position_ids = attention_mask.cumsum(-1) else: - position_ids = torch.arange(0, self.get_max_seqlen(), device=self.device) - position_ids = position_ids[None, ...].repeat(self.get_batch_size(), 1) + B = self.get_batch_size() + S = self.get_max_seqlen(False) - return position_ids + position_ids = torch.arange(0, S, device=self.device) + position_ids = position_ids[None, ...].repeat(B, 1) - def get_causal_mask(self, query_length: int, dtype: torch.dtype) -> torch.Tensor | None: - attention_mask = self.get_attention_mask() + return position_ids - if attention_mask is None: + def get_causal_mask(self, query_length: int, return_none_allowed: bool = True) -> torch.Tensor | None: + if self.has_cu_seqlens() or self.has_attention_mask(): + attention_mask = self.get_attention_mask() + elif return_none_allowed: return None Q = query_length @@ -144,9 +165,9 @@ def get_causal_mask(self, query_length: int, dtype: torch.dtype) -> torch.Tensor else: raise NotImplementedError(_ERROR_MESSAGE) - self._causal_mask = causal_mask[:, None, ...] + self.causal_mask = causal_mask[:, None, ...] - return self._causal_mask + return self.causal_mask def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: if self.has_cu_seqlens(): @@ -187,10 +208,3 @@ def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Te inputs = [i.reshape(B, S, *i.size()[1:]) for i in inputs] return inputs - - @classmethod - def _get_mask_value(cls, device: torch.device, dtype: torch.dtype) -> torch.Tensor: - # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if cls.mask_value is None or cls.mask_value.dtype != dtype or cls.mask_value.device != device: - cls.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) - return cls.mask_value From 12d5c83a1562df8f83109deee46b816afc23363b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Oct 2025 12:44:30 -0700 Subject: [PATCH 144/177] better Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 4e1aa22e6..e392270a2 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -84,9 +84,7 @@ def forward( position_ids: torch.Tensor | None = None, ) -> BaseModelOutputWithPast: hidden_states = self._get_initial_hidden_state(input_ids, position_ids) - rope_cos_sin = self._get_rope_cos_sin( - attention_mask_info.get_max_seqlen(), position_ids, dtype=hidden_states.dtype - ) + rope_cos_sin = self._get_rope_cos_sin(attention_mask_info, position_ids, dtype=hidden_states.dtype) for block in self.h: hidden_states: torch.Tensor = block( @@ -116,10 +114,10 @@ def _get_position_ids( return position_ids def _get_rope_cos_sin( - self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype + self, attention_mask_info: AttentionMaskInfo, position_ids: torch.Tensor, dtype: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: if self.position_embedding_type == "rope": - cos, sin = self.rope(key_length, dtype=dtype) + cos, sin = self.rope(attention_mask_info.get_max_seqlen(False), dtype=dtype) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) return cos, sin From ed73a440eca03f96e129442f4b30bf5c0cd08051 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 13 Oct 2025 14:25:11 -0700 Subject: [PATCH 145/177] cleanup Signed-off-by: Mayank Mishra --- .../multi_gpu/tensor_parallel/tensor_parallel_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py index 27d703a5e..46f23c641 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py @@ -80,7 +80,7 @@ # try sharding vocab matrices if really struggling for memory model_tp = get_model_parallel_class(config.model_type)._from_config( - config, use_padding_free_transformer=True, sequence_parallel=args.sequence_parallel + config, sequence_parallel=args.sequence_parallel ) # copy to device without copying storage From 44c3a9a5090ba83b680e8be5ca46c224950622b8 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 16:57:26 -0700 Subject: [PATCH 146/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index a3860d3e2..459bfee83 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -101,7 +101,7 @@ def get_max_seqlen(self, return_none_allowed: bool = True) -> int | None: if self.max_seqlen is None: # this will cache the max_seqlen but causes synchronization with CPU - self.get_cu_seqlens() + self.get_cu_seqlens(False) if self.max_seqlen is None: raise NotImplementedError(_ERROR_MESSAGE) From 8e0ea462e29a7c3c1762edfb119d7c63e0f21b90 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 16:59:36 -0700 Subject: [PATCH 147/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 459bfee83..05510368f 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -137,7 +137,7 @@ def get_position_ids(self) -> torch.Tensor: S = self.get_max_seqlen(False) position_ids = torch.arange(0, S, device=self.device) - position_ids = position_ids[None, ...].repeat(B, 1) + position_ids = position_ids[None, ...].repeat(B, 1).flatten() return position_ids From 3855ee73ab713fb6ac017861f664a4b175fded81 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 16:59:50 -0700 Subject: [PATCH 148/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 05510368f..a2ec30c45 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -137,7 +137,7 @@ def get_position_ids(self) -> torch.Tensor: S = self.get_max_seqlen(False) position_ids = torch.arange(0, S, device=self.device) - position_ids = position_ids[None, ...].repeat(B, 1).flatten() + position_ids = position_ids[None, ...].expand(B, 1).flatten() return position_ids From bfa093c7c52accfcec7fbba1c1fdcd625c4d177e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:00:46 -0700 Subject: [PATCH 149/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index a2ec30c45..238175d56 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -137,7 +137,7 @@ def get_position_ids(self) -> torch.Tensor: S = self.get_max_seqlen(False) position_ids = torch.arange(0, S, device=self.device) - position_ids = position_ids[None, ...].expand(B, 1).flatten() + position_ids = position_ids[None, ...].expand(B, -1).flatten() return position_ids From 5b55c84655f76ec3ce231b811c6d376279350889 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:09:17 -0700 Subject: [PATCH 150/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 238175d56..d9500957e 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -137,7 +137,9 @@ def get_position_ids(self) -> torch.Tensor: S = self.get_max_seqlen(False) position_ids = torch.arange(0, S, device=self.device) - position_ids = position_ids[None, ...].expand(B, -1).flatten() + position_ids = position_ids[None, ...].expand(B, -1) + + position_ids = position_ids.flatten() return position_ids From 52efb7668094facb5b0799c22d1999e6deb3a6c1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:11:01 -0700 Subject: [PATCH 151/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 22caa0573..61d5d247d 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -86,6 +86,8 @@ def forward( GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values ) + input_ids = input_ids.flatten() + transformer_outputs: BaseModelOutputWithPast = self.transformer( input_ids=input_ids, attention_mask_info=attention_mask_info, From f3df0602ebf327d8f929749d7066d6838691b701 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:12:50 -0700 Subject: [PATCH 152/177] cleanup Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py index 0f0e477f3..2d2bf8643 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py @@ -43,7 +43,7 @@ def flash_attention( assert k.dim() == 3 assert v.dim() == 3 - if attention_mask_info.is_ragged(): + if attention_mask_info.has_cu_seqlens(): assert sliding_window is None cu_seqlens = attention_mask_info.get_cu_seqlens() From 5c1e1691e96e3b8062e3a2e85f15556ab0ec718b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:17:24 -0700 Subject: [PATCH 153/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index d9500957e..bc2bafbec 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -190,7 +190,7 @@ def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tens def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: B = self.get_batch_size() - S = self.get_max_seqlen() + S = self.get_max_seqlen(False) if self.has_cu_seqlens(): kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch From 4622db9f7ee7a5c6f287306489c8c07326d04a3c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:19:44 -0700 Subject: [PATCH 154/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 61d5d247d..e7cc524ae 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -122,10 +122,9 @@ def forward( loss = get_autoregressive_language_modeling_loss( lm_logits=lm_logits, labels=labels, + attention_mask_info=attention_mask_info, hidden_states=None, vocab_weight=None, - cu_seqlens=cu_seqlens, - use_padding_free_transformer=True, reduction=reduction, shift_logits_and_labels=True, tensor_parallel_enabled=False, From c99211b81d42f3a455d421fbfbc6202d83557ee5 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:23:24 -0700 Subject: [PATCH 155/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index e7cc524ae..1c08d58b8 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -119,6 +119,8 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width + labels = labels.flatten() + loss = get_autoregressive_language_modeling_loss( lm_logits=lm_logits, labels=labels, From 45ed3c5022973bd3518177a42cb90c907de7661d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:32:21 -0700 Subject: [PATCH 156/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index bc2bafbec..eaf6dee98 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -131,7 +131,8 @@ def get_attention_mask(self, return_none_allowed: bool = True) -> torch.Tensor | def get_position_ids(self) -> torch.Tensor: if self.has_cu_seqlens() or self.has_attention_mask(): attention_mask = self.get_attention_mask(False) - position_ids = attention_mask.cumsum(-1) + position_ids = attention_mask.cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) else: B = self.get_batch_size() S = self.get_max_seqlen(False) From a65f9de73d92ede0dcb8c9aecf8aa1307900414f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:38:37 -0700 Subject: [PATCH 157/177] cleanup Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 83fd1383e..a2dad78dc 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -177,7 +177,7 @@ def forward( q, k, v = attention_mask_info.unpack_sequence((q, k, v)) q, k, v = [i.transpose(1, 2) for i in (q, k, v)] - attention_mask = attention_mask_info.get_causal_mask(query_length=q.size(-2), dtype=q.dtype) + attention_mask = attention_mask_info.get_causal_mask(query_length=q.size(-2)) x = F.scaled_dot_product_attention( query=q, From 37de101d86db0353dd5ef16669029fe63423cf10 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:40:28 -0700 Subject: [PATCH 158/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 50 +++++++++-------------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 71761f2c1..ff84dfc15 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -32,61 +32,51 @@ def __init__(self, config: CommonConfig, layer_idx: int | None = None) -> Block: def forward( self, - hidden_states: torch.Tensor, + x: torch.Tensor, attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, rope_cos_sin: torch.Tensor | None = None, ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - - hidden_states = self._sequence_mixer_forward( - hidden_states=hidden_states, - past_key_values=past_key_values, - attention_mask_info=attention_mask_info, - rope_cos_sin=rope_cos_sin, + r = x + x = self.ln_1(x) + + x = self._sequence_mixer_forward( + x=x, past_key_values=past_key_values, attention_mask_info=attention_mask_info, rope_cos_sin=rope_cos_sin ) if self.m_residual is not None: - hidden_states = hidden_states * self.m_residual + x = x * self.m_residual - hidden_states = hidden_states + residual + x = x + r - residual = hidden_states - hidden_states = self.ln_2(hidden_states) + r = x + x = self.ln_2(x) - hidden_states = self.mlp_block(hidden_states) + x = self.mlp_block(x) if self.m_residual is not None: - hidden_states = hidden_states * self.m_residual + x = x * self.m_residual - hidden_states = hidden_states + residual + x = x + r - return hidden_states + return x def _sequence_mixer_forward( self, - hidden_states: torch.Tensor, + x: torch.Tensor, attention_mask_info: AttentionMaskInfo, past_key_values: GenerationCache | None = None, rope_cos_sin: torch.Tensor | None = None, ) -> torch.Tensor: if self.sequence_mixer_type in ["softmax_attention", "multihead_latent_attention"]: - hidden_states = self.sequence_mixer( - hidden_states, - past_key_values=past_key_values, - attention_mask_info=attention_mask_info, - rope_cos_sin=rope_cos_sin, + x = self.sequence_mixer( + x, past_key_values=past_key_values, attention_mask_info=attention_mask_info, rope_cos_sin=rope_cos_sin ) elif self.sequence_mixer_type in ["causal_convolution", "mamba2"]: - hidden_states = self.sequence_mixer( - hidden_states, cache_params=past_key_values, attention_mask=attention_mask - ) + x = self.sequence_mixer(x, cache_params=past_key_values, attention_mask=attention_mask) elif self.sequence_mixer_type in ["gru", "rnn"]: - hidden_states = self.sequence_mixer( - x=hidden_states, attention_mask_info=attention_mask_info, cache_params=past_key_values - ) + x = self.sequence_mixer(x=x, attention_mask_info=attention_mask_info, cache_params=past_key_values) else: raise ValueError(f"unexpected sequence_mixer_type ({self.sequence_mixer_type})") - return hidden_states + return x From d762281b097ff6c09d46ace9a3a4636766edcafb Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:42:01 -0700 Subject: [PATCH 159/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 1c08d58b8..a61c22d60 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -132,6 +132,12 @@ def forward( tensor_parallel_enabled=False, ) + if lm_logits is not None: + lm_logits = attention_mask_info.unpack_sequence(lm_logits) + + if hidden_states is not None: + hidden_states = attention_mask_info.unpack_sequence(hidden_states) + aux_loss = get_aux_loss() if loss is not None and not is_aux_loss_zero(aux_loss): From dcc22456c82c6070ec79ffa174d8f12b883e1117 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 17:44:09 -0700 Subject: [PATCH 160/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 6 ++++++ lm_engine/hf_models/mixins/dense/main.py | 7 ++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index eaf6dee98..3253d1f10 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -173,6 +173,9 @@ def get_causal_mask(self, query_length: int, return_none_allowed: bool = True) - return self.causal_mask def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: + if inputs is None: + return None + if self.has_cu_seqlens(): kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch inputs = pack_sequence( @@ -190,6 +193,9 @@ def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tens return inputs def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: + if inputs is None: + return None + B = self.get_batch_size() S = self.get_max_seqlen(False) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index a61c22d60..09eeb0f57 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -132,11 +132,8 @@ def forward( tensor_parallel_enabled=False, ) - if lm_logits is not None: - lm_logits = attention_mask_info.unpack_sequence(lm_logits) - - if hidden_states is not None: - hidden_states = attention_mask_info.unpack_sequence(hidden_states) + lm_logits = attention_mask_info.unpack_sequence(lm_logits) + hidden_states = attention_mask_info.unpack_sequence(hidden_states) aux_loss = get_aux_loss() From c93267a278e92b61145d4d2923fad81f55a07cdc Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 19:04:15 -0700 Subject: [PATCH 161/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 3253d1f10..5a3af54e4 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -184,11 +184,10 @@ def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tens kernel_backend_forward=kernel_backend, kernel_backend_backward=kernel_backend, ) + elif isinstance(inputs, torch.Tensor): + inputs = inputs.flatten(0, 1) else: - if isinstance(inputs, torch.Tensor): - inputs = inputs.flatten(0, 1) - else: - inputs = [i.flatten(0, 1) for i in inputs] + inputs = [i.flatten(0, 1) for i in inputs] return inputs @@ -210,10 +209,9 @@ def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Te kernel_backend_forward=kernel_backend, kernel_backend_backward=kernel_backend, ) + elif isinstance(inputs, torch.Tensor): + inputs = inputs.reshape(B, S, *inputs.size()[1:]) else: - if isinstance(inputs, torch.Tensor): - inputs = inputs.reshape(B, S, *inputs.size()[1:]) - else: - inputs = [i.reshape(B, S, *i.size()[1:]) for i in inputs] + inputs = [i.reshape(B, S, *i.size()[1:]) for i in inputs] return inputs From a40b459e1453fb9a896956c76c7bb1ce42730d00 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 19:36:27 -0700 Subject: [PATCH 162/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 5a3af54e4..873c69f92 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -144,7 +144,12 @@ def get_position_ids(self) -> torch.Tensor: return position_ids - def get_causal_mask(self, query_length: int, return_none_allowed: bool = True) -> torch.Tensor | None: + def get_causal_mask( + self, query_length: int, return_none_allowed: bool = True, dtype: torch.dtype | None = None + ) -> torch.Tensor | None: + if self.causal_mask is not None: + return self.causal_mask + if self.has_cu_seqlens() or self.has_attention_mask(): attention_mask = self.get_attention_mask() elif return_none_allowed: @@ -168,7 +173,16 @@ def get_causal_mask(self, query_length: int, return_none_allowed: bool = True) - else: raise NotImplementedError(_ERROR_MESSAGE) - self.causal_mask = causal_mask[:, None, ...] + causal_mask = causal_mask[:, None, ...] + causal_mask = torch.where(causal_mask, ~causal_mask, self._get_mask_value(attention_mask.device, dtype)) + + # this is needed to prevent NaN since SDPA + # see issue: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = causal_mask * ~torch.all( + causal_mask == self._get_mask_value(self.device, dtype), dim=-1, keepdim=True + ) + + self.causal_mask = causal_mask return self.causal_mask From 6a322f88fed38711ec678ef6802b82ad1376c7b7 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 19:38:53 -0700 Subject: [PATCH 163/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 873c69f92..abcd6120d 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -31,6 +31,7 @@ class AttentionMaskInfo: device: torch.device | None = None mask_value: torch.Tensor | None = None causal_mask: torch.Tensor | None = None + _mask_value: torch.Tensor | None = None def __post_init__(self) -> None: self._has_cu_seqlens = self.cu_seqlens is not None @@ -229,3 +230,9 @@ def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Te inputs = [i.reshape(B, S, *i.size()[1:]) for i in inputs] return inputs + + def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: + self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + return self.mask_value From edebff4482e8c3ad3acc3424fff2adb55bd4400b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 19:44:27 -0700 Subject: [PATCH 164/177] cleanup Signed-off-by: Mayank Mishra --- tests/hf_models/single_gpu/gpt_base_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/hf_models/single_gpu/gpt_base_test.py b/tests/hf_models/single_gpu/gpt_base_test.py index 51e5b90ef..aa5956eec 100644 --- a/tests/hf_models/single_gpu/gpt_base_test.py +++ b/tests/hf_models/single_gpu/gpt_base_test.py @@ -18,17 +18,23 @@ class GPTBaseAttentionTest(TestCommons): @parameterized.expand( TestCommons.make_args_matrix( - [torch.device("cuda")], TestCommons.get_position_embedding_types(), [torch.float16, torch.bfloat16] + [torch.device("cuda")], + TestCommons.get_position_embedding_types(), + [torch.float16, torch.bfloat16], + [False, True], ) ) def test_sdpa_flash_attention_equivalence( - self, device: torch.device, position_embedding_type: str, dtype: torch.dtype + self, device: torch.device, position_embedding_type: str, dtype: torch.dtype, has_attention_mask: bool ) -> None: self.skip_test_if_device_unavailable(device) set_seed(SEED) input_ids, attention_mask, labels = self.get_dummy_inputs(device) + if not has_attention_mask: + attention_mask = None + config = self.get_dense_test_config(position_embedding_type, num_layers=1) model = self.from_config(config, dtype=dtype).to(device) From aa4bcdd2b959eecad18d8d2d517f64708dc50355 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 20:14:35 -0700 Subject: [PATCH 165/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index abcd6120d..7349f870e 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -191,7 +191,7 @@ def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tens if inputs is None: return None - if self.has_cu_seqlens(): + if self.has_cu_seqlens() or self.has_attention_mask(): kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch inputs = pack_sequence( inputs=inputs, @@ -213,7 +213,7 @@ def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Te B = self.get_batch_size() S = self.get_max_seqlen(False) - if self.has_cu_seqlens(): + if self.has_cu_seqlens() or self.has_attention_mask(): kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch other_shape = inputs.size()[1:] if isinstance(inputs, torch.Tensor) else inputs[0].size()[1:] From d40273ac39df008aa02fcd9b981b7eb8a7c9844e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 20:16:16 -0700 Subject: [PATCH 166/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 7a3ae211b..ab1a41d83 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -24,7 +24,7 @@ if is_fma_available(): from fma import continuous_count - from fma.modules.moe import group_with_padding, grouped_gemm_experts, scattered_experts, ungroup_with_padding + from fma.layers.moe import group_with_padding, grouped_gemm_experts, scattered_experts, ungroup_with_padding # TODO add support for combileable bincount in PyTorch directly From 8cb4529bdfc6c6d854a8df1afc68a5323202c0f1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 20:16:53 -0700 Subject: [PATCH 167/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py | 2 +- lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py | 2 +- lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index b3da514e6..4bd966e20 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -22,7 +22,7 @@ if is_fma_available(): from fma import KernelBackend - from fma.modules.gru import gru + from fma.layers.gru import gru class GRU(nn.Module): diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 0207d0158..b2bca98c2 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -22,7 +22,7 @@ if is_fma_available(): from fma import KernelBackend - from fma.modules.rnn import rnn + from fma.layers.rnn import rnn class RNN(nn.Module): diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py index 2a625e370..59e6f7b53 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py @@ -23,7 +23,7 @@ if is_fma_available(): - from fma.modules.moe import scattered_experts + from fma.layers.moe import scattered_experts class ReplicatedLinear_TP(ParameterizedLinear, DTensorModule): From aced6959d9b39cbce7e93cc3f25337856c579c73 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 20:51:55 -0700 Subject: [PATCH 168/177] cleanup Signed-off-by: Mayank Mishra --- flash-model-architectures | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash-model-architectures b/flash-model-architectures index 76c07ca4f..9239d2923 160000 --- a/flash-model-architectures +++ b/flash-model-architectures @@ -1 +1 @@ -Subproject commit 76c07ca4fdb9df55b5493e3a2f06bad2a0f841cd +Subproject commit 9239d2923fb063f95d4ef7b299a7682eefc8b446 From 6f439e39fd3163771828b4d20b6b4afa72625f71 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 20:54:00 -0700 Subject: [PATCH 169/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 7349f870e..dd78bd3b5 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -193,9 +193,12 @@ def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tens if self.has_cu_seqlens() or self.has_attention_mask(): kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch + cu_seqlens = self.get_cu_seqlens(False) + inputs = pack_sequence( inputs=inputs, - cu_seqlens=self.get_cu_seqlens(False), + cu_seqlens=cu_seqlens, + total_tokens=cu_seqlens[-1].item(), kernel_backend_forward=kernel_backend, kernel_backend_backward=kernel_backend, ) @@ -215,12 +218,12 @@ def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Te if self.has_cu_seqlens() or self.has_attention_mask(): kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.unpack_sequence) else KernelBackend.torch - other_shape = inputs.size()[1:] if isinstance(inputs, torch.Tensor) else inputs[0].size()[1:] inputs = unpack_sequence( inputs=inputs, cu_seqlens=self.get_cu_seqlens(False), - output_shape=(B, S, *other_shape), + batch_size=B, + sequence_length=S, kernel_backend_forward=kernel_backend, kernel_backend_backward=kernel_backend, ) From 0a3db83d5142316a3398c0c84a38ebf8783700c2 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 20:54:50 -0700 Subject: [PATCH 170/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index dd78bd3b5..142a2cdd7 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -141,7 +141,7 @@ def get_position_ids(self) -> torch.Tensor: position_ids = torch.arange(0, S, device=self.device) position_ids = position_ids[None, ...].expand(B, -1) - position_ids = position_ids.flatten() + position_ids = self.pack_sequence(position_ids) return position_ids From 8e316339b407960fbbb03d3b06b740d867043a64 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 20:57:23 -0700 Subject: [PATCH 171/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index 142a2cdd7..e25792b3d 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -191,6 +191,10 @@ def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tens if inputs is None: return None + is_tensor = isinstance(inputs, torch.Tensor) + if is_tensor: + inputs = [inputs] + if self.has_cu_seqlens() or self.has_attention_mask(): kernel_backend = KernelBackend.cuda if is_kernel_allowed(Kernel.pack_sequence) else KernelBackend.torch cu_seqlens = self.get_cu_seqlens(False) @@ -202,17 +206,22 @@ def pack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tens kernel_backend_forward=kernel_backend, kernel_backend_backward=kernel_backend, ) - elif isinstance(inputs, torch.Tensor): - inputs = inputs.flatten(0, 1) else: inputs = [i.flatten(0, 1) for i in inputs] + if is_tensor: + inputs = inputs[0] + return inputs def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: if inputs is None: return None + is_tensor = isinstance(inputs, torch.Tensor) + if is_tensor: + inputs = [inputs] + B = self.get_batch_size() S = self.get_max_seqlen(False) @@ -227,11 +236,12 @@ def unpack_sequence(self, inputs: torch.Tensor | list[torch.Tensor]) -> torch.Te kernel_backend_forward=kernel_backend, kernel_backend_backward=kernel_backend, ) - elif isinstance(inputs, torch.Tensor): - inputs = inputs.reshape(B, S, *inputs.size()[1:]) else: inputs = [i.reshape(B, S, *i.size()[1:]) for i in inputs] + if is_tensor: + inputs = inputs[0] + return inputs def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: From d9750fb7073c4ab9401bc7c2e694d11366f1b84f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 20:58:36 -0700 Subject: [PATCH 172/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 09eeb0f57..8a604502f 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -86,7 +86,7 @@ def forward( GenerationCache(self.config) if use_cache and past_key_values is None else past_key_values ) - input_ids = input_ids.flatten() + input_ids = attention_mask_info.pack_sequence(input_ids) transformer_outputs: BaseModelOutputWithPast = self.transformer( input_ids=input_ids, @@ -119,7 +119,7 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width - labels = labels.flatten() + labels = attention_mask_info.pack_sequence(labels) loss = get_autoregressive_language_modeling_loss( lm_logits=lm_logits, From 2a2ee6734a0528e5a543505c223d7e14bd26f259 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 21:03:50 -0700 Subject: [PATCH 173/177] cleanup Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/attention.py | 4 ++-- .../modeling_utils/sequence_mixer_blocks/utils.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index a2dad78dc..cea3d52b3 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -174,10 +174,10 @@ def forward( else: assert self.sliding_window is None - q, k, v = attention_mask_info.unpack_sequence((q, k, v)) + q, k, v = attention_mask_info.unpack_sequence([q, k, v]) q, k, v = [i.transpose(1, 2) for i in (q, k, v)] - attention_mask = attention_mask_info.get_causal_mask(query_length=q.size(-2)) + attention_mask = attention_mask_info.get_causal_mask(query_length=q.size(-2), dtype=q.dtype) x = F.scaled_dot_product_attention( query=q, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py index 2d2bf8643..1478d9c16 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py @@ -43,11 +43,11 @@ def flash_attention( assert k.dim() == 3 assert v.dim() == 3 - if attention_mask_info.has_cu_seqlens(): + if attention_mask_info.has_cu_seqlens() or attention_mask_info.has_attention_mask(): assert sliding_window is None - cu_seqlens = attention_mask_info.get_cu_seqlens() - max_seqlen = attention_mask_info.get_max_seqlen() + cu_seqlens = attention_mask_info.get_cu_seqlens(False) + max_seqlen = attention_mask_info.get_max_seqlen(False) if use_flash_attention_3: x, _ = flash_attention_3_varlen( @@ -75,7 +75,7 @@ def flash_attention( causal=causal, ) else: - q, k, v = attention_mask_info.unpack_sequence((q, k, v)) + q, k, v = attention_mask_info.unpack_sequence([q, k, v]) if use_flash_attention_3: x, _ = flash_attention_3( From 6fcd52511a5f8ac57e40ffe95982f6c48e0dbc65 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 21:10:42 -0700 Subject: [PATCH 174/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mask.py | 3 +++ .../sequence_mixer_blocks/rnn.py | 27 ++++++++++--------- .../sequence_mixer_blocks/utils.py | 2 +- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/lm_engine/hf_models/mask.py b/lm_engine/hf_models/mask.py index e25792b3d..9f5f191a5 100644 --- a/lm_engine/hf_models/mask.py +++ b/lm_engine/hf_models/mask.py @@ -62,6 +62,9 @@ def has_cu_seqlens(self) -> bool: def has_attention_mask(self) -> bool: return self._has_attention_mask + def has_padding(self) -> bool: + return self.has_cu_seqlens() or self.has_attention_mask() + def get_batch_size(self) -> int: if self.batch_size is None: if self.has_cu_seqlens(): diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index b2bca98c2..9152a0a2f 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -79,9 +79,11 @@ def __init__( def forward( self, x: torch.Tensor, attention_mask_info: AttentionMaskInfo, cache_params: GenerationCache | None = None ) -> torch.Tensor: + T = x.size(0) + x = self.input_projection(x) x, g = x.chunk(2, dim=-1) - x = x.view(*x.size()[:-1], self.num_heads, self.state_head_dim) + x = x.view(T, self.num_heads, self.state_head_dim) if self.scaling_factor != 1: x = x * self.scaling_factor @@ -90,26 +92,25 @@ def forward( if self.scaling_factor != 1: weight = weight * self.scaling_factor - cu_seqlens = None if attention_mask_info.is_ragged() else attention_mask_info.get_cu_seqlens() - max_seqlen = None if attention_mask_info.is_ragged() else attention_mask_info.get_max_seqlen() + has_padding = attention_mask_info.has_padding() - x = rnn( - input=attention_mask_info.unpack_sequence(x), + x, s = rnn( + input=x if has_padding else attention_mask_info.unpack_sequence(x), weight=weight, input_state=None if cache_params is None else cache_params.get_cache(self.layer_idx), gradient_clipping=self.gradient_clipping, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + cu_seqlens=attention_mask_info.get_cu_seqlens(), + max_seqlen=attention_mask_info.get_max_seqlen(), kernel_backend=KernelBackend.triton if is_kernel_allowed(Kernel.rnn) else KernelBackend.torch, ) + if not has_padding: + x = attention_mask_info.pack_sequence(x) + if cache_params is not None: - if cu_seqlens is None: - cache_params.update(state=x[:, -1], num_tokens_added=input.size(1), layer_idx=self.layer_idx) - else: - cache_params.update( - state=x[cu_seqlens[1:] - 1], num_tokens_added=cu_seqlens[1:], layer_idx=self.layer_idx - ) + cache_params.update( + state=s, num_tokens_added=attention_mask_info.get_batch_size(), layer_idx=self.layer_idx + ) x = x.flatten(-2, -1) x = x * F.silu(g) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py index 1478d9c16..36a27a7d0 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils.py @@ -43,7 +43,7 @@ def flash_attention( assert k.dim() == 3 assert v.dim() == 3 - if attention_mask_info.has_cu_seqlens() or attention_mask_info.has_attention_mask(): + if attention_mask_info.has_padding(): assert sliding_window is None cu_seqlens = attention_mask_info.get_cu_seqlens(False) From a8b9072d66b673124e6803f5d2edd4e6669cbae0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 21:27:36 -0700 Subject: [PATCH 175/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/hf_models/__init__.py | 1 + lm_engine/hf_models/mixins/dense/main.py | 30 +++++++----------------- lm_engine/model_wrapper/pretraining.py | 24 +++++++++++++++++-- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index 6155f3e9f..ef7933bee 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -5,6 +5,7 @@ from .cache import disable_generation_cache from .config import CommonConfig from .loss import get_autoregressive_language_modeling_loss, is_aux_loss_zero +from .mask import AttentionMaskInfo from .mixins import CausalLMOutputWithPast, PipelineParallelInput, PipelineParallelOutput from .model_conversion import export_to_huggingface, import_from_huggingface from .models import ( diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 8a604502f..5e535927e 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -65,8 +65,7 @@ def forward( labels: torch.Tensor | None = None, use_cache: bool | None = None, return_dict: bool = True, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, + attention_mask_info: AttentionMaskInfo | None = None, reduction: str = "mean", ) -> CausalLMOutputWithPast: assert return_dict @@ -74,9 +73,8 @@ def forward( clear_aux_loss() - attention_mask_info = self._get_attention_mask_info( - x=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask - ) + if attention_mask_info is None: + attention_mask_info = self._get_attention_mask_info(x=input_ids, attention_mask=attention_mask) if position_ids is None: position_ids = attention_mask_info.get_position_ids() @@ -262,23 +260,13 @@ def generate( return generated_tokens - def _get_attention_mask_info( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor | None, - max_seqlen: torch.Tensor, - attention_mask: torch.Tensor | None, - ) -> AttentionMaskInfo: + def _get_attention_mask_info(self, x: torch.Tensor, attention_mask: torch.Tensor | None) -> AttentionMaskInfo: kwargs = {} - if cu_seqlens is None: - if attention_mask is None: - kwargs["batch_size"] = x.size(0) - kwargs["max_seqlen"] = x.size(1) - kwargs["device"] = x.device - else: - kwargs["attention_mask"] = attention_mask + if attention_mask is None: + kwargs["batch_size"] = x.size(0) + kwargs["max_seqlen"] = x.size(1) + kwargs["device"] = x.device else: - kwargs["cu_seqlens"] = cu_seqlens - kwargs["max_seqlen"] = max_seqlen + kwargs["attention_mask"] = attention_mask return AttentionMaskInfo(**kwargs) diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index dddb66a14..084a0424c 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -11,6 +11,7 @@ from ..dtensors import tensor_to_dtensor from ..enums import Kernel from ..hf_models import ( + AttentionMaskInfo, CausalLMOutputWithPast, PipelineParallelInput, PipelineParallelOutput, @@ -251,8 +252,9 @@ def _prepare_model_inputs(self, batch: dict) -> dict: position_ids = self.position_ids batch["input_ids"] = input_ids - batch["cu_seqlens"] = cu_seqlens - batch["max_seqlen"] = max_seqlen + batch["attention_mask_info"] = self._get_attention_mask_info( + x=input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) batch["position_ids"] = position_ids if ProcessGroupManager.is_tensor_parallel_enabled(): @@ -289,6 +291,24 @@ def reset_parameters(self) -> None: persistent=False, ) + def _get_attention_mask_info( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor | None, + max_seqlen: torch.Tensor, + attention_mask: torch.Tensor | None, + ) -> AttentionMaskInfo: + kwargs = {} + if cu_seqlens is None: + kwargs["batch_size"] = x.size(0) + kwargs["max_seqlen"] = x.size(1) + kwargs["device"] = x.device + else: + kwargs["cu_seqlens"] = cu_seqlens + kwargs["max_seqlen"] = max_seqlen + + return AttentionMaskInfo(**kwargs) + class _F(torch.autograd.Function): @staticmethod From 5c846527ec612f7dfc4beefd3380519053843810 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 21:28:09 -0700 Subject: [PATCH 176/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/model_wrapper/pretraining.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index 084a0424c..8d0ce4708 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -292,11 +292,7 @@ def reset_parameters(self) -> None: ) def _get_attention_mask_info( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor | None, - max_seqlen: torch.Tensor, - attention_mask: torch.Tensor | None, + self, x: torch.Tensor, cu_seqlens: torch.Tensor | None, max_seqlen: torch.Tensor ) -> AttentionMaskInfo: kwargs = {} if cu_seqlens is None: From 1e3b0551e4c0b87e17a816ba2ef772b51ee5bd32 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 19 Oct 2025 21:30:35 -0700 Subject: [PATCH 177/177] cleanup Signed-off-by: Mayank Mishra --- lm_engine/model_wrapper/pretraining.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lm_engine/model_wrapper/pretraining.py b/lm_engine/model_wrapper/pretraining.py index 8d0ce4708..bd8085b6f 100644 --- a/lm_engine/model_wrapper/pretraining.py +++ b/lm_engine/model_wrapper/pretraining.py @@ -124,6 +124,8 @@ def forward( batch = self._prepare_model_inputs(batch) labels = batch.pop("labels") + attention_mask_info = batch["attention_mask_info"] + output: CausalLMOutputWithPast | PipelineParallelOutput = self.model(**batch, return_dict=True) if self.is_pipeline_parallel_enabled: @@ -144,12 +146,18 @@ def forward( if use_aux_loss: output = (output, aux_loss) else: - output = self.get_loss(output, labels, lm_loss_multiplier=lm_loss_multiplier) + output = self.get_loss( + output, labels, attention_mask_info=attention_mask_info, lm_loss_multiplier=lm_loss_multiplier + ) return output def get_loss( - self, model_outputs: CausalLMOutputWithPast, labels: torch.Tensor, lm_loss_multiplier: float = 1 + self, + model_outputs: CausalLMOutputWithPast, + labels: torch.Tensor, + attention_mask_info: AttentionMaskInfo, + lm_loss_multiplier: float = 1, ) -> torch.Tensor | dict: tensor_parallel_enabled = ProcessGroupManager.is_tensor_parallel_enabled() use_fused_linear_cross_entropy_kernel = is_kernel_allowed(Kernel.fused_linear_cross_entropy) @@ -157,10 +165,9 @@ def get_loss( lm_loss = get_autoregressive_language_modeling_loss( lm_logits=None if use_fused_linear_cross_entropy_kernel else model_outputs.logits, labels=labels, + attention_mask_info=attention_mask_info, hidden_states=model_outputs.last_hidden_state if use_fused_linear_cross_entropy_kernel else None, vocab_weight=self.model.get_output_embeddings().weight if use_fused_linear_cross_entropy_kernel else None, - cu_seqlens=None, - use_padding_free_transformer=True, reduction="sum", shift_logits_and_labels=False, tensor_parallel_enabled=tensor_parallel_enabled,