@@ -80,7 +80,6 @@ def __init__(self, model, config):
80
80
self .mp_group = config .tensor_parallel .tp_group
81
81
self .mpu = config .tensor_parallel .mpu
82
82
83
- #self._validate_args(self.mpu, config.replace_with_kernel_inject)
84
83
self .quantize_merge_count = 1
85
84
self .quantization_scales = None
86
85
@@ -300,29 +299,6 @@ def _init_quantization_setting(self, quantization_setting):
300
299
f"mlp_extra_grouping = { self .mlp_extra_grouping } , "
301
300
f"quantize_groups = { self .quantize_groups } " , [0 ])
302
301
303
- # TODO: remove this function and add this functionality to pydantic config checking
304
- def _validate_args (self , mpu , replace_with_kernel_inject ):
305
- # TODO: to support SD pipeline we need to avoid this check for now
306
- if replace_with_kernel_inject and not isinstance (self .module , Module ):
307
- raise ValueError (f"model must be a torch.nn.Module, got { type (self .module )} " )
308
- if not isinstance (self ._config .tensor_parallel .tp_size , int ) or self ._config .tensor_parallel .tp_size < 1 :
309
- raise ValueError (f"mp_size must be an int >= 1, got { self ._config .tensor_parallel .tp_size } " )
310
-
311
- if mpu :
312
- methods = ["get_model_parallel_group" , "get_data_parallel_group" ]
313
- for method in methods :
314
- if not hasattr (mpu , method ):
315
- raise ValueError (f"mpu is missing { method } " )
316
- if self ._config .checkpoint is not None and not isinstance (self ._config .checkpoint , (str , dict )):
317
- raise ValueError (f"checkpoint must be None, str or dict, got { type (self ._config .checkpoint )} " )
318
-
319
- supported_dtypes = [None , torch .half , torch .int8 , torch .float , torch .bfloat16 ]
320
- if self ._config .dtype not in supported_dtypes :
321
- raise ValueError (f"{ self ._config .dtype } not supported, valid dtype: { supported_dtypes } " )
322
-
323
- if self .injection_dict is not None and not isinstance (self .injection_dict , dict ):
324
- raise ValueError (f"injection_dict must be None or a dict, got: { self .injection_dict } " )
325
-
326
302
def load_model_with_checkpoint (self , r_module ):
327
303
self .mp_replace = ReplaceWithTensorSlicing (
328
304
mp_group = self .mp_group , mp_size = self ._config .tensor_parallel .tp_size ) #, out_dim=0, in_dim=1)
0 commit comments