Skip to content

Commit

Permalink
[fp8] optimize all-gather (#6043)
Browse files Browse the repository at this point in the history
* [fp8] optimize all-gather

* [fp8] fix all gather fp8 ring

* [fp8] enable compile

* [fp8] fix all gather fp8 ring
  • Loading branch information
ver217 authored Sep 3, 2024
1 parent c650a90 commit c3b5caf
Showing 1 changed file with 101 additions and 5 deletions.
106 changes: 101 additions & 5 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.distributed import ReduceOp

SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
SCALE_BYTES = 4


class Handle:
Expand All @@ -22,7 +23,9 @@ def wait(self):
self.remain_ops()


def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
def cast_to_fp8(
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
Expand Down Expand Up @@ -55,12 +58,15 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale

ret = (scale * inp.float()).to(fp8_type)
if out is not None:
ret = torch.mul(scale, inp.float(), out=out)
else:
ret = (scale * inp.float()).to(fp8_type)
return ret, torch.unsqueeze(scale_inv, dim=0)


def cast_from_fp8(
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False, out=None
) -> torch.Tensor:
r"""
Args:
Expand All @@ -74,9 +80,15 @@ def cast_from_fp8(
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")

if per_channel_scale:
ret = scale_inv[:, None] * inp.float()
if out is not None:
return torch.mul(scale_inv[:, None], inp.float(), out=out)
else:
ret = scale_inv[:, None] * inp.float()
else:
ret = scale_inv * inp.float()
if out is not None:
return torch.mul(scale_inv, inp.float(), out=out)
else:
ret = scale_inv * inp.float()
return ret.to(ret_type)


Expand Down Expand Up @@ -664,6 +676,90 @@ def cast_op():
cast_op()


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)
shape = input_.shape
input_type = input_.dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2

combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
cur_buffer = combined_buffers[dist.get_rank(group)]
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
for out, buf in zip(output_list, combined_buffers):
scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
output = buf[SCALE_BYTES:].view(fp8_type)
cast_from_fp8(output.view(shape), scale, input_type, out=out)
# output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type)
# scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float)
# output = output.float() * scales
# for i, out in enumerate(output_list):
# out.copy_(output[i].view(shape))


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)
rank = dist.get_rank(group)

send_rank = (rank + 1) % world_size
recv_rank = (rank - 1) % world_size

shape = input_.shape
input_type = input_.dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2

combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
cur_buffer = combined_buffers[dist.get_rank(group)]
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale

def send_recv(idx):
send_idx = (rank - idx) % world_size
recv_idx = (rank - idx - 1) % world_size
ops = dist.batch_isend_irecv(
[
dist.P2POp(dist.isend, combined_buffers[send_idx], send_rank, group=group),
dist.P2POp(dist.irecv, combined_buffers[recv_idx], recv_rank, group=group),
]
)
return ops

def cast(idx):
cast_idx = (rank - idx - 1) % world_size
scale = combined_buffers[cast_idx][:SCALE_BYTES].clone().view(torch.float)
output = combined_buffers[cast_idx][SCALE_BYTES:].view(fp8_type)
cast_from_fp8(output.view(shape), scale, input_type, out=output_list[cast_idx])

# warmup
ops = send_recv(0)
output_list[rank].copy_(input_)
for op in ops:
op.wait()
ops = []

# 1p-1c
for i in range(1, world_size - 1):
new_ops = send_recv(i)
for op in ops:
op.wait()
cast(i - 1)
ops = new_ops

# cooldown
for op in ops:
op.wait()
cast(world_size - 2)


class _LinearFp8(torch.autograd.Function):
@staticmethod
def forward(
Expand Down

0 comments on commit c3b5caf

Please sign in to comment.