From c3b5caff0e637a4ba78eacd44bb17aa1a41f5c97 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 3 Sep 2024 15:45:17 +0800 Subject: [PATCH] [fp8] optimize all-gather (#6043) * [fp8] optimize all-gather * [fp8] fix all gather fp8 ring * [fp8] enable compile * [fp8] fix all gather fp8 ring --- colossalai/quantization/fp8.py | 106 +++++++++++++++++++++++++++++++-- 1 file changed, 101 insertions(+), 5 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index c022fab158c8..6a0bd14d1071 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -8,6 +8,7 @@ from torch.distributed import ReduceOp SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0") +SCALE_BYTES = 4 class Handle: @@ -22,7 +23,9 @@ def wait(self): self.remain_ops() -def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]: +def cast_to_fp8( + inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None +) -> Tuple[torch.Tensor, torch.Tensor]: r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. Args: @@ -55,12 +58,15 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - scale = fp8_max / per_tensor_max scale_inv = 1.0 / scale - ret = (scale * inp.float()).to(fp8_type) + if out is not None: + ret = torch.mul(scale, inp.float(), out=out) + else: + ret = (scale * inp.float()).to(fp8_type) return ret, torch.unsqueeze(scale_inv, dim=0) def cast_from_fp8( - inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False + inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False, out=None ) -> torch.Tensor: r""" Args: @@ -74,9 +80,15 @@ def cast_from_fp8( raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.") if per_channel_scale: - ret = scale_inv[:, None] * inp.float() + if out is not None: + return torch.mul(scale_inv[:, None], inp.float(), out=out) + else: + ret = scale_inv[:, None] * inp.float() else: - ret = scale_inv * inp.float() + if out is not None: + return torch.mul(scale_inv, inp.float(), out=out) + else: + ret = scale_inv * inp.float() return ret.to(ret_type) @@ -664,6 +676,90 @@ def cast_op(): cast_op() +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: + world_size = dist.get_world_size(group) + shape = input_.shape + input_type = input_.dtype + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device) + combined_buffers = list(combined_buffer.chunk(world_size, dim=0)) + cur_buffer = combined_buffers[dist.get_rank(group)] + ret = cur_buffer[SCALE_BYTES:].view(fp8_type) + ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret) + cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale + # cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) + dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op) + for out, buf in zip(output_list, combined_buffers): + scale = buf[:SCALE_BYTES].clone().view(scale.dtype) + output = buf[SCALE_BYTES:].view(fp8_type) + cast_from_fp8(output.view(shape), scale, input_type, out=out) + # output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type) + # scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float) + # output = output.float() * scales + # for i, out in enumerate(output_list): + # out.copy_(output[i].view(shape)) + + +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + + send_rank = (rank + 1) % world_size + recv_rank = (rank - 1) % world_size + + shape = input_.shape + input_type = input_.dtype + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device) + combined_buffers = list(combined_buffer.chunk(world_size, dim=0)) + cur_buffer = combined_buffers[dist.get_rank(group)] + ret = cur_buffer[SCALE_BYTES:].view(fp8_type) + ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret) + # cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) + cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale + + def send_recv(idx): + send_idx = (rank - idx) % world_size + recv_idx = (rank - idx - 1) % world_size + ops = dist.batch_isend_irecv( + [ + dist.P2POp(dist.isend, combined_buffers[send_idx], send_rank, group=group), + dist.P2POp(dist.irecv, combined_buffers[recv_idx], recv_rank, group=group), + ] + ) + return ops + + def cast(idx): + cast_idx = (rank - idx - 1) % world_size + scale = combined_buffers[cast_idx][:SCALE_BYTES].clone().view(torch.float) + output = combined_buffers[cast_idx][SCALE_BYTES:].view(fp8_type) + cast_from_fp8(output.view(shape), scale, input_type, out=output_list[cast_idx]) + + # warmup + ops = send_recv(0) + output_list[rank].copy_(input_) + for op in ops: + op.wait() + ops = [] + + # 1p-1c + for i in range(1, world_size - 1): + new_ops = send_recv(i) + for op in ops: + op.wait() + cast(i - 1) + ops = new_ops + + # cooldown + for op in ops: + op.wait() + cast(world_size - 2) + + class _LinearFp8(torch.autograd.Function): @staticmethod def forward(