Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@

import torch
import torch.nn.functional as F
from torch.distributed.tensor.experimental._attention import create_cp_block_mask
from torch.distributed.tensor.experimental._attention import (
_DispatchMode,
_flex_attention_wrapper,
create_cp_block_mask,
)
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
BlockMask,
create_block_mask,
flex_attention,
)

from torchtitan.tools.utils import has_cuda_capability
Expand Down Expand Up @@ -48,9 +51,12 @@ class FlexAttention(torch.nn.Module):

# We registered flex_attention related attributes as class variables as we
# need to amortize the cost of compilation.
"""
flex_attn: ClassVar[Callable] = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
"""
flex_attn: ClassVar[Callable] = _flex_attention_wrapper
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set()
# Attention mask type to the created BlockMask.
Expand Down Expand Up @@ -251,6 +257,9 @@ def init_attention_mask(
# while we continue debugging accuracy issues. However, we want to evaluate
# the user experience with CP enabled.
if cp_mesh is not None:
torch.distributed.tensor.experimental._attention._dispatch_mode = (
_DispatchMode.MODULE_WRAPPER
)
FlexAttention.compiled_create_block_mask = functools.partial(
create_cp_block_mask, device_mesh=cp_mesh
)
Expand Down
6 changes: 1 addition & 5 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ def parallelize_llama(
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
"""

use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
raise NotImplementedError("CP support for FlexAttention is still in progress.")

if parallel_dims.tp_enabled:
enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
Expand Down Expand Up @@ -99,7 +95,7 @@ def parallelize_llama(
model,
job_config.activation_checkpoint,
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
use_flex_attn=getattr(model.model_args, "use_flex_attn", False),
save_list=_save_list,
)

Expand Down
7 changes: 0 additions & 7 deletions torchtitan/models/llama3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)
self.max_seq_len = seq_len

if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
raise NotImplementedError(
"CP support for FlexAttention is still in progress."
)

self.max_seq_len = seq_len

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
nparams = sum(p.numel() for p in model.parameters())
nparams_embedding = sum(
Expand Down
Loading