From 12b44012d95878004c4026bab637ddb563cd6759 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 19 Aug 2024 09:02:16 +0000 Subject: [PATCH] fix --- .../booster/plugin/hybrid_parallel_plugin.py | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 9e5d7b0d77f7..bd970878f1dd 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1278,30 +1278,31 @@ def configure( overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), use_fp8=self.use_fp8, ) - if zero_stage == 0: - is_zero = False - if self.precision in ["fp16", "bf16"]: - optimizer = HybridParallelAMPOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - precision=self.precision, - max_norm=self.max_norm, - pp_process_group=self.pp_group, - tp_process_group=self.tp_group, - **self.amp_config, - ) - else: - optimizer = HybridParallelNaiveOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - max_norm=self.max_norm, - pp_process_group=self.pp_group, - tp_process_group=self.tp_group, - ) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if zero_stage == 0: + is_zero = False + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + **self.amp_config, + ) + else: + optimizer = HybridParallelNaiveOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + ) else: is_zero = self.dp_size > 1 if self.dp_size == 1: