Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ def init_distributed(
base_folder: str = "",
ranks: list[int] | None = None,
) -> int:
# Skip initialization if already initialized
if torch.distributed.is_initialized():
logger.warning(
"torch.distributed is already initialized. Skipping init_distributed. "
"The provided comm_config and other settings will not take effect."
)
return torch.distributed.get_world_size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If going this path, it means a lot of the setting in this function / config won't take effect. Shall we add a warning to users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, will add a warning. But I think this also means user initialize distributed env somewhere else with their own settings.


if comm_config.mode in ("fake_backend", "local_tensor"):
ngpu_str = os.environ.get("NGPU")
if ngpu_str is None:
Expand Down
9 changes: 7 additions & 2 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,16 @@ def forward(
*,
scale: float | None = None,
enable_gqa: bool = False,
is_casual: bool = True,
is_causal: bool = True,
) -> torch.Tensor:
with sdpa_kernel(self.sdpa_backends, set_priority=True):
return F.scaled_dot_product_attention(
q, k, v, scale=scale, is_causal=is_casual, enable_gqa=enable_gqa
q,
k,
v,
scale=scale,
is_causal=is_causal,
enable_gqa=enable_gqa,
)


Expand Down
14 changes: 9 additions & 5 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,15 @@ def __init__(self, job_config: JobConfig):
else None
)

self.dataloader = self.train_spec.build_dataloader_fn(
dp_world_size=batch_degree,
dp_rank=batch_rank,
tokenizer=self.tokenizer,
job_config=job_config,
self.dataloader = (
self.train_spec.build_dataloader_fn(
dp_world_size=batch_degree,
dp_rank=batch_rank,
tokenizer=self.tokenizer,
job_config=job_config,
)
if self.train_spec.build_dataloader_fn is not None
Copy link
Contributor

@tianyu-l tianyu-l Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to https://github.com/pytorch/torchtitan/blob/main/torchtitan/protocols/train_spec.py#L51, it can't be None

I'm OK with type change, but you'll need to assert not None in torchtitan before it's used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using verl's dataloader, i don't want to initialize titan's dataloader(otherwise I will encounter loading c4 dataset error). I did a hack here https://github.com/verl-project/verl/pull/5051/changes#diff-f658afe18d14b480f4067f7544fbdb0ef6962a20ef3b5f5d0c709ae31e91809dR101

else None
)

# build model (using meta init)
Expand Down
Loading