Skip to content

Commit c7f3032

Browse files
nelyahutjruwaseloadams
authored
inference: remove unused _validate_args function (#5505)
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
1 parent f2cc809 commit c7f3032

File tree

1 file changed

+0
-24
lines changed

1 file changed

+0
-24
lines changed

deepspeed/inference/engine.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def __init__(self, model, config):
8080
self.mp_group = config.tensor_parallel.tp_group
8181
self.mpu = config.tensor_parallel.mpu
8282

83-
#self._validate_args(self.mpu, config.replace_with_kernel_inject)
8483
self.quantize_merge_count = 1
8584
self.quantization_scales = None
8685

@@ -300,29 +299,6 @@ def _init_quantization_setting(self, quantization_setting):
300299
f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
301300
f"quantize_groups = {self.quantize_groups}", [0])
302301

303-
# TODO: remove this function and add this functionality to pydantic config checking
304-
def _validate_args(self, mpu, replace_with_kernel_inject):
305-
# TODO: to support SD pipeline we need to avoid this check for now
306-
if replace_with_kernel_inject and not isinstance(self.module, Module):
307-
raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
308-
if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1:
309-
raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}")
310-
311-
if mpu:
312-
methods = ["get_model_parallel_group", "get_data_parallel_group"]
313-
for method in methods:
314-
if not hasattr(mpu, method):
315-
raise ValueError(f"mpu is missing {method}")
316-
if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
317-
raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")
318-
319-
supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16]
320-
if self._config.dtype not in supported_dtypes:
321-
raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
322-
323-
if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
324-
raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}")
325-
326302
def load_model_with_checkpoint(self, r_module):
327303
self.mp_replace = ReplaceWithTensorSlicing(
328304
mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)

0 commit comments

Comments
 (0)