Skip to content

Commit

Permalink
pre_scale support in add_to_fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
wkcn committed Dec 7, 2023
1 parent bcd0f74 commit d248505
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions msamp/operators/arithmetic/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
class Arithmetic:
"""Arithmetic operator for FP8 tensor."""
@staticmethod
def add_to_fp8(fp8_tensor, meta, other, pre_scale=1.0):
def add_to_fp8(fp8_tensor, meta, other):
"""Add high presicon tensor to fp8_tensor in-place.
Args:
fp8_tensor (torch.Tensor): fp8 tensor to add to.
meta (ScalingTensorMeta): meta data of fp8_tensor.
other (torch.Tensor): high precision tensor to add.
pre_scale (float, optional): Pre-scale factor, defaults to 1.0.
"""
if not (fp8_tensor.is_cuda and fp8_tensor.is_contiguous):
raise ValueError('The fp8 tensor is not in cuda memory or contiguous.')
Expand All @@ -33,4 +32,6 @@ def add_to_fp8(fp8_tensor, meta, other, pre_scale=1.0):

is_e4m3 = meta.qtype == Dtypes.kfloat8_e4m3

msamp_arithmetic.add_to_fp8(fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], pre_scale, other, is_e4m3)
msamp_arithmetic.add_to_fp8(
fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], meta.pre_scale, other, is_e4m3
)

0 comments on commit d248505

Please sign in to comment.