Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,16 @@ def dcp_load(
self.sd_adapter is not None
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
hf_state_dict = self.sd_adapter.to_hf(state_dict)
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(checkpoint_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this is cleaner!


begin_load = time.monotonic()
logger.info(f"Starting dcp.load with {hf_storage_reader}")
dcp.load(
hf_state_dict,
storage_reader=HuggingFaceStorageReader(path=checkpoint_id),
storage_reader=hf_storage_reader,
)
logger.info(
f"dcp.load with HuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds"
)

state_dict = self.sd_adapter.from_hf(hf_state_dict)
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/distributed/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def _custom_policy(ctx, func, *args, **kwargs):
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
if args[1].shape in mm_recompute_shapes:
return CheckpointPolicy.PREFER_RECOMPUTE
# return CheckpointPolicy.PREFER_RECOMPUTE
return CheckpointPolicy.MUST_SAVE # TODO(jianiw): testing
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in op_sac_save_list and not (
Expand Down
1 change: 1 addition & 0 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def wrapper(
num_ep_ranks,
padded_max_len,
TOKEN_GROUP_ALIGN_SIZE_M,
use_cpu=True
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
Expand Down
5 changes: 5 additions & 0 deletions torchtitan/experiments/qwen3/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import re
from typing import Any

from torch.distributed.checkpoint import HuggingFaceStorageReader

from torchtitan.protocols.state_dict_adapter import StateDictAdapter

from .args import Qwen3ModelArgs
Expand Down Expand Up @@ -45,6 +47,9 @@ def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None):
"lm_head.weight": "output.weight",
}

def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader:
return HuggingFaceStorageReader(path)

def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:

to_hf_map = {v: k for k, v in self.from_hf_map.items()}
Expand Down
Empty file.
10 changes: 6 additions & 4 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@
attn_mask_type="block_causal",
),
"16B": DeepSeekV3ModelArgs(
vocab_size=102400,
vocab_size=163840,
dim=2048,
inter_dim=10944,
inter_dim=11264,
moe_inter_dim=1408,
n_layers=27,
n_dense_layers=1,
Expand All @@ -92,15 +92,16 @@
score_func="softmax",
route_norm=False,
score_before_experts=False,
use_grouped_mm=False,
),
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
use_flex_attn=True,
attn_mask_type="block_causal",
# use_flex_attn=True,
# attn_mask_type="block_causal",
),
"236B": DeepSeekV3ModelArgs(
vocab_size=102400,
Expand Down Expand Up @@ -155,6 +156,7 @@
v_head_dim=128,
use_flex_attn=True,
attn_mask_type="block_causal",
hf_weight_quantized=True,
),
}

Expand Down
3 changes: 3 additions & 0 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
beta_slow: int = 1
mscale: float = 1.0

# HF checkpoint args
hf_weight_quantized: bool = False

def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
seq_len = job_config.training.seq_len
if seq_len > self.max_seq_len:
Expand Down
66 changes: 62 additions & 4 deletions torchtitan/models/deepseek_v3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from torch import nn

from torchtitan.models.attention import build_attention
from torchtitan.models.moe import FeedForward, MoE
from torchtitan.models.moe import FeedForward, MoE, create_tensor_hook
from torchtitan.protocols.train_spec import ModelProtocol

from .args import DeepSeekV3ModelArgs
from torchtitan.tools.logging import logger


# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294
Expand Down Expand Up @@ -284,6 +285,52 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs):
self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
self.layer_id = layer_id

# Register backward hook to monitor gradients
# self.register_full_backward_hook(self._layer_gradient_hook)
# logger.info(f"[HOOK REGISTRATION] Layer {self.layer_id} TransformerBlock gradient hook registered")

def _layer_gradient_hook(self, module, grad_input, grad_output):
"""Backward hook to monitor gradients of all parameters in this layer."""
logger.info(f"[LAYER GRAD HOOK] Layer {self.layer_id} TransformerBlock backward pass")

total_params = 0

# Check gradients for all named parameters in this layer
for name, param in self.named_parameters():
total_params += 1
if param.grad is not None:
if param.grad.dtype.is_floating_point or param.grad.dtype.is_complex:
# Check for NaN and Inf elements first
has_nan = torch.isnan(param.grad).any().item()
has_inf = torch.isinf(param.grad).any().item()

if has_nan:
logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: "
f"norm=NaN, max=NaN, mean=NaN, shape={param.grad.shape}")
elif has_inf:
logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: "
f"norm=Inf, max=Inf, mean=Inf, shape={param.grad.shape}")
else:
# Calculate gradient statistics safely
try:
grad_norm = param.grad.norm().item()
grad_max = param.grad.abs().max().item()
grad_mean = param.grad.mean().item()

logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: "
f"norm={grad_norm:.6e}, max={grad_max:.6e}, mean={grad_mean:.6e}, "
f"shape={param.grad.shape}")
except:
logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: "
f"failed to compute stats, shape={param.grad.shape}")
else:
logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: "
f"dtype={param.grad.dtype} (non-float grad)")
else:
logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: grad is None")

logger.info(f"[LAYER GRAD SUMMARY] Layer {self.layer_id}: total_params={total_params}")

def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
"""
Forward pass for the Transformer block.
Expand All @@ -295,11 +342,22 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
x = x + self.attention(self.attention_norm(x), freqs_cis)
t1 = self.attention_norm(x)
# t1.register_hook(create_tensor_hook("t_after_attention_norm"))
x = x + self.attention(t1, freqs_cis)
# x.register_hook(create_tensor_hook("t_after_attn"))
if self.moe_enabled:
x = x + self.moe(self.ffn_norm(x))
t = self.ffn_norm(x)
# t.register_hook(create_tensor_hook("t_after_ffn_norm"))
x = x + self.moe(t)
# x.register_hook(create_tensor_hook("x_after_moe"))

else:
x = x + self.feed_forward(self.ffn_norm(x))
t = self.ffn_norm(x)
# t.register_hook(create_tensor_hook("t_after_ffn_norm"))
x = x + self.feed_forward(t)
# x.register_hook(create_tensor_hook("x_after_feedforward"))

return x

def init_weights(self, buffer_device: torch.device):
Expand Down
73 changes: 0 additions & 73 deletions torchtitan/models/deepseek_v3/model/quantization.py

This file was deleted.

Loading