Skip to content

Conversation

@tushar00jain
Copy link
Contributor

@tushar00jain tushar00jain commented Feb 2, 2026

Summary:

  • extract logic in train.py that was different for ft in separate functions
  • override these functions in ft's train.py

Test Plan:
Run lighthouse and two replicas

RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 2 --quorum_tick_ms 100 --join_timeout_ms 10000

TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0


TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=1 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 2, 2026
@tushar00jain tushar00jain force-pushed the pr2311 branch 5 times, most recently from d6a30aa to be416e3 Compare February 2, 2026 16:52
@tushar00jain tushar00jain marked this pull request as ready for review February 2, 2026 16:52
@tushar00jain tushar00jain force-pushed the pr2311 branch 3 times, most recently from 751c4bc to f012f36 Compare February 2, 2026 18:55
@tushar00jain tushar00jain force-pushed the pr2311 branch 3 times, most recently from f4dcc0d to 2f05a79 Compare February 4, 2026 15:42
world_size=world_size,
)

def get_dp_info(self, batch_degree: int, batch_rank: int) -> tuple[int, int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the whole point is for ft trainer not to influence how main trainer API looks like. These functions are there purely because ft trainer can override, which sounds the opposite.

The point is that, ft and main cannot share all the code, so they probably shouldn't. E.g. the Flux trainer has its own logic. I feel the ft trainer can just rewrite, instead of trying to share as much as possible.

Copy link
Contributor Author

@tushar00jain tushar00jain Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_degree is being used a lot which will require pretty much just copy pasting all off init.py. lmk if you'd still prefer that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, whether a change makes sense or not depending on whether the function is too intrusive. This doesn't make sense because this function is not needed for the regular trainer. I'm okay if this method contains the original logic in init():

        def get_batch_dim_info(parallel_dims):
            if parallel_dims.dp_enabled:
                batch_mesh = parallel_dims.get_mesh("batch")
                batch_degree, batch_rank = batch_mesh.size(), batch_mesh.get_local_rank()
            else:
                batch_degree, batch_rank = 1, 0
            return batch_degree, batch_rank

This is just my thought.

world_size=world_size,
)

def get_dp_info(self, batch_degree: int, batch_rank: int) -> tuple[int, int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, whether a change makes sense or not depending on whether the function is too intrusive. This doesn't make sense because this function is not needed for the regular trainer. I'm okay if this method contains the original logic in init():

        def get_batch_dim_info(parallel_dims):
            if parallel_dims.dp_enabled:
                batch_mesh = parallel_dims.get_mesh("batch")
                batch_degree, batch_rank = batch_mesh.size(), batch_mesh.get_local_rank()
            else:
                batch_degree, batch_rank = 1, 0
            return batch_degree, batch_rank

This is just my thought.

@tushar00jain tushar00jain force-pushed the pr2311 branch 2 times, most recently from 512d506 to 5f93807 Compare February 6, 2026 11:56
@tushar00jain tushar00jain force-pushed the pr2311 branch 3 times, most recently from 89a5d90 to 7022f5e Compare February 6, 2026 16:31
Summary:
- extract logic in train.py that was different for ft in separate functions
- override these functions in ft's train.py

Test Plan:
Run lighthouse and two replicas

```bash
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 2 --quorum_tick_ms 100 --join_timeout_ms 10000

TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0


TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=1 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1
```
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this PR introduces two extra methods, compute_global_losses and get_dp_info while removing all the FT from the main trainer. I think this is a balanced compromise. @tianyu-l any thought?

)

def get_dp_info(self) -> tuple[int, int]:
""""""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to delete?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants