Skip to content

Commit

Permalink
MX4 group size configuration for pyper (#3516)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#597

Pull Request resolved: #3516

added pyper configuration for mx4 goup size.

Reviewed By: irobert0126, renganxu

Differential Revision: D67407064

fbshipit-source-id: a23765777879491836fcb9f1a00ba8f1e1b26b76
  • Loading branch information
qchip authored and facebook-github-bot committed Dec 19, 2024
1 parent eaa0961 commit ca4ea00
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def __init__(
self._loss_scale = loss_scale
self._is_fwd = is_fwd
self._row_dim: int = -1 if row_dim is None else row_dim
if self._comm_precision == SparseType.MX4:
self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim

def encode(
self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None
Expand Down Expand Up @@ -252,11 +254,12 @@ def quantized_dtype(self) -> torch.dtype:

def create_context(self) -> Optional[QuantizationContext]:
# fp8 rowwise is activated when row_dim > 0
if (
self._comm_precision == SparseType.FP8
or self._comm_precision == SparseType.MX4
):
if self._comm_precision == SparseType.FP8:
return QuantizationContext(self._row_dim)
if self._comm_precision == SparseType.MX4:
return QuantizationContext(
row_dim=self._row_dim, mx_group_size=self._row_dim
)
# int8 rowwise is default
return QuantizationContext()

Expand Down

0 comments on commit ca4ea00

Please sign in to comment.