Skip to content

Commit 384a664

Browse files
committed
add TE BF16 and unit test
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
1 parent 8c91e8f commit 384a664

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

src/MaxText/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class QuantizationType(str, Enum):
7878
FP8_NANO_V2 = "fp8_nanoo"
7979
FP8_GPU = "fp8_gpu"
8080
FP8_FULL = "fp8_full"
81+
TE = "te_noscaling"
8182
TE_FP8_DS = "te_fp8_delayedscaling"
8283
TE_FP8_CS = "te_fp8_currentscaling"
8384
TE_MXFP8 = "te_mxfp8"

src/MaxText/layers/quantizations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,14 +748,17 @@ def _get_recipe(recipe_name: str):
748748
from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error
749749

750750
RECIPES = {
751+
"te_noscaling": None,
751752
"te_fp8_delayedscaling": recipe.DelayedScaling,
752753
"te_fp8_currentscaling": recipe.Float8CurrentScaling,
753754
"te_mxfp8": recipe.MXFP8BlockScaling,
754755
"te_nvfp4": recipe.NVFP4BlockScaling, # pytype: disable=module-attr
755756
}
756757
if recipe_name not in RECIPES:
757758
raise ValueError(f"Invalid TransformerEngine recipe: {recipe_name}")
758-
return RECIPES[recipe_name]()
759+
760+
te_recipe = RECIPES[recipe_name]
761+
return te_recipe() if te_recipe is not None else None
759762

760763
def get_block_size(self):
761764
"""Get the block size for quantization for recipes that require blocks.

tests/integration_tests/train_tests.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ class TrainTests(unittest.TestCase):
109109
"enable_goodput_recording=False",
110110
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
111111
],
112+
"te_noscaling": [ # tests base config with te_noscaling i.e. BF16
113+
None,
114+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
115+
"base_output_directory=gs://runner-maxtext-logs",
116+
"run_name=runner_test",
117+
"dataset_path=gs://maxtext-dataset",
118+
"quantization=te_noscaling",
119+
"steps=2",
120+
"enable_checkpointing=False",
121+
"enable_goodput_recording=False",
122+
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
123+
],
112124
"te_fp8_delayedscaling": [ # tests base config with te_fp8_delayedscaling
113125
None,
114126
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
@@ -234,6 +246,11 @@ def test_gpu_fp8(self):
234246
def test_gpu_nanoo_fp8(self):
235247
train_main(TrainTests.CONFIGS["nanoo_fp8"] + ["attention=dot_product"])
236248

249+
@pytest.mark.integration_test
250+
@pytest.mark.gpu_only
251+
def test_gpu_te_noscaling(self):
252+
train_main(TrainTests.CONFIGS["te_noscaling"] + ["attention=cudnn_flash_te"])
253+
237254
@pytest.mark.skip(reason="No runner with GPU arch >= 89 is available")
238255
@pytest.mark.integration_test
239256
@pytest.mark.gpu_only

0 commit comments

Comments
 (0)