From c6fa522f8b561bc919c1704c0dba6b87346aed17 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 25 Dec 2024 06:11:27 +0000 Subject: [PATCH] related-change with deepspeed#5445 --- intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py | 7 +++++++ intel_extension_for_pytorch/nn/utils/_weight_prepack.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py b/intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py index c945e48df..100a843e9 100644 --- a/intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py +++ b/intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py @@ -46,6 +46,13 @@ def IPEX_WEIGHT_PREPACK_MODULE_CPU(): deepspeed_modules_mapping.update( {LmHeadLinearAllreduce: _IPEXLmHeadLinearAllreduce} ) + if len(deepspeed_modules) > 3: + for module in deepspeed_modules[3:]: + if module not in deepspeed_modules_mapping: + if issubclass(module, LinearAllreduce): + deepspeed_modules_mapping[module] = _IPEXLinearAllreduce + elif issubclass(module, LinearLayer): + deepspeed_modules_mapping[module] = _IPEXLinear torch_modules.update(deepspeed_modules_mapping) return torch_modules diff --git a/intel_extension_for_pytorch/nn/utils/_weight_prepack.py b/intel_extension_for_pytorch/nn/utils/_weight_prepack.py index f6655050e..4cca48344 100644 --- a/intel_extension_for_pytorch/nn/utils/_weight_prepack.py +++ b/intel_extension_for_pytorch/nn/utils/_weight_prepack.py @@ -101,6 +101,7 @@ def may_import_deepspeed_modules(): try: # import deepspeed in a global space will raise circular import error # intel-extension-for-deepspeed imports both IPEX and deepspeed + import deepspeed.module_inject.layers as dslayers from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer ds_layers = [LinearAllreduce, LinearLayer] @@ -110,6 +111,8 @@ def may_import_deepspeed_modules(): from deepspeed.module_inject.layers import LmHeadLinearAllreduce ds_layers.append(LmHeadLinearAllreduce) + ds_layers += [cls for cls in dslayers.LinearAllreduce.__subclasses__()] + ds_layers += [cls for cls in dslayers.LinearLayer.__subclasses__()] return ds_layers except ImportError: return ds_layers