From d1f6761d08fa870f477390c5951718131f7561ae Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Thu, 5 Feb 2026 14:53:00 -0800 Subject: [PATCH 01/10] verl integration changes --- torchtitan/components/lr_scheduler.py | 11 ++-- torchtitan/distributed/utils.py | 4 ++ torchtitan/models/attention.py | 37 +++++++++++++ torchtitan/models/qwen3/model/args.py | 4 +- torchtitan/models/qwen3/model/model.py | 76 +++++++++++++++++++++----- torchtitan/train.py | 47 ++++++++++------ 6 files changed, 139 insertions(+), 40 deletions(-) diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index 15a3fc6bd1..818964bf14 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -155,20 +155,17 @@ def linear_warmup_stable_decay( """ warmup_stable_steps = warmup_steps + stable_steps if current_step < warmup_steps: - # linear warmup - # 0-indexed step, hence + 1 adjustments - current_step += 1 + # linear warmup (0-indexed to match FSDP/HuggingFace) assert ( warmup_steps != 0 ), "warmup_steps must not be zero to reach this branch" - curr_adjustment = float(current_step / warmup_steps) + curr_adjustment = float(current_step) / float(warmup_steps) elif current_step < warmup_stable_steps: curr_adjustment = 1.0 else: - # 0-indexed step, hence + 1 adjustments - current_step += 1 + # Decay phase (0-indexed to match FSDP/HuggingFace) assert decay_steps != 0, "decay_steps must not be zero to reach this branch" - progress = float(current_step - warmup_stable_steps) / decay_steps + progress = float(current_step - warmup_stable_steps) / float(decay_steps) if lr_decay_type == "linear": curr_adjustment = 1 - progress diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 4fe3f01a66..26d31bc545 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -293,6 +293,10 @@ def init_distributed( base_folder: str = "", ranks: list[int] | None = None, ) -> int: + # Skip initialization if already initialized + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + if comm_config.mode in ("fake_backend", "local_tensor"): ngpu_str = os.environ.get("NGPU") if ngpu_str is None: diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 32130b6d5e..f110853361 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -185,12 +185,25 @@ def forward( *, scale: float | None = None, enable_gqa: bool = False, +<<<<<<< HEAD is_casual: bool = True, +======= + attn_mask: torch.Tensor | None = None, +>>>>>>> 87b7210b (verl integration changes) ) -> torch.Tensor: + import torch.distributed as dist + # Use is_causal=True only if no explicit mask is provided + is_causal = attn_mask is None + if dist.is_initialized() and dist.get_rank() == 0: + print(f"DEBUG TITAN SDPA wrapper: scale={scale}, is_causal={is_causal}, enable_gqa={enable_gqa}, attn_mask={attn_mask}, backends={self.sdpa_backends}") with sdpa_kernel(self.sdpa_backends, set_priority=True): +<<<<<<< HEAD return F.scaled_dot_product_attention( q, k, v, scale=scale, is_causal=is_casual, enable_gqa=enable_gqa ) +======= + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale, is_causal=is_causal, enable_gqa=enable_gqa) +>>>>>>> 87b7210b (verl integration changes) def get_causal_mask_mod() -> _mask_mod_signature: @@ -234,6 +247,30 @@ def document_mask( return document_mask +def get_document_mask_mod_from_positions(positions: torch.Tensor) -> _mask_mod_signature: + """Creates a document mask from position_ids (like HuggingFace). + + Detects document boundaries where position_ids reset (diff != 1). + + Args: + positions: Position IDs tensor with shape [batch, seq] + + Returns: + A mask modifier function that implements document-level masking. + """ + # HuggingFace logic: find where position_ids reset (diff != 1) + first_dummy_value = positions[:, :1] - 1 + position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) + sequence_indices = (position_diff != 1).cumsum(-1) # [batch, seq] + + def document_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx] + + return document_mask + + def get_fixed_block_mask_mod(fixed_block_size: int) -> _mask_mod_signature: """ Divide the input sequence into blocks and only allow attention within the same block. diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index d0a0556bf1..ffaeae6a77 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -36,8 +36,8 @@ class Qwen3ModelArgs(BaseModelArgs): max_seq_len: int = 4096 depth_init: bool = True - attn_type: str = "sdpa" - attn_mask_type: str = "causal" + attn_type: str = "flex" + attn_mask_type: str = "document_mask" eos_id: int = 151645 enable_weight_tying: bool = False diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 2b7173546d..0d0078286f 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -9,6 +9,7 @@ from typing import cast import torch +import torch.distributed as dist import torch.nn.functional as F from torch import nn from torch.nn.attention.flex_attention import and_masks, BlockMask @@ -20,6 +21,7 @@ FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, + get_document_mask_mod_from_positions, ScaledDotProductAttentionWrapper, VarlenAttentionWrapper, VarlenMetadata, @@ -252,7 +254,15 @@ def forward( xq = xq.transpose(1, 2) # (bs, n_heads, seqlen, head_dim) xk = xk.transpose(1, 2) # (bs, n_kv_heads, seqlen, head_dim) xv = xv.transpose(1, 2) # (bs, n_kv_heads, seqlen, head_dim) - + # self.enable_gqa = False + # if not self.enable_gqa: + # keys = repeat_kv(xk, self.n_rep) + # values = repeat_kv(xv, self.n_rep) + # xq = xq.transpose(1, 2) + # xk = keys.transpose(1, 2) + # xv = values.transpose(1, 2) + + match self.attn_type: case "flex": assert isinstance(attention_masks, BlockMask), attention_masks @@ -278,23 +288,56 @@ def forward( scale=self.scaling, ) case "sdpa": - assert attention_masks is None - output = ( - self.inner_attention( - xq, # (bs, n_heads, seqlen, head_dim) - xk, # (bs, n_kv_heads, seqlen, head_dim) - xv, # (bs, n_kv_heads, seqlen, head_dim) - scale=self.scaling, - enable_gqa=self.enable_gqa, - ) - .transpose(1, 2) - .contiguous() - ) # (bs, seqlen, n_local_heads, head_dim) + # assert attention_masks is None + # output = ( + # self.inner_attention( + # xq, # (bs, n_heads, seqlen, head_dim) + # xk, # (bs, n_kv_heads, seqlen, head_dim) + # xv, # (bs, n_kv_heads, seqlen, head_dim) + # scale=self.scaling, + # enable_gqa=self.enable_gqa, + # ) + # .transpose(1, 2) + # .contiguous() + # ) # (bs, seqlen, n_local_heads, head_dim) + # Create document-aware causal mask from positions (like HuggingFace) + attn_mask = None + if positions is not None: + # Detect packed sequences from position_ids + # HuggingFace logic: find where position_ids reset (diff != 1) + first_dummy_value = positions[:, :1] - 1 + position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) + packed_sequence_mask = (position_diff != 1).cumsum(-1) # [batch, seq] + + # Create document-aware mask: tokens can only attend to same document + # Shape: [batch, 1, q_len, kv_len] + doc_mask = packed_sequence_mask.unsqueeze(2) == packed_sequence_mask.unsqueeze(1) # [batch, seq, seq] + doc_mask = doc_mask.unsqueeze(1) # [batch, 1, seq, seq] + + # Create causal mask + causal_mask = torch.tril(torch.ones(seqlen, seqlen, device=x.device, dtype=torch.bool)) + + # Combine: document mask AND causal mask + attn_mask = doc_mask & causal_mask.unsqueeze(0).unsqueeze(0) + + # if dist.is_initialized() and dist.get_rank() == 0: + # print(f"DEBUG TITAN SDPA attn_mask: shape={attn_mask.shape}, attn_mask={attn_mask}") + + output = self.inner_attention(xq, xk, xv, scale=self.scaling, enable_gqa=self.enable_gqa, attn_mask=attn_mask) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) case _: raise ValueError(f"Unknown attention type: {self.attn_type}") output = output.view(bs, seqlen, -1) - return self.wo(output) + + if dist.is_initialized() and dist.get_rank() == 0: + print(f"DEBUG TITAN model.py attention output: shape={output.shape}, dtype={output.dtype}, min={output.min().item():.4f}, max={output.max().item():.4f}, mean={output.mean().item():.4f}") + output = self.wo(output) + if dist.is_initialized() and dist.get_rank() == 0: + print(f"DEBUG TITAN model.py attention output after wo: shape={output.shape}, dtype={output.dtype}, min={output.min().item():.4f}, max={output.max().item():.4f}, mean={output.mean().item():.4f}") + return output class FeedForward(nn.Module): @@ -530,6 +573,11 @@ def _get_flex_attention_masks( B = input_batch.shape[0] assert tokenizer.eos_id is not None mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case "document_mask": + assert extra_inputs is not None and "positions" in extra_inputs + positions = extra_inputs["positions"] + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod_from_positions(positions)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" diff --git a/torchtitan/train.py b/torchtitan/train.py index 9378d742e3..b5e8a8e32b 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -128,11 +128,15 @@ def __init__(self, job_config: JobConfig): else None ) - self.dataloader = self.train_spec.build_dataloader_fn( - dp_world_size=batch_degree, - dp_rank=batch_rank, - tokenizer=self.tokenizer, - job_config=job_config, + self.dataloader = ( + self.train_spec.build_dataloader_fn( + dp_world_size=batch_degree, + dp_rank=batch_rank, + tokenizer=self.tokenizer, + job_config=job_config, + ) + if self.train_spec.build_dataloader_fn is not None + else None ) # build model (using meta init) @@ -459,16 +463,20 @@ def post_dataloading_process( extra_kwargs: dict[str, Any] = {} attn_type = getattr(self.model_args, "attn_type", "sdpa") - if attn_type in ["flex", "varlen"]: - assert ( - self.tokenizer is not None - ), "tokenizer is required for flex/varlen attention" - model = cast(ModelProtocol, self.model_parts[0]) - extra_kwargs["attention_masks"] = model.get_attention_masks( - input_batch=inputs, - tokenizer=self.tokenizer, - extra_inputs=extra_inputs, - ) + if "attention_masks" not in extra_inputs.keys() or extra_inputs["attention_masks"] is None: + if attn_type in ["flex", "varlen"]: + assert ( + self.tokenizer is not None + ), "tokenizer is required for flex/varlen attention" + model = cast(ModelProtocol, self.model_parts[0]) + extra_inputs.pop("attention_masks") + extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, + extra_inputs=extra_inputs, + ) + else: + extra_kwargs["attention_masks"] = extra_inputs.pop("attention_masks") if self.parallel_dims.cp_enabled: inputs, labels, extra_kwargs = prepare_context_parallel_input( @@ -533,6 +541,7 @@ def forward_backward_step( assert len(model_parts) == 1 with self.train_context(): with self.maybe_enable_amp: + print(f"jessica: {extra_inputs=} {extra_kwargs=}") pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) # Compute loss sum (reduction='sum') loss_sum = self.loss_fn(pred, labels) @@ -542,11 +551,15 @@ def forward_backward_step( loss = loss_sum / global_valid_tokens # need to free pred before bwd to avoid peaking memory - del pred - loss.backward() + # del pred + # loss.backward() +<<<<<<< HEAD # The returned loss here is local SUM loss / global_valid_tokens return loss +======= + return loss, pred +>>>>>>> 87b7210b (verl integration changes) def train_step( self, data_iterator: Iterator[tuple[dict[str, torch.Tensor], torch.Tensor]] From 3e72875282a234b350e571d4dc743c2400080ab9 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Fri, 6 Feb 2026 02:16:13 -0800 Subject: [PATCH 02/10] pp=1 working --- torchtitan/components/lr_scheduler.py | 4 +- torchtitan/distributed/utils.py | 4 + torchtitan/models/attention.py | 81 +++++++++++++++----- torchtitan/models/llama3/model/model.py | 6 +- torchtitan/models/qwen3/model/args.py | 4 +- torchtitan/models/qwen3/model/model.py | 82 +++++++------------- torchtitan/protocols/model.py | 2 +- torchtitan/train.py | 99 ++++++++++++++++++++++--- 8 files changed, 190 insertions(+), 92 deletions(-) diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index 818964bf14..17289df6be 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -159,13 +159,13 @@ def linear_warmup_stable_decay( assert ( warmup_steps != 0 ), "warmup_steps must not be zero to reach this branch" - curr_adjustment = float(current_step) / float(warmup_steps) + curr_adjustment = float(current_step / warmup_steps) elif current_step < warmup_stable_steps: curr_adjustment = 1.0 else: # Decay phase (0-indexed to match FSDP/HuggingFace) assert decay_steps != 0, "decay_steps must not be zero to reach this branch" - progress = float(current_step - warmup_stable_steps) / float(decay_steps) + progress = float(current_step - warmup_stable_steps) / decay_steps if lr_decay_type == "linear": curr_adjustment = 1 - progress diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 26d31bc545..ab31f603c7 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -295,6 +295,10 @@ def init_distributed( ) -> int: # Skip initialization if already initialized if torch.distributed.is_initialized(): + logger.warning( + "torch.distributed is already initialized. Skipping init_distributed. " + "The provided comm_config and other settings will not take effect." + ) return torch.distributed.get_world_size() if comm_config.mode in ("fake_backend", "local_tensor"): diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f110853361..d5672eb1de 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -31,9 +31,11 @@ "VarlenMetadata", "get_causal_mask_mod", "get_document_mask_mod", + "get_document_mask_mod_from_positions", "get_sliding_window_mask_mod", "get_fixed_block_mask_mod", "create_attention_mask", + "create_sdpa_document_causal_mask", ] @@ -185,25 +187,19 @@ def forward( *, scale: float | None = None, enable_gqa: bool = False, -<<<<<<< HEAD - is_casual: bool = True, -======= + is_causal: bool = True, attn_mask: torch.Tensor | None = None, ->>>>>>> 87b7210b (verl integration changes) ) -> torch.Tensor: - import torch.distributed as dist - # Use is_causal=True only if no explicit mask is provided - is_causal = attn_mask is None - if dist.is_initialized() and dist.get_rank() == 0: - print(f"DEBUG TITAN SDPA wrapper: scale={scale}, is_causal={is_causal}, enable_gqa={enable_gqa}, attn_mask={attn_mask}, backends={self.sdpa_backends}") with sdpa_kernel(self.sdpa_backends, set_priority=True): -<<<<<<< HEAD return F.scaled_dot_product_attention( - q, k, v, scale=scale, is_causal=is_casual, enable_gqa=enable_gqa + q, + k, + v, + attn_mask=attn_mask, + scale=scale, + is_causal=is_causal, + enable_gqa=enable_gqa, ) -======= - return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale, is_causal=is_causal, enable_gqa=enable_gqa) ->>>>>>> 87b7210b (verl integration changes) def get_causal_mask_mod() -> _mask_mod_signature: @@ -247,8 +243,8 @@ def document_mask( return document_mask -def get_document_mask_mod_from_positions(positions: torch.Tensor) -> _mask_mod_signature: - """Creates a document mask from position_ids (like HuggingFace). +def find_packed_sequence_indices(positions: torch.Tensor) -> torch.Tensor: + """Compute sequence/document indices from position_ids. Detects document boundaries where position_ids reset (diff != 1). @@ -256,12 +252,28 @@ def get_document_mask_mod_from_positions(positions: torch.Tensor) -> _mask_mod_s positions: Position IDs tensor with shape [batch, seq] Returns: - A mask modifier function that implements document-level masking. + A tensor of shape [batch, seq] where each unique integer indicates + tokens belonging to the same document/sequence. """ - # HuggingFace logic: find where position_ids reset (diff != 1) first_dummy_value = positions[:, :1] - 1 position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) - sequence_indices = (position_diff != 1).cumsum(-1) # [batch, seq] + return (position_diff != 1).cumsum(-1) # [batch, seq] + + +def get_document_mask_mod_from_positions( + positions: torch.Tensor, +) -> _mask_mod_signature: + """Creates a document mask from position_ids for flex attention. + + Detects document boundaries where position_ids reset (diff != 1). + + Args: + positions: Position IDs tensor with shape [batch, seq] + + Returns: + A mask modifier function that implements document-level masking. + """ + sequence_indices = find_packed_sequence_indices(positions) def document_mask( b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor @@ -271,6 +283,37 @@ def document_mask( return document_mask +def create_sdpa_document_causal_mask(positions: torch.Tensor) -> torch.Tensor: + """Creates a 4D document-aware causal mask for SDPA from position_ids. + + Detects document boundaries where position_ids reset (diff != 1) and creates + a combined mask that enforces both causal attention and document isolation. + + Args: + positions: Position IDs tensor with shape [batch, seq] + + Returns: + A boolean tensor of shape [batch, 1, seq, seq] where True means "can attend". + """ + seqlen = positions.shape[1] + device = positions.device + + sequence_indices = find_packed_sequence_indices(positions) + + # Create document-aware mask: tokens can only attend to same document + # Shape: [batch, 1, seq, seq] + doc_mask = sequence_indices.unsqueeze(2) == sequence_indices.unsqueeze(1) + doc_mask = doc_mask.unsqueeze(1) # [batch, 1, seq, seq] + + # Create causal mask + causal_mask = torch.tril( + torch.ones(seqlen, seqlen, device=device, dtype=torch.bool) + ) + + # Combine: document mask AND causal mask + return doc_mask & causal_mask.unsqueeze(0).unsqueeze(0) + + def get_fixed_block_mask_mod(fixed_block_size: int) -> _mask_mod_signature: """ Divide the input sequence into blocks and only allow attention within the same block. diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index f79dbcaea7..6a9ba73c48 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -17,6 +17,7 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, + create_sdpa_document_causal_mask, create_varlen_metadata_for_document, FlexAttentionWrapper, get_causal_mask_mod, @@ -531,6 +532,9 @@ def get_attention_masks( extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: match self.model_args.attn_type: + case "sdpa": + assert extra_inputs is not None and "positions" in extra_inputs + return create_sdpa_document_causal_mask(extra_inputs["positions"]) case "flex": return self._get_flex_attention_masks( input_batch, tokenizer, extra_inputs @@ -546,7 +550,7 @@ def get_attention_masks( input_batch, tokenizer.eos_id ) case _: - raise TypeError("Only varlen and flex attn masks are supported") + raise TypeError("Only sdpa, varlen, and flex attn masks are supported") def forward( self, diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index ffaeae6a77..d0a0556bf1 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -36,8 +36,8 @@ class Qwen3ModelArgs(BaseModelArgs): max_seq_len: int = 4096 depth_init: bool = True - attn_type: str = "flex" - attn_mask_type: str = "document_mask" + attn_type: str = "sdpa" + attn_mask_type: str = "causal" eos_id: int = 151645 enable_weight_tying: bool = False diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 0d0078286f..9e497f2e3e 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -9,7 +9,6 @@ from typing import cast import torch -import torch.distributed as dist import torch.nn.functional as F from torch import nn from torch.nn.attention.flex_attention import and_masks, BlockMask @@ -17,6 +16,7 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, + create_sdpa_document_causal_mask, create_varlen_metadata_for_document, FlexAttentionWrapper, get_causal_mask_mod, @@ -254,15 +254,7 @@ def forward( xq = xq.transpose(1, 2) # (bs, n_heads, seqlen, head_dim) xk = xk.transpose(1, 2) # (bs, n_kv_heads, seqlen, head_dim) xv = xv.transpose(1, 2) # (bs, n_kv_heads, seqlen, head_dim) - # self.enable_gqa = False - # if not self.enable_gqa: - # keys = repeat_kv(xk, self.n_rep) - # values = repeat_kv(xv, self.n_rep) - # xq = xq.transpose(1, 2) - # xk = keys.transpose(1, 2) - # xv = values.transpose(1, 2) - - + match self.attn_type: case "flex": assert isinstance(attention_masks, BlockMask), attention_masks @@ -288,55 +280,28 @@ def forward( scale=self.scaling, ) case "sdpa": - # assert attention_masks is None - # output = ( - # self.inner_attention( - # xq, # (bs, n_heads, seqlen, head_dim) - # xk, # (bs, n_kv_heads, seqlen, head_dim) - # xv, # (bs, n_kv_heads, seqlen, head_dim) - # scale=self.scaling, - # enable_gqa=self.enable_gqa, - # ) - # .transpose(1, 2) - # .contiguous() - # ) # (bs, seqlen, n_local_heads, head_dim) - # Create document-aware causal mask from positions (like HuggingFace) - attn_mask = None - if positions is not None: - # Detect packed sequences from position_ids - # HuggingFace logic: find where position_ids reset (diff != 1) - first_dummy_value = positions[:, :1] - 1 - position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) - packed_sequence_mask = (position_diff != 1).cumsum(-1) # [batch, seq] - - # Create document-aware mask: tokens can only attend to same document - # Shape: [batch, 1, q_len, kv_len] - doc_mask = packed_sequence_mask.unsqueeze(2) == packed_sequence_mask.unsqueeze(1) # [batch, seq, seq] - doc_mask = doc_mask.unsqueeze(1) # [batch, 1, seq, seq] - - # Create causal mask - causal_mask = torch.tril(torch.ones(seqlen, seqlen, device=x.device, dtype=torch.bool)) - - # Combine: document mask AND causal mask - attn_mask = doc_mask & causal_mask.unsqueeze(0).unsqueeze(0) - - # if dist.is_initialized() and dist.get_rank() == 0: - # print(f"DEBUG TITAN SDPA attn_mask: shape={attn_mask.shape}, attn_mask={attn_mask}") - - output = self.inner_attention(xq, xk, xv, scale=self.scaling, enable_gqa=self.enable_gqa, attn_mask=attn_mask) - output = output.transpose( - 1, 2 - ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + if attention_masks is not None: + is_causal = False + else: + is_causal = True + output = ( + self.inner_attention( + xq, # (bs, n_heads, seqlen, head_dim) + xk, # (bs, n_kv_heads, seqlen, head_dim) + xv, # (bs, n_kv_heads, seqlen, head_dim) + scale=self.scaling, + enable_gqa=self.enable_gqa, + is_causal=is_causal, + attn_mask=attention_masks, + ) + .transpose(1, 2) + .contiguous() + ) # (bs, seqlen, n_local_heads, head_dim) case _: raise ValueError(f"Unknown attention type: {self.attn_type}") output = output.view(bs, seqlen, -1) - - if dist.is_initialized() and dist.get_rank() == 0: - print(f"DEBUG TITAN model.py attention output: shape={output.shape}, dtype={output.dtype}, min={output.min().item():.4f}, max={output.max().item():.4f}, mean={output.mean().item():.4f}") output = self.wo(output) - if dist.is_initialized() and dist.get_rank() == 0: - print(f"DEBUG TITAN model.py attention output after wo: shape={output.shape}, dtype={output.dtype}, min={output.min().item():.4f}, max={output.max().item():.4f}, mean={output.mean().item():.4f}") return output @@ -574,8 +539,8 @@ def _get_flex_attention_masks( assert tokenizer.eos_id is not None mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) case "document_mask": - assert extra_inputs is not None and "positions" in extra_inputs - positions = extra_inputs["positions"] + assert extra_inputs is not None and "position_ids" in extra_inputs + positions = extra_inputs["position_ids"] B = input_batch.shape[0] mask_mods.append(get_document_mask_mod_from_positions(positions)) case _: @@ -593,6 +558,9 @@ def get_attention_masks( extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: match self.model_args.attn_type: + case "sdpa": + assert extra_inputs is not None and "positions" in extra_inputs + return create_sdpa_document_causal_mask(extra_inputs["positions"]) case "flex": return self._get_flex_attention_masks( input_batch, tokenizer, extra_inputs @@ -608,7 +576,7 @@ def get_attention_masks( input_batch, tokenizer.eos_id ) case _: - raise TypeError("Only varlen and flex attn masks are supported") + raise TypeError("Only sdpa, varlen, and flex attn masks are supported") def forward( self, diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index 99e4c34dc0..29dda6e1a4 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -18,7 +18,7 @@ from torchtitan.models.attention import VarlenMetadata -AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata +AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata | torch.Tensor @dataclass diff --git a/torchtitan/train.py b/torchtitan/train.py index b5e8a8e32b..5fadd95b3b 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -463,18 +463,34 @@ def post_dataloading_process( extra_kwargs: dict[str, Any] = {} attn_type = getattr(self.model_args, "attn_type", "sdpa") - if "attention_masks" not in extra_inputs.keys() or extra_inputs["attention_masks"] is None: + if ( + "attention_masks" not in extra_inputs.keys() + or extra_inputs["attention_masks"] is None + ): if attn_type in ["flex", "varlen"]: assert ( self.tokenizer is not None ), "tokenizer is required for flex/varlen attention" - model = cast(ModelProtocol, self.model_parts[0]) extra_inputs.pop("attention_masks") - extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( + extra_kwargs["attention_masks"] = self.model_parts[ + 0 + ].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, extra_inputs=extra_inputs, ) + elif attn_type == "sdpa": + extra_inputs.pop("attention_masks") + if "positions" in extra_inputs.keys(): + extra_kwargs["attention_masks"] = self.model_parts[ + 0 + ].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, + extra_inputs=extra_inputs, + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") else: extra_kwargs["attention_masks"] = extra_inputs.pop("attention_masks") @@ -541,7 +557,6 @@ def forward_backward_step( assert len(model_parts) == 1 with self.train_context(): with self.maybe_enable_amp: - print(f"jessica: {extra_inputs=} {extra_kwargs=}") pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) # Compute loss sum (reduction='sum') loss_sum = self.loss_fn(pred, labels) @@ -551,15 +566,79 @@ def forward_backward_step( loss = loss_sum / global_valid_tokens # need to free pred before bwd to avoid peaking memory - # del pred - # loss.backward() + del pred + loss.backward() -<<<<<<< HEAD # The returned loss here is local SUM loss / global_valid_tokens return loss -======= - return loss, pred ->>>>>>> 87b7210b (verl integration changes) + + def forward_step( + self, + *, + input_dict: dict[str, torch.Tensor], + labels: torch.Tensor, + global_valid_tokens: torch.Tensor, + return_outputs: bool = False, + ) -> torch.Tensor: + model_parts = self.model_parts + parallel_dims = self.parallel_dims + + inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( + input_dict, labels + ) + + if parallel_dims.pp_enabled: + # Pipeline Parallel forward inside step() call + with self.train_context(): + targets, losses = ( + (labels, []) if self.pp_has_last_stage else (None, None) + ) + if self.pp_has_first_stage: + outputs = self.pp_schedule.eval( + inputs, + **extra_inputs, + **extra_kwargs, + target=targets, + losses=losses, + return_outputs=return_outputs, + ) + else: + outputs = self.pp_schedule.eval( + **extra_kwargs, + target=targets, + losses=losses, + return_outputs=return_outputs, + ) + + pred = outputs if self.pp_has_last_stage else None + # accumulate losses across pipeline microbatches + # TODO: PP+FSDP unexpectedly puts the loss back to the CPU + loss = ( + # Rescale PP loss to be "local loss sum / global valid tokens) + # because each microbathes could have different number of valid tokens + (torch.sum(torch.stack(losses)) / global_valid_tokens).to(self.device) + if self.pp_has_last_stage + else torch.tensor([-1.0], device=self.device) + ) + else: + # Non-PP forward / backward + assert len(model_parts) == 1 + with self.train_context(): + with self.maybe_enable_amp: + pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) + loss_sum = self.loss_fn(pred, labels) + + # Scale the loss by the inverse of the total weight denominator before backward + # This ensures gradients are properly normalized across all microbatches + loss = loss_sum / global_valid_tokens + + # The returned loss here is local SUM loss / global_valid_tokens + if return_outputs: + return loss, pred + else: + # need to free pred before bwd to avoid peaking memory + del pred + return loss, None def train_step( self, data_iterator: Iterator[tuple[dict[str, torch.Tensor], torch.Tensor]] From 82821c7f58b665aca056459a853a266bd74e49a8 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Fri, 6 Feb 2026 10:25:25 -0800 Subject: [PATCH 03/10] pp bug --- torchtitan/train.py | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 5fadd95b3b..9081a884e8 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -577,8 +577,6 @@ def forward_step( *, input_dict: dict[str, torch.Tensor], labels: torch.Tensor, - global_valid_tokens: torch.Tensor, - return_outputs: bool = False, ) -> torch.Tensor: model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -600,45 +598,25 @@ def forward_step( **extra_kwargs, target=targets, losses=losses, - return_outputs=return_outputs, + return_outputs=True, ) else: outputs = self.pp_schedule.eval( **extra_kwargs, target=targets, losses=losses, - return_outputs=return_outputs, + return_outputs=True, ) pred = outputs if self.pp_has_last_stage else None - # accumulate losses across pipeline microbatches - # TODO: PP+FSDP unexpectedly puts the loss back to the CPU - loss = ( - # Rescale PP loss to be "local loss sum / global valid tokens) - # because each microbathes could have different number of valid tokens - (torch.sum(torch.stack(losses)) / global_valid_tokens).to(self.device) - if self.pp_has_last_stage - else torch.tensor([-1.0], device=self.device) - ) else: # Non-PP forward / backward assert len(model_parts) == 1 with self.train_context(): with self.maybe_enable_amp: pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) - loss_sum = self.loss_fn(pred, labels) - - # Scale the loss by the inverse of the total weight denominator before backward - # This ensures gradients are properly normalized across all microbatches - loss = loss_sum / global_valid_tokens - # The returned loss here is local SUM loss / global_valid_tokens - if return_outputs: - return loss, pred - else: - # need to free pred before bwd to avoid peaking memory - del pred - return loss, None + return pred def train_step( self, data_iterator: Iterator[tuple[dict[str, torch.Tensor], torch.Tensor]] From 0d3f3c69476a2bdc9000bf9ffad534ce1768bedc Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Fri, 6 Feb 2026 16:17:31 -0800 Subject: [PATCH 04/10] formatting --- torchtitan/models/attention.py | 46 +++++++++++++------------- torchtitan/models/qwen3/model/model.py | 7 ++-- torchtitan/train.py | 45 ++++++------------------- 3 files changed, 37 insertions(+), 61 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index d5672eb1de..e0a85f57d8 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -260,29 +260,6 @@ def find_packed_sequence_indices(positions: torch.Tensor) -> torch.Tensor: return (position_diff != 1).cumsum(-1) # [batch, seq] -def get_document_mask_mod_from_positions( - positions: torch.Tensor, -) -> _mask_mod_signature: - """Creates a document mask from position_ids for flex attention. - - Detects document boundaries where position_ids reset (diff != 1). - - Args: - positions: Position IDs tensor with shape [batch, seq] - - Returns: - A mask modifier function that implements document-level masking. - """ - sequence_indices = find_packed_sequence_indices(positions) - - def document_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ) -> torch.Tensor: - return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx] - - return document_mask - - def create_sdpa_document_causal_mask(positions: torch.Tensor) -> torch.Tensor: """Creates a 4D document-aware causal mask for SDPA from position_ids. @@ -314,6 +291,29 @@ def create_sdpa_document_causal_mask(positions: torch.Tensor) -> torch.Tensor: return doc_mask & causal_mask.unsqueeze(0).unsqueeze(0) +def get_document_mask_mod_from_positions( + positions: torch.Tensor, +) -> _mask_mod_signature: + """Creates a document mask from position_ids for flex attention. + + Detects document boundaries where position_ids reset (diff != 1). + + Args: + positions: Position IDs tensor with shape [batch, seq] + + Returns: + A mask modifier function that implements document-level masking. + """ + sequence_indices = find_packed_sequence_indices(positions) + + def document_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx] + + return document_mask + + def get_fixed_block_mask_mod(fixed_block_size: int) -> _mask_mod_signature: """ Divide the input sequence into blocks and only allow attention within the same block. diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 9e497f2e3e..db6e7f3417 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -301,8 +301,7 @@ def forward( raise ValueError(f"Unknown attention type: {self.attn_type}") output = output.view(bs, seqlen, -1) - output = self.wo(output) - return output + return self.wo(output) class FeedForward(nn.Module): @@ -539,8 +538,8 @@ def _get_flex_attention_masks( assert tokenizer.eos_id is not None mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) case "document_mask": - assert extra_inputs is not None and "position_ids" in extra_inputs - positions = extra_inputs["position_ids"] + assert extra_inputs is not None and "positions" in extra_inputs + positions = extra_inputs["positions"] B = input_batch.shape[0] mask_mods.append(get_document_mask_mod_from_positions(positions)) case _: diff --git a/torchtitan/train.py b/torchtitan/train.py index 9081a884e8..ba98c9ef7c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -467,24 +467,20 @@ def post_dataloading_process( "attention_masks" not in extra_inputs.keys() or extra_inputs["attention_masks"] is None ): + model = cast(ModelProtocol, self.model_parts[0]) + extra_inputs.pop("attention_masks") + assert ( + self.tokenizer is not None + ), "tokenizer is required for sdpa/flex/varlen attention" if attn_type in ["flex", "varlen"]: - assert ( - self.tokenizer is not None - ), "tokenizer is required for flex/varlen attention" - extra_inputs.pop("attention_masks") - extra_kwargs["attention_masks"] = self.model_parts[ - 0 - ].get_attention_masks( + extra_kwargs["attention_masks"] = model.get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, extra_inputs=extra_inputs, ) elif attn_type == "sdpa": - extra_inputs.pop("attention_masks") if "positions" in extra_inputs.keys(): - extra_kwargs["attention_masks"] = self.model_parts[ - 0 - ].get_attention_masks( + extra_kwargs["attention_masks"] = model.get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, extra_inputs=extra_inputs, @@ -586,29 +582,10 @@ def forward_step( ) if parallel_dims.pp_enabled: - # Pipeline Parallel forward inside step() call - with self.train_context(): - targets, losses = ( - (labels, []) if self.pp_has_last_stage else (None, None) - ) - if self.pp_has_first_stage: - outputs = self.pp_schedule.eval( - inputs, - **extra_inputs, - **extra_kwargs, - target=targets, - losses=losses, - return_outputs=True, - ) - else: - outputs = self.pp_schedule.eval( - **extra_kwargs, - target=targets, - losses=losses, - return_outputs=True, - ) - - pred = outputs if self.pp_has_last_stage else None + raise NotImplementedError( + "Pipeline parallelism is not yet supported in forward_step. " + "This will be implemented in a follow-up PR." + ) else: # Non-PP forward / backward assert len(model_parts) == 1 From c08d6cdcc1728a8fadb034d0320b704ad5a89eda Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Sat, 7 Feb 2026 23:26:25 -0800 Subject: [PATCH 05/10] address comments --- tests/unit_tests/test_lr_scheduler.py | 21 ++- torchtitan/experiments/vlm/model/siglip2.py | 11 +- torchtitan/models/attention.py | 163 ++++++++----------- torchtitan/models/deepseek_v3/model/model.py | 11 +- torchtitan/models/gpt_oss/model/model.py | 9 +- torchtitan/models/llama3/model/model.py | 49 ++++-- torchtitan/models/llama4/model/model.py | 11 +- torchtitan/models/qwen3/model/model.py | 56 ++++--- torchtitan/train.py | 27 --- 9 files changed, 187 insertions(+), 171 deletions(-) diff --git a/tests/unit_tests/test_lr_scheduler.py b/tests/unit_tests/test_lr_scheduler.py index 00c817a46a..394cebcefe 100644 --- a/tests/unit_tests/test_lr_scheduler.py +++ b/tests/unit_tests/test_lr_scheduler.py @@ -9,7 +9,6 @@ import torch from torch.optim import Adam - from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config import ConfigManager @@ -88,16 +87,16 @@ def test_linear_warmup_decay(self): # Expected adjustment factors for each step expected_factors = [ - 0.5, # Step 0: 50% of max LR (warmup) - 1.0, # Step 1: 100% of max LR (warmup complete) - 1.0, # Step 2: We maunally added step of stable phase, to prevent LR from dropping to 0 at last step - 7.0 / 8.0, # Step 3: 7/8 of max LR - 6.0 / 8.0, # Step 4: 3/4 of max LR - 5.0 / 8.0, # Step 5: 5/8 of max LR - 4.0 / 8.0, # Step 6: 1/2 of max LR - 3.0 / 8.0, # Step 7: 3/8 of max LR - 2.0 / 8.0, # Step 8: 1/4 of max LR - 1.0 / 8.0, # Step 9: 1/8 of max LR + 0.0, # Step 0: 0% of max LR (warmup start) + 0.5, # Step 1: 50% of max LR (warmup) + 1.0, # Step 2: 100% of max LR (warmup complete) + 1.0, # Step 3: Stable phase (virtual step to prevent LR dropping to 0 at last step) + 7.0 / 8.0, # Step 4: 7/8 of max LR (decay starts) + 6.0 / 8.0, # Step 5: 3/4 of max LR + 5.0 / 8.0, # Step 6: 5/8 of max LR + 4.0 / 8.0, # Step 7: 1/2 of max LR + 3.0 / 8.0, # Step 8: 3/8 of max LR + 2.0 / 8.0, # Step 9: 1/4 of max LR ] # Check the learning rate at each step diff --git a/torchtitan/experiments/vlm/model/siglip2.py b/torchtitan/experiments/vlm/model/siglip2.py index 69278350d1..46e3c05a93 100644 --- a/torchtitan/experiments/vlm/model/siglip2.py +++ b/torchtitan/experiments/vlm/model/siglip2.py @@ -227,7 +227,16 @@ def get_attention_masks( B = 1 case "block_causal": B = pixel_masks.shape[0] - mask_mods.append(get_document_mask_mod(pixel_masks, tokenizer.eos_id)) + mask_mods.append( + get_document_mask_mod( + input_ids=pixel_masks, eos_id=tokenizer.eos_id + ) + ) + case "position_block_causal": + assert extra_inputs is not None and "positions" in extra_inputs + positions = extra_inputs["positions"] + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(positions=positions)) case _: raise ValueError( f"Unknown attention mask type: {self.args.attn_mask_type}" diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index e0a85f57d8..e5aaf174f3 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -31,11 +31,9 @@ "VarlenMetadata", "get_causal_mask_mod", "get_document_mask_mod", - "get_document_mask_mod_from_positions", "get_sliding_window_mask_mod", "get_fixed_block_mask_mod", "create_attention_mask", - "create_sdpa_document_causal_mask", ] @@ -218,93 +216,44 @@ def _causal_mask( return _causal_mask -def get_document_mask_mod(batch: torch.Tensor, eos_id: int) -> _mask_mod_signature: - """Creates a document mask that prevents attention across document boundaries. - - Args: - batch: Input batch tensor with shape [b, s, h, d] - eos_id: End-of-sequence token ID that marks document boundaries - - Returns: - A mask modifier function that implements document-level masking. - """ - # batch is [b, s, h, d] shape - eos_mask = batch == eos_id - eos_mask[:, -1] = True - cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1) - sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32) - sequence_indices[:, 1:] = cumulative_mask[:, :-1] - - def document_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ) -> torch.Tensor: - return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx] - - return document_mask - - -def find_packed_sequence_indices(positions: torch.Tensor) -> torch.Tensor: - """Compute sequence/document indices from position_ids. - - Detects document boundaries where position_ids reset (diff != 1). - - Args: - positions: Position IDs tensor with shape [batch, seq] - - Returns: - A tensor of shape [batch, seq] where each unique integer indicates - tokens belonging to the same document/sequence. - """ - first_dummy_value = positions[:, :1] - 1 - position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) - return (position_diff != 1).cumsum(-1) # [batch, seq] - - -def create_sdpa_document_causal_mask(positions: torch.Tensor) -> torch.Tensor: - """Creates a 4D document-aware causal mask for SDPA from position_ids. - - Detects document boundaries where position_ids reset (diff != 1) and creates - a combined mask that enforces both causal attention and document isolation. - - Args: - positions: Position IDs tensor with shape [batch, seq] - - Returns: - A boolean tensor of shape [batch, 1, seq, seq] where True means "can attend". - """ - seqlen = positions.shape[1] - device = positions.device - - sequence_indices = find_packed_sequence_indices(positions) - - # Create document-aware mask: tokens can only attend to same document - # Shape: [batch, 1, seq, seq] - doc_mask = sequence_indices.unsqueeze(2) == sequence_indices.unsqueeze(1) - doc_mask = doc_mask.unsqueeze(1) # [batch, 1, seq, seq] - - # Create causal mask - causal_mask = torch.tril( - torch.ones(seqlen, seqlen, device=device, dtype=torch.bool) - ) - - # Combine: document mask AND causal mask - return doc_mask & causal_mask.unsqueeze(0).unsqueeze(0) - - -def get_document_mask_mod_from_positions( - positions: torch.Tensor, +def get_document_mask_mod( + *, + input_ids: torch.Tensor | None = None, + eos_id: int | None = None, + positions: torch.Tensor | None = None, ) -> _mask_mod_signature: - """Creates a document mask from position_ids for flex attention. + """Creates a document mask that prevents attention across document boundaries. - Detects document boundaries where position_ids reset (diff != 1). + Document boundaries can be detected either from EOS tokens or position ID resets. Args: - positions: Position IDs tensor with shape [batch, seq] + input_ids: Input token IDs with shape [batch, seq]. Required with eos_id. + eos_id: End-of-sequence token ID that marks document boundaries. + positions: Position IDs with shape [batch, seq]. Boundaries detected where + position diff != 1 (i.e., position resets). Returns: A mask modifier function that implements document-level masking. + + Raises: + ValueError: If neither or both separator methods are provided. """ - sequence_indices = find_packed_sequence_indices(positions) + if positions is not None: + # Detect boundaries from position resets + first_dummy_value = positions[:, :1] - 1 + position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) + sequence_indices = (position_diff != 1).cumsum(-1) # [batch, seq] + elif input_ids is not None and eos_id is not None: + # Detect boundaries from EOS tokens + eos_mask = input_ids == eos_id + eos_mask[:, -1] = True + cumulative_mask = torch.cumsum(eos_mask.int(), dim=1) + sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32) + sequence_indices[:, 1:] = cumulative_mask[:, :-1] + else: + raise ValueError( + "Must provide either 'positions' or both 'input_ids' and 'eos_id'" + ) def document_mask( b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor @@ -380,31 +329,61 @@ def create_attention_mask(*args, **kwargs): def create_varlen_metadata_for_document( - input_batch: torch.Tensor, eos_id: int + *, + input_ids: torch.Tensor | None = None, + eos_id: int | None = None, + positions: torch.Tensor | None = None, ) -> VarlenMetadata: """ - Creates cumulative sequence length indices needed for variable length attention + Creates cumulative sequence length indices needed for variable length attention. + + Document boundaries can be detected either from EOS tokens or position ID resets. + Exactly one method must be specified. Args: - input_batch - eos_id: the EOS id marker + input_ids: Input token IDs with shape [batch, seq]. Required with eos_id. + eos_id: End-of-sequence token ID that marks document boundaries. + positions: Position IDs with shape [batch, seq]. Boundaries detected where + position diff != 1 (i.e., position resets). Returns: VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len + + Raises: + ValueError: If neither or both separator methods are provided. """ - batch_size, seq_len = input_batch.shape - device = input_batch.device + if positions is not None: + batch_size, seq_len = positions.shape + device = positions.device + + # Detect boundaries from position resets (where diff != 1) + first_dummy_value = positions[:, :1] - 1 + position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) + # boundary_mask[b, i] is True if position i starts a new document + boundary_mask = position_diff != 1 # [batch, seq] + elif input_ids is not None and eos_id is not None: + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # Detect boundaries from EOS tokens + eos_mask = input_ids == eos_id + boundary_mask = torch.zeros_like(eos_mask) + boundary_mask[:, 0] = True # First position always starts a document + boundary_mask[:, 1:] = eos_mask[:, :-1] + else: + raise ValueError( + "Must provide either 'positions' or both 'input_ids' and 'eos_id'" + ) + cu_seqlens_list, all_seq_lengths = [], [] offset = 0 - max_seqlen = 0 for b in range(batch_size): - tokens = input_batch[b] - eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32) + # Find positions where new documents start + boundary_positions = boundary_mask[b].nonzero(as_tuple=True)[0].to(torch.int32) sample_cu_seqlens = torch.cat( [ - torch.tensor([0], dtype=torch.int32, device=device), - eos_positions + 1, + boundary_positions, torch.tensor([seq_len], dtype=torch.int32, device=device), ] ) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 498b012605..8af8e5d490 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -465,7 +465,16 @@ def get_attention_masks( case "block_causal": B = input_batch.shape[0] assert tokenizer.eos_id is not None - mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + mask_mods.append( + get_document_mask_mod( + input_ids=input_batch, eos_id=tokenizer.eos_id + ) + ) + case "position_block_causal": + assert extra_inputs is not None and "positions" in extra_inputs + positions = extra_inputs["positions"] + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(positions=positions)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" diff --git a/torchtitan/models/gpt_oss/model/model.py b/torchtitan/models/gpt_oss/model/model.py index 92b9c42fe6..e3f47775d1 100644 --- a/torchtitan/models/gpt_oss/model/model.py +++ b/torchtitan/models/gpt_oss/model/model.py @@ -338,8 +338,15 @@ def get_attention_masks( B = input_batch.shape[0] assert tokenizer.eos_id is not None basic_mask_mods.append( - get_document_mask_mod(input_batch, tokenizer.eos_id) + get_document_mask_mod( + input_ids=input_batch, eos_id=tokenizer.eos_id + ) ) + case "position_block_causal": + assert extra_inputs is not None and "positions" in extra_inputs + positions = extra_inputs["positions"] + B = input_batch.shape[0] + basic_mask_mods.append(get_document_mask_mod(positions=positions)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 6a9ba73c48..6898efffc9 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -17,7 +17,6 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, - create_sdpa_document_causal_mask, create_varlen_metadata_for_document, FlexAttentionWrapper, get_causal_mask_mod, @@ -515,7 +514,16 @@ def _get_flex_attention_masks( case "block_causal": B = input_batch.shape[0] assert tokenizer.eos_id is not None - mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + mask_mods.append( + get_document_mask_mod( + input_ids=input_batch, eos_id=tokenizer.eos_id + ) + ) + case "position_block_causal": + assert extra_inputs is not None and "positions" in extra_inputs + positions = extra_inputs["positions"] + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(positions=positions)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" @@ -525,6 +533,28 @@ def _get_flex_attention_masks( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) + def _get_varlen_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + match self.model_args.attn_mask_type: + case "block_causal": + assert tokenizer.eos_id is not None + return create_varlen_metadata_for_document( + input_ids=input_batch, eos_id=tokenizer.eos_id + ) + case "position_block_causal": + assert extra_inputs is not None and "positions" in extra_inputs + positions = extra_inputs["positions"] + return create_varlen_metadata_for_document(positions=positions) + case _: + raise ValueError( + f"varlen attention is only supported with block_causal or " + f"position_block_causal attention mask type, got {self.model_args.attn_mask_type}" + ) + def get_attention_masks( self, input_batch: torch.Tensor, @@ -532,25 +562,16 @@ def get_attention_masks( extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: match self.model_args.attn_type: - case "sdpa": - assert extra_inputs is not None and "positions" in extra_inputs - return create_sdpa_document_causal_mask(extra_inputs["positions"]) case "flex": return self._get_flex_attention_masks( input_batch, tokenizer, extra_inputs ) case "varlen": - if self.model_args.attn_mask_type != "block_causal": - raise ValueError( - f"varlen attention is only supported with block_causal \ - attention mask type, got {self.model_args.attn_mask_type}" - ) - assert tokenizer.eos_id is not None - return create_varlen_metadata_for_document( - input_batch, tokenizer.eos_id + return self._get_varlen_attention_masks( + input_batch, tokenizer, extra_inputs ) case _: - raise TypeError("Only sdpa, varlen, and flex attn masks are supported") + raise TypeError("Only varlen and flex attn masks are supported") def forward( self, diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index 342f5b5f18..af856fe20f 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -549,8 +549,17 @@ def get_attention_masks( B = 1 case "block_causal": assert tokenizer.eos_id is not None - mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + mask_mods.append( + get_document_mask_mod( + input_ids=input_batch, eos_id=tokenizer.eos_id + ) + ) + B = input_batch.shape[0] + case "position_block_causal": + assert extra_inputs is not None and "positions" in extra_inputs + positions = extra_inputs["positions"] B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(positions=positions)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index db6e7f3417..ed84e44d89 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -16,12 +16,10 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, - create_sdpa_document_causal_mask, create_varlen_metadata_for_document, FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, - get_document_mask_mod_from_positions, ScaledDotProductAttentionWrapper, VarlenAttentionWrapper, VarlenMetadata, @@ -280,10 +278,7 @@ def forward( scale=self.scaling, ) case "sdpa": - if attention_masks is not None: - is_causal = False - else: - is_causal = True + assert attention_masks is None output = ( self.inner_attention( xq, # (bs, n_heads, seqlen, head_dim) @@ -291,8 +286,6 @@ def forward( xv, # (bs, n_kv_heads, seqlen, head_dim) scale=self.scaling, enable_gqa=self.enable_gqa, - is_causal=is_causal, - attn_mask=attention_masks, ) .transpose(1, 2) .contiguous() @@ -536,12 +529,16 @@ def _get_flex_attention_masks( case "block_causal": B = input_batch.shape[0] assert tokenizer.eos_id is not None - mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) - case "document_mask": + mask_mods.append( + get_document_mask_mod( + input_ids=input_batch, eos_id=tokenizer.eos_id + ) + ) + case "position_block_causal": assert extra_inputs is not None and "positions" in extra_inputs positions = extra_inputs["positions"] B = input_batch.shape[0] - mask_mods.append(get_document_mask_mod_from_positions(positions)) + mask_mods.append(get_document_mask_mod(positions=positions)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" @@ -550,6 +547,28 @@ def _get_flex_attention_masks( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) + def _get_varlen_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + match self.model_args.attn_mask_type: + case "block_causal": + assert tokenizer.eos_id is not None + return create_varlen_metadata_for_document( + input_ids=input_batch, eos_id=tokenizer.eos_id + ) + case "position_block_causal": + assert extra_inputs is not None and "positions" in extra_inputs + positions = extra_inputs["positions"] + return create_varlen_metadata_for_document(positions=positions) + case _: + raise ValueError( + f"varlen attention is only supported with block_causal or " + f"position_block_causal attention mask type, got {self.model_args.attn_mask_type}" + ) + def get_attention_masks( self, input_batch: torch.Tensor, @@ -557,25 +576,16 @@ def get_attention_masks( extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: match self.model_args.attn_type: - case "sdpa": - assert extra_inputs is not None and "positions" in extra_inputs - return create_sdpa_document_causal_mask(extra_inputs["positions"]) case "flex": return self._get_flex_attention_masks( input_batch, tokenizer, extra_inputs ) case "varlen": - if self.model_args.attn_mask_type != "block_causal": - raise ValueError( - f"varlen attention is only supported with block_causal \ - attention mask type, got {self.model_args.attn_mask_type}" - ) - assert tokenizer.eos_id is not None - return create_varlen_metadata_for_document( - input_batch, tokenizer.eos_id + return self._get_varlen_attention_masks( + input_batch, tokenizer, extra_inputs ) case _: - raise TypeError("Only sdpa, varlen, and flex attn masks are supported") + raise TypeError("Only varlen and flex attn masks are supported") def forward( self, diff --git a/torchtitan/train.py b/torchtitan/train.py index ba98c9ef7c..e4f564735c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -568,33 +568,6 @@ def forward_backward_step( # The returned loss here is local SUM loss / global_valid_tokens return loss - def forward_step( - self, - *, - input_dict: dict[str, torch.Tensor], - labels: torch.Tensor, - ) -> torch.Tensor: - model_parts = self.model_parts - parallel_dims = self.parallel_dims - - inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( - input_dict, labels - ) - - if parallel_dims.pp_enabled: - raise NotImplementedError( - "Pipeline parallelism is not yet supported in forward_step. " - "This will be implemented in a follow-up PR." - ) - else: - # Non-PP forward / backward - assert len(model_parts) == 1 - with self.train_context(): - with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) - - return pred - def train_step( self, data_iterator: Iterator[tuple[dict[str, torch.Tensor], torch.Tensor]] ): From 8c8db22a958993780f0d80578b6393fb840bfa32 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Sat, 7 Feb 2026 23:30:39 -0800 Subject: [PATCH 06/10] cleanup --- torchtitan/models/attention.py | 2 -- torchtitan/protocols/model.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index e5aaf174f3..1c1fc65a3d 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -186,14 +186,12 @@ def forward( scale: float | None = None, enable_gqa: bool = False, is_causal: bool = True, - attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: with sdpa_kernel(self.sdpa_backends, set_priority=True): return F.scaled_dot_product_attention( q, k, v, - attn_mask=attn_mask, scale=scale, is_causal=is_causal, enable_gqa=enable_gqa, diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index 29dda6e1a4..99e4c34dc0 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -18,7 +18,7 @@ from torchtitan.models.attention import VarlenMetadata -AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata | torch.Tensor +AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata @dataclass From 047712b36b6ae789d39c88f5f7fe36125b2da4e8 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Sat, 7 Feb 2026 23:35:22 -0800 Subject: [PATCH 07/10] cleanup --- torchtitan/models/attention.py | 4 ++-- torchtitan/train.py | 9 --------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 1c1fc65a3d..4530348328 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -84,7 +84,7 @@ def forward( cu_seq_k, max_q, max_k, - scale=scale, + # scale=scale, # window_size=(left, right) controls the attention window relative to each # query position. 'left' is how many tokens before the query to attend to, # and 'right' is how many tokens after. A value of -1 means unlimited. @@ -95,7 +95,7 @@ def forward( # - (-1, -1): Full bidirectional attention (no masking). Equivalent to # is_causal=False. # - (W, 0): Sliding window causal - attend to at most W previous tokens. - window_size=(-1, 0), + # window_size=(-1, 0), ) diff --git a/torchtitan/train.py b/torchtitan/train.py index e4f564735c..a5d095ef61 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -478,15 +478,6 @@ def post_dataloading_process( tokenizer=self.tokenizer, extra_inputs=extra_inputs, ) - elif attn_type == "sdpa": - if "positions" in extra_inputs.keys(): - extra_kwargs["attention_masks"] = model.get_attention_masks( - input_batch=inputs, - tokenizer=self.tokenizer, - extra_inputs=extra_inputs, - ) - else: - raise ValueError(f"Unknown attention type: {attn_type}") else: extra_kwargs["attention_masks"] = extra_inputs.pop("attention_masks") From e28ce318a874f6c29767c370d1064f564cfd6d78 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Sun, 8 Feb 2026 12:26:34 -0800 Subject: [PATCH 08/10] add more tests --- tests/unit_tests/test_varlen_metadata.py | 144 +++++++++++++++++++++++ torchtitan/models/attention.py | 5 +- 2 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 tests/unit_tests/test_varlen_metadata.py diff --git a/tests/unit_tests/test_varlen_metadata.py b/tests/unit_tests/test_varlen_metadata.py new file mode 100644 index 0000000000..f6f1811ccf --- /dev/null +++ b/tests/unit_tests/test_varlen_metadata.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for create_varlen_metadata_for_document function. + +Verifies that both position-based and EOS-based boundary detection +produce correct cumulative sequence lengths. +""" + +import unittest + +import torch + +from torchtitan.models.attention import create_varlen_metadata_for_document, VarlenMetadata + + +class TestCreateVarlenMetadataForDocument(unittest.TestCase): + """Test create_varlen_metadata_for_document with positions and input_ids+eos_id.""" + + def test_single_document_with_positions(self): + """Single document per batch item - positions are sequential.""" + # positions: [0, 1, 2, 3, 4] - single document + positions = torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.int64) + + metadata = create_varlen_metadata_for_document(positions=positions) + + # Expected: one document of length 5 + # cu_seqlens: [0, 5] + self.assertIsInstance(metadata, VarlenMetadata) + expected_cu_seqlens = torch.tensor([0, 5], dtype=torch.int32) + self.assertTrue( + torch.equal(metadata.cu_seq_q, expected_cu_seqlens), + f"Expected {expected_cu_seqlens}, got {metadata.cu_seq_q}" + ) + self.assertTrue(torch.equal(metadata.cu_seq_k, expected_cu_seqlens)) + self.assertEqual(metadata.max_q, 5) + self.assertEqual(metadata.max_k, 5) + + def test_two_documents_with_positions(self): + """Two documents detected via position reset.""" + # positions: [0, 1, 2, 0, 1] - doc1 has 3 tokens, doc2 has 2 tokens + positions = torch.tensor([[0, 1, 2, 0, 1]], dtype=torch.int64) + + metadata = create_varlen_metadata_for_document(positions=positions) + + # Expected: two documents, lengths 3 and 2 + # cu_seqlens: [0, 3, 5] + expected_cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) + assert torch.equal(metadata.cu_seq_q, expected_cu_seqlens) + assert torch.equal(metadata.cu_seq_k, expected_cu_seqlens) + assert metadata.max_q == 3 + assert metadata.max_k == 3 + + def test_single_document_with_eos(self): + """Single document with EOS at the end.""" + # input_ids: [1, 2, 3, 4, EOS] where EOS=0 + input_ids = torch.tensor([[1, 2, 3, 4, 0]], dtype=torch.int64) + eos_id = 0 + + metadata = create_varlen_metadata_for_document(input_ids=input_ids, eos_id=eos_id) + + # Expected: one document of length 5 + # cu_seqlens: [0, 5] + expected_cu_seqlens = torch.tensor([0, 5], dtype=torch.int32) + assert torch.equal(metadata.cu_seq_q, expected_cu_seqlens) + assert metadata.max_q == 5 + + def test_two_documents_with_eos(self): + """Two documents separated by EOS token.""" + # input_ids: [1, 2, EOS, 3, 4, EOS] where EOS=0 + # doc1: [1, 2, EOS], doc2: [3, 4, EOS] + input_ids = torch.tensor([[1, 2, 0, 3, 4, 0]], dtype=torch.int64) + eos_id = 0 + + metadata = create_varlen_metadata_for_document(input_ids=input_ids, eos_id=eos_id) + + # Expected: two documents, lengths 3 and 3 + # cu_seqlens: [0, 3, 6] + expected_cu_seqlens = torch.tensor([0, 3, 6], dtype=torch.int32) + assert torch.equal(metadata.cu_seq_q, expected_cu_seqlens) + assert metadata.max_q == 3 + + def test_batch_size_two_with_positions(self): + """Batch of 2 samples with different document structures.""" + # Sample 0: [0, 1, 2, 3, 4] - one document of length 5 + # Sample 1: [0, 1, 0, 1, 2] - two documents of lengths 2 and 3 + positions = torch.tensor([ + [0, 1, 2, 3, 4], + [0, 1, 0, 1, 2], + ], dtype=torch.int64) + + metadata = create_varlen_metadata_for_document(positions=positions) + + # Expected cu_seqlens for packed format: + # Sample 0 contributes: starts at 0, length 5 + # Sample 1 contributes: starts at 5 (offset), doc1 starts at 5, doc2 starts at 7 + expected_cu_seqlens = torch.tensor([0, 5, 7, 10], dtype=torch.int32) + assert torch.equal(metadata.cu_seq_q, expected_cu_seqlens), \ + f"Expected {expected_cu_seqlens}, got {metadata.cu_seq_q}" + assert metadata.max_q == 5 # max of [5, 2, 3] + + def test_batch_size_two_with_eos(self): + """Batch of 2 samples with different document structures using EOS.""" + # Sample 0: [1, 2, 3, 4, EOS] - one document + # Sample 1: [1, EOS, 2, 3, EOS] - two documents + input_ids = torch.tensor([ + [1, 2, 3, 4, 0], + [1, 0, 2, 3, 0], + ], dtype=torch.int64) + eos_id = 0 + + metadata = create_varlen_metadata_for_document(input_ids=input_ids, eos_id=eos_id) + + # Sample 0: one doc of length 5, cu_seqlens contribution: [0] + # Sample 1: doc1 length 2, doc2 length 3, cu_seqlens contribution: [5, 7] + expected_cu_seqlens = torch.tensor([0, 5, 7, 10], dtype=torch.int32) + assert torch.equal(metadata.cu_seq_q, expected_cu_seqlens), \ + f"Expected {expected_cu_seqlens}, got {metadata.cu_seq_q}" + + def test_positions_and_eos_produce_same_result(self): + """Verify that positions and EOS-based detection produce equivalent results.""" + # Create equivalent inputs + # Two documents: lengths 3 and 2 + # Positions: [0, 1, 2, 0, 1] + # EOS-based: [A, B, EOS, C, D] where EOS marks end of doc1 + + positions = torch.tensor([[0, 1, 2, 0, 1]], dtype=torch.int64) + input_ids = torch.tensor([[1, 2, 0, 3, 4]], dtype=torch.int64) # 0 is EOS + eos_id = 0 + + metadata_positions = create_varlen_metadata_for_document(positions=positions) + metadata_eos = create_varlen_metadata_for_document(input_ids=input_ids, eos_id=eos_id) + + assert torch.equal(metadata_positions.cu_seq_q, metadata_eos.cu_seq_q), \ + f"Position-based: {metadata_positions.cu_seq_q}, EOS-based: {metadata_eos.cu_seq_q}" + assert metadata_positions.max_q == metadata_eos.max_q + + +if __name__ == "__main__": + unittest.main() diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 4530348328..3d7a3e0ab3 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -84,7 +84,7 @@ def forward( cu_seq_k, max_q, max_k, - # scale=scale, + scale=scale, # window_size=(left, right) controls the attention window relative to each # query position. 'left' is how many tokens before the query to attend to, # and 'right' is how many tokens after. A value of -1 means unlimited. @@ -95,7 +95,7 @@ def forward( # - (-1, -1): Full bidirectional attention (no masking). Equivalent to # is_causal=False. # - (W, 0): Sliding window causal - attend to at most W previous tokens. - # window_size=(-1, 0), + window_size=(-1, 0), ) @@ -359,6 +359,7 @@ def create_varlen_metadata_for_document( position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) # boundary_mask[b, i] is True if position i starts a new document boundary_mask = position_diff != 1 # [batch, seq] + boundary_mask[:, 0] = True elif input_ids is not None and eos_id is not None: batch_size, seq_len = input_ids.shape device = input_ids.device From e3a9eb58f372fe766ab6ace9d548a353a466a7a4 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 9 Feb 2026 16:21:21 -0800 Subject: [PATCH 09/10] address comments --- tests/unit_tests/test_lr_scheduler.py | 21 +-- tests/unit_tests/test_varlen_metadata.py | 144 ------------------- torchtitan/components/lr_scheduler.py | 7 +- torchtitan/experiments/vlm/model/siglip2.py | 11 +- torchtitan/models/attention.py | 92 +++--------- torchtitan/models/deepseek_v3/model/model.py | 11 +- torchtitan/models/gpt_oss/model/model.py | 9 +- torchtitan/models/llama3/model/model.py | 43 ++---- torchtitan/models/llama4/model/model.py | 11 +- torchtitan/models/qwen3/model/model.py | 43 ++---- 10 files changed, 59 insertions(+), 333 deletions(-) delete mode 100644 tests/unit_tests/test_varlen_metadata.py diff --git a/tests/unit_tests/test_lr_scheduler.py b/tests/unit_tests/test_lr_scheduler.py index 394cebcefe..00c817a46a 100644 --- a/tests/unit_tests/test_lr_scheduler.py +++ b/tests/unit_tests/test_lr_scheduler.py @@ -9,6 +9,7 @@ import torch from torch.optim import Adam + from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config import ConfigManager @@ -87,16 +88,16 @@ def test_linear_warmup_decay(self): # Expected adjustment factors for each step expected_factors = [ - 0.0, # Step 0: 0% of max LR (warmup start) - 0.5, # Step 1: 50% of max LR (warmup) - 1.0, # Step 2: 100% of max LR (warmup complete) - 1.0, # Step 3: Stable phase (virtual step to prevent LR dropping to 0 at last step) - 7.0 / 8.0, # Step 4: 7/8 of max LR (decay starts) - 6.0 / 8.0, # Step 5: 3/4 of max LR - 5.0 / 8.0, # Step 6: 5/8 of max LR - 4.0 / 8.0, # Step 7: 1/2 of max LR - 3.0 / 8.0, # Step 8: 3/8 of max LR - 2.0 / 8.0, # Step 9: 1/4 of max LR + 0.5, # Step 0: 50% of max LR (warmup) + 1.0, # Step 1: 100% of max LR (warmup complete) + 1.0, # Step 2: We maunally added step of stable phase, to prevent LR from dropping to 0 at last step + 7.0 / 8.0, # Step 3: 7/8 of max LR + 6.0 / 8.0, # Step 4: 3/4 of max LR + 5.0 / 8.0, # Step 5: 5/8 of max LR + 4.0 / 8.0, # Step 6: 1/2 of max LR + 3.0 / 8.0, # Step 7: 3/8 of max LR + 2.0 / 8.0, # Step 8: 1/4 of max LR + 1.0 / 8.0, # Step 9: 1/8 of max LR ] # Check the learning rate at each step diff --git a/tests/unit_tests/test_varlen_metadata.py b/tests/unit_tests/test_varlen_metadata.py deleted file mode 100644 index f6f1811ccf..0000000000 --- a/tests/unit_tests/test_varlen_metadata.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Tests for create_varlen_metadata_for_document function. - -Verifies that both position-based and EOS-based boundary detection -produce correct cumulative sequence lengths. -""" - -import unittest - -import torch - -from torchtitan.models.attention import create_varlen_metadata_for_document, VarlenMetadata - - -class TestCreateVarlenMetadataForDocument(unittest.TestCase): - """Test create_varlen_metadata_for_document with positions and input_ids+eos_id.""" - - def test_single_document_with_positions(self): - """Single document per batch item - positions are sequential.""" - # positions: [0, 1, 2, 3, 4] - single document - positions = torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.int64) - - metadata = create_varlen_metadata_for_document(positions=positions) - - # Expected: one document of length 5 - # cu_seqlens: [0, 5] - self.assertIsInstance(metadata, VarlenMetadata) - expected_cu_seqlens = torch.tensor([0, 5], dtype=torch.int32) - self.assertTrue( - torch.equal(metadata.cu_seq_q, expected_cu_seqlens), - f"Expected {expected_cu_seqlens}, got {metadata.cu_seq_q}" - ) - self.assertTrue(torch.equal(metadata.cu_seq_k, expected_cu_seqlens)) - self.assertEqual(metadata.max_q, 5) - self.assertEqual(metadata.max_k, 5) - - def test_two_documents_with_positions(self): - """Two documents detected via position reset.""" - # positions: [0, 1, 2, 0, 1] - doc1 has 3 tokens, doc2 has 2 tokens - positions = torch.tensor([[0, 1, 2, 0, 1]], dtype=torch.int64) - - metadata = create_varlen_metadata_for_document(positions=positions) - - # Expected: two documents, lengths 3 and 2 - # cu_seqlens: [0, 3, 5] - expected_cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) - assert torch.equal(metadata.cu_seq_q, expected_cu_seqlens) - assert torch.equal(metadata.cu_seq_k, expected_cu_seqlens) - assert metadata.max_q == 3 - assert metadata.max_k == 3 - - def test_single_document_with_eos(self): - """Single document with EOS at the end.""" - # input_ids: [1, 2, 3, 4, EOS] where EOS=0 - input_ids = torch.tensor([[1, 2, 3, 4, 0]], dtype=torch.int64) - eos_id = 0 - - metadata = create_varlen_metadata_for_document(input_ids=input_ids, eos_id=eos_id) - - # Expected: one document of length 5 - # cu_seqlens: [0, 5] - expected_cu_seqlens = torch.tensor([0, 5], dtype=torch.int32) - assert torch.equal(metadata.cu_seq_q, expected_cu_seqlens) - assert metadata.max_q == 5 - - def test_two_documents_with_eos(self): - """Two documents separated by EOS token.""" - # input_ids: [1, 2, EOS, 3, 4, EOS] where EOS=0 - # doc1: [1, 2, EOS], doc2: [3, 4, EOS] - input_ids = torch.tensor([[1, 2, 0, 3, 4, 0]], dtype=torch.int64) - eos_id = 0 - - metadata = create_varlen_metadata_for_document(input_ids=input_ids, eos_id=eos_id) - - # Expected: two documents, lengths 3 and 3 - # cu_seqlens: [0, 3, 6] - expected_cu_seqlens = torch.tensor([0, 3, 6], dtype=torch.int32) - assert torch.equal(metadata.cu_seq_q, expected_cu_seqlens) - assert metadata.max_q == 3 - - def test_batch_size_two_with_positions(self): - """Batch of 2 samples with different document structures.""" - # Sample 0: [0, 1, 2, 3, 4] - one document of length 5 - # Sample 1: [0, 1, 0, 1, 2] - two documents of lengths 2 and 3 - positions = torch.tensor([ - [0, 1, 2, 3, 4], - [0, 1, 0, 1, 2], - ], dtype=torch.int64) - - metadata = create_varlen_metadata_for_document(positions=positions) - - # Expected cu_seqlens for packed format: - # Sample 0 contributes: starts at 0, length 5 - # Sample 1 contributes: starts at 5 (offset), doc1 starts at 5, doc2 starts at 7 - expected_cu_seqlens = torch.tensor([0, 5, 7, 10], dtype=torch.int32) - assert torch.equal(metadata.cu_seq_q, expected_cu_seqlens), \ - f"Expected {expected_cu_seqlens}, got {metadata.cu_seq_q}" - assert metadata.max_q == 5 # max of [5, 2, 3] - - def test_batch_size_two_with_eos(self): - """Batch of 2 samples with different document structures using EOS.""" - # Sample 0: [1, 2, 3, 4, EOS] - one document - # Sample 1: [1, EOS, 2, 3, EOS] - two documents - input_ids = torch.tensor([ - [1, 2, 3, 4, 0], - [1, 0, 2, 3, 0], - ], dtype=torch.int64) - eos_id = 0 - - metadata = create_varlen_metadata_for_document(input_ids=input_ids, eos_id=eos_id) - - # Sample 0: one doc of length 5, cu_seqlens contribution: [0] - # Sample 1: doc1 length 2, doc2 length 3, cu_seqlens contribution: [5, 7] - expected_cu_seqlens = torch.tensor([0, 5, 7, 10], dtype=torch.int32) - assert torch.equal(metadata.cu_seq_q, expected_cu_seqlens), \ - f"Expected {expected_cu_seqlens}, got {metadata.cu_seq_q}" - - def test_positions_and_eos_produce_same_result(self): - """Verify that positions and EOS-based detection produce equivalent results.""" - # Create equivalent inputs - # Two documents: lengths 3 and 2 - # Positions: [0, 1, 2, 0, 1] - # EOS-based: [A, B, EOS, C, D] where EOS marks end of doc1 - - positions = torch.tensor([[0, 1, 2, 0, 1]], dtype=torch.int64) - input_ids = torch.tensor([[1, 2, 0, 3, 4]], dtype=torch.int64) # 0 is EOS - eos_id = 0 - - metadata_positions = create_varlen_metadata_for_document(positions=positions) - metadata_eos = create_varlen_metadata_for_document(input_ids=input_ids, eos_id=eos_id) - - assert torch.equal(metadata_positions.cu_seq_q, metadata_eos.cu_seq_q), \ - f"Position-based: {metadata_positions.cu_seq_q}, EOS-based: {metadata_eos.cu_seq_q}" - assert metadata_positions.max_q == metadata_eos.max_q - - -if __name__ == "__main__": - unittest.main() diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index 17289df6be..15a3fc6bd1 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -155,7 +155,9 @@ def linear_warmup_stable_decay( """ warmup_stable_steps = warmup_steps + stable_steps if current_step < warmup_steps: - # linear warmup (0-indexed to match FSDP/HuggingFace) + # linear warmup + # 0-indexed step, hence + 1 adjustments + current_step += 1 assert ( warmup_steps != 0 ), "warmup_steps must not be zero to reach this branch" @@ -163,7 +165,8 @@ def linear_warmup_stable_decay( elif current_step < warmup_stable_steps: curr_adjustment = 1.0 else: - # Decay phase (0-indexed to match FSDP/HuggingFace) + # 0-indexed step, hence + 1 adjustments + current_step += 1 assert decay_steps != 0, "decay_steps must not be zero to reach this branch" progress = float(current_step - warmup_stable_steps) / decay_steps diff --git a/torchtitan/experiments/vlm/model/siglip2.py b/torchtitan/experiments/vlm/model/siglip2.py index 46e3c05a93..69278350d1 100644 --- a/torchtitan/experiments/vlm/model/siglip2.py +++ b/torchtitan/experiments/vlm/model/siglip2.py @@ -227,16 +227,7 @@ def get_attention_masks( B = 1 case "block_causal": B = pixel_masks.shape[0] - mask_mods.append( - get_document_mask_mod( - input_ids=pixel_masks, eos_id=tokenizer.eos_id - ) - ) - case "position_block_causal": - assert extra_inputs is not None and "positions" in extra_inputs - positions = extra_inputs["positions"] - B = input_batch.shape[0] - mask_mods.append(get_document_mask_mod(positions=positions)) + mask_mods.append(get_document_mask_mod(pixel_masks, tokenizer.eos_id)) case _: raise ValueError( f"Unknown attention mask type: {self.args.attn_mask_type}" diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 3d7a3e0ab3..39b1826c8b 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -214,44 +214,25 @@ def _causal_mask( return _causal_mask -def get_document_mask_mod( - *, - input_ids: torch.Tensor | None = None, - eos_id: int | None = None, - positions: torch.Tensor | None = None, -) -> _mask_mod_signature: +def get_document_mask_mod(batch: torch.Tensor, eos_id: int) -> _mask_mod_signature: """Creates a document mask that prevents attention across document boundaries. Document boundaries can be detected either from EOS tokens or position ID resets. Args: - input_ids: Input token IDs with shape [batch, seq]. Required with eos_id. - eos_id: End-of-sequence token ID that marks document boundaries. - positions: Position IDs with shape [batch, seq]. Boundaries detected where - position diff != 1 (i.e., position resets). + batch: Input batch tensor with shape [b, s, h, d] + eos_id: End-of-sequence token ID that marks document boundaries Returns: A mask modifier function that implements document-level masking. - - Raises: - ValueError: If neither or both separator methods are provided. """ - if positions is not None: - # Detect boundaries from position resets - first_dummy_value = positions[:, :1] - 1 - position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) - sequence_indices = (position_diff != 1).cumsum(-1) # [batch, seq] - elif input_ids is not None and eos_id is not None: - # Detect boundaries from EOS tokens - eos_mask = input_ids == eos_id - eos_mask[:, -1] = True - cumulative_mask = torch.cumsum(eos_mask.int(), dim=1) - sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32) - sequence_indices[:, 1:] = cumulative_mask[:, :-1] - else: - raise ValueError( - "Must provide either 'positions' or both 'input_ids' and 'eos_id'" - ) + # batch is [b, s, h, d] shape + # batch is [b, s, h, d] shape + eos_mask = batch == eos_id + eos_mask[:, -1] = True + cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1) + sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32) + sequence_indices[:, 1:] = cumulative_mask[:, :-1] def document_mask( b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor @@ -327,62 +308,31 @@ def create_attention_mask(*args, **kwargs): def create_varlen_metadata_for_document( - *, - input_ids: torch.Tensor | None = None, - eos_id: int | None = None, - positions: torch.Tensor | None = None, + input_batch: torch.Tensor, eos_id: int ) -> VarlenMetadata: """ - Creates cumulative sequence length indices needed for variable length attention. - - Document boundaries can be detected either from EOS tokens or position ID resets. - Exactly one method must be specified. + Creates cumulative sequence length indices needed for variable length attention Args: - input_ids: Input token IDs with shape [batch, seq]. Required with eos_id. - eos_id: End-of-sequence token ID that marks document boundaries. - positions: Position IDs with shape [batch, seq]. Boundaries detected where - position diff != 1 (i.e., position resets). + input_batch + eos_id: the EOS id marker Returns: VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len - - Raises: - ValueError: If neither or both separator methods are provided. """ - if positions is not None: - batch_size, seq_len = positions.shape - device = positions.device - - # Detect boundaries from position resets (where diff != 1) - first_dummy_value = positions[:, :1] - 1 - position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) - # boundary_mask[b, i] is True if position i starts a new document - boundary_mask = position_diff != 1 # [batch, seq] - boundary_mask[:, 0] = True - elif input_ids is not None and eos_id is not None: - batch_size, seq_len = input_ids.shape - device = input_ids.device - - # Detect boundaries from EOS tokens - eos_mask = input_ids == eos_id - boundary_mask = torch.zeros_like(eos_mask) - boundary_mask[:, 0] = True # First position always starts a document - boundary_mask[:, 1:] = eos_mask[:, :-1] - else: - raise ValueError( - "Must provide either 'positions' or both 'input_ids' and 'eos_id'" - ) - + batch_size, seq_len = input_batch.shape + device = input_batch.device cu_seqlens_list, all_seq_lengths = [], [] offset = 0 + max_seqlen = 0 for b in range(batch_size): - # Find positions where new documents start - boundary_positions = boundary_mask[b].nonzero(as_tuple=True)[0].to(torch.int32) + tokens = input_batch[b] + eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32) sample_cu_seqlens = torch.cat( [ - boundary_positions, + torch.tensor([0], dtype=torch.int32, device=device), + eos_positions + 1, torch.tensor([seq_len], dtype=torch.int32, device=device), ] ) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 8af8e5d490..498b012605 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -465,16 +465,7 @@ def get_attention_masks( case "block_causal": B = input_batch.shape[0] assert tokenizer.eos_id is not None - mask_mods.append( - get_document_mask_mod( - input_ids=input_batch, eos_id=tokenizer.eos_id - ) - ) - case "position_block_causal": - assert extra_inputs is not None and "positions" in extra_inputs - positions = extra_inputs["positions"] - B = input_batch.shape[0] - mask_mods.append(get_document_mask_mod(positions=positions)) + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" diff --git a/torchtitan/models/gpt_oss/model/model.py b/torchtitan/models/gpt_oss/model/model.py index e3f47775d1..92b9c42fe6 100644 --- a/torchtitan/models/gpt_oss/model/model.py +++ b/torchtitan/models/gpt_oss/model/model.py @@ -338,15 +338,8 @@ def get_attention_masks( B = input_batch.shape[0] assert tokenizer.eos_id is not None basic_mask_mods.append( - get_document_mask_mod( - input_ids=input_batch, eos_id=tokenizer.eos_id - ) + get_document_mask_mod(input_batch, tokenizer.eos_id) ) - case "position_block_causal": - assert extra_inputs is not None and "positions" in extra_inputs - positions = extra_inputs["positions"] - B = input_batch.shape[0] - basic_mask_mods.append(get_document_mask_mod(positions=positions)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 6898efffc9..f79dbcaea7 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -514,16 +514,7 @@ def _get_flex_attention_masks( case "block_causal": B = input_batch.shape[0] assert tokenizer.eos_id is not None - mask_mods.append( - get_document_mask_mod( - input_ids=input_batch, eos_id=tokenizer.eos_id - ) - ) - case "position_block_causal": - assert extra_inputs is not None and "positions" in extra_inputs - positions = extra_inputs["positions"] - B = input_batch.shape[0] - mask_mods.append(get_document_mask_mod(positions=positions)) + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" @@ -533,28 +524,6 @@ def _get_flex_attention_masks( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) - def _get_varlen_attention_masks( - self, - input_batch: torch.Tensor, - tokenizer: BaseTokenizer, - extra_inputs: dict[str, torch.Tensor] | None = None, - ) -> AttentionMasksType: - match self.model_args.attn_mask_type: - case "block_causal": - assert tokenizer.eos_id is not None - return create_varlen_metadata_for_document( - input_ids=input_batch, eos_id=tokenizer.eos_id - ) - case "position_block_causal": - assert extra_inputs is not None and "positions" in extra_inputs - positions = extra_inputs["positions"] - return create_varlen_metadata_for_document(positions=positions) - case _: - raise ValueError( - f"varlen attention is only supported with block_causal or " - f"position_block_causal attention mask type, got {self.model_args.attn_mask_type}" - ) - def get_attention_masks( self, input_batch: torch.Tensor, @@ -567,8 +536,14 @@ def get_attention_masks( input_batch, tokenizer, extra_inputs ) case "varlen": - return self._get_varlen_attention_masks( - input_batch, tokenizer, extra_inputs + if self.model_args.attn_mask_type != "block_causal": + raise ValueError( + f"varlen attention is only supported with block_causal \ + attention mask type, got {self.model_args.attn_mask_type}" + ) + assert tokenizer.eos_id is not None + return create_varlen_metadata_for_document( + input_batch, tokenizer.eos_id ) case _: raise TypeError("Only varlen and flex attn masks are supported") diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index af856fe20f..342f5b5f18 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -549,17 +549,8 @@ def get_attention_masks( B = 1 case "block_causal": assert tokenizer.eos_id is not None - mask_mods.append( - get_document_mask_mod( - input_ids=input_batch, eos_id=tokenizer.eos_id - ) - ) - B = input_batch.shape[0] - case "position_block_causal": - assert extra_inputs is not None and "positions" in extra_inputs - positions = extra_inputs["positions"] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) B = input_batch.shape[0] - mask_mods.append(get_document_mask_mod(positions=positions)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index ed84e44d89..2b7173546d 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -529,16 +529,7 @@ def _get_flex_attention_masks( case "block_causal": B = input_batch.shape[0] assert tokenizer.eos_id is not None - mask_mods.append( - get_document_mask_mod( - input_ids=input_batch, eos_id=tokenizer.eos_id - ) - ) - case "position_block_causal": - assert extra_inputs is not None and "positions" in extra_inputs - positions = extra_inputs["positions"] - B = input_batch.shape[0] - mask_mods.append(get_document_mask_mod(positions=positions)) + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" @@ -547,28 +538,6 @@ def _get_flex_attention_masks( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) - def _get_varlen_attention_masks( - self, - input_batch: torch.Tensor, - tokenizer: BaseTokenizer, - extra_inputs: dict[str, torch.Tensor] | None = None, - ) -> AttentionMasksType: - match self.model_args.attn_mask_type: - case "block_causal": - assert tokenizer.eos_id is not None - return create_varlen_metadata_for_document( - input_ids=input_batch, eos_id=tokenizer.eos_id - ) - case "position_block_causal": - assert extra_inputs is not None and "positions" in extra_inputs - positions = extra_inputs["positions"] - return create_varlen_metadata_for_document(positions=positions) - case _: - raise ValueError( - f"varlen attention is only supported with block_causal or " - f"position_block_causal attention mask type, got {self.model_args.attn_mask_type}" - ) - def get_attention_masks( self, input_batch: torch.Tensor, @@ -581,8 +550,14 @@ def get_attention_masks( input_batch, tokenizer, extra_inputs ) case "varlen": - return self._get_varlen_attention_masks( - input_batch, tokenizer, extra_inputs + if self.model_args.attn_mask_type != "block_causal": + raise ValueError( + f"varlen attention is only supported with block_causal \ + attention mask type, got {self.model_args.attn_mask_type}" + ) + assert tokenizer.eos_id is not None + return create_varlen_metadata_for_document( + input_batch, tokenizer.eos_id ) case _: raise TypeError("Only varlen and flex attn masks are supported") From 2da51a3c72ac19839f2a0dd2e0f4033357d7c6f0 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 9 Feb 2026 16:23:34 -0800 Subject: [PATCH 10/10] address comments --- torchtitan/models/attention.py | 3 --- torchtitan/train.py | 23 ++++++++--------------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 39b1826c8b..e034224019 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -217,8 +217,6 @@ def _causal_mask( def get_document_mask_mod(batch: torch.Tensor, eos_id: int) -> _mask_mod_signature: """Creates a document mask that prevents attention across document boundaries. - Document boundaries can be detected either from EOS tokens or position ID resets. - Args: batch: Input batch tensor with shape [b, s, h, d] eos_id: End-of-sequence token ID that marks document boundaries @@ -227,7 +225,6 @@ def get_document_mask_mod(batch: torch.Tensor, eos_id: int) -> _mask_mod_signatu A mask modifier function that implements document-level masking. """ # batch is [b, s, h, d] shape - # batch is [b, s, h, d] shape eos_mask = batch == eos_id eos_mask[:, -1] = True cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1) diff --git a/torchtitan/train.py b/torchtitan/train.py index a5d095ef61..35ec9a389c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -463,23 +463,16 @@ def post_dataloading_process( extra_kwargs: dict[str, Any] = {} attn_type = getattr(self.model_args, "attn_type", "sdpa") - if ( - "attention_masks" not in extra_inputs.keys() - or extra_inputs["attention_masks"] is None - ): - model = cast(ModelProtocol, self.model_parts[0]) - extra_inputs.pop("attention_masks") + if attn_type in ["flex", "varlen"]: assert ( self.tokenizer is not None - ), "tokenizer is required for sdpa/flex/varlen attention" - if attn_type in ["flex", "varlen"]: - extra_kwargs["attention_masks"] = model.get_attention_masks( - input_batch=inputs, - tokenizer=self.tokenizer, - extra_inputs=extra_inputs, - ) - else: - extra_kwargs["attention_masks"] = extra_inputs.pop("attention_masks") + ), "tokenizer is required for flex/varlen attention" + model = cast(ModelProtocol, self.model_parts[0]) + extra_kwargs["attention_masks"] = model.get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, + extra_inputs=extra_inputs, + ) if self.parallel_dims.cp_enabled: inputs, labels, extra_kwargs = prepare_context_parallel_input(