Skip to content

Commit

Permalink
[Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Browse files Browse the repository at this point in the history
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
  • Loading branch information
Edenzzzz and Edenzzzz authored Aug 9, 2024
1 parent ad3fa4f commit b4d2377
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def forward(self, input):
return output

except ImportError:
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")

FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,
Expand Down Expand Up @@ -270,12 +270,6 @@ def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *arg
Returns:
nn.Module: FusedRMSNorm module.
"""
try:
pass
except ImportError:
raise ImportError(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
)

LazyInitContext.materialize(module)

Expand All @@ -284,11 +278,18 @@ def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *arg
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = getattr(module, "elementwise_affine", True)

rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
try:
rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
except ImportError:
warnings.warn(
"Module replacement failed.\
Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
)
return module

rmsnorm.weight = module.weight

Expand Down

0 comments on commit b4d2377

Please sign in to comment.