Skip to content

Commit 2d362ac

Browse files
committed
fix merge
2 parents 2e4cbe3 + dcc44aa commit 2d362ac

File tree

9 files changed

+118
-82
lines changed

9 files changed

+118
-82
lines changed

colossalai/booster/booster.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from contextlib import contextmanager
32
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
43

@@ -8,6 +7,8 @@
87
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
98
from torch.utils.data import DataLoader
109

10+
from colossalai.logging import get_dist_logger
11+
1112
SUPPORT_PEFT = False
1213
try:
1314
import peft
@@ -81,20 +82,26 @@ def __init__(
8182
plugin, Plugin
8283
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
8384
self.plugin = plugin
85+
self.logger = get_dist_logger()
8486

8587
# set accelerator
8688
if self.plugin and self.plugin.control_device():
8789
self.accelerator = None
8890
if device is not None:
89-
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
91+
self.logger.warning(
92+
"The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0]
93+
)
9094
else:
9195
device = device or "cuda"
9296
self.accelerator = Accelerator(device)
9397

9498
# set precision
9599
if self.plugin and self.plugin.control_precision():
96100
if mixed_precision is not None:
97-
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
101+
self.logger.warning(
102+
"The plugin will control the precision," "so the mixed_precision argument will be ignored.",
103+
ranks=[0],
104+
)
98105
self.mixed_precision = None
99106
elif mixed_precision is None:
100107
self.mixed_precision = None
@@ -267,8 +274,9 @@ def enable_lora(
267274
), "Please provide pretrained directory path if not passing in lora configuration."
268275
if quantize is True:
269276
if bnb_quantization_config is not None:
270-
warnings.warn(
271-
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
277+
self.logger.warning(
278+
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.",
279+
ranks=[0],
272280
)
273281
else:
274282
bnb_quantization_config = BnbQuantizationConfig(

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import gc
2-
import logging
32
import os
43
import random
54
from pathlib import Path
@@ -27,6 +26,7 @@
2726
)
2827
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
2928
from colossalai.interface import ModelWrapper, OptimizerWrapper
29+
from colossalai.logging import get_dist_logger
3030
from colossalai.shardformer import ShardConfig, ShardFormer
3131
from colossalai.zero import GeminiDDP, GeminiOptimizer
3232
from colossalai.zero.gemini.memory_tracer import MemStats
@@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
6363
def __init__(self) -> None:
6464
super().__init__()
6565
self.coordinator = DistCoordinator()
66+
self.logger = get_dist_logger()
6667

6768
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
6869
"""
@@ -118,7 +119,7 @@ def save_sharded_model(
118119
"""
119120
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
120121
if os.path.isfile(checkpoint_path):
121-
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
122+
self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0])
122123
return
123124

124125
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
@@ -143,10 +144,11 @@ def save_sharded_model(
143144
index_file.append_meta_data("total_size", total_size)
144145
index_file.write_index_file(save_index_file)
145146
save_config_file(model.unwrap(), checkpoint_path)
146-
logging.info(
147+
self.logger.info(
147148
f"The model is split into checkpoint shards. "
148149
f"You can find where each parameters has been saved in the "
149-
f"index located at {save_index_file}."
150+
f"index located at {save_index_file}.",
151+
ranks=[0],
150152
)
151153

152154
def load_sharded_model(
@@ -168,7 +170,7 @@ def save_sharded_optimizer(
168170
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
169171

170172
if os.path.isfile(checkpoint):
171-
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
173+
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
172174
return
173175

174176
Path(checkpoint).mkdir(parents=True, exist_ok=True)
@@ -201,10 +203,11 @@ def save_sharded_optimizer(
201203
if self.coordinator.is_master():
202204
index_file.append_meta_data("total_size", total_size)
203205
index_file.write_index_file(save_index_file)
204-
logging.info(
206+
self.logger.info(
205207
f"The optimizer is going to be split to checkpoint shards. "
206208
f"You can find where each parameters has been saved in the "
207-
f"index located at {save_index_file}."
209+
f"index located at {save_index_file}.",
210+
ranks=[0],
208211
)
209212

210213
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
@@ -214,7 +217,7 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi
214217
"""
215218
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
216219
if not os.path.isfile(checkpoint_index_file):
217-
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
220+
self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0])
218221

219222
assert isinstance(optimizer, GeminiOptimizer)
220223

@@ -371,9 +374,12 @@ def __init__(
371374
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
372375
if get_accelerator().name == "npu":
373376
assert placement_policy == "static", "NPU only supports static placement policy"
377+
378+
self.logger = get_dist_logger()
374379
if enable_async_reduce and not pin_memory:
375-
logging.warning(
376-
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set."
380+
self.logger.warning(
381+
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.",
382+
ranks=[0],
377383
)
378384
pin_memory = True
379385
self.gemini_config = dict(

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import ctypes
22
import random
3-
import warnings
43
from collections import defaultdict
54
from contextlib import contextmanager, nullcontext
65
from copy import deepcopy
@@ -27,6 +26,7 @@
2726
from colossalai.cluster import ProcessGroupMesh
2827
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
2928
from colossalai.interface.optimizer import DistributedOptim
29+
from colossalai.logging import get_dist_logger
3030
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
3131
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
3232
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -1036,6 +1036,7 @@ def __init__(
10361036
inner_ring_size: int = None,
10371037
) -> None:
10381038
super().__init__()
1039+
self.logger = get_dist_logger()
10391040

10401041
assert (
10411042
dist.get_world_size() % (tp_size * pp_size) == 0
@@ -1053,8 +1054,9 @@ def __init__(
10531054
tp_size > 1
10541055
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
10551056
if sp_size != 1:
1056-
warnings.warn(
1057-
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
1057+
self.logger.warning(
1058+
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.",
1059+
ranks=[0],
10581060
)
10591061
self.sp_size = 1
10601062
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@@ -1143,7 +1145,12 @@ def __init__(
11431145
else:
11441146
raise NotImplementedError()
11451147
if sequence_parallelism_mode == "ring_attn":
1146-
assert parallel_output, "Ring Attention doesn't support gathering output yet."
1148+
if not parallel_output:
1149+
self.logger.warning(
1150+
"parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.",
1151+
ranks=[0],
1152+
)
1153+
parallel_output = True
11471154

11481155
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
11491156
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
@@ -1249,7 +1256,10 @@ def configure(
12491256
optimizer = cast_to_distributed(optimizer)
12501257

12511258
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
1252-
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
1259+
self.logger.warning(
1260+
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
1261+
ranks=[0],
1262+
)
12531263
zero_config["partition_grad"] = False
12541264
zero_stage = 0
12551265

@@ -1306,9 +1316,10 @@ def configure(
13061316
else:
13071317
is_zero = self.dp_size > 1
13081318
if self.dp_size == 1:
1309-
warnings.warn(
1319+
self.logger.warning(
13101320
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
1311-
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
1321+
"If you do not intend to use cpu_offload, please consider set zero_stage=0.",
1322+
ranks=[0],
13121323
)
13131324

13141325
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
@@ -1351,7 +1362,7 @@ def execute_pipeline(
13511362
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
13521363

13531364
if return_outputs:
1354-
warnings.warn("return_outputs may lead to significant extra memory consumption.")
1365+
self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0])
13551366

13561367
# Create a context for gradient synchronization based on the optimizer type.
13571368
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
@@ -1365,10 +1376,8 @@ def execute_pipeline(
13651376
)
13661377

13671378
# run with gradients accumulation
1368-
if (
1369-
model.require_grad_sync == False
1370-
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
1371-
or not torch.is_grad_enabled()
1379+
if model.require_grad_sync == False or (
1380+
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
13721381
):
13731382
return outputs
13741383

@@ -1468,7 +1477,7 @@ def enable_lora(
14681477
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
14691478
assert self.pp_size == 1 and self.tp_size == 1
14701479
self.lora_enabled = True
1471-
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
1480+
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
14721481

14731482
if bnb_quantization_config is not None:
14741483
model = quantize_model(model, bnb_quantization_config)

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import enum
2-
import logging
32
import os
4-
import warnings
53
from contextlib import nullcontext
64
from functools import partial
75
from pathlib import Path
@@ -33,6 +31,7 @@
3331
)
3432
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
3533
from colossalai.interface.optimizer import DistributedOptim
34+
from colossalai.logging import get_dist_logger
3635
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
3736
from colossalai.quantization import BnbQuantizationConfig, quantize_model
3837
from colossalai.quantization.fp8_hook import FP8Hook
@@ -64,12 +63,7 @@ class OptimizerParamCheckState(enum.Enum):
6463

6564
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
6665
def __init__(
67-
self,
68-
module: nn.Module,
69-
precision: str,
70-
overlap_allgather: bool = False,
71-
cast_inputs: bool = True,
72-
use_fp8: bool = False,
66+
self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False
7367
) -> None:
7468
super().__init__(module)
7569
self.dtype = None
@@ -87,8 +81,6 @@ def __init__(
8781
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
8882
self.overlap_allgather = overlap_allgather
8983
self.op_hooks = []
90-
if self.dtype is not None and cast_inputs:
91-
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
9284
if overlap_allgather:
9385
self.op_hooks.append(ZeroOpHook())
9486
if use_fp8:
@@ -153,7 +145,7 @@ def save_sharded_optimizer(
153145
"""
154146
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
155147
if os.path.isfile(checkpoint):
156-
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
148+
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
157149
return
158150

159151
Path(checkpoint).mkdir(parents=True, exist_ok=True)
@@ -190,10 +182,11 @@ def save_sharded_optimizer(
190182
index_file.append_meta_data("total_size", total_size)
191183
if self.coordinator.is_master():
192184
index_file.write_index_file(save_index_file)
193-
logging.info(
185+
self.logger.info(
194186
f"The optimizer is going to be split to checkpoint shards. "
195187
f"You can find where each parameters has been saved in the "
196-
f"index located at {save_index_file}."
188+
f"index located at {save_index_file}.",
189+
ranks=[0],
197190
)
198191

199192
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
@@ -280,7 +273,7 @@ def save_sharded_model(
280273

281274
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
282275
if os.path.isfile(checkpoint):
283-
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
276+
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
284277
return
285278
from peft import PeftModel
286279

@@ -349,7 +342,6 @@ def __init__(
349342
cpu_offload: bool = False,
350343
master_weights: bool = True,
351344
verbose: bool = False,
352-
cast_inputs: bool = True,
353345
fp8_communication: bool = False,
354346
use_fp8: bool = False,
355347
) -> None:
@@ -379,9 +371,8 @@ def __init__(
379371
)
380372
self.lora_enabled = False
381373
self.verbose = verbose
374+
self.logger = get_dist_logger()
382375
self.use_fp8 = use_fp8
383-
self.cast_inputs = cast_inputs
384-
385376
# set class name with stage, for better error message
386377
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
387378

@@ -417,7 +408,7 @@ def enable_lora(
417408

418409
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
419410
self.lora_enabled = True
420-
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
411+
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
421412

422413
if bnb_quantization_config is not None:
423414
model = quantize_model(model, bnb_quantization_config)
@@ -466,8 +457,9 @@ def add_lora_params_to_optimizer(self, model, optimizer):
466457
origin_param = name2param[origin_key]
467458
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
468459
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
469-
warnings.warn(
470-
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
460+
self.logger.warning(
461+
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.",
462+
ranks=[0],
471463
)
472464
elif (
473465
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
@@ -498,7 +490,6 @@ def configure(
498490
model,
499491
self.precision,
500492
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
501-
cast_inputs=self.cast_inputs,
502493
use_fp8=self.use_fp8,
503494
)
504495

@@ -511,7 +502,10 @@ def configure(
511502
optimizer = cast_to_distributed(optimizer)
512503

513504
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
514-
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
505+
self.logger.warning(
506+
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
507+
ranks=[0],
508+
)
515509
zero_optim_kwargs["partition_grad"] = False
516510
zero_stage = 0
517511

0 commit comments

Comments
 (0)