diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 22ef5af07..23ffda548 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -78,6 +78,7 @@ class QuantizationType(str, Enum): FP8_NANO_V2 = "fp8_nanoo" FP8_GPU = "fp8_gpu" FP8_FULL = "fp8_full" + TE = "te_noscaling" TE_FP8_DS = "te_fp8_delayedscaling" TE_FP8_CS = "te_fp8_currentscaling" TE_MXFP8 = "te_mxfp8" @@ -1780,7 +1781,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.global_batch_size_to_eval_on, self.micro_batch_size_to_eval_on, ) = calculate_global_batch_sizes( - self.eval_per_device_batch_size, self.expansion_factor_real_data, self.num_target_devices, 1 + self.eval_per_device_batch_size, + self.expansion_factor_real_data, + self.num_target_devices, + 1, ) # Calculate ramp-up batch size parameters if enabled. diff --git a/src/MaxText/layers/quantizations.py b/src/MaxText/layers/quantizations.py index d0f9353b6..eafdd99ed 100644 --- a/src/MaxText/layers/quantizations.py +++ b/src/MaxText/layers/quantizations.py @@ -748,6 +748,7 @@ def _get_recipe(recipe_name: str): from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error RECIPES = { + "te_noscaling": None, "te_fp8_delayedscaling": recipe.DelayedScaling, "te_fp8_currentscaling": recipe.Float8CurrentScaling, "te_mxfp8": recipe.MXFP8BlockScaling, @@ -755,7 +756,9 @@ def _get_recipe(recipe_name: str): } if recipe_name not in RECIPES: raise ValueError(f"Invalid TransformerEngine recipe: {recipe_name}") - return RECIPES[recipe_name]() + + te_recipe = RECIPES[recipe_name] + return te_recipe() if te_recipe is not None else None def get_block_size(self): """Get the block size for quantization for recipes that require blocks. diff --git a/tests/integration_tests/train_tests.py b/tests/integration_tests/train_tests.py index d29338e50..f09376759 100644 --- a/tests/integration_tests/train_tests.py +++ b/tests/integration_tests/train_tests.py @@ -109,6 +109,18 @@ class TrainTests(unittest.TestCase): "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", ], + "te_noscaling": [ # tests base config with te_noscaling i.e. BF16 + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + "base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + "dataset_path=gs://maxtext-dataset", + "quantization=te_noscaling", + "steps=2", + "enable_checkpointing=False", + "enable_goodput_recording=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", + ], "te_fp8_delayedscaling": [ # tests base config with te_fp8_delayedscaling None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), @@ -234,6 +246,11 @@ def test_gpu_fp8(self): def test_gpu_nanoo_fp8(self): train_main(TrainTests.CONFIGS["nanoo_fp8"] + ["attention=dot_product"]) + @pytest.mark.integration_test + @pytest.mark.gpu_only + def test_gpu_te_noscaling(self): + train_main(TrainTests.CONFIGS["te_noscaling"] + ["attention=cudnn_flash_te"]) + @pytest.mark.skip(reason="No runner with GPU arch >= 89 is available") @pytest.mark.integration_test @pytest.mark.gpu_only