diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 4fe3f01a66..ab31f603c7 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -293,6 +293,14 @@ def init_distributed( base_folder: str = "", ranks: list[int] | None = None, ) -> int: + # Skip initialization if already initialized + if torch.distributed.is_initialized(): + logger.warning( + "torch.distributed is already initialized. Skipping init_distributed. " + "The provided comm_config and other settings will not take effect." + ) + return torch.distributed.get_world_size() + if comm_config.mode in ("fake_backend", "local_tensor"): ngpu_str = os.environ.get("NGPU") if ngpu_str is None: diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 32130b6d5e..e034224019 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -185,11 +185,16 @@ def forward( *, scale: float | None = None, enable_gqa: bool = False, - is_casual: bool = True, + is_causal: bool = True, ) -> torch.Tensor: with sdpa_kernel(self.sdpa_backends, set_priority=True): return F.scaled_dot_product_attention( - q, k, v, scale=scale, is_causal=is_casual, enable_gqa=enable_gqa + q, + k, + v, + scale=scale, + is_causal=is_causal, + enable_gqa=enable_gqa, ) diff --git a/torchtitan/train.py b/torchtitan/train.py index 9378d742e3..ee6519ea3d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -46,7 +46,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # swappable training components in TrainSpec tokenizer: train_spec_module.BaseTokenizer | None - dataloader: train_spec_module.BaseDataLoader + dataloader: train_spec_module.BaseDataLoader | None # TODO: we should make this list[ModelProtocol] but this will affect many components. # will do this in a separate PR model_parts: list[torch.nn.Module] @@ -128,11 +128,15 @@ def __init__(self, job_config: JobConfig): else None ) - self.dataloader = self.train_spec.build_dataloader_fn( - dp_world_size=batch_degree, - dp_rank=batch_rank, - tokenizer=self.tokenizer, - job_config=job_config, + self.dataloader = ( + self.train_spec.build_dataloader_fn( + dp_world_size=batch_degree, + dp_rank=batch_rank, + tokenizer=self.tokenizer, + job_config=job_config, + ) + if self.train_spec.build_dataloader_fn is not None + else None ) # build model (using meta init)