Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class QuantizationType(str, Enum):
FP8_NANO_V2 = "fp8_nanoo"
FP8_GPU = "fp8_gpu"
FP8_FULL = "fp8_full"
TE = "te_noscaling"
TE_FP8_DS = "te_fp8_delayedscaling"
TE_FP8_CS = "te_fp8_currentscaling"
TE_MXFP8 = "te_mxfp8"
Expand Down Expand Up @@ -1780,7 +1781,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.global_batch_size_to_eval_on,
self.micro_batch_size_to_eval_on,
) = calculate_global_batch_sizes(
self.eval_per_device_batch_size, self.expansion_factor_real_data, self.num_target_devices, 1
self.eval_per_device_batch_size,
self.expansion_factor_real_data,
self.num_target_devices,
1,
)

# Calculate ramp-up batch size parameters if enabled.
Expand Down
5 changes: 4 additions & 1 deletion src/MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,14 +748,17 @@ def _get_recipe(recipe_name: str):
from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error

RECIPES = {
"te_noscaling": None,
"te_fp8_delayedscaling": recipe.DelayedScaling,
"te_fp8_currentscaling": recipe.Float8CurrentScaling,
"te_mxfp8": recipe.MXFP8BlockScaling,
"te_nvfp4": recipe.NVFP4BlockScaling, # pytype: disable=module-attr
}
if recipe_name not in RECIPES:
raise ValueError(f"Invalid TransformerEngine recipe: {recipe_name}")
return RECIPES[recipe_name]()

te_recipe = RECIPES[recipe_name]
return te_recipe() if te_recipe is not None else None

def get_block_size(self):
"""Get the block size for quantization for recipes that require blocks.
Expand Down
17 changes: 17 additions & 0 deletions tests/integration_tests/train_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,18 @@ class TrainTests(unittest.TestCase):
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
],
"te_noscaling": [ # tests base config with te_noscaling i.e. BF16
None,
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
"base_output_directory=gs://runner-maxtext-logs",
"run_name=runner_test",
"dataset_path=gs://maxtext-dataset",
"quantization=te_noscaling",
"steps=2",
"enable_checkpointing=False",
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
],
"te_fp8_delayedscaling": [ # tests base config with te_fp8_delayedscaling
None,
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
Expand Down Expand Up @@ -234,6 +246,11 @@ def test_gpu_fp8(self):
def test_gpu_nanoo_fp8(self):
train_main(TrainTests.CONFIGS["nanoo_fp8"] + ["attention=dot_product"])

@pytest.mark.integration_test
@pytest.mark.gpu_only
def test_gpu_te_noscaling(self):
train_main(TrainTests.CONFIGS["te_noscaling"] + ["attention=cudnn_flash_te"])

@pytest.mark.skip(reason="No runner with GPU arch >= 89 is available")
@pytest.mark.integration_test
@pytest.mark.gpu_only
Expand Down
Loading