diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 74d310dfc1..0ecf05500d 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -206,7 +206,6 @@ def context(cp_context: Generator[None, None, None] | None = None): if SDPBackend.MATH in ScaledDotProductAttention.backends: ScaledDotProductAttention.backends.remove(SDPBackend.MATH) - stack.enter_context(cp_context) yield diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py new file mode 100644 index 0000000000..715ce943e0 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -0,0 +1,55 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from .infra.optimizer import build_gptoss_optimizers + +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .infra.parallelize import parallelize_gptoss +from .model.args import GptOssModelArgs +from .model.model import GptOssModel + +__all__ = [ + "parallelize_gptoss", + "GptOssModelArgs", + "GptOssModel", + "gptoss_configs", +] + + +gptoss_configs = { + "debugmodel": GptOssModelArgs( + hidden_size=256, + num_hidden_layers=4, + use_flex_attn=False, + use_grouped_mm=False, + ), + "20b": GptOssModelArgs( + num_hidden_layers=24, + num_local_experts=32, + ), + "120b": GptOssModelArgs( + num_hidden_layers=36, + num_local_experts=128, + ), +} + + +register_train_spec( + TrainSpec( + name="gpt_oss", + cls=GptOssModel, + config=gptoss_configs, + parallelize_fn=parallelize_gptoss, + pipelining_fn=None, + build_optimizers_fn=build_gptoss_optimizers, # use optimizer hooks to update expert weights + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py new file mode 100644 index 0000000000..e47bdeec58 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py @@ -0,0 +1,297 @@ +from functools import partial +from typing import Callable + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._functional_collectives import all_to_all_single_autograd +from torch.distributed.tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Replicate, + Shard, +) +from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.placement_types import Placement + + +# implementation of Tensor Parallel for the GroupedExperts in MoE +class TensorParallel(ParallelStyle): + def _partition_fn(self, name, module, device_mesh): + module.register_parameter( + "mlp1_weight", nn.Parameter(distribute_tensor(module.mlp1_weight, device_mesh, [Shard(2)])) + ) # Column-wise sharding + module.register_parameter( + "mlp1_bias", + nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])), + ) # Column-wise sharding + module.register_parameter( + "mlp2_weight", + nn.Parameter(distribute_tensor(module.mlp2_weight, device_mesh, [Shard(1)])), + ) # Row-wise sharding + module.register_parameter( + "mlp2_bias", + nn.Parameter(distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()])), + ) # Replicate + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + ) + + +# NOTE: This is to achieve replicate computation on the gate module in the MoE router. +# It does nothing other than (1) setting the module parameters as DTensors on the given mesh +# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. +# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, +# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. +class NoParallel(ParallelStyle): + def __init__( + self, + *, + input_layout: Placement | None = None, + output_layout: Placement | None = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layout = input_layout or Replicate() + self.output_layout = output_layout or Replicate() + self.desired_input_layout = Replicate() + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, (input_layout,), run_check=False + ) + + if input_layout != desired_input_layout: + input_tensor = input_tensor.redistribute( + placements=(desired_input_layout,), async_op=True + ) + return (input_tensor, *inputs[1:]) + + @staticmethod + def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): + if outputs.placements != (output_layout,): + outputs = outputs.redistribute(placements=(output_layout,), async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + None, + partial( + self._prepare_input_fn, self.input_layout, self.desired_input_layout + ), + partial(self._prepare_output_fn, self.output_layout, self.use_local_output), + ) + + +class ExpertParallel(ParallelStyle): + def __init__(self): + super().__init__() + self.input_splits = None + self.output_splits = None + + # performing all-to-all dispatch on the input + def _token_dispatch(self, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + routed_input, num_tokens_per_expert = inputs + + # generate the input splits and output splits for all-to-all + with torch.no_grad(): + num_tokens_per_expert_group = num_tokens_per_expert.new_empty( + num_tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + num_tokens_per_expert_group, + num_tokens_per_expert, + group=device_mesh.get_group(), + ) + # NOTE: this would incur a device-to-host sync + self.input_splits = ( + num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist() + ) + self.output_splits = ( + num_tokens_per_expert_group.view(device_mesh.shape[0], -1) + .sum(dim=1) + .tolist() + ) + + # perform all-to-all + routed_input = all_to_all_single_autograd( + routed_input, + self.output_splits, + self.input_splits, + device_mesh.get_group(), + ) + + # NOTE: After this all-to-all, the routed input is put on proper EP rank. + # However, the num_tokens_per_expert_group is not of the final target format + # [#tokens for local expert 0, #tokens for local expert 1, ...] + # Rather, it is of the format + # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., + # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] + # We need to perform another shuffle to get the correct format -- this is done via the function + # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens + # each expert gets locally is a multiple of ALIGN_SIZE_M. + + return routed_input, num_tokens_per_expert_group + + @staticmethod + def _partition_fn(name, mod, device_mesh): + # shard on the expert dimension + for name, param in mod.named_parameters(recurse=False): + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) + mod.register_parameter(name, dist_param) + + # performing all-to-all combine on the output + def _token_combine(self, mod, routed_output, device_mesh): + routed_output = all_to_all_single_autograd( + routed_output, + self.input_splits, + self.output_splits, + device_mesh.get_group(), + ) + return routed_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=ExpertParallel._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +# This class is for dp2ep with TP (without TP we can just use ExpertParallel) +class ExpertTensorParallel(ExpertParallel): + def __init__( + self, + tp_mesh: DeviceMesh, + ep_mesh: DeviceMesh, + ): + super().__init__() + # TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, + # as DeviceMesh doesn't support slicing from a submesh. + self.tp_mesh = tp_mesh + self.ep_mesh = ep_mesh + + def _token_dispatch(self, mod, inputs, device_mesh): + # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_dispatch(mod, inputs, self.ep_mesh) + + def _partition_fn_2d(self, name, mod, ep_tp_mesh): + mod.register_parameter( + "mlp1_weight", + nn.Parameter(distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(2)])), + ) # Column-wise sharding + mod.register_parameter( + "mlp1_bias", + nn.Parameter(distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)])), + ) # Row-wise sharding + mod.register_parameter( + "mlp2_weight", + nn.Parameter(distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(2)])), + ) # Column-wise sharding + mod.register_parameter( + "mlp2_bias", + nn.Parameter(distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Shard(1)])), + ) # Row-wise sharding + + def _token_combine(self, mod, routed_output, device_mesh): + # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_combine(mod, routed_output, self.ep_mesh) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn_2d, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +def expert_parallel(func: Callable) -> Callable: + """ + This is a wrapper applied to the GroupedExperts computation, serving + the following three purposes: + 1. Convert parameters from DTensors to plain Tensors, to work with + dynamic-shape inputs which cannot be easily expressed as DTensors. + 2. In Expert Parallel, apply the generate_permute_indices kernel to + permute the inputs to be ordered by local experts (see the _token_dispatch + function in ExpertParallel) and permute the outputs back. + 3. In order to use torch._grouped_mm, we need to make sure the number of + tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices + kernel also helps achieve this via padding, without incurring synchronization + between device and host. Note that this will create side effects when wrapping + the for-loop implementation of GroupedExperts, as it does not need padding. + + Among the above: + 1 and 2 are needed only when expert_parallel_degree > 1. + 3 is needed even for single-device computation. + 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. + """ + + def wrapper( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(mlp1_weight, DTensor): + mlp1_weight = mlp1_weight.to_local() + mlp1_bias = mlp1_bias.to_local() + mlp2_weight = mlp2_weight.to_local() + mlp2_bias = mlp2_bias.to_local() + + if num_tokens_per_expert is not None: + from torchtitan.experiments.kernels.moe.indices import ( + generate_permute_indices, + ) + + experts_per_ep_rank = mlp1_weight.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + ALIGN_SIZE_M = 16 + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, # offsets, + ) = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, + ALIGN_SIZE_M, + ) + + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] + + out = func(mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias, x, num_tokens_per_expert) + + if num_tokens_per_expert is not None: + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] + + return out + + return wrapper diff --git a/torchtitan/experiments/gpt_oss/infra/optimizer.py b/torchtitan/experiments/gpt_oss/infra/optimizer.py new file mode 100644 index 0000000000..de8537032d --- /dev/null +++ b/torchtitan/experiments/gpt_oss/infra/optimizer.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.components.ft import FTManager +from torchtitan.components.optimizer import build_optimizers, OptimizersContainer +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + + +# for MoE auxiliary-loss-free load balancing +def _update_expert_bias( + model_parts: list[nn.Module], + world_mesh: dict[str, DeviceMesh], + parallel_dims: ParallelDims, +): + dp_cp_mesh = world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + # TODO: Currently this sync is blocking (thus exposed) and happens on the + # default compute stream. Need to assess if this is OK performance-wise. + for model_part in model_parts: + for transformer_block in model_part.layers.values(): + moe = transformer_block.moe + if moe.load_balance_coeff is None: + return + + if dp_cp_mesh is not None: + torch.distributed.all_reduce( + moe.tokens_per_expert, group=dp_cp_mesh.get_group() + ) + + with torch.no_grad(): + expert_bias_delta = moe.load_balance_coeff * torch.sign( + moe.tokens_per_expert.mean() - moe.tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + moe.expert_bias.add_(expert_bias_delta) + moe.tokens_per_expert.zero_() + + +def build_gptoss_optimizers( + model_parts: list[nn.Module], + job_config: JobConfig, + parallel_dims: ParallelDims, + world_mesh: DeviceMesh, + ft_manager: FTManager, +) -> OptimizersContainer: + optimizers = build_optimizers( + model_parts=model_parts, + job_config=job_config, + parallel_dims=parallel_dims, + world_mesh=world_mesh, + ft_manager=ft_manager, + ) + + optimizers.register_step_pre_hook( + lambda *args, **kwargs: _update_expert_bias( + model_parts, world_mesh=world_mesh, parallel_dims=parallel_dims + ) + ) + + return optimizers diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py new file mode 100644 index 0000000000..47ad01b99e --- /dev/null +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -0,0 +1,431 @@ +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import Replicate, Shard, distribute_tensor +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +if torch.__version__ >= "2.9": + from torch.distributed.tensor.parallel import PrepareModuleInputOutput +else: + print(f"Since torch version {torch.__version__} < 2.9, PrepareModuleInputOutput is not available and MoE EP TP will fail.") + +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.tools.logging import logger + +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor import Partial, Replicate, Shard + +from .expert_parallel import ( + ExpertParallel, + ExpertTensorParallel, + NoParallel, + TensorParallel, +) + + +# Adapted from llama4/infra/parallelize.py +def parallelize_gptoss( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + if parallel_dims.tp_enabled: + if job_config.parallelism.enable_async_tensor_parallel: + # TODO(jianiw): This branch needs to be tested and enabled + raise NotImplementedError( + "Currently, async TP is not tested for gptoss. \ + torch.compile is not supported yet, which is required for async TP." + ) + + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + if enable_float8_tensorwise_tp: + # TODO(jianiw): This branch needs to be tested and enabled + raise NotImplementedError( + "Currently, float8 tensorwise TP is not tested for gptoss" + ) + + apply_non_moe_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled and parallel_dims.ep_enabled + else None + ), + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + if job_config.training.compile: + raise NotImplementedError("torch.compile is not supported yet for gptoss") + + dp_mesh: DeviceMesh | None = None + if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if dp_mod_ep_mesh_dim_names + else None + ), + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + dp_mesh = world_mesh + apply_ddp( + model, + dp_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Replicate(), Replicate()), + ), + # use_local_output=False make the output to be a DTensor instead of a plain Tensor + "attention.wkv_a": NoParallel(use_local_output=False), + "attention.wkv_b": colwise_parallel(use_local_output=False), + "attention.kv_norm": NoParallel(use_local_output=False), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + } + + if transformer_block.attention.q_lora_rank == 0: + layer_plan.update( + { + "attention.wq": colwise_parallel( + use_local_output=False + ), # This is only used when q_lora_rank==0 + } + ) + else: + layer_plan.update( + { + "attention.wq_a": NoParallel(use_local_output=False), + "attention.wq_b": colwise_parallel(use_local_output=False), + "attention.q_norm": NoParallel(use_local_output=False), + } + ) + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears with tensorwise scaling. + if enable_float8_tensorwise_tp: + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": colwise_parallel(), + "attention.wk": colwise_parallel(), + "attention.wv": colwise_parallel(), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + # shard attention.sinks across heads + attn = transformer_block.attention + attn.register_parameter( + "sinks", + nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])), + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", + dp_mod_ep_mesh: DeviceMesh | None = None, +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + for layer_id, transformer_block in model.layers.items(): + if reshard_after_forward_policy == "always": + reshard_after_forward = True + elif reshard_after_forward_policy == "never": + reshard_after_forward = False + elif reshard_after_forward_policy == "default": + if pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.layers) - 1 + else: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + # NOTE: in an MoE layer, the router and the shared experts + # are sharded together with the TransformerBlock + if dp_mod_ep_mesh: + fsdp_mod_ep_config = fsdp_config.copy() + fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + fully_shard( + transformer_block.moe.experts, + **fsdp_mod_ep_config, + reshard_after_forward=reshard_after_forward, + ) + + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + +def apply_moe_ep_tp( + model: nn.Module, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, +): + for transformer_block in model.layers.values(): + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + } + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) + + # if ep_mesh is not None: + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh + # input Replicate, output Partial + experts_plan = TensorParallel() + elif tp_mesh is None: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + else: + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + parallelize_module( + module=transformer_block.moe.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, + ) diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py new file mode 100644 index 0000000000..63c2b6bb82 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -0,0 +1,142 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Literal + +from torch import nn + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools.logging import logger + +# from transformers.models.gpt_oss.modeling_gpt_oss import GPT_OSS_PRETRAINED_INIT_CONFIGURATION + + +# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +@dataclass +class GptOssModelArgs(BaseModelArgs): + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. + route_scale (float): Scaling factor for routing scores. + use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers. + load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + """ + + max_batch_size: int = 8 + max_seq_len: int = 131072 + dtype: Literal["bf16", "fp8"] = "bf16" + vocab_size: int = 201088 + hidden_size: int = 2880 + num_hidden_layers: int = 24 + norm_eps: float = 1e-5 # eps used for RMSNorm + # MoE + num_local_experts: int = 32 + num_experts_per_tok: int = 4 + use_grouped_mm: bool = True + load_balance_coeff: float | None = 1e-3 + # Multi-Head Latent Attention (MLA) + head_dim: int = 64 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + sliding_window: int = 128 + use_flex_attn: bool = True + attn_mask_type: str = "causal" + # yarn + original_seq_len: int = 4096 + rope_theta: float = 150000.0 + rope_factor: float = 32 + beta_fast: int = 32 + beta_slow: int = 1 + + def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: + """ + Update the model_config config from the given job config. + """ + # self.vocab_size = tokenizer.vocab_size + self.max_seq_len = job_config.training.seq_len + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + """ + Adopted from llama4 implementation. + """ + nparams_embedding = 0 + nparams_moe_router = 0 + nparams_shared_expert = 0 + nparams_experts = 0 + nparams_dense = 0 + + for name, p in model.named_parameters(): + if "embedding" in name: + nparams_embedding += p.numel() + nparams_dense += p.numel() + elif "moe.shared_expert" in name: + nparams_shared_expert += p.numel() + elif "moe.router" in name: + nparams_moe_router += p.numel() + elif "moe.experts" in name: + nparams_experts += p.numel() + else: + nparams_dense += p.numel() + + nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams = nparams_dense + nparams_sparse + nparams_sparse_active = ( + nparams_moe_router + + nparams_shared_expert + + nparams_experts * self.num_experts_per_tok // self.num_local_experts + ) + + logger.info( + f"Total parameter count: dense {nparams_dense:,}, " + f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}" + ) + + l, h, q, t = ( + self.num_hidden_layers, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = ( + 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + + 12 * l * h * q * t + ) + + return nparams, num_flops_per_token diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py new file mode 100644 index 0000000000..c16cb53274 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -0,0 +1,480 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +from torch import nn +from torch.distributed.tensor import DTensor +from torchtitan.models.attention import build_attention +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import GptOssModelArgs +from .moe import MoE + +# TODO: may be able to remove this once parallelized properly +def convert_submodules_to_bf16( + module: nn.Module, + exclude_names: tuple[str, ...] = ("freqs_cis", "attention_norm", "ffn_norm", "norm"), + attr_opt_out: str = "no_bf16", # if a submodule sets `self.no_bf16 = True`, it will be skipped + ) -> None: + """ + Recursively convert parameters & buffers of submodules to bfloat16, + except: + - modules whose *qualified name* ends with any of `exclude_names` + - modules with attribute `{attr_opt_out} == True` + Conversion is *shallow per-module* so exclusions are respected even deep in the tree. + """ + + def should_skip(qname: str, mod: nn.Module) -> bool: + base = qname.rsplit(".", 1)[-1] # local (leaf) name + if base in exclude_names: + return True + if getattr(mod, attr_opt_out, False): + return True + return False + + def convert_shallow(mod: nn.Module): + # convert parameters owned by this module + for _, p in mod.named_parameters(recurse=False): + if p.is_floating_point(): + p.data = p.data.to(torch.bfloat16) + # convert buffers owned by this module + for _, b in mod.named_buffers(recurse=False): + # keep non-float buffers (e.g., ints, bool masks) as-is + if torch.is_floating_point(b): + b.data = b.data.to(torch.bfloat16) + + # walk the module tree; convert only *this* module's tensors if not skipped + for qname, mod in module.named_modules(): + # skip the root container name (empty) check gracefully + local_name = qname.rsplit(".", 1)[-1] if qname else "" + if local_name and should_skip(qname, mod): + continue + convert_shallow(mod) + +# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 +def precompute_freqs_cis(args: GptOssModelArgs) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (GptOssModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ + dim = args.head_dim + seqlen = args.max_seq_len + beta_fast = args.beta_fast + beta_slow = args.beta_slow + base = args.rope_theta + factor = args.rope_factor + original_seq_len = args.original_seq_len + + # YaRN default m-scale (attention_factor). Matches HF when attention_factor is None. + mscale = 0.1 * math.log(factor) + 1.0 + + def find_correction_dim( + num_rotations: float, dim: int, base: float, max_seq_len: int + ) -> float: + """ + Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range( + low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int + ) -> Tuple[int, int]: + """ + Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: + """ + Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, + clamped to the range [0, 1]. + """ + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Basic RoPE frequency calculation + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + # YaRN scaling for extended context. YaRN is used to extend the context length after pre-training. + if seqlen > original_seq_len: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + # Create position indices + t = torch.arange(seqlen) + + # Outer product: [positions] Ɨ [frequencies] + freqs = torch.outer(t, freqs) + + # Convert to complex exponentials: e^(i*freq*pos) + freqs_cis = torch.polar(torch.full_like(freqs, fill_value=mscale), freqs) + + return freqs_cis + + +def apply_rotary_emb_inner(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + dtype = x.dtype + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + y = torch.view_as_real(x * freqs_cis).flatten(3) + return y.to(dtype) + +def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor): + """ + HF-style inputs (half-split last dim) -> interleave -> Torchtitan complex RoPE -> de-interleave. + Shapes: + q, k: [B, T, H, D] with D even (HF half-split: first D/2 real, last D/2 imag) + freqs_cis: complex, last dim == D/2. Typically [T, D/2] or [1, T, D/2]. + Returns: + q_out, k_out in HF half-split layout (same shape as q, k). + """ + B, T, H, D = q.shape + assert D % 2 == 0, "head_dim must be even for RoPE" + rot = D // 2 + assert freqs_cis.shape[-1] == rot, "freqs_cis last dim must be D/2" + freqs_cis = freqs_cis[:T, :] + + # --- inline: HF half-split -> interleaved (real0, imag0, real1, imag1, ...) + # q_i, k_i: [B, T, H, D] + q_i = torch.empty_like(q) + k_i = torch.empty_like(k) + q_i[..., 0::2] = q[..., :rot] + q_i[..., 1::2] = q[..., rot:] + k_i[..., 0::2] = k[..., :rot] + k_i[..., 1::2] = k[..., rot:] + + # --- Torchtitan default complex apply (expects interleaved last dim) + # freqs_cis will be reshaped inside to [1, T, 1, rot] + q_rot_i = apply_rotary_emb_inner(q_i, freqs_cis) # uses TT's complex path + k_rot_i = apply_rotary_emb_inner(k_i, freqs_cis) + + # --- inline: interleaved -> HF half-split + q_out = torch.cat([q_rot_i[..., 0::2], q_rot_i[..., 1::2]], dim=-1) + k_out = torch.cat([k_rot_i[..., 0::2], k_rot_i[..., 1::2]], dim=-1) + return q_out, k_out + +# Torch Attention backup implementation (for debugging and sampling) from HuggingFace +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def eager_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + attention_mask: torch.Tensor, + scaling: float, + dropout: float = 0.0, + num_key_value_groups: int = 1, + **kwargs, +): + key_states = repeat_kv(key, num_key_value_groups) + value_states = repeat_kv(value, num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + # attention_mask can be [Tq, Tk] or [B, H, Tq, Tk] + # Convert boolean "allowed" -> additive mask + if attention_mask.dtype == torch.bool: + m = attention_mask + add_mask = torch.zeros_like(m, dtype=attn_weights.dtype) + add_mask = add_mask.masked_fill(~m, -float("inf")) + else: + add_mask = attention_mask.to(attn_weights.dtype) + + # Truncate to current key length and add (broadcasts if needed) + add_mask = add_mask[..., : key_states.shape[-2]] + attn_weights = attn_weights + add_mask + + sinks = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = nn.functional.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=False) + attn_output = torch.matmul(attn_weights, value_states) + return attn_output + +class Attention(nn.Module): + """ + Multi-head attention (MLA) module. + """ + + def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = False): + super().__init__() + + self.sliding_window = model_args.sliding_window if use_sliding_attention else None + self.head_dim = model_args.head_dim + + self.wq = nn.Linear( + model_args.hidden_size, model_args.num_attention_heads * model_args.head_dim, bias=True + ) + self.wk = nn.Linear( + model_args.hidden_size, model_args.num_key_value_heads * model_args.head_dim, bias=True + ) + self.wv = nn.Linear( + model_args.hidden_size, model_args.num_key_value_heads * model_args.head_dim, bias=True + ) + self.wo = nn.Linear( + model_args.num_attention_heads * model_args.head_dim, model_args.hidden_size, bias=True + ) + self.sinks = nn.Parameter(torch.empty(model_args.num_attention_heads)) + + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.attn = build_attention(True, model_args.attn_mask_type) + else: + # NOTE: sampling with FlexAttn seems broken; use TorchAttn if needed + self.attn = eager_attention_forward + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + hidden_shape = (bsz, seqlen, -1, self.head_dim) + + q = self.wq(x).view(hidden_shape) + k = self.wk(x).view(hidden_shape) + v = self.wv(x).view(hidden_shape) + + q, k = apply_rotary_emb(q, k, freqs_cis) + + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + + if self.use_flex_attn: + output = self.attn( + q, k, v, + self.sinks.to_local() if isinstance(self.sinks, DTensor) else self.sinks, + sliding_window=self.sliding_window, + enable_gqa=True, + ) + else: + output = self.attn( + q, k, v, self.sinks, + attention_mask=self.sliding_window_causal(seqlen, x.device), + scaling=self.head_dim**-0.5, + dropout=0.0, + num_key_value_groups=8, + ) + output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D) + + # Reshape and project output + output = output.reshape(bsz, seqlen, -1).contiguous() # (bsz, seqlen, n_heads * v_head_dim) + output = self.wo(output) # (bsz, seqlen, dim) + return output + + def init_weights(self, init_std: float): + linear_list = [ + self.wq, + self.wk, + self.wv, + ] + + nn.init.trunc_normal_(self.sinks, mean=0.0, std=init_std) + for linear in linear_list: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + # TODO: statically init the mask using train.seq_len + def sliding_window_causal(self, seqlen, device): + i = torch.arange(seqlen, device=device) + q_idx = i[:, None] + kv_idx = i[None, :] + + causal_mask = q_idx >= kv_idx + if self.sliding_window is None: + return causal_mask + window_mask = q_idx - kv_idx <= self.sliding_window + return causal_mask & window_mask + + +class TransformerBlock(nn.Module): + """ + Transformer block with attention and feed-forward layers. + """ + + def __init__(self, layer_id: int, model_args: GptOssModelArgs): + + super().__init__() + use_sliding_attention = layer_id % 2 == 0 + self.attention = Attention(model_args, use_sliding_attention=use_sliding_attention) + self.attention_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) + + self.moe = MoE(model_args) + + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + self.layer_id = layer_id + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + x = x + self.attention(self.attention_norm(x), freqs_cis) + x = x + self.moe(self.ffn_norm(x)) + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.moe.init_weights(self.weight_init_std, buffer_device) + + +class GptOssModel(nn.Module, ModelProtocol): + """ + GPT-OSS Transformer model with attention and feed-forward layers. + """ + + def __init__(self, model_args: GptOssModelArgs): + super().__init__() + self.max_seq_len = model_args.max_seq_len + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.hidden_size) + self.register_buffer( + "freqs_cis", precompute_freqs_cis(model_args), persistent=True + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.num_hidden_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to(torch.bfloat16) + # convert_submodules_to_bf16(self.layers[str(layer_id)]) + + self.norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) + self.output = nn.Linear( + model_args.hidden_size, + model_args.vocab_size, + dtype=torch.get_default_dtype(), + bias=False, + ) + self.model_args = model_args + self.init_weights() + # convert_submodules_to_bf16(self) + + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = precompute_freqs_cis(self.model_args) + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.hidden_size**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def forward(self, tokens: torch.Tensor): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + h = self.tok_embeddings(tokens) + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + h = self.norm(h) + output = self.output(h) + return output diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py new file mode 100644 index 0000000000..c056819758 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -0,0 +1,324 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.distributed.tensor import DTensor +import torch.nn.functional as F +from torch import nn +from torchtitan.models.gpt_oss.infra.expert_parallel import expert_parallel + +from .args import GptOssModelArgs + +def swiglu(x, alpha: float = 1.702, limit: float = 7.0): + x_glu, x_linear = x[..., ::2], x[..., 1::2] + # Clamp the input values + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + # Note we add an extra bias of 1 to the linear layer + return out_glu * (x_linear + 1) + +class GroupedExperts(nn.Module): + def __init__( + self, + dim: int, + num_experts: int, + use_grouped_mm: bool, + ): + super().__init__() + self.num_experts = num_experts + self.use_grouped_mm = use_grouped_mm + + self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, dim * 2))) + self.mlp1_bias = nn.Parameter(torch.empty((num_experts, dim * 2))) + self.mlp2_weight = nn.Parameter(torch.empty((num_experts, dim, dim))) + self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) + + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.use_grouped_mm: + return GroupedExperts._run_experts_grouped_mm( + self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, x, num_tokens_per_expert + ) + else: + return GroupedExperts._run_experts_for_loop( + self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, x, num_tokens_per_expert + ) + + # TODO: keeping this for-loop implementation for comparison + # and readability, may remove later + @expert_parallel + @staticmethod + def _run_experts_for_loop( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = torch.matmul(x_expert, mlp1_weight[expert_idx]) + mlp1_bias[expert_idx] + h = swiglu(h) + h = torch.matmul(h, mlp2_weight[expert_idx]) + mlp2_bias[expert_idx] + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = torch.bmm(x, mlp1_weight) + mlp1_bias.unsqueeze(1) + h = swiglu(h) + out = torch.bmm(h, mlp2_weight) + mlp2_bias.unsqueeze(1) + + return out + + # @expert_parallel # NOTE: EP currently reduces 20B MFU from 17.8% to 16.5%! + @staticmethod + def _run_experts_grouped_mm( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + num_tokens_per_expert_long = num_tokens_per_expert.to(torch.long) + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + if isinstance(mlp1_weight, DTensor): + mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias = mlp1_weight.to_local(), mlp1_bias.to_local(), mlp2_weight.to_local(), mlp2_bias.to_local() + + h = torch._grouped_mm(x.bfloat16(), mlp1_weight.bfloat16(), offs=offsets) + if offsets is not None: + b1 = mlp1_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + tail_slack = x.shape[0] - int(offsets[-1]) + if tail_slack: + b1 = torch.cat([b1, b1.new_zeros((tail_slack, b1.shape[-1]))], dim=0) + h = h + b1.to(h.dtype) + + h = swiglu(h) + h = torch._grouped_mm(h, mlp2_weight.bfloat16(), offs=offsets) + if offsets is not None: + b2 = mlp2_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + tail_slack = x.shape[0] - int(offsets[-1]) + if tail_slack: + b2 = torch.cat([b2, b2.new_zeros((tail_slack, b2.shape[-1]))], dim=0) + h = h + b2.to(h.dtype) + + return h + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.mlp1_weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp1_bias, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp2_weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp2_bias, mean=0.0, std=init_std) + + def extra_repr(self): + return (f"num_experts={self.num_experts}, " + f"use_grouped_mm={self.use_grouped_mm}, " + f"mlp1_weight={tuple(self.mlp1_weight.shape)}, " + f"mlp1_bias={tuple(self.mlp1_bias.shape)}, " + f"mlp2_weight={tuple(self.mlp2_weight.shape)}, " + f"mlp2_bias={tuple(self.mlp2_bias.shape)}") + +class TokenChoiceTopKRouter(nn.Module): + """This class implements token-choice routing. In token-choice top-K routing, each token is + routed to top K experts based on the router scores. + + Args: + dim (int): Dimension of the input. + num_experts (int): Number of experts in each moe layer. + top_k (int): Number of experts each token will be routed to in token-choice routing. + """ + + def __init__( + self, + dim: int, + num_experts: int, + top_k: int, + ): + super().__init__() + + self.dim = dim + self.num_experts = num_experts + self.top_k = top_k + self.gate = nn.Linear(self.dim, self.num_experts, bias=True) + + def forward( + self, x: torch.Tensor, expert_bias: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + TODO: We haven't implement the group-based routing (node limit routing), + and currently EP is not supporting node limit routing yet. + + Args: + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + + Returns: + routed_input (torch.Tensor): + Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. + token_indices (torch.Tensor): + Token indices for routed_input with shape ``(bs*slen*top_k,)``. + num_tokens_per_expert (torch.Tensor): + Number of tokens assigned to each expert with shape ``(num_experts,)``. + """ + # scores shape (bs*slen, num_experts) + router_logits = self.gate(x) + + if expert_bias is not None: + router_logits = router_logits + expert_bias + + # top scores shape (bs*slen, top_k) + top_scores, selected_experts_indices = torch.topk( + router_logits, k=self.top_k, dim=1 + ) + + top_scores = F.softmax(top_scores, dim=1) + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + + # Reorder the token indices to match the order of the experts + # token_indices_experts_sorted shape (bs*slen*top_k,) + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + + # reorder the scores to match the order of the token indices + top_scores = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + + return top_scores, token_indices_experts_sorted, num_tokens_per_expert + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + + +class MoE(nn.Module): + def __init__(self, model_args: GptOssModelArgs): + + super().__init__() + dim = model_args.hidden_size + + num_experts = model_args.num_local_experts + top_k = model_args.num_experts_per_tok + + self.experts = GroupedExperts( + dim=dim, + num_experts=num_experts, + use_grouped_mm=model_args.use_grouped_mm, + ) + self.router = TokenChoiceTopKRouter( + dim=dim, + num_experts=num_experts, + top_k=top_k, + ) + self.load_balance_coeff = model_args.load_balance_coeff + if self.load_balance_coeff is not None: + assert self.load_balance_coeff > 0.0 + self.register_buffer( + "expert_bias", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + self.register_buffer( + "tokens_per_expert", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + else: + self.expert_bias = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + bs, slen, dim = x.shape + + # top_scores and selected_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_tokens_per_expert, + ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + + if self.load_balance_coeff is not None and torch.is_grad_enabled(): + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + + # shape (bs*slen*top_k, dim) + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + + # shape (bs*slen*top_k, dim) + routed_output = self.experts(routed_input, num_tokens_per_expert) + + routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( + x.dtype + ) + + out = torch.zeros_like(x.reshape(bs * slen, dim)) + + # Accumulate multiple expert results becase each token can be routed to multiple experts + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bs, slen, dim) + return out + + def init_weights( + self, + init_std: float, + buffer_device: torch.device, + ): + self.experts.init_weights(init_std) + self.router.init_weights(init_std) + if self.load_balance_coeff is not None: + with torch.device(buffer_device): + self.expert_bias = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) diff --git a/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py b/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py new file mode 100644 index 0000000000..dbbb880af5 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py @@ -0,0 +1,405 @@ +""" +Compare logits and generations of GPT-OSS implemented in TorchTitan and HuggingFace. +This requires at least a 2xH100. + +First ensure you convert the HF model to a TorchTitan DCP checkpoint: +uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py hf-to-dcp --input-path openai/gpt-oss-20b --output-path gptoss_dcp/ + +Then you can run a comparison like this: +uv run torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py \ + --tt_config torchtitan/models/gpt_oss/train_configs/gpt_oss_20b.toml \ + --tt_checkpoint_path gptoss_dcp/ \ + --hf_model_path openai/gpt-oss-20b \ + --prompt "Once upon a time, in a land far away," \ + --temperature 0.8 \ + --max_new_tokens 256 \ + --batch_size 1 \ + --out +""" + +import json +import os +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Sequence, Tuple, NamedTuple + +import torch +import torch.nn as nn +import torch.distributed.checkpoint as dcp +import tyro +from transformers import AutoModelForCausalLM, AutoTokenizer + +from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.utils import device_module, device_type +from torchtitan.components.metrics import build_device_memory_monitor +from torchtitan.config_manager import ConfigManager +from torchtitan.protocols.train_spec import get_train_spec +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torch.distributed import DeviceMesh +from torch.distributed.elastic.multiprocessing.errors import record + +# -------- Torchtitan Sampling Utils -------- +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + q = torch.empty_like(probs).exponential_(1, generator=rng) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) + + +def logits_to_probs( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) + pivot = v.select(dim=-1, index=-1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def generate_next_token( + model, + x: torch.Tensor, + *, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, +) -> torch.Tensor: + logits = model(x) # (B, T, vocab_size) + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + next_token = multinomial_sample_one(probs, rng=rng) + return next_token + + +@torch.no_grad() +def tt_generate_text( + model, + input_ids: torch.Tensor, + *, + max_new_tokens: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + seed: Optional[int] = None, +) -> torch.Tensor: + # ensure batch dimension (T,) --> (B, T) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(seed) + + generated_tokens = input_ids.clone() + + for i in range(max_new_tokens): + next_token = generate_next_token( + model, + x=generated_tokens.to(input_ids.device), + temperature=temperature, + top_k=top_k, + rng=rng, + ) + print(f"generated token {i}: {next_token}") + + generated_tokens = torch.cat([generated_tokens, next_token], dim=1) + + return generated_tokens + +@dataclass +class GenerateConfig: + """Configuration for test generation.""" + hf_model_path: Optional[str] = None + """HuggingFace model path to load (if provided).""" + tt_config: Optional[str] = None + """TOML config file path for TorchTitan model.""" + tt_checkpoint_path: Optional[str] = None + """Checkpoint path for the TorchTitan model (if provided).""" + tt_tokenizer_path: Optional[str] = "libs/torchtitan/torchtitan/models/gpt_oss_20b/tokenizer" + """Tokenizer path to load.""" + temperature: float = 1.0 + """Sampling temperature (0 for greedy).""" + max_new_tokens: int = 32 + """Max number of tokens to generate.""" + batch_size: int = 1 + """Batch size for inputs.""" + top_k: Optional[int] = None + """Top-k sampling (optional).""" + seed: Optional[int] = None + """Random seed for reproducibility.""" + deterministic: bool = False + """Use deterministic algorithms.""" + prompt: str = "" + """Input prompt string.""" + out: bool = False + """If true, print JSON report at end.""" + + +class LogitsComparison(NamedTuple): + max_abs_diff: float + mean_abs_diff: float + max_rel_diff: float + mean_rel_diff: float + allclose_results: Sequence[Tuple[float, float, str, bool]] + sample_diffs: Optional[torch.Tensor] + systematic_offset: Optional[Tuple[float, float]] + + +def load_hf_model(path: str, device: torch.device) -> nn.Module: + model = AutoModelForCausalLM.from_pretrained(path).to(device) + model.eval() + return model + +def print_param_dtypes_first_block(model): + """ + Prints the dtype of every parameter in the given model. + For any parameters under a 'layers' module (e.g., layers.), + only prints those from the first block (idx == "0"). + This works for both GptOssForCausalLM (with a .model submodule) + and GptOssModel architectures. + """ + for name, param in model.named_parameters(): + parts = name.split('.') + # If this parameter is under a 'layers' module, check its index + if 'layers' in parts: + idx = parts.index('layers') + 1 + if idx < len(parts) and parts[idx] != '0': + continue + print(f"{name:50s} → {param.dtype}") + +def get_logits(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + out = model(input_ids) + if hasattr(out, "logits"): + return out.logits + else: + return out + + +def compare_logits( + tt_logits: torch.Tensor, + hf_logits: torch.Tensor, + tolerances: Sequence[Tuple[float, float, str]] = ( + (1e-4, 1e-6, "Very Strict"), + (1e-2, 1e-4, "Strict"), + (1e-1, 1e-2, "Moderate"), + ), +) -> LogitsComparison: + # Apply softmax to convert logits to probabilities + hf_logits = torch.nn.functional.softmax(hf_logits.float(), dim=-1) + tt_logits = torch.nn.functional.softmax(tt_logits.float(), dim=-1) + + diff = torch.abs(tt_logits - hf_logits) + max_abs = float(torch.max(diff)) + mean_abs = float(torch.mean(diff)) + rel = diff / (torch.abs(tt_logits) + 1e-8) + max_rel = float(torch.max(rel)) + mean_rel = float(torch.mean(rel)) + + results = [] + any_match = False + for rtol, atol, name in tolerances: + match = torch.allclose(tt_logits, hf_logits, rtol=rtol, atol=atol) + results.append((rtol, atol, name, bool(match))) + if match: + any_match = True + break + + sample_diffs = None + sys_offset = None + if not any_match: + flat = (tt_logits - hf_logits).flatten() + sample_diffs = flat[:25] + sys_offset = (float(torch.mean(flat)), float(torch.std(flat))) + + return LogitsComparison(max_abs, mean_abs, max_rel, mean_rel, results, sample_diffs, sys_offset) + + +def generate_text( + model: nn.Module, + input_ids: torch.Tensor, + max_new_tokens: int, + temperature: float = 0.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + do_sample = temperature > 0 + temp_arg = temperature if do_sample else None + with torch.no_grad(): + return model.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temp_arg, + top_k=top_k, + ) + + +def print_logits_comparison(comp: LogitsComparison): + print("\n" + "="*70) + print("LOGITS COMPARISON") + print("="*70) + print(f"Max abs diff: {comp.max_abs_diff:.6f}") + print(f"Mean abs diff: {comp.mean_abs_diff:.6f}") + print(f"Max rel diff: {comp.max_rel_diff:.6f}") + print(f"Mean rel diff: {comp.mean_rel_diff:.6f}\n") + print("Tolerance tests:") + for rtol, atol, name, match in comp.allclose_results: + print(f" {'āœ…' if match else 'āŒ'} {name} (rtol={rtol}, atol={atol})") + if comp.sample_diffs is not None: + print("\nšŸ” Sample diffs (first 25):") + for v in comp.sample_diffs.tolist(): + print(f" {v:.6f}") + mean, std = comp.systematic_offset + print(f"\nSystematic offset: mean={mean:.6f}, std={std:.6f}") + + +def print_generation(title: str, outputs: torch.Tensor, tokenizer): + text = tokenizer.decode(outputs[0].tolist()) + print("\n" + "="*60) + print(title) + print("="*60) + print(text) + print("="*60) + + +def print_generation_comparison( + tt_out: torch.Tensor, + hf_out: torch.Tensor, + tokenizer, + prompt_len: int, +): + tt_tokens = tt_out[0][prompt_len:].tolist() + hf_tokens = hf_out[0][prompt_len:].tolist() + n = min(len(tt_tokens), len(hf_tokens)) + matches = sum(1 for i in range(n) if tt_tokens[i] == hf_tokens[i]) + print("\n" + "="*70) + print("GENERATION COMPARISON") + print("="*70) + print(f"Match rate: {matches}/{n} ({matches/n*100:.1f}%)") + if matches != n or len(tt_tokens) != len(hf_tokens): + print("First mismatches:") + for i in range(min(10, n)): + if tt_tokens[i] != hf_tokens[i]: + tt_txt = tokenizer.decode([tt_tokens[i]]) + hf_txt = tokenizer.decode([hf_tokens[i]]) + print(f" Pos {i}: TT='{tt_txt}' vs HF='{hf_txt}'") + + +@record +def test_generate(args: GenerateConfig): + init_logger() + + if not args.hf_model_path and not args.tt_config: + raise ValueError("Either hf_model_path or tt_config must be provided.") + if not args.prompt: + logger.warning("Empty prompt; generating from scratch.") + + # --- Common setup: tokenizer & inputs --- + if args.hf_model_path: + tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path) + input_ids = tokenizer.encode(args.prompt, add_special_tokens=False, return_tensors="pt") + print(input_ids) + if args.tt_config: + config_mgr = ConfigManager() + config = config_mgr.parse_args([ + f"--job.config_file={args.tt_config}", + f"--model.tokenizer_path={args.tt_tokenizer_path}", + ]) + train_spec = get_train_spec(config.model.name) + + # --- HuggingFace model (optional) --- + hf_model = None + hf_logits = None + hf_out = None + if args.hf_model_path: # NOTE: comment this block out for rapid tt testing + hf_device = torch.device(f"{device_type}:0") + hf_model = load_hf_model(args.hf_model_path, hf_device) + print("\n" + "="*60) + print("HUGGINGFACE MODEL ARCHITECTURE:") + print(hf_model) + print("="*60) + print_param_dtypes_first_block(hf_model) + print("="*60) + + hf_in = input_ids.to(hf_device) + hf_logits = get_logits(hf_model, hf_in).to(input_ids.device) + print(f"hf_logits: {hf_logits[:, :, 42069:42072]}") + hf_out = generate_text( + hf_model, hf_in, + max_new_tokens=args.max_new_tokens, + temperature=0.0, + top_k=args.top_k, + ).to(input_ids.device) + + # --- TorchTitan model (optional) --- + tt_model = None + tt_logits = None + tt_out = None + if args.tt_config: + # (Original TT setup: distributed, device, checkpoint load, etc.) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + device = torch.device(f"{device_type}:1") + device_module.set_device(device) + dist_utils.set_determinism(None, device, args.seed, args.deterministic) + + # instantiate & load TT model + model_args = train_spec.config[config.model.flavor] + model_args.update_from_config(config, tokenizer) + init_dev = "meta" if world_size > 1 else device + with torch.device(init_dev): + tt_model = train_spec.cls(model_args) + if world_size > 1: + # parallelize if needed + pass + print("\n" + "="*60) + print("TORCHTITAN MODEL ARCHITECTURE:") + print(tt_model) + print("="*60) + print_param_dtypes_first_block(tt_model) + print("="*60) + + tt_model.eval() + if args.tt_checkpoint_path: # only load checkpoint if provided + tt_state = tt_model.state_dict() + tt_state.pop("freqs_cis", None) + state = {"model": tt_state} + dcp.load(state, checkpoint_id=args.tt_checkpoint_path) + + tt_logits = get_logits(tt_model, input_ids.to(device)).to(hf_logits.device if hf_logits is not None else device) + print(f"āœ… Torchtitan model forward pass succeeded: {tt_logits.shape=}") + print(f"tt_logits: {tt_logits[:, :, 42069:42072]}") + + tt_out = tt_generate_text( + tt_model, input_ids.to(device), + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + seed=args.seed, + ) + + # --- Logits comparison (if both present) --- + if hf_logits is not None and tt_logits is not None: + comp = compare_logits(tt_logits, hf_logits) + print_logits_comparison(comp) + + # --- Print generations --- + if hf_out is not None: + print_generation("HUGGINGFACE MODEL OUTPUT:", hf_out, tokenizer) + if tt_out is not None: + print_generation("TORCHTITAN MODEL OUTPUT:", tt_out, tokenizer) + + # --- Generation comparison --- + if hf_out is not None and tt_out is not None: + prompt_len = input_ids.size(1) + print_generation_comparison(tt_out, hf_out, tokenizer, prompt_len) + + +if __name__ == "__main__": + args = tyro.cli(GenerateConfig) + test_generate(args) diff --git a/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py new file mode 100644 index 0000000000..59c15ab944 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py @@ -0,0 +1,661 @@ +""" +Convert checkpoints between TorchTitan and HuggingFace. + +# Convert HF to TorchTitan DCP +uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py hf-to-dcp --input-path openai/gpt-oss-20b --output-path gptoss_dcp/ + +# Convert TorchTitan DCP to HF +uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py dcp-to-hf --input-path gptoss_dcp/ --output-path gptoss_hf/ +""" + +import re +import tempfile +from pathlib import Path +from typing import Union, Tuple, Optional + +import torch +import torch.distributed.checkpoint as DCP +from torch.distributed.checkpoint.format_utils import dcp_to_torch_save +from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict +from torchtitan.datasets.transformation import get_tokenizer_with_chat_template +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaConfig +from torchtitan.models.llama3.model import precompute_freqs_cis +from tqdm import tqdm +from tyro.extras import SubcommandApp + +from torchtitan.tools.logging import init_logger, logger + +app = SubcommandApp() + + + +def validate_hf_keys(hf_state_dict, model_config, model_name): + """Validate that all expected weight keys exist in the HF state dict.""" + missing_keys = [] + n_layers = model_config.num_hidden_layers + + # Check basic weights + required_keys = [ + "model.embed_tokens.weight", + "lm_head.weight", + "model.norm.weight" + ] + + for key in required_keys: + if key not in hf_state_dict: + missing_keys.append(key) + + # Check layer weights + for layer_idx in range(n_layers): + layer_prefix = f'model.layers.{layer_idx}' + + # Check attention weights + attention_keys = [ + f"{layer_prefix}.self_attn.q_proj.weight", + f"{layer_prefix}.self_attn.k_proj.weight", + f"{layer_prefix}.self_attn.v_proj.weight", + f"{layer_prefix}.self_attn.o_proj.weight", + f"{layer_prefix}.self_attn.q_proj.bias", + f"{layer_prefix}.self_attn.k_proj.bias", + f"{layer_prefix}.self_attn.v_proj.bias", + f"{layer_prefix}.self_attn.o_proj.bias", + f"{layer_prefix}.input_layernorm.weight", + f"{layer_prefix}.post_attention_layernorm.weight", + ] + + for key in attention_keys: + if key not in hf_state_dict: + missing_keys.append(key) + + # Check MoE weights + mlp_keys = [ + f"{layer_prefix}.mlp.router.weight", + f"{layer_prefix}.mlp.router.bias", + f"{layer_prefix}.mlp.experts.gate_up_proj", + f"{layer_prefix}.mlp.experts.gate_up_proj_bias", + f"{layer_prefix}.mlp.experts.down_proj", + f"{layer_prefix}.mlp.experts.down_proj_bias", + ] + + for key in mlp_keys: + if key not in hf_state_dict: + missing_keys.append(key) + + if missing_keys: + logger.error(f"Missing {len(missing_keys)} expected weight keys in HF model:") + for key in missing_keys[:10]: # Show first 10 + logger.error(f" - {key}") + if len(missing_keys) > 10: + logger.error(f" ... and {len(missing_keys) - 10} more") + + # Try to diagnose the issue + logger.info("Available keys in HF model:") + available_keys = list(hf_state_dict.keys()) + for key in available_keys[:20]: # Show first 20 + logger.info(f" - {key}") + if len(available_keys) > 20: + logger.info(f" ... and {len(available_keys) - 20} more") + + raise ValueError(f"HF model '{model_name}' is missing expected weight keys. " + f"This suggests the model architecture doesn't match expectations.") + + logger.info(f"āœ“ Weight key validation passed - found all expected keys") + + +def map_hf_to_torchtitan(hf_state_dict, model_config, max_seq_len=131072, rope_theta=500000.0, model_name="meta-llama/Llama-3.1-8B"): + """Map HuggingFace state dict to TorchTitan format. + + Note: TorchTitan and HuggingFace use different RoPE implementations: + - TorchTitan: Adjacent element pairing with complex arithmetic + - HuggingFace: First/second half pairing with cos/sin arithmetic + + This difference is architectural, not a bug. Converted models will have + slightly different positional encoding but typically minimal impact on performance. + """ + + # Validate that all expected keys exist + validate_hf_keys(hf_state_dict, model_config, model_name) + + n_layers = model_config.num_hidden_layers + n_heads = model_config.num_attention_heads + dim = model_config.hidden_size + dims_per_head = dim // n_heads + + # Fix: Corrected model family detection logic + if "llama" in model_name.lower(): + model_family = "llama3" + elif "qwen" in model_name.lower(): + model_family = "qwen3" + max_seq_len = model_config.max_position_embeddings + rope_theta = model_config.rope_theta + elif "gpt-oss" in model_name.lower(): + model_family = "gptoss" + max_seq_len = model_config.max_position_embeddings + rope_theta = model_config.rope_theta + else: + raise ValueError(f"Unsupported HuggingFace model for conversion: {model_name}") + + # Determine n_kv_heads for GQA models + n_kv_heads = model_config.num_key_value_heads + head_dim = model_config.head_dim + print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}, model_family={model_family}, max_seq_len={max_seq_len}, rope_theta={rope_theta}") + torchtitan_state_dict = {} + + # Convert embeddings and output + torchtitan_state_dict["tok_embeddings.weight"] = hf_state_dict["model.embed_tokens.weight"].clone() + torchtitan_state_dict["output.weight"] = hf_state_dict["lm_head.weight"].clone() + torchtitan_state_dict["norm.weight"] = hf_state_dict["model.norm.weight"].clone() + + def permute(w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return w.view(n_heads_arg, 2, dim1 // n_heads_arg // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + # Convert layers + for layer_idx in tqdm(range(n_layers), desc="Converting layers"): + hf_layer_prefix = f'model.layers.{layer_idx}' + layer_prefix = f'layers.{layer_idx}' + + wq = hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight'] = wq.clone() + wq_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wq.bias'] = wq_bias.clone() + + wk = hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight'] = wk.clone() + wk_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wk.bias'] = wk_bias.clone() + + wv = hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'] = wv.clone() + wv_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wv.bias'] = wv_bias.clone() + + wo = hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'] = wo.clone() + wo_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wo.bias'] = wo_bias.clone() + + sinks = hf_state_dict[f'{hf_layer_prefix}.self_attn.sinks'] + torchtitan_state_dict[f'{layer_prefix}.attention.sinks'] = sinks.clone() + + # MoE weights + mlp1 = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.gate_up_proj'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp1_weight'] = mlp1.clone() + + mlp1_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.gate_up_proj_bias'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp1_bias'] = mlp1_bias.clone() + + mlp2 = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.down_proj'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp2_weight'] = mlp2.clone() + + mlp2_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.down_proj_bias'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp2_bias'] = mlp2_bias.clone() + + # router + gate = hf_state_dict[f'{hf_layer_prefix}.mlp.router.weight'] + torchtitan_state_dict[f'{layer_prefix}.moe.router.gate.weight'] = gate.clone() + router_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.router.bias'] + torchtitan_state_dict[f'{layer_prefix}.moe.router.gate.bias'] = router_bias.clone() + + # # @vwxyzjn: This is technically not needed, but we added here because we haven't figured out + # # how to tell torchtitan to ignore this parameter. + # tokens_per_expert = torch.zeros_like(expert_bias) + # torchtitan_state_dict[f'{layer_prefix}.moe.tokens_per_expert'] = tokens_per_expert.clone() + + # Layer norms + attention_norm = hf_state_dict[f'{hf_layer_prefix}.input_layernorm.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention_norm.weight'] = attention_norm.clone() + ffn_norm = hf_state_dict[f'{hf_layer_prefix}.post_attention_layernorm.weight'] + torchtitan_state_dict[f'{layer_prefix}.ffn_norm.weight'] = ffn_norm.clone() + + # Precompute RoPE frequencies + # NOTE: we no longer precompute RoPE frequencies in TorchTitan + # this `model_config` is HF but needs to be TT (to include e.g. beta_fast) + # torchtitan_state_dict["freqs_cis"] = precompute_freqs_cis(model_config) + + print(f"Converted {len(torchtitan_state_dict)} parameters from HuggingFace to TorchTitan format") + return torchtitan_state_dict + + +def map_torchtitan_to_hf_per_param(name: str, weight: torch.Tensor, model_family: str = "llama3") -> Tuple[Optional[str], Optional[torch.Tensor]]: + """Map a single TorchTitan parameter to HuggingFace format. + + Args: + name: Parameter name in TorchTitan format + weight: Parameter tensor + model_family: Model family ("llama3", "qwen3", or "gptoss") + + Returns: + Tuple of (hf_name, hf_weight) or (None, None) if parameter should be skipped + """ + # Skip freqs_cis as it's computed dynamically in HF + if name == "freqs_cis": + return None, None + + assert model_family in ("llama3", "qwen3", "gptoss"), f"Unsupported model family: {model_family}" + + # HuggingFace permutation function (exact copy from their conversion script) + def permute(w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + # Handle embeddings and output weights + if name == "tok_embeddings.weight": + return "model.embed_tokens.weight", weight.clone() + elif name == "output.weight": + return "lm_head.weight", weight.clone() + elif name == "norm.weight": + return "model.norm.weight", weight.clone() + + # Handle layer-specific parameters + layer_match = re.match(r"layers\.(\d+)\.", name) + if not layer_match: + return None, None + + layer_idx = layer_match.group(1) + layer_suffix = name[len(f"layers.{layer_idx}."):] + hf_layer_prefix = f"model.layers.{layer_idx}" + + if model_family == "gptoss": + mapping = { + "attention.wq.weight": "self_attn.q_proj.weight", + "attention.wq.bias": "self_attn.q_proj.bias", + "attention.wk.weight": "self_attn.k_proj.weight", + "attention.wk.bias": "self_attn.k_proj.bias", + "attention.wv.weight": "self_attn.v_proj.weight", + "attention.wv.bias": "self_attn.v_proj.bias", + "attention.wo.weight": "self_attn.o_proj.weight", + "attention.wo.bias": "self_attn.o_proj.bias", + "attention.sinks": "self_attn.sinks", + "moe.experts.mlp1_weight": "mlp.experts.gate_up_proj", + "moe.experts.mlp1_bias": "mlp.experts.gate_up_proj_bias", + "moe.experts.mlp2_weight": "mlp.experts.down_proj", + "moe.experts.mlp2_bias": "mlp.experts.down_proj_bias", + "moe.router.gate.weight": "mlp.router.weight", + "moe.router.gate.bias": "mlp.router.bias", + "moe.expert_bias": "mlp.router.bias", # NOTE: this gets added into router bias + "attention_norm.weight": "input_layernorm.weight", + "ffn_norm.weight": "post_attention_layernorm.weight", + } + hf_suffix = mapping.get(layer_suffix) + if hf_suffix: + return f"{hf_layer_prefix}.{hf_suffix}", weight.clone() + return None, None + + # Handle attention weights + if layer_suffix == "attention.wq.weight": + if model_family == "llama3": + # For query weights, assume standard head_dim=128 + dim = weight.shape[1] + head_dim = 128 + n_heads = dim // head_dim + transformed_weight = permute(weight, n_heads) + else: + transformed_weight = weight + return f"{hf_layer_prefix}.self_attn.q_proj.weight", transformed_weight.clone() + + elif layer_suffix == "attention.wk.weight": + if model_family == "llama3": + # For key weights, infer n_kv_heads from weight shape + dim = weight.shape[1] + head_dim = 128 + n_kv_heads = weight.shape[0] // head_dim + key_value_dim = n_kv_heads * head_dim + transformed_weight = permute(weight, n_kv_heads, key_value_dim, dim) + else: + transformed_weight = weight + return f"{hf_layer_prefix}.self_attn.k_proj.weight", transformed_weight.clone() + + elif layer_suffix == "attention.wv.weight": + return f"{hf_layer_prefix}.self_attn.v_proj.weight", weight.clone() + + elif layer_suffix == "attention.wo.weight": + return f"{hf_layer_prefix}.self_attn.o_proj.weight", weight.clone() + + # Handle qwen3-specific attention norms + elif layer_suffix == "attention.q_norm.weight" and model_family == "qwen3": + return f"{hf_layer_prefix}.self_attn.q_norm.weight", weight.clone() + + elif layer_suffix == "attention.k_norm.weight" and model_family == "qwen3": + return f"{hf_layer_prefix}.self_attn.k_norm.weight", weight.clone() + + # Handle MLP weights + elif layer_suffix == "feed_forward.w1.weight": + return f"{hf_layer_prefix}.mlp.gate_proj.weight", weight.clone() + + elif layer_suffix == "feed_forward.w2.weight": + return f"{hf_layer_prefix}.mlp.down_proj.weight", weight.clone() + + elif layer_suffix == "feed_forward.w3.weight": + return f"{hf_layer_prefix}.mlp.up_proj.weight", weight.clone() + + # Handle layer norms + elif layer_suffix == "attention_norm.weight": + return f"{hf_layer_prefix}.input_layernorm.weight", weight.clone() + + elif layer_suffix == "ffn_norm.weight": + return f"{hf_layer_prefix}.post_attention_layernorm.weight", weight.clone() + + # If no mapping found, return None + return None, None + + +def map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len=131072, rope_theta=500000.0): + """Map TorchTitan state dict to HuggingFace format.""" + if any(k.endswith('.attention.q_norm.weight') for k in torchtitan_state_dict): + model_family = 'qwen3' + elif any(k.endswith('.attention.wq.bias') for k in torchtitan_state_dict): + model_family = 'gptoss' + else: + model_family = 'llama3' + + layer_keys = [k for k in torchtitan_state_dict.keys() if k.startswith("layers.")] + assert len(layer_keys) > 0, "No layers found in state dict" + n_layers = max([int(k.split(".")[1]) for k in layer_keys]) + 1 + hf_state_dict = {} + + # Get model info from sample weight + sample_wq_key = next(k for k in torchtitan_state_dict.keys() if k.endswith('.attention.wq.weight')) + wq_weight = torchtitan_state_dict[sample_wq_key] + dim = wq_weight.shape[1] # input dimension + + # Check if we have a key weight to determine n_kv_heads + sample_wk_key = next(k for k in torchtitan_state_dict.keys() if k.endswith('.attention.wk.weight')) + wk_weight = torchtitan_state_dict[sample_wk_key] + + # Standard Llama head dim is 128 for the 3B, 8B, 70B and 405B models + # NOTE: The only exception is the 1B model: https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json#L9 + # But let's ignore that for now + head_dim = 128 + n_heads = dim // head_dim + + # For GQA models, n_kv_heads might be different + n_kv_heads = wk_weight.shape[0] // head_dim + + print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}, model_family={model_family}") + + # HuggingFace permutation function (exact copy from their conversion script) + def permute(w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + # Convert embeddings and output (no permutation needed) + if 'tok_embeddings.weight' in torchtitan_state_dict: + hf_state_dict['model.embed_tokens.weight'] = torchtitan_state_dict['tok_embeddings.weight'].clone() + if 'output.weight' in torchtitan_state_dict: + hf_state_dict['lm_head.weight'] = torchtitan_state_dict['output.weight'].clone() + if 'norm.weight' in torchtitan_state_dict: + hf_state_dict['model.norm.weight'] = torchtitan_state_dict['norm.weight'].clone() + + # Convert layers + for layer_idx in tqdm(range(n_layers), desc="Converting layers"): + layer_prefix = f'layers.{layer_idx}' + hf_layer_prefix = f'model.layers.{layer_idx}' + + if model_family == 'gptoss': + # Attention projections and biases + mappings = { + f'{layer_prefix}.attention.wq.weight': f'{hf_layer_prefix}.self_attn.q_proj.weight', + f'{layer_prefix}.attention.wq.bias': f'{hf_layer_prefix}.self_attn.q_proj.bias', + f'{layer_prefix}.attention.wk.weight': f'{hf_layer_prefix}.self_attn.k_proj.weight', + f'{layer_prefix}.attention.wk.bias': f'{hf_layer_prefix}.self_attn.k_proj.bias', + f'{layer_prefix}.attention.wv.weight': f'{hf_layer_prefix}.self_attn.v_proj.weight', + f'{layer_prefix}.attention.wv.bias': f'{hf_layer_prefix}.self_attn.v_proj.bias', + f'{layer_prefix}.attention.wo.weight': f'{hf_layer_prefix}.self_attn.o_proj.weight', + f'{layer_prefix}.attention.wo.bias': f'{hf_layer_prefix}.self_attn.o_proj.bias', + f'{layer_prefix}.attention.sinks': f'{hf_layer_prefix}.self_attn.sinks', + f'{layer_prefix}.moe.experts.mlp1_weight': f'{hf_layer_prefix}.mlp.experts.gate_up_proj', + f'{layer_prefix}.moe.experts.mlp1_bias': f'{hf_layer_prefix}.mlp.experts.gate_up_proj_bias', + f'{layer_prefix}.moe.experts.mlp2_weight': f'{hf_layer_prefix}.mlp.experts.down_proj', + f'{layer_prefix}.moe.experts.mlp2_bias': f'{hf_layer_prefix}.mlp.experts.down_proj_bias', + f'{layer_prefix}.moe.router.gate.weight': f'{hf_layer_prefix}.mlp.router.weight', + f'{layer_prefix}.attention_norm.weight': f'{hf_layer_prefix}.input_layernorm.weight', + f'{layer_prefix}.ffn_norm.weight': f'{hf_layer_prefix}.post_attention_layernorm.weight', + } + for tt_key, hf_key in mappings.items(): + if tt_key in torchtitan_state_dict: + hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() + # Combine router gate bias with expert bias (if present) + router_bias_key = f'{layer_prefix}.moe.router.gate.bias' + expert_bias_key = f'{layer_prefix}.moe.expert_bias' + if ( + router_bias_key in torchtitan_state_dict + or expert_bias_key in torchtitan_state_dict + ): + if router_bias_key in torchtitan_state_dict: + bias = torchtitan_state_dict[router_bias_key].clone() + else: + bias = torch.zeros_like(torchtitan_state_dict[expert_bias_key]) + if expert_bias_key in torchtitan_state_dict: + bias = bias + torchtitan_state_dict[expert_bias_key] + hf_state_dict[f'{hf_layer_prefix}.mlp.router.bias'] = bias + continue + + # Attention weights with proper permutation + if f'{layer_prefix}.attention.wq.weight' in torchtitan_state_dict: + wq = torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight'] + if model_family == "llama3": + wq = permute(wq, n_heads) + hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight'] = wq.clone() + + if f'{layer_prefix}.attention.wk.weight' in torchtitan_state_dict: + wk = torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight'] + key_value_dim = n_kv_heads * head_dim + if model_family == "llama3": + wk = permute(wk, n_kv_heads, key_value_dim, dim) + hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight'] = wk.clone() + + if f'{layer_prefix}.attention.wv.weight' in torchtitan_state_dict: + # Value weights don't get permuted + hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'].clone() + + if model_family == "qwen3": + if f'{layer_prefix}.attention.q_norm.weight' in torchtitan_state_dict: + hf_state_dict[f'{hf_layer_prefix}.self_attn.q_norm.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.q_norm.weight'].clone() + if f'{layer_prefix}.attention.k_norm.weight' in torchtitan_state_dict: + hf_state_dict[f'{hf_layer_prefix}.self_attn.k_norm.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.k_norm.weight'].clone() + + if f'{layer_prefix}.attention.wo.weight' in torchtitan_state_dict: + # Output projection doesn't get permuted + hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'].clone() + + # MLP weights (no permutation) + mlp_mappings = { + f'{layer_prefix}.feed_forward.w1.weight': f'{hf_layer_prefix}.mlp.gate_proj.weight', + f'{layer_prefix}.feed_forward.w2.weight': f'{hf_layer_prefix}.mlp.down_proj.weight', + f'{layer_prefix}.feed_forward.w3.weight': f'{hf_layer_prefix}.mlp.up_proj.weight', + } + + for tt_key, hf_key in mlp_mappings.items(): + if tt_key in torchtitan_state_dict: + hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() + + # Layer norms (no permutation) + norm_mappings = { + f'{layer_prefix}.attention_norm.weight': f'{hf_layer_prefix}.input_layernorm.weight', + f'{layer_prefix}.ffn_norm.weight': f'{hf_layer_prefix}.post_attention_layernorm.weight', + } + + for tt_key, hf_key in norm_mappings.items(): + if tt_key in torchtitan_state_dict: + hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() + + print(f"Converted {len(hf_state_dict)} parameters from TorchTitan to HuggingFace format") + return hf_state_dict + + +def map_torchtitan_to_hf2(torchtitan_state_dict, max_seq_len=131072, rope_theta=500000.0, validate_against_original=True): + """Map TorchTitan state dict to HuggingFace format using per-parameter function.""" + + # Auto-detect model family + if any(k.endswith('.attention.q_norm.weight') for k in torchtitan_state_dict): + model_family = "qwen3" + elif any(k.endswith('.attention.wq.bias') for k in torchtitan_state_dict): + model_family = "gptoss" + else: + model_family = "llama3" + + logger.info(f"Converting using per-parameter function with model_family={model_family}") + + hf_state_dict = {} + skipped_params = [] + + # Convert each parameter individually + for name, weight in tqdm(torchtitan_state_dict.items(), desc="Converting parameters"): + hf_name, hf_weight = map_torchtitan_to_hf_per_param(name, weight, model_family) + if hf_name is not None: + if hf_name in hf_state_dict: + hf_state_dict[hf_name] = hf_state_dict[hf_name] + hf_weight # NOTE: adds expert_bias into router bias + else: + hf_state_dict[hf_name] = hf_weight + else: + skipped_params.append(name) + + logger.info(f"Converted {len(hf_state_dict)} parameters, skipped {len(skipped_params)} parameters") + if skipped_params: + logger.info(f"Skipped parameters: {skipped_params}") + + # Validation against original function + if validate_against_original: + logger.info("Validating against original conversion function...") + + # Get original result + original_hf_state_dict = map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len, rope_theta) + + # Compare keys + new_keys = set(hf_state_dict.keys()) + original_keys = set(original_hf_state_dict.keys()) + + if new_keys != original_keys: + missing_in_new = original_keys - new_keys + extra_in_new = new_keys - original_keys + logger.error(f"Key mismatch! Missing in new: {missing_in_new}, Extra in new: {extra_in_new}") + raise ValueError("Key sets don't match between implementations") + + # Compare tensor values + mismatched_tensors = [] + for key in original_keys: + if not torch.allclose(hf_state_dict[key], original_hf_state_dict[key], rtol=1e-5, atol=1e-8): + mismatched_tensors.append(key) + + if mismatched_tensors: + logger.error(f"Tensor value mismatches in: {mismatched_tensors}") + # Show details for first mismatch + key = mismatched_tensors[0] + logger.error(f"First mismatch in {key}:") + logger.error(f" Max abs diff: {torch.max(torch.abs(hf_state_dict[key] - original_hf_state_dict[key]))}") + logger.error(f" Original shape: {original_hf_state_dict[key].shape}") + logger.error(f" New shape: {hf_state_dict[key].shape}") + raise ValueError("Tensor values don't match between implementations") + + logger.info("āœ“ Validation passed! New implementation matches original.") + + return hf_state_dict + + +@app.command(name="hf_to_dcp") +@torch.inference_mode() +def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, dtype: str = "auto"): + """Convert HuggingFace model to TorchTitan DCP format. + + Args: + input_path: HuggingFace model name or path + output_path: Output DCP checkpoint path + max_seq_len: Max sequence length for RoPE + rope_theta: RoPE theta parameter + dtype: Data type to use ("auto" to preserve original, or specific dtype like "float32") + """ + logger.info(f"Loading model from {input_path}") + + # Load model with original dtype if "auto", otherwise use specified dtype + hf_model = AutoModelForCausalLM.from_pretrained(input_path, torch_dtype=torch.bfloat16) + + hf_state_dict = hf_model.state_dict() + logger.info(f"Loaded model with dtype: {next(iter(hf_state_dict.values())).dtype}") + + logger.info("Converting weights to TorchTitan format") + torchtitan_state_dict = map_hf_to_torchtitan(hf_state_dict, hf_model.config, max_seq_len, rope_theta, input_path) + + logger.info(f"Writing to DCP at '{output_path}'") + output_path.mkdir(parents=True, exist_ok=True) + storage_writer = DCP.filesystem.FileSystemWriter(output_path, thread_count=8) + DCP.save({"model": torchtitan_state_dict}, storage_writer=storage_writer) + logger.info("Conversion complete!") + + +@app.command(name="dcp_to_hf") +@torch.inference_mode() +def convert_dcp_to_hf(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, default_model: str = "meta-llama/Meta-Llama-3.1-8B", validate_against_original: bool = False): + """Convert TorchTitan DCP format to HuggingFace model. + + Args: + input_path: Input DCP checkpoint path + output_path: Output HuggingFace model path + max_seq_len: Max sequence length for RoPE + rope_theta: RoPE theta parameter + default_model: Default HuggingFace model for config + """ + + if str(input_path).startswith("s3://"): + import s3_utils + local_path = s3_utils.sync_to_nvme(str(input_path)) + input_path = Path(local_path) + + logger.info(f"Loading DCP checkpoint from {input_path}") + + from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner + from torch.distributed.checkpoint.state_dict_loader import _load_state_dict + # Load DCP input_path + state_dict = {} + _load_state_dict( + state_dict, + storage_reader=DCP.filesystem.FileSystemReader(input_path), + planner=_EmptyStateDictLoadPlanner(), + no_dist=True, + ) + torchtitan_state_dict = state_dict["model"] + logger.info("Converting weights to HuggingFace format") + hf_state_dict = map_torchtitan_to_hf2(torchtitan_state_dict, max_seq_len, rope_theta, validate_against_original=validate_against_original) + + if '/' not in default_model: + if 'qwen' in default_model.lower(): + default_model = f'Qwen/{default_model}' + elif 'llama' in default_model.lower(): + default_model = f'meta-llama/{default_model}' + else: + raise ValueError(f"Unsupported model: {default_model}") + + # Create HuggingFace config + hf_config = AutoConfig.from_pretrained(default_model) + + # Create and load model + logger.info("Creating HuggingFace model") + # tokenizer = AutoTokenizer.from_pretrained(default_model) + tokenizer = get_tokenizer_with_chat_template(default_model, "tulu", override=True) + hf_model = AutoModelForCausalLM.from_pretrained(default_model, device_map="auto") # NOTE: need device_map="auto" to avoid CPU OOM + + # load state dict + logger.info("Loading state dict") + hf_model.load_state_dict(hf_state_dict, strict=True) + + # Save model + logger.info(f"Saving model to {output_path}") + output_path.mkdir(parents=True, exist_ok=True) + hf_model.save_pretrained(output_path) + tokenizer.save_pretrained(output_path) + logger.info("Conversion complete!") + + +if __name__ == "__main__": + init_logger() + app.cli() diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml new file mode 100644 index 0000000000..878e478ff5 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -0,0 +1,73 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "GPT-OSS debug training" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "debugmodel" +# test tokenizer, for debug purpose only +tokenizer_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 1 +compile = false +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 2 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"] diff --git a/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml new file mode 100644 index 0000000000..81908972ad --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml @@ -0,0 +1,70 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "GPT-OSS 120B model training" +print_args = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 10 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "120B" +tokenizer_path = "./assets/tokenizer/GPT-OSS" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2_000 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 2.2e-5 + +[training] +local_batch_size = 4 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 10_000 +compile = false +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 8 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"] diff --git a/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml new file mode 100644 index 0000000000..88d1c4d27f --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml @@ -0,0 +1,70 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "GPT-OSS 20B model training" +print_args = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 10 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "20B" +tokenizer_path = "./assets/tokenizer/GPT-OSS" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 2.2e-5 + +[training] +local_batch_size = 8 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 1000 +compile = false +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"] diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f66361a6d2..3c3b607571 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -72,6 +72,7 @@ def __init__( self.attn_mask_type = attn_mask_type self.fixed_block_size = fixed_block_size + self.mask_cache = {} FlexAttention.used_attn_mask_types.add(self.mask_key) @property @@ -84,9 +85,61 @@ def forward( k: torch.Tensor, v: torch.Tensor, scale: float | None = None, + sink_weights: torch.Tensor | None = None, + sliding_window: int = 0, + enable_gqa: bool = False, ) -> torch.Tensor: - block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) + if sink_weights is None: + block_mask = FlexAttention.block_masks[self.mask_key] + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) + + B, H_q, S_q, D = q.shape + _, H_kv, S_kv, _ = k.shape + + # regular (no-sink) mask + no extra KV col + mask_key = (sliding_window, S_q, S_kv) + if mask_key not in self.mask_cache: + if sliding_window is not None and sliding_window > 0: + mask_mod = FlexAttention._get_sliding_window_mask_mod(sliding_window) + else: + mask_mod = FlexAttention._get_causal_mask_mod() + block_mask = create_block_mask( + mask_mod, B, H_q, S_q, S_kv, + _compile=True, device=q.device # NOTE: set _compile=False if sampling for debugging + ) + self.mask_cache[mask_key] = block_mask + + block_mask = self.mask_cache[mask_key] + + # run fast flex_attn and return LSE + out, lse = FlexAttention.flex_attn( + q, k, v, + block_mask=block_mask, + enable_gqa=enable_gqa, + return_lse=True + ) + + # rescale by sigma(lse - w[h]) and broadcast over D + if sink_weights is not None: + w = sink_weights # [H] + scale = torch.sigmoid(lse - w.view(1, -1, 1)).unsqueeze(-1) # [B,H,S,1] + out = out * scale + + out = out.to(q.dtype) + return out + + @staticmethod + def _get_sliding_window_mask_mod(window: int): + """ + Returns a mask_mod function that + - only allows kv_idx ≤ q_idx (causal) + - and only if (q_idx - kv_idx) ≤ window + """ + def sliding_mod(b, h, q_idx, kv_idx): + # causal within window + keep = (kv_idx <= q_idx) & (q_idx - kv_idx <= window) + return keep + return sliding_mod @staticmethod def _get_causal_mask_mod() -> _mask_mod_signature: diff --git a/torchtitan/train.py b/torchtitan/train.py index 966cf868a3..904dc96d45 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -325,7 +325,7 @@ def __init__(self, job_config: JobConfig): ) self.train_context = dist_utils.get_train_context( loss_parallel_enabled, - parallelism_config.enable_compiled_autograd, + parallelism_config.enable_compiled_autograd ) self.maybe_enable_amp = dist_utils.maybe_enable_amp( parallel_dims,