Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft]Add fp8 cast and transpose #26

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions saxml/tools/offline_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def parse_known_args(argv):
default='gptj',
choices=[
'gptj',
'gpt5b',
'gemma2b',
'gemma7b',
'llama2-70b-weight-linear-only-int8',
Expand All @@ -74,6 +75,12 @@ def parse_known_args(argv):
type=lambda x: bool(str(x).lower() == 'true'),
help='Transpose embedding to reduce latency.',
)
parser.add_argument(
'--transpose_cast_fp8',
default=False,
type=lambda x: bool(str(x).lower() == 'true'),
help='Transpose and cast fp8 weights.',
)
parser.add_argument(
'--quantize_embedding',
default=False,
Expand Down
6 changes: 6 additions & 0 deletions saxml/tools/quant_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def process(self, action: quantization_actions.OptAction):
if action.transpose_embedding:
target_var = jnp.transpose(target_var)

if action.transpose_cast_fp8:
print(action.target_name)
target_var = jnp.transpose(target_var)

if action.quantize_axis:
quantize_axis = action.quantize_axis
quantize_factor = action.quantize_factor
Expand All @@ -124,6 +128,8 @@ def process(self, action: quantization_actions.OptAction):
p_value = action.optimization_p_value
per_channel = True

print('here in quantize')
print(action.target_name)
if self._symmetric:
target_var, scale = quantization_configs.quantize_tensor(
target_var,
Expand Down
18 changes: 18 additions & 0 deletions saxml/tools/quantization_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class OptAction:
quantize_axis: Optional[list[int]] = None
quantize_factor: float = 1.0
transpose_embedding: bool = False
transpose_cast_fp8: bool = False
number_bit: int = 8
pack_dim: int = 0
use_optimization: bool = False
Expand Down Expand Up @@ -170,6 +171,7 @@ def _get_sub_layer(var_name: str) -> int:
axis = None
var_dtype = 'bfloat16'
transpose_embedding = False
transpose_cast_fp8 = False
quantize_factor = 1.0
pack_dim = 0
curr_config = config.get_quantize_axis_and_factor(source_name)
Expand Down Expand Up @@ -249,6 +251,21 @@ def _get_sub_layer(var_name: str) -> int:
var_dtype = 'int8'
axis = [1] # we are quantizing embedding along the outer axis here
quantize_factor = config.factor # pytype: disable=attribute-error

# fp8 weights cast and transpose
# print(source_name)
if source_name in ('mdl_vars.params.lm.transformer.x_layers_0.ff_layer.ffn_layer1.linear.w',
'mdl_vars.params.lm.transformer.x_layers_0.ff_layer.ffn_layer2.linear.w',
'mdl_vars.params.lm.transformer.x_layers_1.ff_layer.ffn_layer2.linear.w'
'mdl_vars.params.lm.transformer.x_layers_1.ff_layer.ffn_layer2.linear.w'
'mdl_vars.params.lm.transformer.x_layers_2.ff_layer.ffn_layer2.linear.w'
'mdl_vars.params.lm.transformer.x_layers_2.ff_layer.ffn_layer2.linear.w'
'mdl_vars.params.lm.transformer.x_layers_3.ff_layer.ffn_layer2.linear.w'
'mdl_vars.params.lm.transformer.x_layers_3.ff_layer.ffn_layer2.linear.w'):
if use_fp and number_bit == 8:
axis = [-1]
transpose_cast_fp8 = True
target_name = source_name

# Setting quantization configs back to non-quantize when the source name
# matches the skip pattern.
Expand Down Expand Up @@ -278,6 +295,7 @@ def _get_sub_layer(var_name: str) -> int:
quantize_axis=axis,
quantize_factor=quantize_factor,
transpose_embedding=transpose_embedding,
transpose_cast_fp8=transpose_cast_fp8,
number_bit=layer_wise_num_bits,
pack_dim=pack_dim,
use_optimization=use_optimization,
Expand Down
9 changes: 9 additions & 0 deletions saxml/tools/quantization_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,16 @@ class QuantizationConfigsGPTJ(QuantizationConfigs):
'self_attention.combined_qkv.w': ([1], factor, 1, -1),
'self_attention.post.w': ([1, 2], factor, 0, -1),
}
class QuantizationConfigsGPT5B(QuantizationConfigs):
"""Quantization config for GPT5B model."""

factor = 1.0
configs = {
'ff_layer.ffn_layer1.linear.w': ([0, 1], factor, 0, -1),
'ff_layer.ffn_layer2.linear.w': ([0, 1], factor, 0, -1),
'self_attention.combined_qkv.w': ([0, 1, 2, 3], factor, 1, -1),
'self_attention.post.w': ([0, 1, 2], factor, 0, -1),
}

class QuantizationConfigsGPTJStacked(QuantizationConfigs):
"""Quantization config for GPTJ model."""
Expand Down
1 change: 1 addition & 0 deletions saxml/tools/quantization_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

NAME_TO_CONFIG = {
'gptj': quantization_configs.QuantizationConfigsGPTJ(),
'gpt5b': quantization_configs.QuantizationConfigsGPT5B(),
'gemma2b': quantization_configs.QuantizationConfigsGemma2B(),
'gemma7b': quantization_configs.QuantizationConfigsGemma7B(),
'llama2-70b-weight-linear-only-int8': (
Expand Down