From 072a12cf69c7cb982eebd362e5f616cfc37cff9e Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 6 Feb 2026 15:31:43 -0800 Subject: [PATCH] Add run-to-run determinism testing to H100 CI This adds automatic run-to-run determinism verification for H100 integration tests. Tests marked with `determinism_test=True` will run twice with identical configuration and deterministic flags, then compare losses to ensure they match exactly. The core loss extraction logic is factored into `torchtitan/tools/loss_utils.py` and shared between the integration test runner and the existing `loss_compare.py` script. The scripts directory is now a package to enable clean imports via `python -m scripts.loss_compare`. The Float8 and HSDP+CP+compile+Float8 tests in the H100 suite are enabled for determinism testing (CUDA only). The `--run-to-run-determinism` flag in loss_compare.py now explicitly validates that no test-specific options are provided, raising a ValueError if they are. Co-authored-by: Claude stack-info: PR: https://github.com/pytorch/torchtitan/pull/2339, branch: xmfan/stack/11 --- .../integration_test_8gpu_features.yaml | 6 +- .../workflows/integration_test_8gpu_h100.yaml | 2 +- CLAUDE.md | 10 +++ scripts/__init__.py | 5 ++ scripts/loss_compare.py | 58 +++++++----- tests/integration_tests/__init__.py | 1 + tests/integration_tests/h100.py | 2 + tests/integration_tests/run_tests.py | 85 ++++++++++++++++++ torchtitan/tools/loss_utils.py | 90 +++++++++++++++++++ 9 files changed, 235 insertions(+), 24 deletions(-) create mode 100644 CLAUDE.md create mode 100644 scripts/__init__.py create mode 100644 torchtitan/tools/loss_utils.py diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index 4e094d4eb1..014382041a 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -91,14 +91,14 @@ jobs: exit 1 fi - python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --export-result="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs/result.txt" --steps=100 + python -m scripts.loss_compare . . --baseline-options="${baseline_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --export-result="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs/result.txt" --steps=100 echo "Checking FSDP8 the first tep loss is the same as FSDP2HSDP4" - python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=1 + python -m scripts.loss_compare . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=1 rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/* echo "Checking FSDP8 loss from a new run v.s. FSDP8 loss from text file parity" - python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --import-result="${LOSS_FILE}" --assert-equal --steps=100 + python -m scripts.loss_compare . . --baseline-options="${baseline_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --import-result="${LOSS_FILE}" --assert-equal --steps=100 rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/* python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 diff --git a/.github/workflows/integration_test_8gpu_h100.yaml b/.github/workflows/integration_test_8gpu_h100.yaml index 87da5b47d0..15bb13583e 100644 --- a/.github/workflows/integration_test_8gpu_h100.yaml +++ b/.github/workflows/integration_test_8gpu_h100.yaml @@ -71,5 +71,5 @@ jobs: sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" # Enable CPP stacktraces for debugging symmetric memory initialization errors. - TORCH_SHOW_CPP_STACKTRACES=1 python -m tests.integration_tests.run_tests --test_suite h100 $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 + TORCH_SHOW_CPP_STACKTRACES=1 python -m tests.integration_tests.run_tests --test_suite h100 --gpu_arch_type ${{ matrix.gpu-arch-type }} $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..627b7ce07d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,10 @@ +# Commit messages + +Don't commit unless the user explicitly asks you to. + +When writing a commit message, don't make a bullet list of the individual +changes. Instead, if the PR is large, explain the order to review changes +(e.g., the logical progression), or if it's short just omit the bullet list +entirely. + +Disclose that the PR was authored with Claude. diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/scripts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py index 0ccd5da240..73897bd73c 100644 --- a/scripts/loss_compare.py +++ b/scripts/loss_compare.py @@ -69,6 +69,8 @@ import unittest from typing import Any +from torchtitan.tools.loss_utils import extract_losses_from_log + # ============================================================================= # GLOBAL CONFIGURATION # ============================================================================= @@ -186,6 +188,7 @@ def validate_arguments( assert_equal: bool, export_result: str | None, import_result: str | None, + run_to_run_determinism: bool = False, ) -> bool: """Validate command line arguments. @@ -193,6 +196,10 @@ def validate_arguments( True if baseline-only mode (all settings identical with import_result), False otherwise. """ + # Skip identical settings check for run-to-run determinism testing + if run_to_run_determinism: + return False # Not baseline-only mode + # Validate that we are comparing different settings commits_differ = baseline_commit != test_commit configs_differ = baseline_config != test_config @@ -224,8 +231,8 @@ def validate_arguments( "or options" ) log_print( - " Or use --import-result with --assert-equal " - "or --export-result to run baseline-only mode" + " Or use --import-result with --assert-equal, " + "--export-result, or --run-to-run-determinism" ) sys.exit(1) @@ -511,24 +518,6 @@ def run_training( # ============================================================================= -def extract_losses_from_log(log_file: str) -> dict[int, float]: - """Extract step and loss pairs from a log file.""" - losses = {} - step_loss_pattern = re.compile(r"step:\s*(\d+)\s*loss:\s*(\d+\.\d+)") - ansi_escape = re.compile(r"\x1b\[[0-9;]*m") - - with open(log_file, "r") as f: - for line in f: - # Strip ANSI codes before matching - clean_line = ansi_escape.sub("", line) - match = step_loss_pattern.search(clean_line) - if match: - step, loss = match.groups() - losses[int(step)] = float(loss) - - return losses - - def read_losses_from_file(loss_file: str) -> dict[int, float]: """Read losses from a processed loss file.""" losses = {} @@ -1002,9 +991,37 @@ def parse_arguments() -> argparse.Namespace: default=8, help="Number of GPUs for test run (default: 8)", ) + parser.add_argument( + "--run-to-run-determinism", + action="store_true", + help=( + "Test run-to-run determinism by running the same configuration twice. " + "Implies --assert-equal. Only baseline options should be provided; " + "test-specific options (--test-config, --test-options, etc.) are not allowed." + ), + ) args = parser.parse_args() + # Handle run-to-run determinism mode (must be before defaults are set) + if args.run_to_run_determinism: + # Validate that no test-specific options are provided + has_test_opts = ( + args.test_config + or args.test_options + or args.test_train_file + or args.test_ngpus != args.baseline_ngpus + ) + if has_test_opts: + raise ValueError( + "--run-to-run-determinism cannot be used with test-specific options " + "(--test-config, --test-options, --test-train-file, --test-ngpus)" + ) + + # Force assert_equal and copy baseline options to test + args.assert_equal = True + args.test_options = args.baseline_options + # Set default values if not provided if not args.test_config: args.test_config = args.baseline_config @@ -1090,6 +1107,7 @@ def main() -> None: args.assert_equal, args.export_result, args.import_result, + args.run_to_run_determinism, ) # Setup environment diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py index 7d676090e0..0833e34546 100644 --- a/tests/integration_tests/__init__.py +++ b/tests/integration_tests/__init__.py @@ -24,6 +24,7 @@ class OverrideDefinitions: ngpu: int = 4 disabled: bool = False skip_rocm_test: bool = False + determinism_test: bool = False # Run twice and verify losses are identical def __repr__(self): return self.test_descr diff --git a/tests/integration_tests/h100.py b/tests/integration_tests/h100.py index cd75548e56..c1850f4459 100755 --- a/tests/integration_tests/h100.py +++ b/tests/integration_tests/h100.py @@ -42,6 +42,7 @@ def build_h100_tests_list() -> list[OverrideDefinitions]: ], "Float8 test", "float8", + determinism_test=True, ), # TODO: re-enable this test once the async TP issue is fixed OverrideDefinitions( @@ -77,6 +78,7 @@ def build_h100_tests_list() -> list[OverrideDefinitions]: "HSDP+CP+torch.compile+Float8", "hsdp+cp+compile+float8", ngpu=8, + determinism_test=True, ), ] return integration_tests_flavors diff --git a/tests/integration_tests/run_tests.py b/tests/integration_tests/run_tests.py index 77851bd4a0..6ff0895f48 100644 --- a/tests/integration_tests/run_tests.py +++ b/tests/integration_tests/run_tests.py @@ -7,9 +7,11 @@ import argparse import os import subprocess +import tempfile import time from torchtitan.tools.logging import logger +from torchtitan.tools.loss_utils import compare_losses, extract_losses_from_log from tests.integration_tests import OverrideDefinitions @@ -71,6 +73,82 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir ) +def run_determinism_test( + test_flavor: OverrideDefinitions, full_path: str, output_dir: str +): + """Run a test twice and verify losses are identical (run-to-run determinism). + + This runs the same configuration twice with deterministic settings enabled, + then compares the losses from both runs to ensure they match exactly. + """ + test_name = test_flavor.test_name + all_ranks = ",".join(map(str, range(test_flavor.ngpu))) + + # Build the base command with determinism flags + override_arg = test_flavor.override_args[0] if test_flavor.override_args else [] + override_str = " ".join(override_arg) if override_arg else "" + + base_cmd = ( + f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " + f"./run_train.sh --job.dump_folder {output_dir}/{test_name}_determinism " + f"--debug.deterministic --debug.seed=42 --training.steps=10" + ) + if override_str: + base_cmd += " " + override_str + + logger.info( + f"===== {time.strftime('%Y-%m-%d %H:%M:%S')} Determinism test, flavor : {test_flavor.test_descr} =====" + ) + + # Create temp files for logs + with tempfile.NamedTemporaryFile( + mode="w", suffix="_run1.log", delete=False + ) as log1_file: + log1_path = log1_file.name + with tempfile.NamedTemporaryFile( + mode="w", suffix="_run2.log", delete=False + ) as log2_file: + log2_path = log2_file.name + + try: + # Run 1 + logger.info(f"Determinism test run 1: {base_cmd}") + cmd1 = f"{base_cmd} 2>&1 | tee {log1_path}" + result1 = _run_cmd(cmd1) + if result1.returncode != 0: + raise Exception( + f"Determinism test run 1 failed, flavor : {test_flavor.test_descr}" + ) + + # Run 2 + logger.info(f"Determinism test run 2: {base_cmd}") + cmd2 = f"{base_cmd} 2>&1 | tee {log2_path}" + result2 = _run_cmd(cmd2) + if result2.returncode != 0: + raise Exception( + f"Determinism test run 2 failed, flavor : {test_flavor.test_descr}" + ) + + # Extract and compare losses + losses1 = extract_losses_from_log(log1_path) + losses2 = extract_losses_from_log(log2_path) + + success, message = compare_losses(losses1, losses2, "run1", "run2") + if not success: + raise Exception( + f"Determinism test failed for {test_flavor.test_descr}: {message}" + ) + + logger.info(f"Determinism test passed for {test_flavor.test_descr}: {message}") + + finally: + # Clean up temp files + if os.path.exists(log1_path): + os.remove(log1_path) + if os.path.exists(log2_path): + os.remove(log2_path) + + def run_tests(args, test_list: list[OverrideDefinitions]): """Run all integration tests to test the core features of TorchTitan""" @@ -106,6 +184,13 @@ def run_tests(args, test_list: list[OverrideDefinitions]): run_single_test(test_flavor, args.config_path, args.output_dir) ran_any_test = True + # Run determinism test if enabled (CUDA only) + if ( + test_flavor.determinism_test + and getattr(args, "gpu_arch_type", "cuda") == "cuda" + ): + run_determinism_test(test_flavor, args.config_path, args.output_dir) + if not ran_any_test: available_tests = [t.test_name for t in test_list if not t.disabled] if hasattr(args, "test_suite"): diff --git a/torchtitan/tools/loss_utils.py b/torchtitan/tools/loss_utils.py new file mode 100644 index 0000000000..8d8ec4d38e --- /dev/null +++ b/torchtitan/tools/loss_utils.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared utilities for loss extraction and comparison. + +This module provides common functionality used by both: +- scripts/loss_compare.py (CLI tool for comparing losses across commits) +- tests/integration_tests/run_tests.py (integration test runner) +""" + +import re + + +def extract_losses_from_log(log_file: str) -> dict[int, float]: + """Extract step and loss pairs from a training log file. + + Parses log lines matching the pattern: "step: N loss: X.XXXX" + Handles ANSI escape codes that may be present in colored terminal output. + + Args: + log_file: Path to the training log file + + Returns: + Dictionary mapping step numbers to loss values + """ + losses = {} + step_loss_pattern = re.compile(r"step:\s*(\d+)\s*loss:\s*(\d+\.\d+)") + ansi_escape = re.compile(r"\x1b\[[0-9;]*m") + + with open(log_file, "r") as f: + for line in f: + # Strip ANSI codes before matching + clean_line = ansi_escape.sub("", line) + match = step_loss_pattern.search(clean_line) + if match: + step, loss = match.groups() + losses[int(step)] = float(loss) + + return losses + + +def compare_losses( + losses1: dict[int, float], + losses2: dict[int, float], + name1: str = "run1", + name2: str = "run2", +) -> tuple[bool, str]: + """Compare two loss dictionaries for equality. + + Args: + losses1: First loss dictionary (step -> loss) + losses2: Second loss dictionary (step -> loss) + name1: Name for first run (for error messages) + name2: Name for second run (for error messages) + + Returns: + Tuple of (success: bool, message: str) + - success is True if all losses match exactly + - message contains details about the comparison or mismatch + """ + if not losses1: + return False, f"No losses found in {name1}" + + if not losses2: + return False, f"No losses found in {name2}" + + steps1 = set(losses1.keys()) + steps2 = set(losses2.keys()) + + if steps1 != steps2: + return False, ( + f"Steps mismatch: {name1} has {len(steps1)} steps, " + f"{name2} has {len(steps2)} steps" + ) + + mismatches = [] + for step in sorted(steps1): + loss1 = losses1[step] + loss2 = losses2[step] + if loss1 != loss2: + mismatches.append(f" step {step}: {name1}={loss1}, {name2}={loss2}") + + if mismatches: + return False, "Loss mismatches:\n" + "\n".join(mismatches) + + return True, f"All {len(steps1)} steps have identical losses"