Skip to content

Commit

Permalink
Merge branch 'main' into metric_logger
Browse files Browse the repository at this point in the history
  • Loading branch information
SurbhiJainUSC authored Feb 27, 2025
2 parents 21d7269 + 4f18851 commit 96b6f74
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 23 deletions.
38 changes: 27 additions & 11 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax import lax
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import shard_map
from jax.experimental.pallas.ops.gpu import attention as gpu_pallas_attention
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
import jax.numpy as jnp
Expand Down Expand Up @@ -237,17 +238,27 @@ def apply_attention(
):
return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode)
elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected":
if isinstance(key, KVTensor):
key = key.dequant()
if isinstance(value, KVTensor):
value = value.dequant()

if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
"""Decode not supported with flash attention.
Use `dot_product` instead."""
)
return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
if jax.devices()[0].platform == "tpu":
if isinstance(key, KVTensor):
key = key.dequant()
if isinstance(value, KVTensor):
value = value.dequant()

if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
"""Decode not supported with flash attention.
Use `dot_product` instead."""
)
return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
else:
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
# fallback to dot_product as pallas gpu flash attention doesn't support decode stage
return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode)
else:
key = jnp.repeat(key, self.num_query_heads // self.num_kv_heads, axis=2)
value = jnp.repeat(value, self.num_query_heads // self.num_kv_heads, axis=2)
out = gpu_pallas_attention.mha(query, key, value, decoder_segment_ids, sm_scale=1.0, causal=True)
return out, None, None
elif self.attention_kernel == "cudnn_flash_te":
if isinstance(key, KVTensor):
key = key.dequant()
Expand Down Expand Up @@ -562,6 +573,9 @@ def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: s

einsum = jnp.einsum
if self.kv_quant:
# manually cast to bf16 to avoid the fp32 XLA ops for speedup
if isinstance(value, KVTensor) and self.kv_quant.dtype == jnp.float8_e4m3fn:
value.qvalue = value.qvalue.astype(jnp.bfloat16)
einsum = self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(value)
if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0, 1, 2, 3):
out = einsum("bkgts,bskd->btkgd", attn_weights, value)
Expand Down Expand Up @@ -899,6 +913,8 @@ def get_cached_values(self, cache_vars, target_dtype, cache_axis_order) -> jax.A
scale_value /= quantizations.MAX_INT8
elif dtype == jnp.int4:
scale_value /= quantizations.MAX_INT4
elif dtype == jnp.float8_e4m3fn:
scale_value /= quantizations.E4M3_MAX

cache_value = KVTensor(qvalue=cache_value, scale=[scale_value], scale_t=None, dequant_dtype=target_dtype, bias=[])
cache_value_in_logical_shape = jax.tree.map(lambda x: self.reverse_transepose(x, cache_axis_order), cache_value)
Expand Down
58 changes: 46 additions & 12 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

MAX_INT8 = 127.5
MAX_INT4 = 7.5
E4M3_MAX = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32)

Array = common_types.Array
Config = common_types.Config
Expand Down Expand Up @@ -295,6 +296,10 @@ def _get_int8_quant_config(config):
)


def _get_aqt_fp8_quant_config(config):
return aqt_config.config_fwd_fp8()


def _dot_general_make(quant_cfg):
lhs_bits = quant_cfg[_A_BITS]
lhs_scale = quant_cfg[_A_SCALE]
Expand Down Expand Up @@ -344,6 +349,8 @@ def _get_quant_config(config):
return _get_mixed_precision_quant_config(mixed_precision_config)
if config.quantization == "fp8":
return "fp8"
if config.quantization == "aqt_fp8":
return _get_aqt_fp8_quant_config(config)
raise ValueError(f"Invalid value configured for quantization {config.quantization}.")


Expand Down Expand Up @@ -440,6 +447,8 @@ def _get_dtype(self, dtype_cfg: str):
return jnp.int4
if dtype_cfg == "int8":
return jnp.int8
if dtype_cfg == "fp8":
return jnp.float8_e4m3fn
raise ValueError(f"Invalid kv_quant_dtype: {dtype_cfg}")

def _get_max_axis(self, axis_names: AxisNames):
Expand All @@ -460,31 +469,47 @@ def quantize(self, kv: Array, axis_names: AxisNames):
if self.dtype == jnp.int4:
value = jnp.int4(jnp.rint(kv * (MAX_INT4 / scale)))
return value, scale
if self.dtype == jnp.float8_e4m3fn:
value = jnp.float8_e4m3fn(kv * (E4M3_MAX / scale))
return value, scale
raise ValueError(f"Invalid KV quant dtype:{self.dtype}.")

def einsum_fn_with_rhs_qtensor(
self,
kv: Array | aqt_tensor.QTensor,
rhs_dequant_mode=None,
rhs_calibration_mode=None,
lhs_dequant_mode=None,
lhs_calibration_mode=None,
):
# Assumes kv is already quantized.
einsum = jnp.einsum
if isinstance(kv, aqt_tensor.QTensor):
num_bits = 4 if kv.qvalue.dtype == jnp.int4 else 8
kv_cfg = aqt_config.dot_general_make(
lhs_bits=None,
rhs_bits=num_bits,
bwd_bits=None,
use_fwd_quant=False,
)
if kv.qvalue.dtype != jnp.float8_e4m3fn:
num_bits = 4 if kv.qvalue.dtype == jnp.int4 else 8
kv_cfg = aqt_config.dot_general_make(
lhs_bits=None,
rhs_bits=num_bits,
bwd_bits=None,
use_fwd_quant=False,
)
else:
kv_cfg = aqt_config.config_fwd_fp8()

if rhs_dequant_mode:
aqt_config.set_fwd_dequant_mode(kv_cfg, rhs_dequant_mode=rhs_dequant_mode)
if rhs_calibration_mode:
aqt_config.set_fwd_calibration_mode(
kv_cfg,
rhs_calibration_mode=rhs_calibration_mode,
)
if lhs_dequant_mode:
aqt_config.set_fwd_dequant_mode(kv_cfg, lhs_dequant_mode=lhs_dequant_mode)
if lhs_calibration_mode:
aqt_config.set_fwd_calibration_mode(
kv_cfg,
lhs_calibration_mode=lhs_calibration_mode,
)
einsum = aqt_flax.AqtEinsum(
rhs_quant_mode=aqt_flax.QuantMode.TRAIN,
lhs_freeze_mode=aqt_flax.FreezerMode.NONE,
Expand All @@ -494,8 +519,17 @@ def einsum_fn_with_rhs_qtensor(
return einsum

def einsum_fn_with_rhs_qtensor_and_dequant(self, value):
return self.einsum_fn_with_rhs_qtensor(
value,
rhs_dequant_mode=aqt_config.DequantMode.OTHER_INPUT,
rhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS,
)
if self.dtype == jnp.float8_e4m3fn:
return self.einsum_fn_with_rhs_qtensor(
value,
lhs_dequant_mode=aqt_config.DequantMode.THIS_INPUT,
lhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS,
rhs_dequant_mode=aqt_config.DequantMode.OTHER_INPUT,
rhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS,
)
else:
return self.einsum_fn_with_rhs_qtensor(
value,
rhs_dequant_mode=aqt_config.DequantMode.OTHER_INPUT,
rhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS,
)

0 comments on commit 96b6f74

Please sign in to comment.