diff --git a/saxml/tools/offline_quantize.py b/saxml/tools/offline_quantize.py index 43062570..5f9e22eb 100644 --- a/saxml/tools/offline_quantize.py +++ b/saxml/tools/offline_quantize.py @@ -56,6 +56,7 @@ def parse_known_args(argv): default='gptj', choices=[ 'gptj', + 'gpt5b', 'gemma2b', 'gemma7b', 'llama2-70b-weight-linear-only-int8', @@ -74,6 +75,12 @@ def parse_known_args(argv): type=lambda x: bool(str(x).lower() == 'true'), help='Transpose embedding to reduce latency.', ) + parser.add_argument( + '--transpose_cast_fp8', + default=False, + type=lambda x: bool(str(x).lower() == 'true'), + help='Transpose and cast fp8 weights.', + ) parser.add_argument( '--quantize_embedding', default=False, diff --git a/saxml/tools/quant_fn.py b/saxml/tools/quant_fn.py index 54d77d42..d4fc7cf1 100644 --- a/saxml/tools/quant_fn.py +++ b/saxml/tools/quant_fn.py @@ -108,6 +108,10 @@ def process(self, action: quantization_actions.OptAction): if action.transpose_embedding: target_var = jnp.transpose(target_var) + if action.transpose_cast_fp8: + print(action.target_name) + target_var = jnp.transpose(target_var) + if action.quantize_axis: quantize_axis = action.quantize_axis quantize_factor = action.quantize_factor @@ -124,6 +128,8 @@ def process(self, action: quantization_actions.OptAction): p_value = action.optimization_p_value per_channel = True + print('here in quantize') + print(action.target_name) if self._symmetric: target_var, scale = quantization_configs.quantize_tensor( target_var, diff --git a/saxml/tools/quantization_actions.py b/saxml/tools/quantization_actions.py index 6d36be03..1c7adf13 100644 --- a/saxml/tools/quantization_actions.py +++ b/saxml/tools/quantization_actions.py @@ -36,6 +36,7 @@ class OptAction: quantize_axis: Optional[list[int]] = None quantize_factor: float = 1.0 transpose_embedding: bool = False + transpose_cast_fp8: bool = False number_bit: int = 8 pack_dim: int = 0 use_optimization: bool = False @@ -170,6 +171,7 @@ def _get_sub_layer(var_name: str) -> int: axis = None var_dtype = 'bfloat16' transpose_embedding = False + transpose_cast_fp8 = False quantize_factor = 1.0 pack_dim = 0 curr_config = config.get_quantize_axis_and_factor(source_name) @@ -249,6 +251,21 @@ def _get_sub_layer(var_name: str) -> int: var_dtype = 'int8' axis = [1] # we are quantizing embedding along the outer axis here quantize_factor = config.factor # pytype: disable=attribute-error + + # fp8 weights cast and transpose + # print(source_name) + if source_name in ('mdl_vars.params.lm.transformer.x_layers_0.ff_layer.ffn_layer1.linear.w', + 'mdl_vars.params.lm.transformer.x_layers_0.ff_layer.ffn_layer2.linear.w', + 'mdl_vars.params.lm.transformer.x_layers_1.ff_layer.ffn_layer2.linear.w' + 'mdl_vars.params.lm.transformer.x_layers_1.ff_layer.ffn_layer2.linear.w' + 'mdl_vars.params.lm.transformer.x_layers_2.ff_layer.ffn_layer2.linear.w' + 'mdl_vars.params.lm.transformer.x_layers_2.ff_layer.ffn_layer2.linear.w' + 'mdl_vars.params.lm.transformer.x_layers_3.ff_layer.ffn_layer2.linear.w' + 'mdl_vars.params.lm.transformer.x_layers_3.ff_layer.ffn_layer2.linear.w'): + if use_fp and number_bit == 8: + axis = [-1] + transpose_cast_fp8 = True + target_name = source_name # Setting quantization configs back to non-quantize when the source name # matches the skip pattern. @@ -278,6 +295,7 @@ def _get_sub_layer(var_name: str) -> int: quantize_axis=axis, quantize_factor=quantize_factor, transpose_embedding=transpose_embedding, + transpose_cast_fp8=transpose_cast_fp8, number_bit=layer_wise_num_bits, pack_dim=pack_dim, use_optimization=use_optimization, diff --git a/saxml/tools/quantization_configs.py b/saxml/tools/quantization_configs.py index 55d28a99..06e34920 100644 --- a/saxml/tools/quantization_configs.py +++ b/saxml/tools/quantization_configs.py @@ -64,7 +64,16 @@ class QuantizationConfigsGPTJ(QuantizationConfigs): 'self_attention.combined_qkv.w': ([1], factor, 1, -1), 'self_attention.post.w': ([1, 2], factor, 0, -1), } +class QuantizationConfigsGPT5B(QuantizationConfigs): + """Quantization config for GPT5B model.""" + factor = 1.0 + configs = { + 'ff_layer.ffn_layer1.linear.w': ([0, 1], factor, 0, -1), + 'ff_layer.ffn_layer2.linear.w': ([0, 1], factor, 0, -1), + 'self_attention.combined_qkv.w': ([0, 1, 2, 3], factor, 1, -1), + 'self_attention.post.w': ([0, 1, 2], factor, 0, -1), + } class QuantizationConfigsGPTJStacked(QuantizationConfigs): """Quantization config for GPTJ model.""" diff --git a/saxml/tools/quantization_provider.py b/saxml/tools/quantization_provider.py index 73768459..79445e82 100644 --- a/saxml/tools/quantization_provider.py +++ b/saxml/tools/quantization_provider.py @@ -18,6 +18,7 @@ NAME_TO_CONFIG = { 'gptj': quantization_configs.QuantizationConfigsGPTJ(), + 'gpt5b': quantization_configs.QuantizationConfigsGPT5B(), 'gemma2b': quantization_configs.QuantizationConfigsGemma2B(), 'gemma7b': quantization_configs.QuantizationConfigsGemma7B(), 'llama2-70b-weight-linear-only-int8': (