From 0a8d2b0a6b4ebadee4f99e7cbc924714adf4a0d1 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 4 Sep 2025 22:44:58 -0700 Subject: [PATCH 01/13] benchmarking --- torchtitan/components/checkpoint.py | 27 ++++++++++++++++--- torchtitan/config/job_config.py | 3 +++ torchtitan/models/deepseek_v3/__init__.py | 2 +- .../train_configs/deepseek_v3_671b.toml | 6 +++-- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e9e7014425..f8c10c7d5a 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -40,6 +40,7 @@ from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP +from torchtitan.models.deepseek_v3.model.quantization import BLOCK_SIZE from torchtitan.protocols import BaseStateDictAdapter from torchtitan.tools.logging import logger from torchtitan.tools.utils import GarbageCollection @@ -247,6 +248,7 @@ def load_state_dict(state_dict): # Checkpoint policy related fields. self.initial_load_model_only = checkpoint_config.initial_load_model_only self.initial_load_in_hf = checkpoint_config.initial_load_in_hf + self.initial_load_dequantize = checkpoint_config.initial_load_dequantize self.initial_load_path = checkpoint_config.initial_load_path self.last_save_model_only = checkpoint_config.last_save_model_only self.last_save_in_hf = checkpoint_config.last_save_in_hf @@ -417,6 +419,7 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, + dequantize: bool = False, ) -> None: """Load the checkpoint with dcp. Args: @@ -432,10 +435,25 @@ def dcp_load( ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." hf_state_dict = self.sd_adapter.to_hf(state_dict) - dcp.load( - hf_state_dict, - storage_reader=HuggingFaceStorageReader(path=checkpoint_id), - ) + if not dequantize: + dcp.load( + hf_state_dict, + storage_reader=HuggingFaceStorageReader(path=checkpoint_id), + ) + else: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + BLOCK_SIZE = 128 # hardcode for deepseek 671b now + dcp.load( + hf_state_dict, + storage_reader=QuantizedHuggingFaceStorageReader( + path=checkpoint_id, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=4, + ), + ) state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) @@ -600,6 +618,7 @@ def load(self, step: int = -1) -> bool: states, checkpoint_id=checkpoint_id, from_hf=from_hf, + dequantize=self.initial_load_dequantize, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index ff509521e0..b253bbf012 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -450,6 +450,9 @@ class Checkpoint: non-tensors. The default value is False. """ + initial_load_dequantize: bool = False + + last_save_model_only: bool = True """ When last_save_model_only=True, only the model will be saved at the end of training, diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index a290ea7e66..5125a7904c 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -134,7 +134,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=61, + n_layers=4, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 5e933a4772..88464af450 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -39,7 +39,6 @@ local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 10_000 -compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -54,12 +53,15 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" interval = 500 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" +initial_load_in_hf = true +initial_load_dequantize = true +initial_load_path = "/data/users/jianiw/model/DeepSeek-V3.1-Base" [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] From 224cd37834c6a5a042c688592c8d8e52dea8ef38 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 8 Sep 2025 16:51:02 -0700 Subject: [PATCH 02/13] test --- torchtitan/components/checkpoint.py | 6 ++++ torchtitan/models/deepseek_v3/__init__.py | 2 +- .../deepseek_v3/model/state_dict_adapter.py | 33 ++++++++++++++++--- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index f8c10c7d5a..22f0bda8c0 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -436,15 +436,20 @@ def dcp_load( hf_state_dict = self.sd_adapter.to_hf(state_dict) if not dequantize: + begin_load = time.monotonic() + logger.info("Starting dcp.load with HuggingFaceStorageReader") dcp.load( hf_state_dict, storage_reader=HuggingFaceStorageReader(path=checkpoint_id), ) + logger.info(f"dcp.load with HuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds") else: from torch.distributed.checkpoint.quantized_hf_storage import ( QuantizedHuggingFaceStorageReader, ) BLOCK_SIZE = 128 # hardcode for deepseek 671b now + begin_load = time.monotonic() + logger.info("Starting dcp.load with QuantizedHuggingFaceStorageReader") dcp.load( hf_state_dict, storage_reader=QuantizedHuggingFaceStorageReader( @@ -454,6 +459,7 @@ def dcp_load( thread_count=4, ), ) + logger.info(f"dcp.load with QuantizedHuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds") state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 5125a7904c..a290ea7e66 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -134,7 +134,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=4, + n_layers=61, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index e947d70695..7631752518 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -6,6 +6,7 @@ import re +import time from typing import Any import torch @@ -16,6 +17,7 @@ from .args import DeepSeekV3ModelArgs from .quantization import calculate_scale_shape, dequantize_from_fp8 +from torchtitan.tools.logging import logger class DeepSeekV3StateDictAdapter(StateDictAdapter): @@ -220,6 +222,8 @@ def _get_local_experts_weights( Returns: Dictionary mapping individual expert keys to their DTensor weights """ + start_time = time.time() + logger.info(f"Starting _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}") device_mesh = grouped_expert_weight.device_mesh dtensor_placements = grouped_expert_weight.placements @@ -285,6 +289,9 @@ def _get_local_experts_weights( local_expert_tensors[expert_key] = expert_dtensor + end_time = time.time() + duration = end_time - start_time + logger.info(f"Completed _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}, duration: {duration:.4f}s") return local_expert_tensors def _concatenate_expert_weights_dtensor( @@ -312,6 +319,8 @@ def _concatenate_expert_weights_dtensor( Returns: Concatenated GroupedExperts weight DTensor if all experts are available, otherwise None """ + start_time = time.time() + logger.info(f"Starting _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}") # If we have all the experts for this abstract_key, concatenate them experts = expert_weights_by_layer[layer_num][abstract_key] expected_n_experts = ( @@ -341,6 +350,9 @@ def _concatenate_expert_weights_dtensor( if not expert_weights_by_layer[layer_num]: del expert_weights_by_layer[layer_num] + end_time = time.time() + duration = end_time - start_time + logger.info(f"Completed _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}, duration: {duration:.4f}s") return stacked_dtensor def _split_experts_weights( @@ -453,6 +465,9 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: 1. Convert between the HF shape and the torchtitan shape. 2. Split the GroupedExperts' weight into separate expert's wegiht. """ + start_time = time.time() + logger.info(f"Starting to_hf conversion, state_dict has {len(state_dict)} keys") + to_hf_map = {v: k for k, v in self.from_hf_map.items()} hf_state_dict = {} @@ -501,10 +516,13 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_state_dict[new_key] = value # Prepare for dequantization - hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( - hf_state_dict - ) - return hf_state_dict_with_scale_inv + # hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( + # hf_state_dict + # ) + end_time = time.time() + duration = end_time - start_time + logger.info(f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s") + return hf_state_dict def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ @@ -512,10 +530,12 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: 2. Convert between the HF shape and the torchtitan shape. 3. Concate separate expert's wegiht into GroupedExperts' weight. """ + start_time = time.time() + logger.info(f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys") # dequantize the tensor in state_dict and remove the scale_inv tensor - hf_state_dict = self._dequantize(hf_state_dict) + # hf_state_dict = self._dequantize(hf_state_dict) state_dict = {} expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} @@ -565,4 +585,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: new_key = self.from_hf_map[key] state_dict[new_key] = value + end_time = time.time() + duration = end_time - start_time + logger.info(f"Completed from_hf conversion, processed {len(hf_state_dict)} keys, duration: {duration:.4f}s") return state_dict From 117d96a7614265ed68f684cf974b463ff8391601 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 9 Sep 2025 16:53:56 -0700 Subject: [PATCH 03/13] benchmarking --- torchtitan/components/checkpoint.py | 7 +++++-- .../models/deepseek_v3/train_configs/deepseek_v3_671b.toml | 5 +---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 22f0bda8c0..7c70bcb242 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -447,7 +447,10 @@ def dcp_load( from torch.distributed.checkpoint.quantized_hf_storage import ( QuantizedHuggingFaceStorageReader, ) - BLOCK_SIZE = 128 # hardcode for deepseek 671b now + + # NOTE: The following config is for DeepSeek-V3 671B model, which is using + # FP8 weight format with 128x128 block scaling. + BLOCK_SIZE = 128 begin_load = time.monotonic() logger.info("Starting dcp.load with QuantizedHuggingFaceStorageReader") dcp.load( @@ -456,7 +459,7 @@ def dcp_load( path=checkpoint_id, target_dtype=torch.float32, block_size=BLOCK_SIZE, - thread_count=4, + thread_count=8, ), ) logger.info(f"dcp.load with QuantizedHuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds") diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 88464af450..922ed5b37e 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -53,15 +53,12 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable = true +enable = false folder = "checkpoint" interval = 500 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" -initial_load_in_hf = true -initial_load_dequantize = true -initial_load_path = "/data/users/jianiw/model/DeepSeek-V3.1-Base" [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] From ff35def7763513e2c24e51549d87c7709d06c37b Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 9 Sep 2025 16:55:37 -0700 Subject: [PATCH 04/13] remove dequantize --- .../models/deepseek_v3/model/quantization.py | 73 ------------------- .../deepseek_v3/model/state_dict_adapter.py | 58 --------------- 2 files changed, 131 deletions(-) delete mode 100644 torchtitan/models/deepseek_v3/model/quantization.py diff --git a/torchtitan/models/deepseek_v3/model/quantization.py b/torchtitan/models/deepseek_v3/model/quantization.py deleted file mode 100644 index a8ac6003a2..0000000000 --- a/torchtitan/models/deepseek_v3/model/quantization.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torchtitan.tools.logging import logger - -# Fixed block size of 128x128 as specified in the algorithm -BLOCK_SIZE = 128 - - -def calculate_scale_shape( - weight: torch.Tensor, BLOCK_SIZE: int = BLOCK_SIZE -) -> torch.Size: - # Calculate the scale tensor shape - orig_shape = weight.shape - - # Calculate number of blocks needed - block_rows = (orig_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE - block_cols = (orig_shape[1] + BLOCK_SIZE - 1) // BLOCK_SIZE - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = torch.Size((block_rows, block_cols)) - - return expected_scale_shape - - -def dequantize_from_fp8( - weight: torch.Tensor, - scale_inv: torch.Tensor, - dtype=torch.bfloat16, - BLOCK_SIZE: int = BLOCK_SIZE, -) -> torch.Tensor: - # Convert to float32 for computation - float_weight = weight.to(torch.float32) - # Get original dimensions - orig_shape = weight.shape - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = calculate_scale_shape(weight, BLOCK_SIZE) - block_rows, block_cols = expected_scale_shape - if scale_inv.shape != expected_scale_shape: - logger.warning( - f"scale_inv shape {scale_inv.shape} doesn't match expected shape {expected_scale_shape}" - ) - - # NOTE: When processing large models on-the-fly, misalignment between block boundaries - # and DTensor local shape partitioning can lead to silent numerical inaccuracies. - dequantized = float_weight.detach().clone().to(dtype=dtype) - - # Apply scaling factors to each block - for i in range(block_rows): - row_start = i * BLOCK_SIZE - row_end = min(row_start + BLOCK_SIZE, orig_shape[0]) - - for j in range(block_cols): - col_start = j * BLOCK_SIZE - col_end = min(col_start + BLOCK_SIZE, orig_shape[1]) - - # Get the block - block = float_weight[row_start:row_end, col_start:col_end] - - scale = scale_inv[i, j] - block = block * scale - - # Explicitly convert block to dtype - block_converted = block.to(dtype=torch.float32) - # Store the dequantized block - dequantized[row_start:row_end, col_start:col_end] = block_converted - - return dequantized diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 7631752518..c14703a094 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -16,7 +16,6 @@ from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import DeepSeekV3ModelArgs -from .quantization import calculate_scale_shape, dequantize_from_fp8 from torchtitan.tools.logging import logger @@ -410,55 +409,6 @@ def _concatenate_expert_weights( return stacked_tensor - def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: - """ - Dequantize the weights from float8 to float32. - """ - - scale_inv_keys = [] - for key, weight in state_dict.items(): - if key.endswith(".weight") and key + "_scale_inv" in state_dict: - scale_inv = state_dict[key + "_scale_inv"] - dequantized_weight = dequantize_from_fp8( - weight, scale_inv, dtype=torch.float32 - ) - # update the weight and remove the scale_inv tensor - state_dict[key] = dequantized_weight - scale_inv_keys.append(key + "_scale_inv") - - for key in scale_inv_keys: - state_dict.pop(key) - - return state_dict - - def _add_quantization_scale_inv_tensors( - self, state_dict: dict[str, Any] - ) -> dict[str, Any]: - """ - Add quantization scale tensors the state_dict. - """ - non_quantized_keys = [ - "input_layernorm.weight", - "post_attention_layernorm.weight", - "norm.weight", - "lm_head.weight", - "embed_tokens.weight", - "mlp.gate.weight", - ] - - weight_scale_inv_state_dict = {} - for key, value in state_dict.items(): - if key.endswith(".weight") and not any( - non_quantized_key in key for non_quantized_key in non_quantized_keys - ): - expected_scale_shape = calculate_scale_shape(value) - # add weight_scale_inv to the state_dict - weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( - expected_scale_shape, dtype=torch.float32 - ) - - state_dict.update(weight_scale_inv_state_dict) - return state_dict def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ @@ -515,10 +465,6 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: new_key = to_hf_map[key] hf_state_dict[new_key] = value - # Prepare for dequantization - # hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( - # hf_state_dict - # ) end_time = time.time() duration = end_time - start_time logger.info(f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s") @@ -533,11 +479,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: start_time = time.time() logger.info(f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys") - # dequantize the tensor in state_dict and remove the scale_inv tensor - - # hf_state_dict = self._dequantize(hf_state_dict) state_dict = {} - expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} for key, value in hf_state_dict.items(): From bda5294417c752b54ff6c34aa563278f09e0bea8 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 10 Sep 2025 14:05:28 -0700 Subject: [PATCH 05/13] reformat --- torchtitan/components/checkpoint.py | 42 ++++----------- torchtitan/config/job_config.py | 3 -- .../qwen3/model/state_dict_adapter.py | 5 ++ torchtitan/models/deepseek_v3/__init__.py | 5 +- torchtitan/models/deepseek_v3/model/args.py | 3 ++ .../deepseek_v3/model/state_dict_adapter.py | 54 +++++++++++++++---- .../train_configs/deepseek_v3_671b.toml | 2 +- .../models/llama3/model/state_dict_adapter.py | 4 ++ torchtitan/protocols/state_dict_adapter.py | 16 ++++++ 9 files changed, 86 insertions(+), 48 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 7c70bcb242..933a6eda3b 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -40,7 +40,6 @@ from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP -from torchtitan.models.deepseek_v3.model.quantization import BLOCK_SIZE from torchtitan.protocols import BaseStateDictAdapter from torchtitan.tools.logging import logger from torchtitan.tools.utils import GarbageCollection @@ -248,7 +247,6 @@ def load_state_dict(state_dict): # Checkpoint policy related fields. self.initial_load_model_only = checkpoint_config.initial_load_model_only self.initial_load_in_hf = checkpoint_config.initial_load_in_hf - self.initial_load_dequantize = checkpoint_config.initial_load_dequantize self.initial_load_path = checkpoint_config.initial_load_path self.last_save_model_only = checkpoint_config.last_save_model_only self.last_save_in_hf = checkpoint_config.last_save_in_hf @@ -419,7 +417,6 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, - dequantize: bool = False, ) -> None: """Load the checkpoint with dcp. Args: @@ -434,35 +431,17 @@ def dcp_load( self.sd_adapter is not None ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." hf_state_dict = self.sd_adapter.to_hf(state_dict) + hf_storage_reader = self.sd_adapter.get_hf_storage_reader(checkpoint_id) - if not dequantize: - begin_load = time.monotonic() - logger.info("Starting dcp.load with HuggingFaceStorageReader") - dcp.load( - hf_state_dict, - storage_reader=HuggingFaceStorageReader(path=checkpoint_id), - ) - logger.info(f"dcp.load with HuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds") - else: - from torch.distributed.checkpoint.quantized_hf_storage import ( - QuantizedHuggingFaceStorageReader, - ) - - # NOTE: The following config is for DeepSeek-V3 671B model, which is using - # FP8 weight format with 128x128 block scaling. - BLOCK_SIZE = 128 - begin_load = time.monotonic() - logger.info("Starting dcp.load with QuantizedHuggingFaceStorageReader") - dcp.load( - hf_state_dict, - storage_reader=QuantizedHuggingFaceStorageReader( - path=checkpoint_id, - target_dtype=torch.float32, - block_size=BLOCK_SIZE, - thread_count=8, - ), - ) - logger.info(f"dcp.load with QuantizedHuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds") + begin_load = time.monotonic() + logger.info("Starting dcp.load with HuggingFaceStorageReader") + dcp.load( + hf_state_dict, + storage_reader=hf_storage_reader, + ) + logger.info( + f"dcp.load with HuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds" + ) state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) @@ -627,7 +606,6 @@ def load(self, step: int = -1) -> bool: states, checkpoint_id=checkpoint_id, from_hf=from_hf, - dequantize=self.initial_load_dequantize, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index b253bbf012..ff509521e0 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -450,9 +450,6 @@ class Checkpoint: non-tensors. The default value is False. """ - initial_load_dequantize: bool = False - - last_save_model_only: bool = True """ When last_save_model_only=True, only the model will be saved at the end of training, diff --git a/torchtitan/experiments/qwen3/model/state_dict_adapter.py b/torchtitan/experiments/qwen3/model/state_dict_adapter.py index 760cc662be..600d9b511e 100644 --- a/torchtitan/experiments/qwen3/model/state_dict_adapter.py +++ b/torchtitan/experiments/qwen3/model/state_dict_adapter.py @@ -15,6 +15,8 @@ import re from typing import Any +from torch.distributed.checkpoint import HuggingFaceStorageReader + from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import Qwen3ModelArgs @@ -45,6 +47,9 @@ def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None): "lm_head.weight": "output.weight", } + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + return HuggingFaceStorageReader(path) + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: to_hf_map = {v: k for k, v in self.from_hf_map.items()} diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index a290ea7e66..221378bdca 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -153,8 +153,9 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", + hf_weight_quantized=True, ), } diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index d6afedfa34..b27b7a9d50 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -86,6 +86,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs): beta_slow: int = 1 mscale: float = 1.0 + # HF checkpoint args + hf_weight_quantized: bool = False + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len if seq_len > self.max_seq_len: diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index c14703a094..4c18a944a8 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -10,13 +10,14 @@ from typing import Any import torch +from torch.distributed.checkpoint import HuggingFaceStorageReader from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard from torchtitan.protocols.state_dict_adapter import StateDictAdapter +from torchtitan.tools.logging import logger from .args import DeepSeekV3ModelArgs -from torchtitan.tools.logging import logger class DeepSeekV3StateDictAdapter(StateDictAdapter): @@ -79,6 +80,26 @@ def __init__( self.grouped_expert_weight_shape = {} # {titan_abstract_key: shape} self.local_experts_indices = {} # {titan_abstract_key: (start_idx, end_idx)} + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + if self.model_args.hf_weight_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + + # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + BLOCK_SIZE = 128 + return ( + QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=8, + ), + ) + else: + return HuggingFaceStorageReader(path) + def _calculate_strided_shard_shard_indices( self, strided_shard_dim_degree: int, @@ -222,7 +243,9 @@ def _get_local_experts_weights( Dictionary mapping individual expert keys to their DTensor weights """ start_time = time.time() - logger.info(f"Starting _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}") + logger.info( + f"Starting _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}" + ) device_mesh = grouped_expert_weight.device_mesh dtensor_placements = grouped_expert_weight.placements @@ -290,7 +313,9 @@ def _get_local_experts_weights( end_time = time.time() duration = end_time - start_time - logger.info(f"Completed _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}, duration: {duration:.4f}s") + logger.info( + f"Completed _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}, duration: {duration:.4f}s" + ) return local_expert_tensors def _concatenate_expert_weights_dtensor( @@ -319,7 +344,9 @@ def _concatenate_expert_weights_dtensor( Concatenated GroupedExperts weight DTensor if all experts are available, otherwise None """ start_time = time.time() - logger.info(f"Starting _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}") + logger.info( + f"Starting _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}" + ) # If we have all the experts for this abstract_key, concatenate them experts = expert_weights_by_layer[layer_num][abstract_key] expected_n_experts = ( @@ -351,7 +378,9 @@ def _concatenate_expert_weights_dtensor( end_time = time.time() duration = end_time - start_time - logger.info(f"Completed _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}, duration: {duration:.4f}s") + logger.info( + f"Completed _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}, duration: {duration:.4f}s" + ) return stacked_dtensor def _split_experts_weights( @@ -409,7 +438,6 @@ def _concatenate_expert_weights( return stacked_tensor - def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. Convert between the HF shape and the torchtitan shape. @@ -417,7 +445,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ start_time = time.time() logger.info(f"Starting to_hf conversion, state_dict has {len(state_dict)} keys") - + to_hf_map = {v: k for k, v in self.from_hf_map.items()} hf_state_dict = {} @@ -467,7 +495,9 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: end_time = time.time() duration = end_time - start_time - logger.info(f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s") + logger.info( + f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s" + ) return hf_state_dict def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: @@ -477,7 +507,9 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: 3. Concate separate expert's wegiht into GroupedExperts' weight. """ start_time = time.time() - logger.info(f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys") + logger.info( + f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys" + ) state_dict = {} expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} @@ -529,5 +561,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: end_time = time.time() duration = end_time - start_time - logger.info(f"Completed from_hf conversion, processed {len(hf_state_dict)} keys, duration: {duration:.4f}s") + logger.info( + f"Completed from_hf conversion, processed {len(hf_state_dict)} keys, duration: {duration:.4f}s" + ) return state_dict diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 922ed5b37e..153566efbd 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -65,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable = true components = ["loss"] # ["model", "loss"] [quantize.dense.float8] diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 2c386ece0d..1475ba2055 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -10,6 +10,7 @@ logger = logging.getLogger() +from torch.distributed.checkpoint import HuggingFaceStorageReader from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import TransformerModelArgs @@ -41,6 +42,9 @@ def __init__( "lm_head.weight": "output.weight", } + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + return HuggingFaceStorageReader(path) + # HuggingFace permutation function (exact copy from their conversion script) def _permute(self, w, n_heads_arg, dim1=None, dim2=None): if dim1 is None: diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 5b441e9bbf..a6fedd0af4 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -11,6 +11,9 @@ from abc import ABC, abstractmethod from typing import Any +from torch.distributed.checkpoint import HuggingFaceStorageReader + + logger = logging.getLogger() from .model import BaseModelArgs @@ -58,6 +61,19 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ pass + @abstractmethod + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + """Returns hf storage reader to read HF checkpoint + + Args: + path: the path to read HF checkpoint + + Returns: + THe HuggingFace storage reader to read rom HF checkpoint + + """ + pass + class StateDictAdapter(BaseStateDictAdapter): """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping""" From f673f31451a2f737b6f8a63e30244afaf5b8ba7a Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 15 Sep 2025 11:36:25 -0700 Subject: [PATCH 06/13] fix return --- torchtitan/components/checkpoint.py | 5 +++-- .../models/deepseek_v3/model/state_dict_adapter.py | 12 +++++------- .../deepseek_v3/train_configs/deepseek_v3_671b.toml | 8 +++++--- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 933a6eda3b..3d11fa31c4 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -434,7 +434,7 @@ def dcp_load( hf_storage_reader = self.sd_adapter.get_hf_storage_reader(checkpoint_id) begin_load = time.monotonic() - logger.info("Starting dcp.load with HuggingFaceStorageReader") + logger.info(f"Starting dcp.load with {hf_storage_reader}") dcp.load( hf_state_dict, storage_reader=hf_storage_reader, @@ -759,7 +759,8 @@ def _save_last_step(self, curr_step: int) -> None: self.dcp_save( states, checkpoint_id=self._create_checkpoint_id(curr_step), - async_mode=AsyncMode.DISABLED, + async_mode=AsyncMode.DISA/dcp.load wit + BLED, enable_garbage_collection=True, to_hf=self.last_save_in_hf, ) diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 4c18a944a8..f0e40cd87b 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -89,13 +89,11 @@ def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. # If loading checkpoints without quantization, use HuggingFaceStorageReader instead BLOCK_SIZE = 128 - return ( - QuantizedHuggingFaceStorageReader( - path=path, - target_dtype=torch.float32, - block_size=BLOCK_SIZE, - thread_count=8, - ), + return QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=8, ) else: return HuggingFaceStorageReader(path) diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 153566efbd..97e7a30fad 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -35,7 +35,7 @@ decay_type = "cosine" min_lr_factor = 0.1 [training] -local_batch_size = 4 +local_batch_size = 1 seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 10_000 @@ -53,19 +53,21 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" interval = 500 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" +initial_load_in_hf = true +initial_load_path = "/home/jianiw/tmp/mffuse/deepseek-v3/DeepSeek-V3.1-Base" [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable = true +enable = false components = ["loss"] # ["model", "loss"] [quantize.dense.float8] From db627d41f79ca31659cfa877819e9ce8394ce939 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 15 Sep 2025 12:31:07 -0700 Subject: [PATCH 07/13] fix checkpoint --- torchtitan/components/checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 3d11fa31c4..0855e3856f 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -759,8 +759,7 @@ def _save_last_step(self, curr_step: int) -> None: self.dcp_save( states, checkpoint_id=self._create_checkpoint_id(curr_step), - async_mode=AsyncMode.DISA/dcp.load wit - BLED, + async_mode=AsyncMode.DISABLED, enable_garbage_collection=True, to_hf=self.last_save_in_hf, ) From 393e46f1ba786853be5cfedd1f67850046da7110 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 15 Sep 2025 18:22:52 -0700 Subject: [PATCH 08/13] add 671b model configs --- .../models/deepseek_v3/train_configs/deepseek_v3_671b.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 97e7a30fad..e5448c3e4a 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -63,7 +63,7 @@ initial_load_in_hf = true initial_load_path = "/home/jianiw/tmp/mffuse/deepseek-v3/DeepSeek-V3.1-Base" [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "full" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] From 10df007eedd018de273297c6a93b45d425aeff8f Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 16 Sep 2025 10:58:39 -0700 Subject: [PATCH 09/13] add support for CPU index --- torchtitan/distributed/expert_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 12512bfac0..4761f331d8 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -272,6 +272,7 @@ def wrapper( num_ep_ranks, padded_max_len, TOKEN_GROUP_ALIGN_SIZE_M, + use_cpu=True ) x = torch.vstack((x, x.new_zeros((x.shape[-1])))) From bcb9a6e1c30682a8f629a12966569a12d737671f Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 16 Sep 2025 12:38:09 -0700 Subject: [PATCH 10/13] add logging --- torchtitan/distributed/activation_checkpoint.py | 3 ++- torchtitan/models/deepseek_v3/__init__.py | 4 ++-- .../models/deepseek_v3/model/state_dict_adapter.py | 9 ++++++++- .../deepseek_v3/train_configs/deepseek_v3_671b.toml | 2 +- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 227c2ca211..b36aec855c 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -103,7 +103,8 @@ def _custom_policy(ctx, func, *args, **kwargs): mm_count_key = f"{mode}_mm_count" if func == torch.ops.aten.mm.default: if args[1].shape in mm_recompute_shapes: - return CheckpointPolicy.PREFER_RECOMPUTE + # return CheckpointPolicy.PREFER_RECOMPUTE + return CheckpointPolicy.MUST_SAVE # TODO(jianiw): testing meta[mm_count_key] += 1 # Saves output of all compute ops, except every second mm to_save = func in op_sac_save_list and not ( diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 221378bdca..bd9a734409 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -153,8 +153,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - # use_flex_attn=True, - # attn_mask_type="block_causal", + use_flex_attn=True, + attn_mask_type="block_causal", hf_weight_quantized=True, ), } diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index f0e40cd87b..d22faadf1d 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. +import logging import re import time from typing import Any @@ -93,7 +94,7 @@ def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: path=path, target_dtype=torch.float32, block_size=BLOCK_SIZE, - thread_count=8, + thread_count=4, ) else: return HuggingFaceStorageReader(path) @@ -471,6 +472,9 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_state_dict.update(local_expert_fqn) else: + logger.info( + f"Using the old torch.split for value {new_abstract_key} " + ) # keep this path for offline conversion split_values = self._split_experts_weights( value, self.model_args.moe_args.num_experts @@ -536,6 +540,9 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: value.device_mesh, ) else: # keep this path to be compatibile with offline conversion + logger.info( + f"Using the old torch.split for value {titan_abstract_key} " + ) stacked_value = self._concatenate_expert_weights( expert_weights_by_layer, titan_abstract_key, diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index e5448c3e4a..97e7a30fad 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -63,7 +63,7 @@ initial_load_in_hf = true initial_load_path = "/home/jianiw/tmp/mffuse/deepseek-v3/DeepSeek-V3.1-Base" [activation_checkpoint] -mode = "full" # ["none", "selective", "full"] +mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] From 6dfc5a0e652ca7a24facf160f91273ac675c5026 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 29 Sep 2025 16:14:57 -0700 Subject: [PATCH 11/13] detect anomaly --- torchtitan/models/deepseek-v3/model/model.py | 0 torchtitan/models/deepseek_v3/__init__.py | 10 +-- torchtitan/models/deepseek_v3/model/model.py | 69 +++++++++++++++++++ .../train_configs/deepseek_v3_16b.toml | 15 ++-- .../train_configs/deepseek_v3_671b.toml | 2 +- torchtitan/train.py | 30 ++++++++ 6 files changed, 114 insertions(+), 12 deletions(-) create mode 100644 torchtitan/models/deepseek-v3/model/model.py diff --git a/torchtitan/models/deepseek-v3/model/model.py b/torchtitan/models/deepseek-v3/model/model.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index bd9a734409..3adac87e3d 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -78,9 +78,9 @@ attn_mask_type="block_causal", ), "16B": DeepSeekV3ModelArgs( - vocab_size=102400, + vocab_size=163840, dim=2048, - inter_dim=10944, + inter_dim=11264, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, @@ -134,7 +134,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=61, + n_layers=45, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( @@ -153,8 +153,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", hf_weight_quantized=True, ), } diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index dc612fafb7..c9c6705a36 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -15,6 +15,7 @@ from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs +from torchtitan.tools.logging import logger # Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 @@ -284,6 +285,74 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 self.layer_id = layer_id + # Register backward hook to monitor gradients + self.register_full_backward_hook(self._layer_gradient_hook) + logger.info(f"[HOOK REGISTRATION] Layer {self.layer_id} TransformerBlock gradient hook registered") + + def _layer_gradient_hook(self, module, grad_input, grad_output): + """Backward hook to monitor gradients of all parameters in this layer.""" + logger.info(f"[LAYER GRAD HOOK] Layer {self.layer_id} TransformerBlock backward pass") + + # Collect gradient statistics for this layer + layer_grad_norms = [] + nan_params = [] + inf_params = [] + total_params = 0 + + # Check gradients for all named parameters in this layer + for name, param in self.named_parameters(): + total_params += 1 + if param.grad is not None: + if param.grad.dtype.is_floating_point or param.grad.dtype.is_complex: + # Check for NaN and Inf elements first + has_nan = torch.isnan(param.grad).any().item() + has_inf = torch.isinf(param.grad).any().item() + + # Calculate norm safely + try: + grad_norm = param.grad.norm().item() + # Check if norm overflowed to inf (but individual elements might be finite) + if torch.isinf(torch.tensor(grad_norm)) and not has_inf: + print(f"[LAYER GRAD OVERFLOW] Layer {self.layer_id} - {name}: " + f"norm overflow to inf, max_val={param.grad.abs().max().item():.6e}, " + f"shape={param.grad.shape}") + except: + grad_norm = float('inf') + print(f"[LAYER GRAD ERROR] Layer {self.layer_id} - {name}: failed to compute norm") + + layer_grad_norms.append(grad_norm) + + if has_nan: + nan_params.append(name) + print(f"[LAYER GRAD NaN] Layer {self.layer_id} - {name}: shape={param.grad.shape}") + elif has_inf: + inf_params.append(name) + print(f"[LAYER GRAD INF] Layer {self.layer_id} - {name}: shape={param.grad.shape}") + elif torch.isinf(torch.tensor(grad_norm)): + print(f"[LAYER GRAD LARGE] Layer {self.layer_id} - {name}: " + f"very large gradients, max={param.grad.abs().max().item():.6e}") + else: + print(f"[LAYER GRAD] Layer {self.layer_id} - {name}: dtype={param.grad.dtype} (non-float grad)") + else: + print(f"[LAYER GRAD] Layer {self.layer_id} - {name}: grad is None") + + # Compute and logger.info layer statistics + if layer_grad_norms: + avg_grad_norm = sum(layer_grad_norms) / len(layer_grad_norms) + max_grad_norm = max(layer_grad_norms) + min_grad_norm = min(layer_grad_norms) + + logger.info(f"[LAYER GRAD SUMMARY] Layer {self.layer_id}: " + f"avg_norm={avg_grad_norm:.6f}, max_norm={max_grad_norm:.6f}, " + f"min_norm={min_grad_norm:.6f}, total_params={total_params}") + else: + logger.info(f"[LAYER GRAD SUMMARY] Layer {self.layer_id}: No floating point gradients found, total_params={total_params}") + + if nan_params: + logger.info(f"[LAYER GRAD ERROR] Layer {self.layer_id} NaN gradients: {nan_params}") + if inf_params: + logger.info(f"[LAYER GRAD ERROR] Layer {self.layer_id} Inf gradients: {inf_params}") + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): """ Forward pass for the Transformer block. diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index d0cd250583..95e76180a6 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -35,11 +35,12 @@ decay_type = "cosine" min_lr_factor = 0.1 [training] -local_batch_size = 8 +local_batch_size = 1 seq_len = 4096 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 100 dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) +deterministic = true [parallelism] data_parallel_replicate_degree = 1 @@ -49,23 +50,25 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 8 +expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" -interval = 10 +interval = 100 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" +initial_load_in_hf = true +initial_load_path = "/data/users/jianiw/model/Moonlight-16B-A3B" [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable=false components = ["loss"] # ["model", "loss"] [quantize.dense.float8] diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 97e7a30fad..d4fc35c6d9 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -25,7 +25,7 @@ hf_assets_path = "./assets/hf/DeepSeek-V3.1-Base" [optimizer] name = "AdamW" -lr = 2.2e-4 +lr = 2.2e-5 eps = 1e-8 [lr_scheduler] diff --git a/torchtitan/train.py b/torchtitan/train.py index c343279dda..6685948ce9 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -409,6 +409,9 @@ def batch_generator( def forward_backward_step( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor ) -> torch.Tensor: + + # torch.autograd.set_detect_anomaly(True) + model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -473,7 +476,34 @@ def forward_backward_step( loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory del pred + # with torch.autograd.detect_anomaly(): loss.backward() + + # Check for NaN/Inf gradients after backward pass + print(f"[GRAD CHECK] Checking gradients after backward pass...") + nan_params = [] + inf_params = [] + total_params = 0 + total_with_grad = 0 + + for name, param in model_parts[0].named_parameters(): + total_params += 1 + if param.grad is not None: + total_with_grad += 1 + if torch.isnan(param.grad).any(): + nan_params.append(name) + print(f"[GRAD NaN] {name}: shape={param.grad.shape}, norm={param.grad.norm():.6f}") + elif torch.isinf(param.grad).any(): + inf_params.append(name) + print(f"[GRAD INF] {name}: shape={param.grad.shape}, norm={param.grad.norm():.6f}") + + print(f"[GRAD SUMMARY] Total params: {total_params}, with grad: {total_with_grad}") + print(f"[GRAD SUMMARY] NaN gradients: {len(nan_params)}, Inf gradients: {len(inf_params)}") + + if nan_params: + print(f"[GRAD ERROR] Parameters with NaN gradients: {nan_params[:10]}...") # Show first 10 + if inf_params: + print(f"[GRAD ERROR] Parameters with Inf gradients: {inf_params[:10]}...") # Show first 10 return loss From da1c1c62b561bd233c90d3e96f43a94a7d8aae14 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 2 Oct 2025 16:58:49 -0700 Subject: [PATCH 12/13] add torchtitan --- torchtitan/models/deepseek_v3/__init__.py | 1 + torchtitan/models/deepseek_v3/model/model.py | 14 +++++++------- .../train_configs/deepseek_v3_671b.toml | 2 +- torchtitan/train.py | 6 +++--- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 3adac87e3d..63581c4aa6 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -92,6 +92,7 @@ score_func="softmax", route_norm=False, score_before_experts=False, + use_grouped_mm=False, ), q_lora_rank=0, kv_lora_rank=512, diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index c9c6705a36..65fc841fee 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -313,28 +313,28 @@ def _layer_gradient_hook(self, module, grad_input, grad_output): grad_norm = param.grad.norm().item() # Check if norm overflowed to inf (but individual elements might be finite) if torch.isinf(torch.tensor(grad_norm)) and not has_inf: - print(f"[LAYER GRAD OVERFLOW] Layer {self.layer_id} - {name}: " + logger.info(f"[LAYER GRAD OVERFLOW] Layer {self.layer_id} - {name}: " f"norm overflow to inf, max_val={param.grad.abs().max().item():.6e}, " f"shape={param.grad.shape}") except: grad_norm = float('inf') - print(f"[LAYER GRAD ERROR] Layer {self.layer_id} - {name}: failed to compute norm") + logger.info(f"[LAYER GRAD ERROR] Layer {self.layer_id} - {name}: failed to compute norm") layer_grad_norms.append(grad_norm) if has_nan: nan_params.append(name) - print(f"[LAYER GRAD NaN] Layer {self.layer_id} - {name}: shape={param.grad.shape}") + logger.info(f"[LAYER GRAD NaN] Layer {self.layer_id} - {name}: shape={param.grad.shape}") elif has_inf: inf_params.append(name) - print(f"[LAYER GRAD INF] Layer {self.layer_id} - {name}: shape={param.grad.shape}") + logger.info(f"[LAYER GRAD INF] Layer {self.layer_id} - {name}: shape={param.grad.shape}") elif torch.isinf(torch.tensor(grad_norm)): - print(f"[LAYER GRAD LARGE] Layer {self.layer_id} - {name}: " + logger.info(f"[LAYER GRAD LARGE] Layer {self.layer_id} - {name}: " f"very large gradients, max={param.grad.abs().max().item():.6e}") else: - print(f"[LAYER GRAD] Layer {self.layer_id} - {name}: dtype={param.grad.dtype} (non-float grad)") + logger.info(f"[LAYER GRAD] Layer {self.layer_id} - {name}: dtype={param.grad.dtype} (non-float grad)") else: - print(f"[LAYER GRAD] Layer {self.layer_id} - {name}: grad is None") + logger.info(f"[LAYER GRAD] Layer {self.layer_id} - {name}: grad is None") # Compute and logger.info layer statistics if layer_grad_norms: diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index d4fc35c6d9..55b6737833 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -63,7 +63,7 @@ initial_load_in_hf = true initial_load_path = "/home/jianiw/tmp/mffuse/deepseek-v3/DeepSeek-V3.1-Base" [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "full" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] diff --git a/torchtitan/train.py b/torchtitan/train.py index 6685948ce9..a91dc9262d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -410,7 +410,7 @@ def forward_backward_step( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor ) -> torch.Tensor: - # torch.autograd.set_detect_anomaly(True) + torch.autograd.set_detect_anomaly(True) model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -476,8 +476,8 @@ def forward_backward_step( loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory del pred - # with torch.autograd.detect_anomaly(): - loss.backward() + with torch.autograd.detect_anomaly(): + loss.backward() # Check for NaN/Inf gradients after backward pass print(f"[GRAD CHECK] Checking gradients after backward pass...") From c79c3f9e8a1e25bcc50fd3d3b996c01ad9640148 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 3 Oct 2025 15:17:38 -0700 Subject: [PATCH 13/13] add hooks --- torchtitan/models/deepseek_v3/__init__.py | 10 +-- torchtitan/models/deepseek_v3/model/model.py | 87 ++++++++---------- .../train_configs/deepseek_v3_16b.toml | 2 +- .../train_configs/deepseek_v3_671b.toml | 4 +- torchtitan/models/moe.py | 88 ++++++++++++++++++- torchtitan/train.py | 6 +- 6 files changed, 135 insertions(+), 62 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 63581c4aa6..b13490e0a3 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -100,8 +100,8 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400, @@ -135,7 +135,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=45, + n_layers=61, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( @@ -154,8 +154,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - # use_flex_attn=True, - # attn_mask_type="block_causal", + use_flex_attn=True, + attn_mask_type="block_causal", hf_weight_quantized=True, ), } diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 65fc841fee..14ea53fe5e 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -11,7 +11,7 @@ from torch import nn from torchtitan.models.attention import build_attention -from torchtitan.models.moe import FeedForward, MoE +from torchtitan.models.moe import FeedForward, MoE, create_tensor_hook from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs @@ -286,17 +286,13 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.layer_id = layer_id # Register backward hook to monitor gradients - self.register_full_backward_hook(self._layer_gradient_hook) - logger.info(f"[HOOK REGISTRATION] Layer {self.layer_id} TransformerBlock gradient hook registered") + # self.register_full_backward_hook(self._layer_gradient_hook) + # logger.info(f"[HOOK REGISTRATION] Layer {self.layer_id} TransformerBlock gradient hook registered") def _layer_gradient_hook(self, module, grad_input, grad_output): """Backward hook to monitor gradients of all parameters in this layer.""" logger.info(f"[LAYER GRAD HOOK] Layer {self.layer_id} TransformerBlock backward pass") - # Collect gradient statistics for this layer - layer_grad_norms = [] - nan_params = [] - inf_params = [] total_params = 0 # Check gradients for all named parameters in this layer @@ -308,50 +304,32 @@ def _layer_gradient_hook(self, module, grad_input, grad_output): has_nan = torch.isnan(param.grad).any().item() has_inf = torch.isinf(param.grad).any().item() - # Calculate norm safely - try: - grad_norm = param.grad.norm().item() - # Check if norm overflowed to inf (but individual elements might be finite) - if torch.isinf(torch.tensor(grad_norm)) and not has_inf: - logger.info(f"[LAYER GRAD OVERFLOW] Layer {self.layer_id} - {name}: " - f"norm overflow to inf, max_val={param.grad.abs().max().item():.6e}, " - f"shape={param.grad.shape}") - except: - grad_norm = float('inf') - logger.info(f"[LAYER GRAD ERROR] Layer {self.layer_id} - {name}: failed to compute norm") - - layer_grad_norms.append(grad_norm) - if has_nan: - nan_params.append(name) - logger.info(f"[LAYER GRAD NaN] Layer {self.layer_id} - {name}: shape={param.grad.shape}") + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: " + f"norm=NaN, max=NaN, mean=NaN, shape={param.grad.shape}") elif has_inf: - inf_params.append(name) - logger.info(f"[LAYER GRAD INF] Layer {self.layer_id} - {name}: shape={param.grad.shape}") - elif torch.isinf(torch.tensor(grad_norm)): - logger.info(f"[LAYER GRAD LARGE] Layer {self.layer_id} - {name}: " - f"very large gradients, max={param.grad.abs().max().item():.6e}") + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: " + f"norm=Inf, max=Inf, mean=Inf, shape={param.grad.shape}") + else: + # Calculate gradient statistics safely + try: + grad_norm = param.grad.norm().item() + grad_max = param.grad.abs().max().item() + grad_mean = param.grad.mean().item() + + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: " + f"norm={grad_norm:.6e}, max={grad_max:.6e}, mean={grad_mean:.6e}, " + f"shape={param.grad.shape}") + except: + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: " + f"failed to compute stats, shape={param.grad.shape}") else: - logger.info(f"[LAYER GRAD] Layer {self.layer_id} - {name}: dtype={param.grad.dtype} (non-float grad)") + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: " + f"dtype={param.grad.dtype} (non-float grad)") else: - logger.info(f"[LAYER GRAD] Layer {self.layer_id} - {name}: grad is None") + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: grad is None") - # Compute and logger.info layer statistics - if layer_grad_norms: - avg_grad_norm = sum(layer_grad_norms) / len(layer_grad_norms) - max_grad_norm = max(layer_grad_norms) - min_grad_norm = min(layer_grad_norms) - - logger.info(f"[LAYER GRAD SUMMARY] Layer {self.layer_id}: " - f"avg_norm={avg_grad_norm:.6f}, max_norm={max_grad_norm:.6f}, " - f"min_norm={min_grad_norm:.6f}, total_params={total_params}") - else: - logger.info(f"[LAYER GRAD SUMMARY] Layer {self.layer_id}: No floating point gradients found, total_params={total_params}") - - if nan_params: - logger.info(f"[LAYER GRAD ERROR] Layer {self.layer_id} NaN gradients: {nan_params}") - if inf_params: - logger.info(f"[LAYER GRAD ERROR] Layer {self.layer_id} Inf gradients: {inf_params}") + logger.info(f"[LAYER GRAD SUMMARY] Layer {self.layer_id}: total_params={total_params}") def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): """ @@ -364,11 +342,22 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): Returns: torch.Tensor: Output tensor with the same shape as the input. """ - x = x + self.attention(self.attention_norm(x), freqs_cis) + t1 = self.attention_norm(x) + # t1.register_hook(create_tensor_hook("t_after_attention_norm")) + x = x + self.attention(t1, freqs_cis) + # x.register_hook(create_tensor_hook("t_after_attn")) if self.moe_enabled: - x = x + self.moe(self.ffn_norm(x)) + t = self.ffn_norm(x) + # t.register_hook(create_tensor_hook("t_after_ffn_norm")) + x = x + self.moe(t) + # x.register_hook(create_tensor_hook("x_after_moe")) + else: - x = x + self.feed_forward(self.ffn_norm(x)) + t = self.ffn_norm(x) + # t.register_hook(create_tensor_hook("t_after_ffn_norm")) + x = x + self.feed_forward(t) + # x.register_hook(create_tensor_hook("x_after_feedforward")) + return x def init_weights(self, buffer_device: torch.device): diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 95e76180a6..e9d4f21550 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -38,7 +38,7 @@ min_lr_factor = 0.1 local_batch_size = 1 seq_len = 4096 max_norm = 1.0 # grad norm clipping -steps = 100 +steps = 3 dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) deterministic = true diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 55b6737833..4a63f599a5 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -25,7 +25,7 @@ hf_assets_path = "./assets/hf/DeepSeek-V3.1-Base" [optimizer] name = "AdamW" -lr = 2.2e-5 +lr = 2.2e-6 eps = 1e-8 [lr_scheduler] @@ -67,7 +67,7 @@ mode = "full" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable = false +enable = true components = ["loss"] # ["model", "loss"] [quantize.dense.float8] diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 9f519dc04e..cab150cc9b 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools from dataclasses import dataclass from typing import Literal @@ -12,6 +13,77 @@ from torch import nn from torchtitan.distributed.expert_parallel import expert_parallel +from torchtitan.tools.logging import logger + + +def _tensor_gradient_hook_fn(tensor_name: str, grad): + """ + Utility function for tensor gradient hooks that prints gradient statistics. + This function signature matches the format expected by tensor.register_hook(). + + Args: + tensor_name (str): Name identifier for the tensor + grad (torch.Tensor): Gradient tensor + """ + if grad is None: + logger.info(f"[TENSOR GRAD] {tensor_name}: grad is None") + return + + if not (grad.dtype.is_floating_point or grad.dtype.is_complex): + logger.info(f"[TENSOR GRAD] {tensor_name}: dtype={grad.dtype} (non-float grad)") + return + + # Check for NaN and Inf elements + has_nan = torch.isnan(grad).any().item() + has_inf = torch.isinf(grad).any().item() + + if has_nan: + logger.info( + f"[TENSOR GRAD] {tensor_name}: " + f"mean=NaN, max=NaN, min=NaN, has_nan=True, has_inf={has_inf}, " + f"shape={grad.shape}" + ) + elif has_inf: + logger.info( + f"[TENSOR GRAD] {tensor_name}: " + f"mean=Inf, max=Inf, min=Inf, has_nan=False, has_inf=True, " + f"shape={grad.shape}" + ) + else: + # Calculate gradient statistics safely + try: + grad_mean = grad.mean().item() + grad_max = grad.max().item() + grad_min = grad.min().item() + + logger.info( + f"[TENSOR GRAD] {tensor_name}: " + f"mean={grad_mean:.6e}, max={grad_max:.6e}, min={grad_min:.6e}, " + f"has_nan=False, has_inf=False, shape={grad.shape}" + ) + except Exception as e: + logger.info( + f"[TENSOR GRAD] {tensor_name}: " + f"failed to compute stats: {e}, shape={grad.shape}" + ) + + +def create_tensor_hook(tensor_name: str): + """ + Utility function to create a tensor gradient hook using functools.partial. + This follows the pattern shown in the user's example. + + Args: + tensor_name (str): Name identifier for the tensor being hooked + + Returns: + Callable: Hook function that can be registered on a tensor using tensor.register_hook() + + Example usage: + hook_fn = create_tensor_hook("my_tensor") + tensor.register_hook(hook_fn) + """ + return functools.partial(_tensor_gradient_hook_fn, tensor_name) @dataclass @@ -370,14 +442,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bs, slen, dim = x.shape x = x.view(-1, dim) - # top_scores and selected_experts_indices shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) + # Register hook on input tensor + # x.register_hook(create_tensor_hook("moe_input")) + + # Router forward pass - compute expert scores and routing ( top_scores, selected_experts_indices, num_tokens_per_expert, ) = self.router(x, self.expert_bias) + # Register hooks on router outputs for gradient monitoring + # top_scores.register_hook(create_tensor_hook("router_top_scores")) + # tokens_per_expert will be used to update the expert bias for load balancing. # and also to count the expert usage # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- @@ -407,21 +484,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # shape (bs*slen*top_k, dim) routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + # routed_input.register_hook(create_tensor_hook("routed_input")) if self.score_before_experts: routed_input = ( routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) + # routed_input.register_hook(create_tensor_hook("routed_input_after_scoring")) # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) + # routed_output.register_hook(create_tensor_hook("experts_output")) # shared expert # Note: we execute the shared expert before scoring the output of the routed expert # to "implicitly" overlap the shared expert compute with token combine communication if self.shared_experts is not None: out = self.shared_experts(x) + # out.register_hook(create_tensor_hook("shared_experts_output")) else: out = torch.zeros_like(x) @@ -430,10 +511,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) + # routed_output.register_hook(create_tensor_hook("routed_output_after_scoring")) out = out.scatter_add( dim=0, index=token_indices_experts_sorted, src=routed_output ) + # out.register_hook(create_tensor_hook("moe_final_output")) + out = out.reshape(bs, slen, dim) return out diff --git a/torchtitan/train.py b/torchtitan/train.py index a91dc9262d..6685948ce9 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -410,7 +410,7 @@ def forward_backward_step( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor ) -> torch.Tensor: - torch.autograd.set_detect_anomaly(True) + # torch.autograd.set_detect_anomaly(True) model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -476,8 +476,8 @@ def forward_backward_step( loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory del pred - with torch.autograd.detect_anomaly(): - loss.backward() + # with torch.autograd.detect_anomaly(): + loss.backward() # Check for NaN/Inf gradients after backward pass print(f"[GRAD CHECK] Checking gradients after backward pass...")