diff --git a/dockerfile/torch1.14-cuda11.8.dockerfile b/dockerfile/torch1.14-cuda11.8.dockerfile index 8b159501..0ece04ea 100644 --- a/dockerfile/torch1.14-cuda11.8.dockerfile +++ b/dockerfile/torch1.14-cuda11.8.dockerfile @@ -50,7 +50,7 @@ RUN cd third_party/msccl && \ make install # cache TE build to save time in CI RUN python3 -m pip install --upgrade pip && \ - python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@v1.1 ADD . . RUN python3 -m pip install . && \ diff --git a/dockerfile/torch2.1-cuda12.2.dockerfile b/dockerfile/torch2.1-cuda12.2.dockerfile index 3fd0a0e8..4af75a29 100644 --- a/dockerfile/torch2.1-cuda12.2.dockerfile +++ b/dockerfile/torch2.1-cuda12.2.dockerfile @@ -50,7 +50,7 @@ RUN cd third_party/msccl && \ make install # cache TE build to save time in CI RUN python3 -m pip install --upgrade pip && \ - python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@v1.1 ADD . . RUN python3 -m pip install . && \ diff --git a/msamp/te/modules.py b/msamp/te/modules.py index 415cee86..20c41057 100644 --- a/msamp/te/modules.py +++ b/msamp/te/modules.py @@ -62,7 +62,7 @@ def set_fp8_weights(self): weight_cast_attr = f'weight{i}_fp8' weight_transpose_attr = f'weight{i}_t_fp8' - if (hasattr(self, weight_cast_attr) and getattr(self, weight_cast_attr).shape == shape): + if (hasattr(self, weight_cast_attr) and getattr(self, weight_cast_attr)._data.shape == shape): return setattr( diff --git a/pyproject.toml b/pyproject.toml index 0a988102..7382c59b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers=[ ] dependencies = [ "torch", - "transformer-engine@git+https://github.com/NVIDIA/TransformerEngine.git@stable", + "transformer-engine@git+https://github.com/NVIDIA/TransformerEngine.git@v1.1", "colorlog>=6.7.0", "deepspeed==0.13.1", "mpi4py", diff --git a/tests/te/test_te_replacer.py b/tests/te/test_te_replacer.py index f336c776..4071de1b 100644 --- a/tests/te/test_te_replacer.py +++ b/tests/te/test_te_replacer.py @@ -164,5 +164,8 @@ def test_fp8_ddp_with_te(self): fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max') with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - output = model(x, attention_mask=None) + output = model(x, attention_mask=None, is_first_microbatch=True) + output.sum().backward() + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + output = model(x, attention_mask=None, is_first_microbatch=False) output.sum().backward()