From 7b107cce04ae807d268e682fc5a724bbe6f420d8 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 28 May 2024 11:55:54 +0800 Subject: [PATCH 01/11] add all2all ops --- bmtrain/distributed/ops.py | 69 +++++++++++++--- bmtrain/nccl/__init__.py | 84 ++++++++++++++++++++ csrc/bind.cpp | 1 + csrc/include/nccl.hpp | 33 ++++++++ tests/test_all.py | 1 + tests/test_nccl_all2all.py | 159 +++++++++++++++++++++++++++++++++++++ 6 files changed, 334 insertions(+), 13 deletions(-) create mode 100644 tests/test_nccl_all2all.py diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index d1b489e2..a52529f8 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -6,6 +6,7 @@ from ..nccl import reduceScatter as ncclReduceScatter from ..nccl import send as ncclSend from ..nccl import recv as ncclRecv +from ..nccl import all2all as ncclAllToAll from ..nccl import commCount,commRank,NCCLCommunicator DTYPE_LIST = [ torch.float64, @@ -44,6 +45,13 @@ def recv_meta(prev_rank, comm): shape = meta_data[2:n_dims+2].tolist() return dtype,shape +def to_contiguous(x): + if not x.is_contiguous(): + x = x.contiguous() + if x.storage_offset() != 0 or x.storage().size() != x.numel(): + x = x.clone() + return x + class OpBroadcast(torch.autograd.Function): @staticmethod @@ -72,10 +80,7 @@ def forward(ctx, input : torch.Tensor, comm = None): if comm is None: comm = config["comm"] world_size = commCount(comm) - if not input.is_contiguous(): - input = input.contiguous() - if input.storage_offset() != 0 or input.storage().size() != input.numel(): - input = input.clone() + input = to_contiguous(input) output = torch.empty( (world_size,) + input.size(), dtype=input.dtype, device=input.device) ctx.comm = comm ncclAllGather( @@ -87,6 +92,7 @@ def forward(ctx, input : torch.Tensor, comm = None): @staticmethod def backward(ctx, grad_output): + grad_output = to_contiguous(grad_output) return grad_output[commRank(ctx.comm)], None def all_gather(x : torch.Tensor, comm = None): @@ -113,10 +119,7 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None) ctx.comm = comm rank = commRank(comm) assert input.shape[0] % commCount(comm) == 0, "The dimension 0 must be divisible by the number of communication processes" - if not input.is_contiguous(): - input = input.contiguous() - if input.storage_offset() != 0 or input.storage().size() != input.numel(): - input = input.clone() + input = to_contiguous(input) output_shape = (input.shape[0] // commCount(comm), *input.shape[1:]) output = torch.empty( output_shape, dtype=input.dtype, device=input.device ) ncclReduceScatter( @@ -136,6 +139,7 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None) @staticmethod def backward(ctx, grad_output): + grad_output = to_contiguous(grad_output) with torch.no_grad(): grad_output = OpAllGather.apply(grad_output, ctx.comm).flatten(0,1) if ctx.op in ["max", "min", "prod"]: @@ -169,10 +173,7 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None) if comm is None: comm = config["comm"] ctx.comm = comm - if not input.is_contiguous(): - input = input.contiguous() - if input.storage_offset() != 0 or input.storage().size() != input.numel(): - input = input.clone() + input = to_contiguous(input) output = torch.empty( input.size(), dtype=input.dtype, device=input.device) ncclAllReduce( @@ -193,6 +194,7 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None) @staticmethod def backward(ctx, grad_output): + grad_output = to_contiguous(grad_output) if ctx.op == "sum": return grad_output, None, None elif ctx.op == "avg": @@ -220,4 +222,45 @@ def all_reduce(x : torch.Tensor, op : str = "sum", comm = None): return OpAllReduce.apply(x, op, comm) - +class OpAllToAll(torch.autograd.Function): + @staticmethod + def forward(ctx, input : torch.Tensor, comm : NCCLCommunicator = None): + if comm is None: + comm = config["comm"] + ctx.comm = comm + input = to_contiguous(input) + output = torch.empty(input.size(), dtype=input.dtype, device=input.device) + + ncclAllToAll( + input.storage(), + output.storage(), + comm + ) + return output + + @staticmethod + def backward(ctx, grad_output): + grad_output = to_contiguous(grad_output) + grad_input = torch.empty(grad_output.size(), dtype=grad_output.dtype, device=grad_output.device) + ncclAllToAll( + grad_output.storage(), + grad_input.storage(), + ctx.comm + ) + return grad_input, None + +def all_to_all(x : torch.Tensor, comm = None): + """Split input tensor and then scatter the split list to all processes in a group. + + Args: + x (torch.Tensor): The input tensor of shape (...). + + Returns: + torch.Tensor: the concatenated of received tensors + + """ + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + + assert x.is_cuda + return OpAllToAll.apply(x, comm) diff --git a/bmtrain/nccl/__init__.py b/bmtrain/nccl/__init__.py index 0f4129d5..531c5498 100644 --- a/bmtrain/nccl/__init__.py +++ b/bmtrain/nccl/__init__.py @@ -46,6 +46,26 @@ def dtype2nccl(dtype : torch.dtype) -> int: raise TypeError("Unsupport dtype %s" % dtype) return MAP[dtype] +def dtype2byte(dtype : torch.dtype) -> int: + MAP = { + torch.int8: 1, + torch.uint8 : 1, + torch.int32 : 4, + torch.int : 4, + torch.int64 : 8, + torch.float16 : 2, + torch.half : 2, + torch.bfloat16 : 2, + torch.float32 : 4, + torch.float : 4, + torch.float64 : 8, + torch.double : 8, + torch.bool : 1 + } + if dtype not in MAP: + raise TypeError("Unsupport dtype %s" % dtype) + return MAP[dtype] + def op2nccl( op : Literal["sum", "prod", "max", "min", "avg"] ): @@ -323,6 +343,70 @@ def reduceScatter( torch.cuda.current_stream().cuda_stream ) +def all2all( + src : torch.storage._StorageBase, + dst : torch.storage._StorageBase, + comm : NCCLCommunicator + ): + """NCCL all2all (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#all-to-all) + Args: + src (torch.storage._StorageBase): Source buffer. + dst (torch.storage._StorageBase): Destination buffer. + comm (NCCLCommunicator): NCCL communicator. + + The size of the dst buffer must be equal to the size of src buffer / world_size. + The dst buffer on rank `i` will contail the i-th block of the reduced result. + """ + assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.is_cuda and dst.is_cuda + + sendbuff = src.data_ptr() + recvbuff = dst.data_ptr() + assert src.size() == dst.size(), "src and dst Buffer size not equal" + # assert src.size() % world_size == 0, "Buffer size cannot be evenly divided by world_size" + datatype = dtype2nccl(src.dtype) + databyte = dtype2byte(src.dtype) + datacount = src.size() + databytes = datacount * databyte + + C.ncclAll2All(sendbuff, recvbuff, datacount, databytes, datatype, comm.ptr, torch.cuda.current_stream().cuda_stream) + + +def all2one( + src : torch.storage._StorageBase, + dst : torch.storage._StorageBase, + rank : int, + comm : NCCLCommunicator + ): + """NCCL all2one (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html?highlight=point#all-to-one-gather) + Args: + src (torch.storage._StorageBase): Source buffer. + dst (torch.storage._StorageBase): Destination buffer. + rank : all send to rank. + comm (NCCLCommunicator): NCCL communicator. + + The size of the dst buffer must be equal to the size of src buffer / world_size. + The dst buffer on rank `i` will contail the i-th block of the reduced result. + """ + assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.is_cuda and dst.is_cuda + + sendbuff = src.data_ptr() + recvbuff = dst.data_ptr() + world_size = commCount(comm) + assert src.size() == dst.size(), "src and dst Buffer size not equal" + assert src.size() % world_size == 0, "Buffer size cannot be evenly divided by world_size" + datacount = src.size() // world_size + datatype = dtype2nccl(src.dtype) + databyte = dtype2byte(src.dtype) + + groupStart() + if commRank(comm) == rank: + for r in range(world_size): + C.ncclRecv(recvbuff + r * datacount * databyte, datacount, datatype, r, comm.ptr, torch.cuda.current_stream().cuda_stream) + C.ncclSend(sendbuff + rank * datacount * databyte, datacount, datatype, rank, comm.ptr, torch.cuda.current_stream().cuda_stream) + groupEnd() + def groupStart(): """ NCCL API: `ncclGroupStart `_ diff --git a/csrc/bind.cpp b/csrc/bind.cpp index b8f6fa85..71c6eef1 100644 --- a/csrc/bind.cpp +++ b/csrc/bind.cpp @@ -32,4 +32,5 @@ PYBIND11_MODULE(C, m) { m.def("ncclRecv", &pyNCCLRecv, "nccl recv"); m.def("ncclCommCount", &pyNCCLCommCount, "nccl comm count"); m.def("ncclCommUserRank", &pyNCCLCommUserRank, "nccl comm user rank"); + m.def("ncclAll2All", &pyNCCLAll2All, "nccl All2All"); } diff --git a/csrc/include/nccl.hpp b/csrc/include/nccl.hpp index bba0278b..7c332cec 100644 --- a/csrc/include/nccl.hpp +++ b/csrc/include/nccl.hpp @@ -131,6 +131,39 @@ void pyNCCLReduceScatter( reinterpret_cast(stream) )); } + +void pyNCCLAll2All( + std::uintptr_t sendbuff, + std::uintptr_t recvbuff, + size_t count, + size_t bytes, + int data_type, + std::uintptr_t comm, + std::uintptr_t stream) { + int num_rank; + checkNCCLStatus(ncclCommCount(reinterpret_cast(comm), &num_rank)); + count = count / num_rank; + bytes = bytes / num_rank; + checkNCCLStatus(ncclGroupStart()); + for (int r=0; r < num_rank; r++) { + checkNCCLStatus(ncclSend( + reinterpret_cast(sendbuff + r * bytes), + count, + static_cast(data_type), + r, + reinterpret_cast(comm), + reinterpret_cast(stream))); + checkNCCLStatus(ncclRecv( + reinterpret_cast(recvbuff + r * bytes), + count, + static_cast(data_type), + r, + reinterpret_cast(comm), + reinterpret_cast(stream))); + } + checkNCCLStatus(ncclGroupEnd()); +} + void pyNCCLSend( std::uintptr_t sendbuff, size_t sendcount, diff --git a/tests/test_all.py b/tests/test_all.py index db5d2dd4..a9b431c1 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -26,6 +26,7 @@ ("send_recv", 4), ("nccl_backward", 4), + ("nccl_all2all", 4), ("no_grad", 1), ("column_parallel_linear", 2), ("row_parallel_linear", 2), diff --git a/tests/test_nccl_all2all.py b/tests/test_nccl_all2all.py new file mode 100644 index 00000000..ce01ed91 --- /dev/null +++ b/tests/test_nccl_all2all.py @@ -0,0 +1,159 @@ + +from utils import * + +import bmtrain as bmt +import torch +from bmtrain import nccl +import math +import os + +def test_main(dtype): + # x shape (2,8) + refx = torch.tensor([[(bmt.rank()*2+y)*10+x for x in range(8)] for y in range(2)], dtype=dtype, device="cuda") + refy = torch.tensor([[y*10+bmt.rank()*2+x for x in range(2)] for y in range(8)], dtype=dtype, device="cuda") + + x = refx.clone() + bmt.print_rank("x") + for r in range(4): + bmt.print_rank(x, rank=r) + bmt.synchronize() + + x = torch.cat(x.chunk(4, dim=1), dim=0).contiguous() + y = torch.zeros((8,2), dtype=dtype, device="cuda") + nccl.all2all(x.storage(), y.storage(), bmt.config['comm']) + bmt.print_rank("y") + for r in range(4): + bmt.print_rank(y, rank=r) + if bmt.rank() == r: assert (y == refy).all() + bmt.synchronize() + + x = torch.zeros((8,2), dtype=dtype, device="cuda") + nccl.all2all(y.storage(), x.storage(), bmt.config['comm']) + x = torch.cat(x.chunk(4, dim=0), dim=1).contiguous() + bmt.print_rank("x") + for r in range(4): + bmt.print_rank(x, rank=r) + if bmt.rank() == r: assert (x == refx).all() + bmt.synchronize() + +class Attention(bmt.DistributedModule): + def __init__(self, + dim_model : int, dim_head : int, + num_heads : int, bias : bool = True, + sequence_parallel : bool = False, + dtype = None + ) -> None: + super().__init__() + + self.project_q = bmt.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = bmt.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = bmt.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_out = bmt.nn.Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + + self.softmax = torch.nn.Softmax(dim=-1) + self.num_heads = num_heads + self.dim_head = dim_head + self.dim_model = dim_model + + self.sequence_parallel = sequence_parallel + + def forward(self, + hidden : torch.Tensor, # (batch_size, seq_q, dim_model) + ) -> torch.Tensor: + batch_size, seq, dim_model = hidden.size() + + h_q : torch.Tensor = self.project_q(hidden) + h_k : torch.Tensor = self.project_k(hidden) + h_v : torch.Tensor = self.project_v(hidden) + + if self.sequence_parallel: + assert batch_size == 1 + h_q = h_q.view(seq, -1) + h_k = h_k.view(seq, -1) + h_v = h_v.view(seq, -1) + h_q = torch.cat(h_q.chunk(bmt.world_size(), dim=1), dim=0).contiguous() + h_k = torch.cat(h_k.chunk(bmt.world_size(), dim=1), dim=0).contiguous() + h_v = torch.cat(h_v.chunk(bmt.world_size(), dim=1), dim=0).contiguous() + h_q = bmt.distributed.all_to_all(h_q, bmt.config['comm']) + h_k = bmt.distributed.all_to_all(h_k, bmt.config['comm']) + h_v = bmt.distributed.all_to_all(h_v, bmt.config['comm']) + seq = seq * bmt.world_size() + h_q = h_q.view(batch_size, seq, -1) + h_k = h_k.view(batch_size, seq, -1) + h_v = h_v.view(batch_size, seq, -1) + + h_q = h_q.view(batch_size, seq, -1, self.dim_head) + h_k = h_k.view(batch_size, seq, -1, self.dim_head) + h_v = h_v.view(batch_size, seq, -1, self.dim_head) + + h_q = h_q.permute(0, 2, 1, 3).contiguous() + h_k = h_k.permute(0, 2, 1, 3).contiguous() + h_v = h_v.permute(0, 2, 1, 3).contiguous() + + h_q = h_q.view(-1, seq, self.dim_head) + h_k = h_k.view(-1, seq, self.dim_head) + h_v = h_v.view(-1, seq, self.dim_head) + + score = torch.bmm( + h_q, h_k.transpose(1, 2) + ) + score = score / math.sqrt(self.dim_head) + + score = score.view(batch_size, -1, seq, seq) + + score = score.view(-1, seq, seq) + + h_out = torch.bmm( + score, h_v + ) + h_out = h_out.view(batch_size, -1, seq, self.dim_head) + h_out = h_out.permute(0, 2, 1, 3).contiguous() + h_out = h_out.view(batch_size, seq, -1) + + if self.sequence_parallel: + h_out = h_out.view(seq, -1) + h_out = bmt.distributed.all_to_all(h_out, bmt.config['comm']) + h_out = torch.cat(h_out.chunk(bmt.world_size(), dim=0), dim=1).contiguous() + seq = seq // bmt.world_size() + h_out = h_out.view(batch_size, seq, -1) + + attn_out = self.project_out(h_out) + return attn_out + +def test_ulysses(dtype): + model1 = Attention(dim_model=768, dim_head=32, num_heads=8, dtype=dtype, sequence_parallel=False) + bmt.init_parameters(model1) + bmt.save(model1, "test.pt") + model2 = Attention(dim_model=768, dim_head=32, num_heads=8, dtype=dtype, sequence_parallel=True) + bmt.load(model2, "test.pt") + + xx = torch.randn((1, 128, 768), dtype=dtype, device="cuda").requires_grad_() + x_sp = xx.clone().chunk(bmt.world_size(), dim=1)[bmt.rank()].detach().requires_grad_() + + yy = model1(xx) + y_sp = model2(x_sp) + + gg = torch.randn((1, 128, 768), dtype=dtype, device="cuda") + g = gg.chunk(bmt.world_size(), dim=1)[bmt.rank()] + + yy.backward(gg) + y_sp.backward(g) + + for r in range(bmt.world_size()): + if bmt.rank() == r: + print(r) + print(y_sp) + print(yy.chunk(bmt.world_size(), dim=1)[bmt.rank()]) + assert torch.allclose(y_sp, yy.chunk(bmt.world_size(), dim=1)[bmt.rank()]) + print(x_sp.grad) + print(xx.grad.chunk(bmt.world_size(), dim=1)[bmt.rank()]) + assert torch.allclose(x_sp.grad, xx.grad.chunk(bmt.world_size(), dim=1)[bmt.rank()], rtol=1e-3, atol=1e-1) + bmt.synchronize() + + if bmt.rank() == 0: os.remove("test.pt") + +if __name__ == "__main__": + bmt.init_distributed() + + # test_main(torch.half) + test_ulysses(torch.half) From d658b12ae3d9d1a59b4afeb5c1f51e3c949fd0fe Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 28 May 2024 12:58:01 +0800 Subject: [PATCH 02/11] benchmark and ops update --- bmtrain/benchmark/__init__.py | 3 +- bmtrain/benchmark/all2all.py | 49 +++++++++++++++++++++++++++++++++ bmtrain/distributed/__init__.py | 2 +- 3 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 bmtrain/benchmark/all2all.py diff --git a/bmtrain/benchmark/__init__.py b/bmtrain/benchmark/__init__.py index 571d621f..de72ecfb 100644 --- a/bmtrain/benchmark/__init__.py +++ b/bmtrain/benchmark/__init__.py @@ -1,3 +1,4 @@ from .all_gather import all_gather from .reduce_scatter import reduce_scatter -from .send_recv import send_recv \ No newline at end of file +from .send_recv import send_recv +from .all2all import all2all, all2one diff --git a/bmtrain/benchmark/all2all.py b/bmtrain/benchmark/all2all.py new file mode 100644 index 00000000..36805584 --- /dev/null +++ b/bmtrain/benchmark/all2all.py @@ -0,0 +1,49 @@ +from .. import nccl +from .shape import SHAPES +from ..global_var import config +from ..utils import round_up, print_rank +from .utils import format_size +import torch + +def all2all(): + current_stream = torch.cuda.current_stream() + for shape in SHAPES: + global_size = round_up(shape, config['world_size'] * 2) + + result_tensor = torch.empty(global_size // 2, dtype=torch.half, device="cuda") + global_tensor = torch.empty(global_size // 2, dtype=torch.half, device="cuda") + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + current_stream.record_event(start_evt) + nccl.all2all(global_tensor.storage(), result_tensor.storage(), config['comm']) + current_stream.record_event(end_evt) + current_stream.synchronize() + + time_usage = start_evt.elapsed_time(end_evt) + bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage + print_rank("All to All:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw)) + +def all2one(): + current_stream = torch.cuda.current_stream() + for shape in SHAPES: + global_size = round_up(shape, config['world_size'] * 2) + + result_tensor = torch.empty(global_size // 2, dtype=torch.half, device="cuda") + global_tensor = torch.empty(global_size // 2, dtype=torch.half, device="cuda") + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + current_stream.record_event(start_evt) + nccl.groupStart() + for r in range(config['world_size']): + nccl.all2one(global_tensor.storage(), result_tensor.storage(), r, config['comm']) + nccl.groupEnd() + current_stream.record_event(end_evt) + current_stream.synchronize() + + time_usage = start_evt.elapsed_time(end_evt) + bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage + print_rank("All to one:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw)) diff --git a/bmtrain/distributed/__init__.py b/bmtrain/distributed/__init__.py index 84a4adf8..bb6844a8 100644 --- a/bmtrain/distributed/__init__.py +++ b/bmtrain/distributed/__init__.py @@ -1 +1 @@ -from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, reduce_scatter +from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, reduce_scatter, all_to_all From d9e026c4b10c453371d167cb79b923291445d1cf Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 28 May 2024 15:02:40 +0800 Subject: [PATCH 03/11] add all2all benchmark --- bmtrain/optim/optim_manager.py | 3 ++- example/benchmark.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 1a98ed92..dd33c534 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -5,6 +5,7 @@ from ..lr_scheduler.warmup import WarmupLRScheduler from .. import nccl from ..global_var import config +import bmtrain as bmt def check_overflow(param_groups): # check overflow @@ -209,7 +210,7 @@ def _justify_scale(self, scale): def state_dict(self, gather_opt=False) -> dict: return { - "optimizers": [opt.state_dict(gather_opt) for opt in self.optimizers], + "optimizers": [opt.state_dict(gather_opt) if isinstance(opt, bmt.optim.AdamOffloadOptimizer) else opt.state_dict() for opt in self.optimizers], "lr_schedulers": [lrs.state_dict() if lrs else None for lrs in self.lr_schedulers], "loss_scale": self.loss_scale, "loss_scale_enabled": self.loss_scale_enabled, diff --git a/example/benchmark.py b/example/benchmark.py index 8a7092d9..bfac5fbd 100644 --- a/example/benchmark.py +++ b/example/benchmark.py @@ -1,12 +1,16 @@ import bmtrain as bmt +from bmtrain import benchmark def main(): bmt.init_distributed() bmt.print_rank("======= All Gather =======") - bmt.benchmark.all_gather() + benchmark.all_gather() bmt.print_rank("===== Reduce Scatter =====") - bmt.benchmark.reduce_scatter() - + benchmark.reduce_scatter() + bmt.print_rank("===== All 2 All =====") + benchmark.all2all() + bmt.print_rank("===== All 2 One =====") + benchmark.all2one() if __name__ == '__main__': - main() \ No newline at end of file + main() From bac71f63dce80553cb445081d8e871c2f15ee26d Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Wed, 29 May 2024 16:01:39 +0800 Subject: [PATCH 04/11] sequence parallel comm init and update topology --- bmtrain/init.py | 71 ++++++++++++++++++++++++++----------- bmtrain/utils.py | 47 ++++++++++++++++++++++++ example/layers/attention.py | 34 +++++++++++++++++- 3 files changed, 131 insertions(+), 21 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 69273c09..b78105ad 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -3,9 +3,10 @@ import random import torch.distributed as dist import os -from .utils import print_dict +from .utils import print_dict, topology_helper import ctypes from .global_var import config +from collections import OrderedDict from . import nccl from .synchronize import synchronize @@ -16,6 +17,7 @@ def init_distributed( pipe_size: int = -1, num_micro_batches: int = None, tp_size : int = 1, + sp_size : int = 1, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -66,6 +68,7 @@ def init_distributed( torch.cuda.set_device(local_rank) config["initialized"] = True config["pipe_size"] = pipe_size if pipe_size > 0 else 1 + config["sp_size"] = sp_size if size > 0 else 1 config["pipe_enabled"] = pipe_size > 0 config["local_rank"] = local_rank config["local_size"] = local_size @@ -148,6 +151,13 @@ def init_distributed( unique_id = bytes.fromhex(store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()) config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) + if config['sp_size'] > 1: + assert world_size // (config['pipe_size'] * config['tp_size']) % config['sp_size'] == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size * sp size" + if topo.sp_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"SP_UNIQUE_ID{topo.sp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"SP_UNIQUE_ID{topo.sp_idx}").decode()) + config['sp_comm'] = nccl.commInitRank(unique_id, sp_size, topo.sp_id) config ['zero_comm'] = config['comm'] for i in range(world_size): @@ -175,27 +185,44 @@ def __init__(self,config): dp_size = world_size // (pp_size * tp_size) config['tp_zero_size'] = dp_size config['zero_size'] = world_size // pp_size + order = ["tp", "dp", "pp"] + order_dict = OrderedDict() + for o in order: + if o == "tp": + order_dict[o] = tp_size + elif o == "dp": + order_dict[o] = dp_size + elif o == "pp": + order_dict[o] = pp_size + self._topo = topology_helper(order_dict) self.stages = config['pipe_size'] + self.pipe_idx = self._topo.get_group_id(self.rank, "pipe") + self.stage_id = self._topo.get_group_rank(self.rank, "pipe") + self.tp_id = self._topo.get_group_rank(self.rank, "tp") + self.tp_idx = self._topo.get_group_id(self.rank, "tp") + # pp->zero + self.pp_zero_idx = self.stage_id + self.pp_zero_id = self.pipe_idx + # tp->zero + self.tp_zero_idx = self.tp_id + self.tp_zero_id = self.tp_idx + # pp->tp->zero + self.dp_id = self._topo.get_group_rank(self.rank, "dp") + self.dp_idx = self._topo.get_group_id(self.rank, "dp") + self.pp_tp_zero_id = self.dp_id + self.pp_tp_zero_idx = self.dp_idx + # only zero + self.zero_idx = self._topo.get_group_id(self.rank, "dp") + self.zero_id = self._topo.get_group_rank(self.rank, "dp") - stage_size = world_size // pp_size - for i in range(world_size): - self.pipe_idx = self.rank % stage_size - self.stage_id = self.rank // stage_size - self.tp_id = self.rank % tp_size - self.tp_idx = self.rank // tp_size - #pp->zero - self.pp_zero_idx = self.stage_id - self.pp_zero_id = self.pipe_idx - #tp->zero - self.tp_zero_idx = self.tp_id - self.tp_zero_id = self.tp_idx - #pp->tp->zero - self.pp_tp_zero_idx = self.stage_id * tp_size + self.tp_id - self.pp_tp_zero_id = self.pipe_idx // tp_size - #only zero - self.zero_idx = 0 - self.zero_id = self.rank - + # divide sp group based on dp + order_dict = OrderedDict() + order_dict["sp"] = config["sp_size"] + order_dict["dp_sp"] = dp_size // config["sp_size"] + self._topo_sp = topology_helper(order_dict) + offset = self.dp_idx * self.dp_size // self.sp_size + self.sp_idx = offset + self._topo_sp.get_group_id(self.dp_id, "sp") + self.sp_id = self._topo_sp.get_group_rank(self.dp_id, "sp") def get_group_id(self,group_name): if group_name == "pipe": @@ -206,6 +233,8 @@ def get_group_id(self,group_name): return self.tp_zero_idx elif group_name == "tp": return self.tp_idx + elif group_name == "sp": + return self.sp_idx def get_group_rank(self,group_name): if group_name == "pipe": @@ -216,6 +245,8 @@ def get_group_rank(self,group_name): return self.tp_zero_id elif group_name == "tp": return self.tp_id + elif group_name == "sp": + return self.sp_id def is_initialized() -> bool: return config["initialized"] diff --git a/bmtrain/utils.py b/bmtrain/utils.py index 8cb87808..ed35b377 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -1,6 +1,8 @@ import torch import sys from typing import Any, Dict, Iterable, Optional +from itertools import product +import math from .global_var import config import os import ctypes @@ -143,3 +145,48 @@ def value(self): if self._steps <= 0: return self._value return self._value / (1 - pow(self.alpha, self._steps)) + + +def get_offset(sizes, pos): + offset = 0 + for idx in range(len(sizes)): + if idx == (len(sizes) - 1): + offset += pos[idx] + else: + offset += math.prod(sizes[idx+1:]) * pos[idx] + return offset + +class topology_helper: + def __init__(self, groups_dict): + self.keys = list(groups_dict.keys())[::-1] + self.values = list(groups_dict.values())[::-1] + self.world_size = math.prod(list(groups_dict.values())) + self.rank_grid, self.group_ids = self._init_rank_grid() + + def _init_rank_grid(self): + grid_shape = tuple(self.values) + rank_grid = [None for _ in range(self.world_size)] + group_ids = [None for _ in range(self.world_size)] + + for rank, group_ranks in enumerate(product(*[range(size) for size in grid_shape])): + rank_grid[rank] = group_ranks + group_ids[rank] = [0 for _ in range(len(self.keys))] + for i,r in enumerate(rank_grid[rank]): + ranks = list(rank_grid[rank]) + sizes = list(self.values) + group_pos = [group_ranks[j] for j in range(len(group_ranks))if j != i ] + group_offset = get_offset([sizes[s] for s in range(len(sizes)) if s != i], group_pos) + group_ids[rank][i] = group_offset + + return rank_grid, group_ids + + def get_group_rank(self, rank, group): + if group not in self.keys: + raise ValueError(f"Group {group} not found ") + return self.rank_grid[rank][self.keys.index(group)] + + def get_group_id(self, rank, group): + if group not in self.keys: + raise ValueError(f"Group {group} not found ") + idx = self.keys.index(group) + return self.group_ids[rank][idx] diff --git a/example/layers/attention.py b/example/layers/attention.py index 0f5155d4..cfcbcbf9 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -9,6 +9,26 @@ import math from bmtrain.global_var import config from bmtrain.distributed import all_gather +def all2all_tensor(tensor, gather_dim, scatter_dim) + # Input shape: (B, S, N, D) | (B, N, S, D) + origin_size = list(tensor.size()) + output_size = origin_size.copy() + output_size[gather_dim] = origin_size[gather_dim] * bmt.config['sp_size'] + output_size[scatter_dim] = origin_size[scatter_dim] / bmt.config['sp_size'] + tensor = tensor.permute(seq_dim, head_dim, tensor.size(0), tensor.size(-1)) + tensor = torch.cat(tensor.chunk(bmt.config['sp_size'], dim=1), dim=0).contiguous() + tensor = bmt.distributed.all_to_all(tensor, bmt.config['sp_comm']) + tensor = tensor.view(*output_size) + return tensor + +def all2all_qkv(q, k, v, seq_dim, head_dim): + q = all2all_tensor(q, seq_dim, head_dim) + k = all2all_tensor(k, seq_dim, head_dim) + v = all2all_tensor(v, seq_dim, head_dim) + return q, k, v + + + class Attention(bmt.DistributedModule): def __init__(self, @@ -62,10 +82,19 @@ def forward(self, seq_q = h_q.size()[1] seq_kv = h_k.size(1) - + h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) + if config['sequence_parallel']: + seq_dim = 1 + head_dim = 2 + all2all_qkv(h_q, h_k, h_v, seq_dim, head_dim) + seq_q = h_q.size()[1] + seq_kv = h_k.size(1) + + + h_q = h_q.permute(0, 2, 1, 3).contiguous() h_k = h_k.permute(0, 2, 1, 3).contiguous() @@ -103,6 +132,9 @@ def forward(self, score, h_v ) h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) + if config['sequence_parallel']: + h_out = all2all_tensor(h_out, 1, 2) + seq_q = h_out.size(2) h_out = h_out.permute(0, 2, 1, 3).contiguous() h_out = h_out.view(batch_size, seq_q, -1) if config['tp_size'] > 1: From eaa371932a093e19cdc455ed64f14e502a251921 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 3 Jun 2024 18:30:22 +0800 Subject: [PATCH 05/11] provide sequence parallel example --- bmtrain/init.py | 110 +++++++++++++++---------------- bmtrain/nn/parallel_embedding.py | 2 +- bmtrain/optim/optim_manager.py | 4 +- bmtrain/utils.py | 39 +---------- example/layers/attention.py | 41 ++++++------ example/layers/embedding.py | 3 +- example/models/gpt.py | 6 ++ example/train.py | 10 +-- 8 files changed, 92 insertions(+), 123 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index b78105ad..62ef60e9 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -3,22 +3,22 @@ import random import torch.distributed as dist import os -from .utils import print_dict, topology_helper -import ctypes +from .utils import print_dict from .global_var import config from collections import OrderedDict from . import nccl from .synchronize import synchronize + def init_distributed( - init_method : str = "env://", - seed : int = 0, - pipe_size: int = -1, - num_micro_batches: int = None, - tp_size : int = 1, - sp_size : int = 1, - ): + init_method: str = "env://", + seed: int = 0, + pipe_size: int = -1, + num_micro_batches: int = None, + tp_size: int = 1, + sp_size: int = 1, +): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. It must be called before any other distributed functions. @@ -49,14 +49,14 @@ def init_distributed( local_rank = int(os.environ.get("LOCAL_RANK", "0")) rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_size = int(os.environ.get("LOCAL_WORLD_SIZE","1")) + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", "1")) if "MASTER_ADDR" not in os.environ: - os.environ["MASTER_ADDR"]="localhost" + os.environ["MASTER_ADDR"] = "localhost" if "MASTER_PORT" not in os.environ: - os.environ["MASTER_PORT"]="10010" + os.environ["MASTER_PORT"] = "10010" addr = os.environ["MASTER_ADDR"] port = os.environ["MASTER_PORT"] - master = addr+":"+port + master = addr + ":" + port timeout = datetime.timedelta(seconds=1800) rendezvous_iterator = dist.rendezvous( init_method, rank, world_size, timeout=timeout @@ -68,7 +68,7 @@ def init_distributed( torch.cuda.set_device(local_rank) config["initialized"] = True config["pipe_size"] = pipe_size if pipe_size > 0 else 1 - config["sp_size"] = sp_size if size > 0 else 1 + config["sp_size"] = sp_size if sp_size > 0 else 1 config["pipe_enabled"] = pipe_size > 0 config["local_rank"] = local_rank config["local_size"] = local_size @@ -84,7 +84,9 @@ def init_distributed( config["topology"] = topology(config) config["zero_rank"] = config['topology'].get_group_rank("zero") config["tp_rank"] = config['topology'].get_group_rank("tp") + config["sp_rank"] = config['topology'].get_group_rank("sp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") + config["tp_sp_zero_rank"] = config['topology'].get_group_rank("tp_sp_zero") config["save_param_to_cpu"] = True cpus_this_worker = None @@ -149,16 +151,16 @@ def init_distributed( unique_id = nccl.getUniqueId() store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex() ) unique_id = bytes.fromhex(store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()) - config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) + config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size // (config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) if config['sp_size'] > 1: - assert world_size // (config['pipe_size'] * config['tp_size']) % config['sp_size'] == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size * sp size" + assert world_size % (config['pipe_size'] * config['tp_size'] * config['sp_size']) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size * sp size" if topo.sp_id == 0: unique_id = nccl.getUniqueId() store.set(f"SP_UNIQUE_ID{topo.sp_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"SP_UNIQUE_ID{topo.sp_idx}").decode()) config['sp_comm'] = nccl.commInitRank(unique_id, sp_size, topo.sp_id) - config ['zero_comm'] = config['comm'] + config['zero_comm'] = config['comm'] for i in range(world_size): if i == rank: @@ -167,64 +169,52 @@ def init_distributed( "local_rank": local_rank, "world_size": world_size, "local_size": local_size, - "master" : master, + "master": master, "device": torch.cuda.current_device(), "cpus": cpus_this_worker }) synchronize() + class topology: - def __init__(self,config): + def __init__(self, config): # pipe_idx is the idx of the pipeline in the group self.rank = config['rank'] pp_size = config["pipe_size"] tp_size = config["tp_size"] world_size = config["world_size"] - assert world_size % (pp_size * tp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size" + sp_size = config["sp_size"] + assert world_size % (pp_size * tp_size * sp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size * sequence parallel size" dp_size = world_size // (pp_size * tp_size) config['tp_zero_size'] = dp_size config['zero_size'] = world_size // pp_size - order = ["tp", "dp", "pp"] - order_dict = OrderedDict() - for o in order: - if o == "tp": - order_dict[o] = tp_size - elif o == "dp": - order_dict[o] = dp_size - elif o == "pp": - order_dict[o] = pp_size - self._topo = topology_helper(order_dict) self.stages = config['pipe_size'] - self.pipe_idx = self._topo.get_group_id(self.rank, "pipe") - self.stage_id = self._topo.get_group_rank(self.rank, "pipe") - self.tp_id = self._topo.get_group_rank(self.rank, "tp") - self.tp_idx = self._topo.get_group_id(self.rank, "tp") - # pp->zero + + stage_size = world_size // pp_size + self.pipe_idx = self.rank % stage_size + self.stage_id = self.rank // stage_size + self.tp_id = self.rank % tp_size + self.tp_idx = self.rank // tp_size + self.tp_sp_idx = self.rank // tp_size // sp_size + self.tp_sp_id = self.rank % (tp_size * sp_size) + #pp->zero self.pp_zero_idx = self.stage_id self.pp_zero_id = self.pipe_idx - # tp->zero + #tp->zero self.tp_zero_idx = self.tp_id self.tp_zero_id = self.tp_idx - # pp->tp->zero - self.dp_id = self._topo.get_group_rank(self.rank, "dp") - self.dp_idx = self._topo.get_group_id(self.rank, "dp") - self.pp_tp_zero_id = self.dp_id - self.pp_tp_zero_idx = self.dp_idx - # only zero - self.zero_idx = self._topo.get_group_id(self.rank, "dp") - self.zero_id = self._topo.get_group_rank(self.rank, "dp") - - # divide sp group based on dp - order_dict = OrderedDict() - order_dict["sp"] = config["sp_size"] - order_dict["dp_sp"] = dp_size // config["sp_size"] - self._topo_sp = topology_helper(order_dict) - offset = self.dp_idx * self.dp_size // self.sp_size - self.sp_idx = offset + self._topo_sp.get_group_id(self.dp_id, "sp") - self.sp_id = self._topo_sp.get_group_rank(self.dp_id, "sp") - - def get_group_id(self,group_name): + #tp->sp->zero + self.tp_sp_zero_idx = self.tp_sp_id + self.tp_sp_zero_id = self.tp_sp_idx + #pp->tp->zero + self.pp_tp_zero_idx = self.stage_id * tp_size + self.tp_id + self.pp_tp_zero_id = self.pipe_idx // tp_size + #only zero + self.zero_idx = 0 + self.zero_id = self.rank + + def get_group_id(self, group_name): if group_name == "pipe": return self.pipe_idx elif group_name == "zero": @@ -235,8 +225,10 @@ def get_group_id(self,group_name): return self.tp_idx elif group_name == "sp": return self.sp_idx - - def get_group_rank(self,group_name): + elif group_name == "tp_sp_zero": + return self.tp_sp_zero_idx + + def get_group_rank(self, group_name): if group_name == "pipe": return self.stage_id elif group_name == "zero": @@ -247,7 +239,9 @@ def get_group_rank(self,group_name): return self.tp_id elif group_name == "sp": return self.sp_id + elif group_name == "tp_sp_zero": + return self.tp_sp_zero_id + def is_initialized() -> bool: return config["initialized"] - diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index 43e7397d..0f0868c0 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -16,7 +16,7 @@ def __init__( embedding_size: int, dtype: torch.dtype = torch.half, init_mean: float = 0.0, - init_std: float = 1, + init_std: float = 0.02, ): super().__init__() diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index dd33c534..7e717a99 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -67,7 +67,7 @@ def __init__(self, self.min_loss_scale = min_loss_scale self.max_loss_scale = max_loss_scale if grad_scale is None: - grad_scale = config['zero_size'] + grad_scale = config['zero_size'] // config['tp_size'] self.grad_scale = grad_scale self.optimizers = [] @@ -90,7 +90,7 @@ def add_optimizer( def scale_loss(self, loss : torch.Tensor) -> torch.Tensor: - return loss * ( self.loss_scale / self.grad_scale ) # loss scale + return loss * ( self.loss_scale / self.grad_scale) # loss scale def backward(self, loss : torch.Tensor): """ diff --git a/bmtrain/utils.py b/bmtrain/utils.py index ed35b377..580a79c1 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -1,6 +1,6 @@ import torch import sys -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Optional from itertools import product import math from .global_var import config @@ -154,39 +154,4 @@ def get_offset(sizes, pos): offset += pos[idx] else: offset += math.prod(sizes[idx+1:]) * pos[idx] - return offset - -class topology_helper: - def __init__(self, groups_dict): - self.keys = list(groups_dict.keys())[::-1] - self.values = list(groups_dict.values())[::-1] - self.world_size = math.prod(list(groups_dict.values())) - self.rank_grid, self.group_ids = self._init_rank_grid() - - def _init_rank_grid(self): - grid_shape = tuple(self.values) - rank_grid = [None for _ in range(self.world_size)] - group_ids = [None for _ in range(self.world_size)] - - for rank, group_ranks in enumerate(product(*[range(size) for size in grid_shape])): - rank_grid[rank] = group_ranks - group_ids[rank] = [0 for _ in range(len(self.keys))] - for i,r in enumerate(rank_grid[rank]): - ranks = list(rank_grid[rank]) - sizes = list(self.values) - group_pos = [group_ranks[j] for j in range(len(group_ranks))if j != i ] - group_offset = get_offset([sizes[s] for s in range(len(sizes)) if s != i], group_pos) - group_ids[rank][i] = group_offset - - return rank_grid, group_ids - - def get_group_rank(self, rank, group): - if group not in self.keys: - raise ValueError(f"Group {group} not found ") - return self.rank_grid[rank][self.keys.index(group)] - - def get_group_id(self, rank, group): - if group not in self.keys: - raise ValueError(f"Group {group} not found ") - idx = self.keys.index(group) - return self.group_ids[rank][idx] + return offset \ No newline at end of file diff --git a/example/layers/attention.py b/example/layers/attention.py index cfcbcbf9..ddf12adb 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -8,19 +8,27 @@ ) import math from bmtrain.global_var import config -from bmtrain.distributed import all_gather -def all2all_tensor(tensor, gather_dim, scatter_dim) + +def inverse_permute(permute_dims): + inverse_dims = [0] * len(permute_dims) + for i, dim in enumerate(permute_dims): + inverse_dims[dim] = i + return inverse_dims + +def all2all_tensor(tensor, gather_dim, scatter_dim): # Input shape: (B, S, N, D) | (B, N, S, D) origin_size = list(tensor.size()) output_size = origin_size.copy() output_size[gather_dim] = origin_size[gather_dim] * bmt.config['sp_size'] - output_size[scatter_dim] = origin_size[scatter_dim] / bmt.config['sp_size'] - tensor = tensor.permute(seq_dim, head_dim, tensor.size(0), tensor.size(-1)) + output_size[scatter_dim] = origin_size[scatter_dim] // bmt.config['sp_size'] + inv_order = inverse_permute([gather_dim, scatter_dim, 0, -1]) + tensor = tensor.permute(gather_dim, scatter_dim, 0, -1) tensor = torch.cat(tensor.chunk(bmt.config['sp_size'], dim=1), dim=0).contiguous() tensor = bmt.distributed.all_to_all(tensor, bmt.config['sp_comm']) - tensor = tensor.view(*output_size) + tensor = tensor.permute(inv_order).contiguous() return tensor + def all2all_qkv(q, k, v, seq_dim, head_dim): q = all2all_tensor(q, seq_dim, head_dim) k = all2all_tensor(k, seq_dim, head_dim) @@ -28,8 +36,6 @@ def all2all_qkv(q, k, v, seq_dim, head_dim): return q, k, v - - class Attention(bmt.DistributedModule): def __init__(self, dim_model : int, dim_head : int, @@ -49,12 +55,11 @@ def __init__(self, self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) - self.softmax = torch.nn.Softmax(dim=-1) self.num_heads = num_heads self.dim_head = dim_head self.dim_model = dim_model - + def forward(self, hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model) hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model) @@ -82,19 +87,16 @@ def forward(self, seq_q = h_q.size()[1] seq_kv = h_k.size(1) - + h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) - if config['sequence_parallel']: + if config['sp_size'] > 1: seq_dim = 1 head_dim = 2 - all2all_qkv(h_q, h_k, h_v, seq_dim, head_dim) + h_q, h_k, h_v = all2all_qkv(h_q, h_k, h_v, seq_dim, head_dim) seq_q = h_q.size()[1] seq_kv = h_k.size(1) - - - h_q = h_q.permute(0, 2, 1, 3).contiguous() h_k = h_k.permute(0, 2, 1, 3).contiguous() @@ -131,10 +133,10 @@ def forward(self, h_out = torch.bmm( score, h_v ) - h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) - if config['sequence_parallel']: - h_out = all2all_tensor(h_out, 1, 2) - seq_q = h_out.size(2) + h_out = h_out.view(batch_size, -1, seq_q, self.dim_head).contiguous() + if config['sp_size'] > 1: + h_out = all2all_tensor(h_out, 1, 2) + seq_q = h_out.size(2) h_out = h_out.permute(0, 2, 1, 3).contiguous() h_out = h_out.view(batch_size, seq_q, -1) if config['tp_size'] > 1: @@ -143,7 +145,6 @@ def forward(self, attn_out = self.project_out(h_out) return attn_out - diff --git a/example/layers/embedding.py b/example/layers/embedding.py index f62151c4..ba0f576e 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -25,7 +25,8 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq if _weight is None: - self.weight = bmt.DistributedParameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device="cuda"), init_method=torch.nn.init.normal_) + init_method = lambda t: torch.nn.init.normal_(t, 0, 0.02) + self.weight = bmt.DistributedParameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device="cuda"), init_method=init_method) else: self.weight = bmt.DistributedParameter(_weight) diff --git a/example/models/gpt.py b/example/models/gpt.py index ed604382..ae52cffb 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -53,12 +53,18 @@ def forward(self, if config["tp_size"] > 1: input = input.chunk(config["tp_size"], dim=1)[config["tp_rank"]] pos = pos.chunk(config["tp_size"], dim=1)[config["tp_rank"]] + if config["sp_size"] > 1: + input = input.chunk(config["sp_size"], dim=1)[config["sp_rank"]] + pos = pos.chunk(config["sp_size"], dim=1)[config["sp_rank"]] + out = self.pos_emb(pos) + self.word_emb(input) # for layer in self.transformers: out = self.transformers(out, mask_2d, None) out = self.layernorm(out) logits = self.word_emb(out, projection=True) + if bmt.config['sp_size'] > 1: + logits = bmt.distributed.all_gather(logits, comm=bmt.config['sp_comm']).view(logits.shape[0], -1, logits.shape[-1]) bmt.inspect.record_tensor(logits, "logits") return logits diff --git a/example/train.py b/example/train.py index d5906a06..3e5d526a 100644 --- a/example/train.py +++ b/example/train.py @@ -10,6 +10,7 @@ def main(): bmt.init_distributed( seed=0, tp_size=2, + sp_size=2, ) model = GPT( @@ -36,8 +37,8 @@ def main(): batch_size = 2 seq_len = 512 - world_size = bmt.config["world_size"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_size"] - r = bmt.config["rank"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_rank"] + world_size = bmt.world_size() // config['sp_size'] // config['tp_size'] + r = bmt.config["tp_sp_zero_rank"] for i in range(world_size): sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) @@ -93,7 +94,7 @@ def main(): optim_manager.zero_grad() optim_manager.backward(loss) - + grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0) # print inspected tensors in the forward & backward pass # print parameters of the model if iteration % 100 == 0: @@ -118,10 +119,11 @@ def main(): # print time and loss bmt.print_rank( - "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( + "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | grad_norm: {:.2f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( iteration, global_loss, avg_loss_recorder.value, + grad_norm, lr_scheduler.current_lr, optim_manager.loss_scale, avg_time_recorder.value From d4d42f3ca3338c4cac191825f97457a4f633c56f Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 3 Jun 2024 18:43:34 +0800 Subject: [PATCH 06/11] empty commit --- example/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/example/README.md b/example/README.md index 395b5e64..292d32c2 100644 --- a/example/README.md +++ b/example/README.md @@ -2,4 +2,5 @@ This is an example of BMTrain's implementation of GPT-2. -For more model implementations, please refer to [Model Center](https://github.com/OpenBMB/ModelCenter). \ No newline at end of file +For more model implementations, please refer to [Model Center](https://github.com/OpenBMB/ModelCenter). + From 43d119095a44416eb692ef0a82d7413cf88a3921 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 6 Jun 2024 19:01:23 +0800 Subject: [PATCH 07/11] update ops and example for burst_attn --- bmtrain/distributed/ops.py | 24 ++ bmtrain/init.py | 7 +- bmtrain/loss/cross_entropy.py | 2 +- bmtrain/nn/__init__.py | 1 + bmtrain/nn/burst_attn.py | 145 +++++++ bmtrain/nn/burst_lao.py | 719 ++++++++++++++++++++++++++++++++++ bmtrain/nn/burst_utils.py | 284 ++++++++++++++ bmtrain/optim/adam_offload.py | 4 +- example/layers/attention.py | 52 +-- example/models/gpt.py | 2 +- pyproject.toml | 2 + tests/test_all2all.py | 106 +++++ tests/test_burst.py | 88 +++++ tests/test_loss_func.py | 3 +- 14 files changed, 1406 insertions(+), 33 deletions(-) create mode 100644 bmtrain/nn/burst_attn.py create mode 100644 bmtrain/nn/burst_lao.py create mode 100644 bmtrain/nn/burst_utils.py create mode 100644 tests/test_all2all.py create mode 100644 tests/test_burst.py diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index a52529f8..6dc42846 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -264,3 +264,27 @@ def all_to_all(x : torch.Tensor, comm = None): assert x.is_cuda return OpAllToAll.apply(x, comm) + +def inverse_permute(permute_dims): + inverse_dims = [0] * len(permute_dims) + for i, dim in enumerate(permute_dims): + inverse_dims[dim] = i + return inverse_dims + +def all2all_transpose(tensor : torch.Tensor, gather_dim : int, scatter_dim : int, comm = None): + # Input shape: (B, S, N, D) | (B, N, S, D) + origin_size = list(tensor.size()) + output_size = origin_size.copy() + count = commCount(comm) + output_size[gather_dim] = origin_size[gather_dim] * count + output_size[scatter_dim] = origin_size[scatter_dim] // count + inv_order = inverse_permute([gather_dim, scatter_dim, 0, -1]) + tensor = tensor.permute(gather_dim, scatter_dim, 0, -1) + tensor = torch.cat(tensor.chunk(count, dim=1), dim=0).contiguous() + tensor = all_to_all(tensor, count) + tensor = tensor.permute(inv_order).contiguous() + return tensor + + + + diff --git a/bmtrain/init.py b/bmtrain/init.py index 62ef60e9..a6d1641e 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -76,6 +76,7 @@ def init_distributed( config["world_size"] = world_size config["calc_stream"] = torch.cuda.current_stream() config["load_stream"] = torch.cuda.Stream(priority=-1) + config["sp_stream"] = torch.cuda.Stream(priority=-1) config["tp_comm_stream"] = torch.cuda.Stream(priority=-1) config["pp_comm_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() @@ -194,10 +195,12 @@ def __init__(self, config): stage_size = world_size // pp_size self.pipe_idx = self.rank % stage_size self.stage_id = self.rank // stage_size - self.tp_id = self.rank % tp_size - self.tp_idx = self.rank // tp_size self.tp_sp_idx = self.rank // tp_size // sp_size self.tp_sp_id = self.rank % (tp_size * sp_size) + self.tp_id = self.tp_sp_id % tp_size + self.tp_idx = self.tp_sp_idx * sp_size + self.tp_sp_id // tp_size + self.sp_id = self.tp_sp_id // tp_size + self.sp_idx = self.tp_sp_idx * tp_size + self.tp_sp_id % tp_size #pp->zero self.pp_zero_idx = self.stage_id self.pp_zero_id = self.pipe_idx diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 5be07665..d51bda09 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -244,7 +244,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.weight is not None: if self.weight.dim() != 1 or self.weight.size(0) != input.size(1): - raise ValueError("weight should be a 1D tensor of size C"); + raise ValueError("weight should be a 1D tensor of size C") w = self.weight[torch.where(target==self.ignore_index, 0, target)].float() w[target==self.ignore_index] = 0 else: diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 60fed663..26a8f667 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -3,3 +3,4 @@ from .row_parallel_linear import RowParallelLinear from .parallel_embedding import VPEmbedding from .parallel_linear_func import OpParallelLinear +from .burst_attn import OpBurstAttn diff --git a/bmtrain/nn/burst_attn.py b/bmtrain/nn/burst_attn.py new file mode 100644 index 00000000..1193d126 --- /dev/null +++ b/bmtrain/nn/burst_attn.py @@ -0,0 +1,145 @@ +import bmtrain as bmt +import torch +import math +from .burst_utils import ( + inter_normal_attn, + inter_normal_attn_backward, + inter_flash_attn_triton, + inter_flash_attn_backward_triton, + inter_flash_cuda_fwd, + inter_flash_cuda_bwd, +) +from .burst_utils import triton_scale_out, record_stream, Ring + + +class OpBurstAttn(torch.autograd.Function): + """ + for Normal Attention: + q, k, v: [B, N, S, H] (batch_size, num_heads, sub_seqlen, head_dim) + for Flash: + q, k, v: [B, S, N, H] (batch_size, num_heads, sub_seqlen, head_dim) + + """ + + @staticmethod + def forward( + ctx, q, k, v, softmax_scale=None, flash=None, optimize_bwd_comm=False + ): + m_i = None + acc_o = None + lse_i = None + ctx.optimize_bwd_comm = optimize_bwd_comm or flash is None + if softmax_scale is None: + ctx.softmax_scale = 1 / math.sqrt(q.shape[-1]) + else: + ctx.softmax_scale = softmax_scale + ctx.flash = None if flash not in ["cuda", "triton"] else flash + if ctx.flash: + forward_func = ( + inter_flash_attn_triton + if ctx.flash == "triton" + else inter_flash_cuda_fwd + ) + else: + forward_func = inter_normal_attn + sp_count = bmt.config["sp_size"] + burst_comm = Ring( + bmt.config["sp_comm"], bmt.config["sp_rank"] + ) + ctx.burst_comm = burst_comm + + for r in range(1, sp_count + 1): + bufs = burst_comm.ring_send_recv(k, v) + burst_comm.commit() + if ctx.flash: + if ctx.flash == "triton": + acc_o, m_i, lse_i = forward_func( + q, k, v, m_i, lse_i, acc_o, ctx.softmax_scale, None + ) + else: + acc_o, lse_i = forward_func( + q, k, v, acc_o, lse_i, ctx.softmax_scale + ) + else: + acc_o, m_i, lse_i = forward_func( + q, k, v, m_i, lse_i, acc_o, ctx.softmax_scale, None + ) + k, v = record_stream(*bufs) + burst_comm.wait() + + if ctx.flash == "triton": + acc_o = triton_scale_out(acc_o, m_i, lse_i) + elif not ctx.flash: + o_scale = torch.exp(m_i - lse_i) + acc_o = acc_o * o_scale + acc_o = acc_o.to(dtype=q.dtype) + if flash is not None: + lse_i = lse_i.squeeze(dim=-1).transpose(1, 2).contiguous() + ctx.save_for_backward(q, k, v, lse_i.contiguous(), acc_o) + return acc_o + + @staticmethod + def backward(ctx, grad_output): + q, k, v, lse_i, o_i = ctx.saved_tensors + d_q = torch.zeros_like(q) + d_k = torch.zeros_like(k) + d_v = torch.zeros_like(v) + if not ctx.optimize_bwd_comm: + delta = o_i.contiguous() + else: + delta = (o_i * grad_output) + if ctx.flash: + delta = delta.to(torch.float32) + delta = delta.sum(-1, keepdim=not ctx.flash) + if ctx.flash: + delta = delta.transpose(1, 2).contiguous() + + if ctx.flash: + backward_func = ( + inter_flash_attn_backward_triton + if ctx.flash == "triton" + else inter_flash_cuda_bwd + ) + else: + backward_func = inter_normal_attn_backward + + burst_comm = ctx.burst_comm + #i = bmt.config["sp_rank"] + sp_count = bmt.config["sp_size"] + dq = torch.zeros_like(d_q) + for r in range(1, sp_count + 1): + #j = (i + sp_count - r) % sp_count + + if r != sp_count: + bufs = burst_comm.ring_send_recv(delta, grad_output, q, lse_i) + if r != 1: + dq_buf = burst_comm.ring_send_recv(d_q) + burst_comm.commit() + backward_func( + grad_output, + q, + k, + v, + delta, + lse_i, + dq, + d_k, + d_v, + ctx.softmax_scale, + None, + ) + burst_comm.wait() + if r != sp_count: + delta, grad_output, q, lse_i = record_stream(*bufs) + torch.cuda.current_stream().wait_stream(bmt.config["sp_stream"]) + if r != 1: + (d_q,) = record_stream(*dq_buf) + d_q += dq + else: + d_q = dq.clone().detach() + + (d_q,) = burst_comm.ring_send_recv(d_q) + burst_comm.commit() + burst_comm.wait() + + return d_q, d_k, d_v, None, None, None, None diff --git a/bmtrain/nn/burst_lao.py b/bmtrain/nn/burst_lao.py new file mode 100644 index 00000000..63fd25e9 --- /dev/null +++ b/bmtrain/nn/burst_lao.py @@ -0,0 +1,719 @@ +""" +*Experimental* implementation of FlashAttention in Triton. +Tested with triton==2.0.0.dev20221202. +Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions +other than 64: +https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 +We'll update this implementation with the new Triton backend once this is fixed. + +We use the FlashAttention implementation from Phil Tillet a starting point. +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +Changes: +- Implement both causal and non-causal attention. +- Implement both self-attention and cross-attention. +- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. +- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. +- Support attention bias. +- Speed up the forward pass a bit, and only store the LSE instead of m and l. +- Make the backward for d=128 much faster by reducing register spilling. +- Optionally parallelize the backward pass across seqlen_k, to deal with the case of +small batch size * nheads. + +Caution: +- This is an *experimental* implementation. The forward pass should be quite robust but +I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). +- This implementation has only been tested on A100. +- If you plan to use headdim other than 64 and 128, you should test for race conditions +(due to the Triton compiler), as done in tests/test_flash_attn.py +"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions +for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident +that there are none left for other head dimensions. + +Differences between this Triton version and the CUDA version: +- Triton version doesn't support dropout. +- Triton forward is generally faster than CUDA forward, while Triton backward is +generally slower than CUDA backward. Overall Triton forward + backward is slightly slower +than CUDA forward + backward. +- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). +- Triton version supports attention bias, while CUDA version doesn't. +""" + +import math + +import torch + +import triton +import triton.language as tl + + +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, K, V, Bias, Out, M_in, Lse_in, O_in, + Lse, M_out, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_bb, stride_bh, stride_bm, + stride_ob, stride_oh, stride_om, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, + CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # off_b = tl.program_id(1) + # off_h = tl.program_id(2) + # off_hb = off_b * nheads + off_h + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + if BIAS_TYPE == 'vector': + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + elif BIAS_TYPE == 'matrix': + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lin_ptrs = Lse_in + off_hb * seqlen_q_rounded + offs_m + acc_o_ptrs = O_in + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + lse_i = tl.load(lin_ptrs) + # m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_ptrs = M_in + off_hb * seqlen_q_rounded + offs_m + m_i = tl.load(m_ptrs) + acc_o = tl.load(acc_o_ptrs) + # acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0) + else: + k = tl.load(k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if BIAS_TYPE != 'none': + if BIAS_TYPE == 'vector': + if EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == 'matrix': + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0).to(tl.float32) + # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler + # can then fuse the mult and add into an fma instruction. But if we have bias we need to + # to multiply with softmax_scale here. + qk = qk * softmax_scale + bias + #m_ij = tl.maximum(tl.maximum(tl.max(qk, 1) , m_i), -1e16) + #m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, m_i) + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + # m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, m_i) + # m_ij = tl.maximum(tl.maximum(tl.max(qk, 1) * softmax_scale, m_i), -1e16) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # # -- update output accumulator -- + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0) + else: + v = tl.load(v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + #o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + #tl.store(t_ptrs, o_scale) + #o_scale = tl.load(t_ptrs) + #acc_o = acc_o * o_scale[:, None] + #acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + + + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + m_ptrs = M_out + off_hb * seqlen_q_rounded + offs_m + tl.store(m_ptrs, m_i) + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store(out_ptrs, acc_o, + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, DO, Delta, + stride_ob, stride_oh, stride_om, + stride_dob, stride_doh, stride_dom, + nheads, seqlen_q, seqlen_q_rounded, headdim, + BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) + do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD: tl.constexpr, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + if BIAS_TYPE == 'vector': + b_ptrs = Bias + offs_n + elif BIAS_TYPE == 'matrix': + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop + # may have zero step, and pipelining with the bias matrix could screw it up. + # So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) + return + # k and v stay in SRAM throughout + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_d[None, :] < headdim), other=0.0) + # recompute p = softmax(qk, dim=-1).T + qk = tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) + if IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + if BIAS_TYPE != 'none': + tl.debug_barrier() # Race condition otherwise + if BIAS_TYPE == 'vector': + if EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == 'matrix': + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_n[None, :] < seqlen_k), + other=0.0).to(tl.float32) + qk = qk * softmax_scale + bias + # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. + # Also wrong for headdim=64. + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + lse_i = tl.load(LSE + offs_m_curr) + if BIAS_TYPE == 'none': + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + else: + p = tl.exp(qk - lse_i[:, None]) + # compute dv + # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs + # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, + # the output is correct. + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_d[None, :] < headdim), other=0.0) + # if EVEN_M: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs) + # else: + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + # else: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + # else: + # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + # & (offs_d[None, :] < headdim), other=0.0) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + dp = tl.dot(do, v, trans_b=True) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() + # compute ds = p * (dp - delta[:, None]) + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(ds, q, trans_a=True) + # compute dq + if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix' + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, + eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last") + else: + dq = tl.load(dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last") + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add(dq_ptrs, dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + # increment pointers + dq_ptrs += BLOCK_M * stride_dqm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + if BIAS_TYPE == 'matrix': + b_ptrs += BLOCK_M * stride_bm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_bb, stride_bh, stride_bm, + stride_dob, stride_doh, stride_dom, + stride_dqb, stride_dqh, stride_dqm, + stride_dkb, stride_dkh, stride_dkn, + stride_dvb, stride_dvh, stride_dvn, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, + CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + if BIAS_TYPE != 'none': + Bias += off_b * stride_bb + off_h * stride_bh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD=False, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD=True, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + + +def _flash_attn_forward(q, k, v, prev_m, prev_lse, prev_o, bias=None, causal=False, softmax_scale=None): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert prev_m.shape == (batch, nheads, seqlen_k) + assert d <= 128, 'FlashAttention only support head dimensions up to 128' + assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' + assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' + assert q.is_cuda and k.is_cuda and v.is_cuda + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + + has_bias = bias is not None + bias_type = 'none' + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + bias = bias.transpose(0,1).contiguous() + if bias.stride(-1) != 1: + bias = bias.contiguous() + if bias.shape[2:] == (1, seqlen_k): + bias_type = 'vector' + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = 'matrix' + else: + raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' + ' or (seqlen_q, seqlen_k)') + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + m = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, k, v, bias, o, prev_m, prev_lse, prev_o, + lse, m, tmp, + softmax_scale, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + *bias_strides, + o.stride(0), o.stride(2), o.stride(1), + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, + seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, causal, BLOCK_HEADDIM, + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o, lse, m, softmax_scale, # softmax_scale could have been updated + + +def _flash_attn_backward(do, q, k, v, delta, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + assert do.shape == (batch, seqlen_q, nheads, d) , f'do shape is {do.shape} and q shape is {q.shape}' + assert k.shape == (batch, seqlen_q, nheads, d), f'k shape is {k.shape} and q shape is {q.shape}' + assert v.shape == (batch, seqlen_q, nheads, d), f'v shape is {v.shape} and q shape is {q.shape}' + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded), f"lse shape is {lse.shape}" + assert delta.shape == (batch, nheads, seqlen_q_rounded), f"delta shape is {delta.shape}" + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + dq_accum = torch.empty_like(q, dtype=torch.float32) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + + has_bias = bias is not None + bias_type = 'none' + if has_bias: + bias = bias.transpose(0,1).contiguous() + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + if bias.shape[2:] == (1, seqlen_k): + bias_type = 'vector' + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = 'matrix' + else: + raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' + ' or (seqlen_q, seqlen_k)') + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads) + _bwd_kernel[grid]( + q, k, v, bias, + do, dq_accum, dk, dv, + lse, delta, + softmax_scale, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + *bias_strides, + do.stride(0), do.stride(2), do.stride(1), + dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), + dk.stride(0), dk.stride(2), dk.stride(1), + dv.stride(0), dv.stride(2), dv.stride(1), + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, + seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, causal, BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + + diff --git a/bmtrain/nn/burst_utils.py b/bmtrain/nn/burst_utils.py new file mode 100644 index 00000000..2217b190 --- /dev/null +++ b/bmtrain/nn/burst_utils.py @@ -0,0 +1,284 @@ +import bmtrain as bmt +import torch +from flash_attn.flash_attn_interface import ( + _flash_attn_forward as _flash_attn_forward_cuda, +) +from flash_attn.flash_attn_interface import ( + _flash_attn_backward as _flash_attn_backward_cuda, +) +import inspect + +class ops_wrapper: + def __init__(self, op, tensor, *args, **kwargs): + self.op = op + self.tensor = tensor + self.args = args + self.kwargs = kwargs + + def __call__(self): + return self.op(self.tensor.storage(), *self.args, **self.kwargs) + +class Ring: + def __init__(self, comm, rank): + self.comm = comm + self.rank = rank + self.reqs = [] + self.ops = [] + + def ring_send_recv(self, *tensor_list): + comm = self.comm + rank = self.rank + count = bmt.nccl.commCount(comm) + next_rank = (rank + 1) % count + prev_rank = (rank - 1 + count) % count + output = [] + i = 0 + for tensor in tensor_list: + i += 1 + res = torch.zeros_like(tensor) + send_op = ops_wrapper(bmt.nccl.send, tensor, next_rank, comm) + recv_op = ops_wrapper(bmt.nccl.recv, res, prev_rank, comm) + self.ops.append(send_op) + self.ops.append(recv_op) + output.append(res) + return output + + def commit(self): + torch.cuda.synchronize() + with torch.cuda.stream(bmt.config["sp_stream"]): + for op in self.ops: + op.tensor.record_stream(bmt.config["sp_stream"]) + bmt.nccl.groupStart() + for op in self.ops: + op() + bmt.nccl.groupEnd() + reqs = None + self.reqs = reqs + + def wait(self): + torch.cuda.current_stream().wait_stream(bmt.config["sp_stream"]) + self.reqs = [] + self.ops = [] + +@torch.jit.script +def triton_scale_out(acc_o, m_i, lse_i): + o_scale = torch.exp(m_i - lse_i) + o_scale = o_scale.unsqueeze(-1).transpose(1, 2) + acc_o = acc_o * o_scale + return acc_o + + +@torch.jit.script +def cuda_scale_out_lse_helper( + o, + lse, + o_i, + lse_i, +): + o_i = o_i.to(torch.float32) + lse_i = lse_i.transpose(-2, -1).unsqueeze(dim=-1).contiguous() + new_lse = lse + torch.log(1 + torch.exp(lse_i - lse)) + o = torch.exp(lse - new_lse) * o + torch.exp(lse_i - new_lse) * o_i + + lse = new_lse + return o, lse + + +def record_stream(*tensorlist): + for t in tensorlist: + t.record_stream(torch.cuda.current_stream()) + return tensorlist + + +def inter_normal_attn(q, k, v, m_i, lse_i, acc_o, softmax_scale=1.0, mask_bias=None): + m_i = m_i.to(q.dtype) if m_i is not None else None + qk = q @ k.transpose(-2, -1) * softmax_scale + if mask_bias is not None: + qk = torch.masked_fill( + qk, + not mask_bias, + torch.scalar_tensor(float("-10000"), device=qk.device, dtype=qk.dtype), + ) + + m_ij = torch.max(qk, dim=-1, keepdim=True)[0] + if m_i is not None: + m_ij = torch.maximum(m_ij, m_i) + p = torch.exp(qk - m_ij) + if mask_bias is not None: + p = torch.masked_fill( + p, + not mask_bias, + torch.scalar_tensor(float("0"), device=qk.device, dtype=qk.dtype), + ) + l_ij = torch.sum(p, dim=-1, keepdim=True) + if acc_o is not None: + acc_o_scale = torch.exp(m_i - m_ij) + pv = (p @ v).to(dtype=torch.float32) + acc_o = pv + acc_o_scale * acc_o + else: + acc_o = (p @ v).to(dtype=torch.float32) + + if lse_i is None: + lse_i = torch.log(l_ij + 1e-5) + m_ij + else: + lse_i = torch.log(torch.exp(lse_i - m_ij) + l_ij + 1e-5) + m_ij + return acc_o, m_ij, lse_i + + +def inter_normal_attn_backward( + do, q, k, v, delta, lse, d_q, d_k, d_v, softmax_scale, mask_bias +): + # ensure q,k,v with shape [b,n,s,d] + qk = q @ k.transpose(-2, -1) * softmax_scale + if mask_bias is not None: + qk = torch.masked_fill( + qk, + not mask_bias, + torch.scalar_tensor(float("-10000"), device=qk.device, dtype=qk.dtype), + ) + p = torch.exp(qk - lse) + if mask_bias is not None: + p = torch.masked_fill( + p, + not mask_bias, + torch.scalar_tensor(float("0"), device=qk.device, dtype=qk.dtype), + ) + d_v += p.transpose(-2, -1) @ do + d_p = do @ v.transpose(-2, -1) + softmax_scale = softmax_scale + d_s = p * (d_p - delta) * softmax_scale + d_q += d_s @ k + d_k += d_s.transpose(-2, -1) @ q + + +def inter_flash_attn_triton( + q, k, v, m_i, lse_i, acc_o, softmax_scale=1.0, mask_bias=None +): + from .burst_lao import _flash_attn_forward + b, s, n, d = q.shape + if m_i is None: + m_i = ( + -torch.ones((b, n, s), dtype=torch.float32, device="cuda") * torch.inf + ).contiguous() + if lse_i is None: + lse_i = ( + -torch.ones((b, n, s), dtype=torch.float32, device="cuda") * torch.inf + ).contiguous() + if acc_o is None: + acc_o = torch.zeros((b, s, n, d), dtype=torch.float32, device="cuda") + acc_o, lse_i, m_ij, softamx_scale = _flash_attn_forward( + q, + k, + v, + m_i, + lse_i, + acc_o.to(dtype=torch.float32), + causal=False, + bias=mask_bias, + softmax_scale=softmax_scale, + ) + return acc_o, m_ij, lse_i + + +def inter_flash_attn_backward_triton( + do, q, k, v, delta, lse, dq, dk, dv, softmax_scale, mask_bias +): + from .burst_lao import _flash_attn_backward + # dq_ = torch.empty_like(q) + dk_ = torch.empty_like(q) + dv_ = torch.empty_like(q) + _flash_attn_backward( + do, + q, + k, + v, + delta, + lse, + dq, + dk_, + dv_, + softmax_scale=softmax_scale, + bias=mask_bias, + ) + # dq += dq_ + dk += dk_ + dv += dv_ + + +def inter_flash_cuda_fwd(q, k, v, o, lse, softmax_scale=1.0): + o_i, _, _, _, _, lse_i, _, _ = _flash_attn_forward_cuda( + q, + k, + v, + 0.0, + softmax_scale, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + return_softmax=False, + ) + if o is None: + o = o_i.to(torch.float32) + lse = lse_i.transpose(-2, -1).unsqueeze(dim=-1).contiguous() + else: + o, lse = cuda_scale_out_lse_helper(o, lse, o_i, lse_i) + return o, lse + + +def inter_flash_cuda_bwd(do, q, k, v, o, lse, dq, dk, dv, softmax_scale, mask_bias): + dk_ = torch.empty_like(q) + dv_ = torch.empty_like(q) + if len(o.shape) == 3: + # use sum(o_i * gradoutput) as delta and pass a empty out to flash backward + # this feature requires a build of this PR: https://github.com/Dao-AILab/flash-attention/pull/905 + delta = o + o = q + elif len(o.shape) == 4: + delta = None + if delta is not None: + assert inspect.signature(_flash_attn_backward_cuda).parameters.get( + "softmax_d" + ), "optimize_bwd_comm is not supported for this version of flash-attention, \ + you have to compile flash-attention with this PR: \ + https://github.com/Dao-AILab/flash-attention/pull/905" + _flash_attn_backward_cuda( + do, + q, + k, + v, + o, + lse, + dq, + dk_, + dv_, + 0.0, + softmax_scale, + False, + (-1, -1), + None, + False, + None, + delta, + ) + else: + _flash_attn_backward_cuda( + do, + q, + k, + v, + o, + lse, + dq, + dk_, + dv_, + 0.0, + softmax_scale, + False, + (-1, -1), + None, + False, + None, + ) + # dq += dq_ + dk += dk_ + dv += dv_ diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index c088a5ee..2a41c8fe 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -155,8 +155,8 @@ def step(self, closure=None, scale=1): ) total_numel += state["_param_fp16"].numel() if self.record_delta: - sum_delta += param._delta_info[2].item(); - sum_sq_delta += param._delta_info[3].item(); + sum_delta += param._delta_info[2].item() + sum_sq_delta += param._delta_info[3].item() # transfer parameters back to device asynchronously param.copy_(state["_param_fp16"], non_blocking=True) if self.record_delta: diff --git a/example/layers/attention.py b/example/layers/attention.py index ddf12adb..0ae31e64 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Literal import torch import bmtrain as bmt from bmtrain.nn import ( @@ -8,6 +8,8 @@ ) import math from bmtrain.global_var import config +from bmtrain.distributed import all2all_transpose +from bmtrain.nn import OpBurstAttn def inverse_permute(permute_dims): inverse_dims = [0] * len(permute_dims) @@ -15,24 +17,10 @@ def inverse_permute(permute_dims): inverse_dims[dim] = i return inverse_dims -def all2all_tensor(tensor, gather_dim, scatter_dim): - # Input shape: (B, S, N, D) | (B, N, S, D) - origin_size = list(tensor.size()) - output_size = origin_size.copy() - output_size[gather_dim] = origin_size[gather_dim] * bmt.config['sp_size'] - output_size[scatter_dim] = origin_size[scatter_dim] // bmt.config['sp_size'] - inv_order = inverse_permute([gather_dim, scatter_dim, 0, -1]) - tensor = tensor.permute(gather_dim, scatter_dim, 0, -1) - tensor = torch.cat(tensor.chunk(bmt.config['sp_size'], dim=1), dim=0).contiguous() - tensor = bmt.distributed.all_to_all(tensor, bmt.config['sp_comm']) - tensor = tensor.permute(inv_order).contiguous() - return tensor - - def all2all_qkv(q, k, v, seq_dim, head_dim): - q = all2all_tensor(q, seq_dim, head_dim) - k = all2all_tensor(k, seq_dim, head_dim) - v = all2all_tensor(v, seq_dim, head_dim) + q = all2all_transpose(q, seq_dim, head_dim) + k = all2all_transpose(k, seq_dim, head_dim) + v = all2all_transpose(v, seq_dim, head_dim) return q, k, v @@ -40,7 +28,9 @@ class Attention(bmt.DistributedModule): def __init__(self, dim_model : int, dim_head : int, num_heads : int, bias : bool = True, - dtype = None + dtype : Optional[torch.dtype] = None, + sp_method : Optional[Literal["all2all", "burst"]] = None, + ) -> None: super().__init__() @@ -55,6 +45,11 @@ def __init__(self, self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + + self.sp_method = sp_method + if bmt.config['sp_size'] > 1: + assert sp_method is not None + self.softmax = torch.nn.Softmax(dim=-1) self.num_heads = num_heads self.dim_head = dim_head @@ -92,11 +87,18 @@ def forward(self, h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) if config['sp_size'] > 1: - seq_dim = 1 - head_dim = 2 - h_q, h_k, h_v = all2all_qkv(h_q, h_k, h_v, seq_dim, head_dim) - seq_q = h_q.size()[1] - seq_kv = h_k.size(1) + if self.sp_method == "all2all": + seq_dim = 1 + head_dim = 2 + h_q, h_k, h_v = all2all_qkv(h_q, h_k, h_v, seq_dim, head_dim) + seq_q = h_q.size()[1] + seq_kv = h_k.size(1) + elif self.sp_method == "burst": + o = OpBurstAttn(h_q, h_k, h_v, math.sqrt(self.dim_head), "none", optimize_bwd_comm=False) + return o + else: + raise ValueError("Invalid sp_method for sequence parallel, should be 'all2all' or 'burst'") + h_q = h_q.permute(0, 2, 1, 3).contiguous() h_k = h_k.permute(0, 2, 1, 3).contiguous() @@ -135,7 +137,7 @@ def forward(self, ) h_out = h_out.view(batch_size, -1, seq_q, self.dim_head).contiguous() if config['sp_size'] > 1: - h_out = all2all_tensor(h_out, 1, 2) + h_out = all2all_transpose(h_out, 1, 2) seq_q = h_out.size(2) h_out = h_out.permute(0, 2, 1, 3).contiguous() h_out = h_out.view(batch_size, seq_q, -1) diff --git a/example/models/gpt.py b/example/models/gpt.py index ae52cffb..cc5b6cc5 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -1,6 +1,6 @@ import torch import bmtrain as bmt -from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder +from layers import TransformerEncoder, Layernorm, Embedding from bmtrain.global_var import config class GPT(bmt.DistributedModule): diff --git a/pyproject.toml b/pyproject.toml index b563eb32..3cac8ef6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,3 +6,5 @@ requires = [ "cmake > 3.27.0" ] build-backend = "setuptools.build_meta" + + diff --git a/tests/test_all2all.py b/tests/test_all2all.py new file mode 100644 index 00000000..63122f5a --- /dev/null +++ b/tests/test_all2all.py @@ -0,0 +1,106 @@ +import os +import torch +import bmtrain as bmt +from example.layers.attention import all2all_tensor + +def print_rank(msg): + if bmt.rank() == 0: + print(msg) + +def check_helper(v1, v2, debug=False): + if debug: + print_rank(torch.max(torch.abs(v1 - v2))) + print_rank(torch.mean(torch.abs(v1 - v2))) + torch.testing.assert_close(v1, v2, rtol=1e-3, atol=1e-2) + + +def check_helper_list(l1, l2, end=False): + if bmt.rank() == 0: + for i in range(len(l1)): + check_helper(l1[i], l2[i]) + if end: + exit() + + +def check_is_nan(tensor): + if torch.isnan(tensor).any(): + print("nan detected") + exit() + + + +def test(q, k, v, func, grad_output): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + grad_output = grad_output.contiguous() + o = func(q, k, v) + gq, gk, gv = torch.autograd.grad(o, (q, k, v), grad_output) + return o, (gq, gk, gv) + +def test_msg(test_func, msg, *args, **kwargs): + try: + test_func(*args, **kwargs) + bmt.print_rank(msg, " Success") + except: + bmt.print_rank(msg, " Failed") + exit() + +def get_chunk(t, dim): + return t.chunk(bmt.config["sp_size"], dim=dim)[bmt.config['sp_rank']].contiguous() + +def ref_attn(q, k, v): + scale = q.shape[-1] ** -0.5 + s = q @ k.transpose(-2, -1) * scale + s = torch.softmax(s, dim=-1) + p = s @ v + return p + +def all2all_attn(q, k, v): + q = all2all_tensor(q, 2, 1) + k = all2all_tensor(k, 2, 1) + v = all2all_tensor(v, 2, 1) + o = ref_attn(q, k, v) + o = all2all_tensor(o, 1, 2) + return o + +def test_all2all(): + bmt.init_distributed(sp_size=2) + b, n, s, d = 2, 16, 1024, 32 + if bmt.rank() == 0: + qkv = torch.randn(b, n*3, s, d, dtype=torch.float16).cuda() + grad_output = torch.randn(b, n, s, d, dtype=torch.float16).cuda() + torch.save(qkv, "qkv.pt") + torch.save(grad_output, "grad.pt") + bmt.synchronize() + qkv = torch.load("qkv.pt", map_location="cuda") + grad_output = torch.load("grad.pt", map_location="cuda") + qkv1 = [t.clone().detach().requires_grad_() for t in qkv.chunk(3, dim=1)] + if bmt.rank() == 0: + os.remove("qkv.pt") + os.remove("grad.pt") + + o_ref, g_ref = test(qkv1[0], qkv1[1], qkv1[2], ref_attn, grad_output) + for i in range(3): + qkv1[i] = qkv1[i].chunk(bmt.world_size(), dim=2)[bmt.rank()] + grad_output = ( + grad_output + .chunk(bmt.world_size(), dim=2)[bmt.rank()] + .clone() + .detach() + .contiguous() + ) + o1, grad_qkv1 = test(qkv1[0], qkv1[1], qkv1[2], all2all_attn, grad_output) + o1 = o1.contiguous() + grad_qkv1 = [g.contiguous() for g in grad_qkv1] + o_ref = get_chunk(o_ref, dim=2) + g_ref = [get_chunk(g, dim=2) for g in g_ref] + test_msg(check_helper, "Output Correctness Check", o_ref, o1) + test_msg(check_helper, "Value Correctness Check", g_ref[2], grad_qkv1[2]) + test_msg(check_helper, "Key Correctness Check", g_ref[1], grad_qkv1[1]) + test_msg(check_helper, "Query Correctness Check", g_ref[0], grad_qkv1[0]) + +if __name__ == "__main__": + + test_all2all() + diff --git a/tests/test_burst.py b/tests/test_burst.py new file mode 100644 index 00000000..f95fbfdc --- /dev/null +++ b/tests/test_burst.py @@ -0,0 +1,88 @@ +import torch +import bmtrain as bmt +from flash_attn.flash_attn_interface import flash_attn_func as flash_cuda +import numpy as np + +OpBurstAttn = bmt.nn.OpBurstAttn +def ref_attn(q, k, v): + scale = q.shape[-1] ** -0.5 + s = q @ k.transpose(-2, -1) * scale + s = torch.softmax(s, dim=-1) + p = s @ v + return p + + +def flash(q, k, v): + return flash_cuda(q, k, v, causal=False, softmax_scale=None) + + +def burst(q, k, v): + res_burst = OpBurstAttn.apply(q, k, v, None, "cuda", False) + return res_burst + +def test_func(q, k, v, func, grad_output): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + grad_output = grad_output.contiguous() + o = func(q, k, v) + gq, gk, gv = torch.autograd.grad(o, (q, k, v), grad_output) + return o, (gq, gk, gv) + +def test_burst(): + dtype = torch.float32 + bmt.init_distributed(sp_size=4) + flash = "cuda" + seq_dim = 2 if not flash else 1 + def get_chunk(t, dim): + return t.chunk(bmt.config['sp_size'], dim=dim)[bmt.config['sp_rank']].contiguous() + + b, s, n, d = 2, 4096, 16, 32 + if bmt.config["sp_rank"] == 0: + qkv = torch.randn(b, n*3, s, d, dtype=dtype).cuda() + grad_output = torch.randn(b, n, s, d, dtype=dtype).cuda() + torch.save(qkv, "./qkv.pt") + torch.save(grad_output, "./grad.pt") + bmt.synchronize() + qkv = torch.load("qkv.pt", map_location="cuda") + grad_output = torch.load("grad.pt", map_location="cuda") + qkv1 = [t.clone().detach().requires_grad_() for t in qkv.chunk(3, dim=1)] + + o_ref, g_ref = test_func(qkv1[0], qkv1[1], qkv1[2], ref_attn, grad_output) + for i in range(3): + if flash is not None: + qkv1[i] = qkv1[i].transpose(1, 2) + qkv1[i] = qkv1[i].chunk(bmt.world_size(), dim=seq_dim)[bmt.rank()] + qkv1[i] = qkv1[i].clone().detach().requires_grad_() + if flash is not None: + grad_output = grad_output.transpose(1, 2) + + grad_output = ( + grad_output.chunk(bmt.world_size(), dim=seq_dim)[bmt.rank()] + .clone() + .detach() + .contiguous() + ) + o1, grad_qkv1 = test_func(qkv1[0], qkv1[1], qkv1[2], burst, grad_output) + if flash: + o1 = o1.transpose(1, 2).contiguous() + grad_qkv1 = [g.transpose(1, 2).contiguous() for g in grad_qkv1] + if bmt.rank() == 0: + from IPython import embed;embed() + bmt.synchronize() + o_ref = get_chunk(o_ref, dim=2) + g_ref = [get_chunk(g, dim=2) for g in g_ref] + np.testing.assert_allclose( + o1.detach().cpu().numpy(), + o_ref.detach().cpu().numpy(), + atol=1e-2, + rtol=0, + ) + for i in range(3): + falsh_g_rank = g_ref[i].detach().cpu().numpy() + burst_g_rank = grad_qkv1[i].detach().cpu().numpy() + np.testing.assert_allclose(falsh_g_rank, burst_g_rank, atol=1e-2, rtol=0) + bmt.print_rank(f"passed {i}") + +if __name__ == "__main__": + test_burst() diff --git a/tests/test_loss_func.py b/tests/test_loss_func.py index a448b6d1..01fd817d 100644 --- a/tests/test_loss_func.py +++ b/tests/test_loss_func.py @@ -2,7 +2,6 @@ import torch import bmtrain as bmt -import torch import random import copy @@ -76,4 +75,4 @@ def test_other(dtype): test_other(torch.bfloat16) test_simple(torch.bfloat16) except NotImplementedError: - pass \ No newline at end of file + pass From 1f1431989d21b68ddd5a1a9dfca7ae36135edddf Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 7 Jun 2024 18:41:50 +0800 Subject: [PATCH 08/11] normal attn test pass --- bmtrain/nn/burst_attn.py | 4 +- bmtrain/nn/burst_utils.py | 2 +- example/test_all2all.py | 110 ++++++++++++++++++++++++++++++++++++++ tests/test_burst.py | 9 ++-- tests/test_lse.py | 61 +++++++++++++++++++++ 5 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 example/test_all2all.py create mode 100644 tests/test_lse.py diff --git a/bmtrain/nn/burst_attn.py b/bmtrain/nn/burst_attn.py index 1193d126..b2993266 100644 --- a/bmtrain/nn/burst_attn.py +++ b/bmtrain/nn/burst_attn.py @@ -23,7 +23,7 @@ class OpBurstAttn(torch.autograd.Function): @staticmethod def forward( - ctx, q, k, v, softmax_scale=None, flash=None, optimize_bwd_comm=False + ctx, q, k, v, softmax_scale=None, flash=None, optimize_bwd_comm=False, return_softmax=False ): m_i = None acc_o = None @@ -76,7 +76,7 @@ def forward( if flash is not None: lse_i = lse_i.squeeze(dim=-1).transpose(1, 2).contiguous() ctx.save_for_backward(q, k, v, lse_i.contiguous(), acc_o) - return acc_o + return acc_o if not return_softmax else (acc_o, lse_i) @staticmethod def backward(ctx, grad_output): diff --git a/bmtrain/nn/burst_utils.py b/bmtrain/nn/burst_utils.py index 2217b190..9f09174f 100644 --- a/bmtrain/nn/burst_utils.py +++ b/bmtrain/nn/burst_utils.py @@ -147,7 +147,7 @@ def inter_normal_attn_backward( d_p = do @ v.transpose(-2, -1) softmax_scale = softmax_scale d_s = p * (d_p - delta) * softmax_scale - d_q += d_s @ k + d_q[:] = d_s @ k d_k += d_s.transpose(-2, -1) @ q diff --git a/example/test_all2all.py b/example/test_all2all.py new file mode 100644 index 00000000..3510bd52 --- /dev/null +++ b/example/test_all2all.py @@ -0,0 +1,110 @@ +from typing import Optional +import os +import torch +import bmtrain as bmt +from bmtrain.global_var import config +from layers.attention import all2all_tensor +import torch +import bmtrain as bmt + +def print_rank(msg): + if bmt.rank() == 0: + print(msg) + +def check_helper(v1, v2, debug=False): + if debug: + print_rank(torch.max(torch.abs(v1 - v2))) + print_rank(torch.mean(torch.abs(v1 - v2))) + torch.testing.assert_close(v1, v2, rtol=1e-3, atol=1e-2) + + +def check_helper_list(l1, l2, end=False): + if bmt.rank() == 0: + for i in range(len(l1)): + check_helper(l1[i], l2[i]) + if end: + exit() + + +def check_is_nan(tensor): + if torch.isnan(tensor).any(): + print("nan detected") + exit() + + + +def test(q, k, v, func, grad_output): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + grad_output = grad_output.contiguous() + o = func(q, k, v) + gq, gk, gv = torch.autograd.grad(o, (q, k, v), grad_output) + return o, (gq, gk, gv) + +def test_msg(test_func, msg, *args, **kwargs): + try: + test_func(*args, **kwargs) + bmt.print_rank(msg, " Success") + except: + bmt.print_rank(msg, " Failed") + exit() + +def get_chunk(t, dim): + return t.chunk(bmt.config["sp_size"], dim=dim)[bmt.config['sp_rank']].contiguous() + +def ref_attn(q, k, v): + scale = q.shape[-1] ** -0.5 + s = q @ k.transpose(-2, -1) * scale + s = torch.softmax(s, dim=-1) + p = s @ v + return p + +def all2all_attn(q, k, v): + q = all2all_tensor(q, 2, 1) + k = all2all_tensor(k, 2, 1) + v = all2all_tensor(v, 2, 1) + o = ref_attn(q, k, v) + o = all2all_tensor(o, 1, 2) + return o + +def test_all2all(): + bmt.init_distributed(sp_size=2) + b, n, s, d = 2, 16, 1024, 32 + if bmt.rank() == 0: + qkv = torch.randn(b, n*3, s, d, dtype=torch.float16).cuda() + grad_output = torch.randn(b, n, s, d, dtype=torch.float16).cuda() + torch.save(qkv, "qkv.pt") + torch.save(grad_output, "grad.pt") + bmt.synchronize() + qkv = torch.load("qkv.pt", map_location="cuda") + grad_output = torch.load("grad.pt", map_location="cuda") + qkv1 = [t.clone().detach().requires_grad_() for t in qkv.chunk(3, dim=1)] + if bmt.rank() == 0: + os.remove("qkv.pt") + os.remove("grad.pt") + + o_ref, g_ref = test(qkv1[0], qkv1[1], qkv1[2], ref_attn, grad_output) + for i in range(3): + qkv1[i] = qkv1[i].chunk(bmt.world_size(), dim=2)[bmt.rank()] + grad_output = ( + grad_output + .chunk(bmt.world_size(), dim=2)[bmt.rank()] + .clone() + .detach() + .contiguous() + ) + o1, grad_qkv1 = test(qkv1[0], qkv1[1], qkv1[2], all2all_attn, grad_output) + o1 = o1.contiguous() + grad_qkv1 = [g.contiguous() for g in grad_qkv1] + o_ref = get_chunk(o_ref, dim=2) + g_ref = [get_chunk(g, dim=2) for g in g_ref] + test_msg(check_helper, "Output Correctness Check", o_ref, o1) + test_msg(check_helper, "Value Correctness Check", g_ref[2], grad_qkv1[2]) + test_msg(check_helper, "Key Correctness Check", g_ref[1], grad_qkv1[1]) + test_msg(check_helper, "Query Correctness Check", g_ref[0], grad_qkv1[0]) + +if __name__ == "__main__": + + test_all2all() + diff --git a/tests/test_burst.py b/tests/test_burst.py index f95fbfdc..2e357f6d 100644 --- a/tests/test_burst.py +++ b/tests/test_burst.py @@ -17,7 +17,7 @@ def flash(q, k, v): def burst(q, k, v): - res_burst = OpBurstAttn.apply(q, k, v, None, "cuda", False) + res_burst = OpBurstAttn.apply(q, k, v, None, None, False) return res_burst def test_func(q, k, v, func, grad_output): @@ -30,9 +30,9 @@ def test_func(q, k, v, func, grad_output): return o, (gq, gk, gv) def test_burst(): - dtype = torch.float32 + dtype = torch.float16 bmt.init_distributed(sp_size=4) - flash = "cuda" + flash = None seq_dim = 2 if not flash else 1 def get_chunk(t, dim): return t.chunk(bmt.config['sp_size'], dim=dim)[bmt.config['sp_rank']].contiguous() @@ -67,9 +67,6 @@ def get_chunk(t, dim): if flash: o1 = o1.transpose(1, 2).contiguous() grad_qkv1 = [g.transpose(1, 2).contiguous() for g in grad_qkv1] - if bmt.rank() == 0: - from IPython import embed;embed() - bmt.synchronize() o_ref = get_chunk(o_ref, dim=2) g_ref = [get_chunk(g, dim=2) for g in g_ref] np.testing.assert_allclose( diff --git a/tests/test_lse.py b/tests/test_lse.py new file mode 100644 index 00000000..f0de3955 --- /dev/null +++ b/tests/test_lse.py @@ -0,0 +1,61 @@ +import torch +import bmtrain as bmt +from flash_attn.flash_attn_interface import flash_attn_func as flash_cuda +import numpy as np + +OpBurstAttn = bmt.nn.OpBurstAttn + +def flash(q, k, v): + return flash_cuda(q, k, v, causal=False, softmax_scale=None, return_attn_probs=True) + + +def burst(q, k, v): + res_burst = OpBurstAttn.apply(q, k, v, None, None, False, True) + return res_burst + +def test_func(q, k, v, func, grad_output): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + grad_output = grad_output.contiguous() + o = func(q, k, v) + gq, gk, gv = torch.autograd.grad(o, (q, k, v), grad_output) + return o, (gq, gk, gv) + +def test_burst(): + dtype = torch.float16 + bmt.init_distributed(sp_size=4) + def get_chunk(t, dim): + return t.chunk(bmt.config['sp_size'], dim=dim)[bmt.config['sp_rank']].contiguous() + + b, s, n, d = 2, 4096, 16, 32 + if bmt.config["sp_rank"] == 0: + qkv = torch.randn(b, n*3, s, d, dtype=dtype).cuda() + torch.save(qkv, "./qkv.pt") + bmt.synchronize() + qkv = torch.load("qkv.pt", map_location="cuda") + qkv1 = [t.clone().detach().requires_grad_().transpose(1, 2).contiguous() for t in qkv.chunk(3, dim=1)] + qkv_burst_normal = [get_chunk(t, dim=2).clone().detach().requires_grad_() for t in qkv.chunk(3, dim=1)] + output, lse, _ = flash(qkv1[0], qkv1[1], qkv1[2]) + output_burst, lse_burst = burst(qkv_burst_normal[0], qkv_burst_normal[1], qkv_burst_normal[2]) + def test_allclose(t1, t2, atol, rtol): + t1 = t1.detach().cpu().numpy() + t2 = t2.detach().cpu().numpy() + assert np.testing.assert_allclose(t1, t2, atol=atol, rtol=rtol) + try: + output = get_chunk(output.transpose(1, 2), 2).contiguous() + lse = get_chunk(lse, 2).unsqueeze(dim=-1) + print(torch.allclose(output, output_burst)) + print(torch.allclose(lse, lse_burst)) + raise Exception + except Exception: + if bmt.rank() == 0: + from IPython import embed;embed() + bmt.synchronize() + + + + +if __name__ == "__main__": + test_burst() + From 5c9eacd9d85a7c021fff74361edb700daf8579f4 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 11 Jun 2024 14:52:32 +0800 Subject: [PATCH 09/11] add causal for burst_attn and add test for causal/no_causal --- bmtrain/nn/burst_attn.py | 78 +++++++++++++++++++++++++++------------ bmtrain/nn/burst_utils.py | 40 +++++++++++++++----- tests/test_burst.py | 76 +++++++++++++++++++++----------------- 3 files changed, 128 insertions(+), 66 deletions(-) diff --git a/bmtrain/nn/burst_attn.py b/bmtrain/nn/burst_attn.py index b2993266..ee3d880b 100644 --- a/bmtrain/nn/burst_attn.py +++ b/bmtrain/nn/burst_attn.py @@ -23,12 +23,17 @@ class OpBurstAttn(torch.autograd.Function): @staticmethod def forward( - ctx, q, k, v, softmax_scale=None, flash=None, optimize_bwd_comm=False, return_softmax=False + ctx, q, k, v, softmax_scale=None, flash=None, causal=False, optimize_bwd_comm=False, return_softmax=False, bias=None ): + assert flash in [None, "cuda", "triton"], "flash must be None, 'cuda', or 'triton'" + assert bias is None or flash != "cuda", "Flash Attn cuda impl does not support bias" + m_i = None acc_o = None lse_i = None - ctx.optimize_bwd_comm = optimize_bwd_comm or flash is None + ctx.optimize_bwd_comm = optimize_bwd_comm or flash != "cuda" + ctx.causal = causal + ctx.has_bias = bias is not None if softmax_scale is None: ctx.softmax_scale = 1 / math.sqrt(q.shape[-1]) else: @@ -47,22 +52,26 @@ def forward( bmt.config["sp_comm"], bmt.config["sp_rank"] ) ctx.burst_comm = burst_comm + sp_rank = bmt.config["sp_rank"] for r in range(1, sp_count + 1): + if causal and r > sp_rank + 1: + continue + causal_arg = causal if r == 1 else False bufs = burst_comm.ring_send_recv(k, v) burst_comm.commit() if ctx.flash: if ctx.flash == "triton": acc_o, m_i, lse_i = forward_func( - q, k, v, m_i, lse_i, acc_o, ctx.softmax_scale, None + q, k, v, m_i, lse_i, acc_o, ctx.softmax_scale, bias, causal=causal_arg ) else: acc_o, lse_i = forward_func( - q, k, v, acc_o, lse_i, ctx.softmax_scale + q, k, v, acc_o, lse_i, ctx.softmax_scale, causal=causal_arg ) else: acc_o, m_i, lse_i = forward_func( - q, k, v, m_i, lse_i, acc_o, ctx.softmax_scale, None + q, k, v, m_i, lse_i, acc_o, ctx.softmax_scale, bias, causal=causal_arg ) k, v = record_stream(*bufs) burst_comm.wait() @@ -73,14 +82,20 @@ def forward( o_scale = torch.exp(m_i - lse_i) acc_o = acc_o * o_scale acc_o = acc_o.to(dtype=q.dtype) - if flash is not None: - lse_i = lse_i.squeeze(dim=-1).transpose(1, 2).contiguous() - ctx.save_for_backward(q, k, v, lse_i.contiguous(), acc_o) + if flash == "cuda": + lse_i = lse_i.transpose(1, 2) + lse_i = lse_i.contiguous() + save_tensor = (q, k, v, lse_i, acc_o) if bias is None else (q, k, v, lse_i, acc_o, bias) + ctx.save_for_backward(*save_tensor) return acc_o if not return_softmax else (acc_o, lse_i) @staticmethod def backward(ctx, grad_output): - q, k, v, lse_i, o_i = ctx.saved_tensors + if ctx.has_bias: + q, k, v, lse_i, o_i, bias = ctx.saved_tensors + else: + q, k, v, lse_i, o_i = ctx.saved_tensors + bias = None d_q = torch.zeros_like(q) d_k = torch.zeros_like(k) d_v = torch.zeros_like(v) @@ -109,25 +124,42 @@ def backward(ctx, grad_output): dq = torch.zeros_like(d_q) for r in range(1, sp_count + 1): #j = (i + sp_count - r) % sp_count - + if ctx.causal and r > bmt.config['sp_rank']+1: + continue if r != sp_count: bufs = burst_comm.ring_send_recv(delta, grad_output, q, lse_i) if r != 1: dq_buf = burst_comm.ring_send_recv(d_q) burst_comm.commit() - backward_func( - grad_output, - q, - k, - v, - delta, - lse_i, - dq, - d_k, - d_v, - ctx.softmax_scale, - None, - ) + if ctx.flash == "cuda": + backward_func( + grad_output, + q, + k, + v, + delta, + lse_i, + dq, + d_k, + d_v, + ctx.softmax_scale, + causal=ctx.causal and r == 1, + ) + else: + backward_func( + grad_output, + q, + k, + v, + delta, + lse_i, + dq, + d_k, + d_v, + ctx.softmax_scale, + bias, + causal=ctx.causal and r == 1, + ) burst_comm.wait() if r != sp_count: delta, grad_output, q, lse_i = record_stream(*bufs) diff --git a/bmtrain/nn/burst_utils.py b/bmtrain/nn/burst_utils.py index 9f09174f..413100fe 100644 --- a/bmtrain/nn/burst_utils.py +++ b/bmtrain/nn/burst_utils.py @@ -90,9 +90,15 @@ def record_stream(*tensorlist): return tensorlist -def inter_normal_attn(q, k, v, m_i, lse_i, acc_o, softmax_scale=1.0, mask_bias=None): +def inter_normal_attn(q, k, v, m_i, lse_i, acc_o, softmax_scale=1.0, mask_bias=None, causal=False): m_i = m_i.to(q.dtype) if m_i is not None else None qk = q @ k.transpose(-2, -1) * softmax_scale + if causal: + tril_mask = torch.tril( + torch.ones((qk.size(-2), qk.size(-1)), device=qk.device, dtype=qk.dtype) + ) + mask_bias = mask_bias * tril_mask if mask_bias is not None else tril_mask + if mask_bias is not None: qk = torch.masked_fill( qk, @@ -126,10 +132,15 @@ def inter_normal_attn(q, k, v, m_i, lse_i, acc_o, softmax_scale=1.0, mask_bias=N def inter_normal_attn_backward( - do, q, k, v, delta, lse, d_q, d_k, d_v, softmax_scale, mask_bias + do, q, k, v, delta, lse, d_q, d_k, d_v, softmax_scale, mask_bias, causal ): # ensure q,k,v with shape [b,n,s,d] qk = q @ k.transpose(-2, -1) * softmax_scale + if causal: + tril_mask = torch.tril( + torch.ones((qk.size(-2), qk.size(-1)), device=qk.device, dtype=qk.dtype) + ) + mask_bias = mask_bias * tril_mask if mask_bias is not None else tril_mask if mask_bias is not None: qk = torch.masked_fill( qk, @@ -152,10 +163,15 @@ def inter_normal_attn_backward( def inter_flash_attn_triton( - q, k, v, m_i, lse_i, acc_o, softmax_scale=1.0, mask_bias=None + q, k, v, m_i, lse_i, acc_o, softmax_scale=1.0, mask_bias=None, causal=False ): from .burst_lao import _flash_attn_forward b, s, n, d = q.shape + if causal: + tril_mask = torch.tril( + torch.ones((s, s), device=q.device, dtype=q.dtype) + ) + mask_bias = mask_bias * tril_mask if mask_bias is not None else tril_mask if m_i is None: m_i = ( -torch.ones((b, n, s), dtype=torch.float32, device="cuda") * torch.inf @@ -181,10 +197,16 @@ def inter_flash_attn_triton( def inter_flash_attn_backward_triton( - do, q, k, v, delta, lse, dq, dk, dv, softmax_scale, mask_bias + do, q, k, v, delta, lse, dq, dk, dv, softmax_scale, mask_bias, causal ): from .burst_lao import _flash_attn_backward # dq_ = torch.empty_like(q) + b, s, n, d = q.shape + if causal: + tril_mask = torch.tril( + torch.ones((s, s), device=q.device, dtype=q.dtype) + ) + mask_bias = mask_bias * tril_mask if mask_bias is not None else tril_mask dk_ = torch.empty_like(q) dv_ = torch.empty_like(q) _flash_attn_backward( @@ -205,14 +227,14 @@ def inter_flash_attn_backward_triton( dv += dv_ -def inter_flash_cuda_fwd(q, k, v, o, lse, softmax_scale=1.0): +def inter_flash_cuda_fwd(q, k, v, o, lse, softmax_scale=1.0, causal=False): o_i, _, _, _, _, lse_i, _, _ = _flash_attn_forward_cuda( q, k, v, 0.0, softmax_scale, - causal=False, + causal=causal, window_size=(-1, -1), alibi_slopes=None, return_softmax=False, @@ -225,7 +247,7 @@ def inter_flash_cuda_fwd(q, k, v, o, lse, softmax_scale=1.0): return o, lse -def inter_flash_cuda_bwd(do, q, k, v, o, lse, dq, dk, dv, softmax_scale, mask_bias): +def inter_flash_cuda_bwd(do, q, k, v, o, lse, dq, dk, dv, softmax_scale, causal): dk_ = torch.empty_like(q) dv_ = torch.empty_like(q) if len(o.shape) == 3: @@ -253,7 +275,7 @@ def inter_flash_cuda_bwd(do, q, k, v, o, lse, dq, dk, dv, softmax_scale, mask_bi dv_, 0.0, softmax_scale, - False, + causal, (-1, -1), None, False, @@ -273,7 +295,7 @@ def inter_flash_cuda_bwd(do, q, k, v, o, lse, dq, dk, dv, softmax_scale, mask_bi dv_, 0.0, softmax_scale, - False, + causal, (-1, -1), None, False, diff --git a/tests/test_burst.py b/tests/test_burst.py index 2e357f6d..6876f806 100644 --- a/tests/test_burst.py +++ b/tests/test_burst.py @@ -4,56 +4,58 @@ import numpy as np OpBurstAttn = bmt.nn.OpBurstAttn -def ref_attn(q, k, v): + +def ref_attn(q, k, v, causal=False): scale = q.shape[-1] ** -0.5 s = q @ k.transpose(-2, -1) * scale s = torch.softmax(s, dim=-1) + if causal: + s = torch.tril(s) p = s @ v return p -def flash(q, k, v): - return flash_cuda(q, k, v, causal=False, softmax_scale=None) - - -def burst(q, k, v): - res_burst = OpBurstAttn.apply(q, k, v, None, None, False) +def burst(q, k, v, flash, causal, softmax_scale=None): + # assert not causal, "causal not supported yet" + res_burst = OpBurstAttn.apply(q, k, v, softmax_scale, flash) return res_burst -def test_func(q, k, v, func, grad_output): - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - grad_output = grad_output.contiguous() - o = func(q, k, v) - gq, gk, gv = torch.autograd.grad(o, (q, k, v), grad_output) - return o, (gq, gk, gv) +def get_chunk(t, dim): + return t.chunk(bmt.config['sp_size'], dim=dim)[bmt.config['sp_rank']].contiguous() -def test_burst(): - dtype = torch.float16 +def test_main(): bmt.init_distributed(sp_size=4) - flash = None - seq_dim = 2 if not flash else 1 - def get_chunk(t, dim): - return t.chunk(bmt.config['sp_size'], dim=dim)[bmt.config['sp_rank']].contiguous() + test_burst(torch.float32, flash=None, causal=False) + test_burst(torch.float16, flash=None, causal=False) + test_burst(torch.float16, flash="cuda", causal=False) + test_burst(torch.float16, flash="triton", causal=False) + test_burst(torch.float32, flash=None, causal=True) + test_burst(torch.float16, flash=None, causal=True) + test_burst(torch.float16, flash="cuda", causal=True) + test_burst(torch.float16, flash="triton", causal=True) +def test_burst(dtype, flash, causal): + seq_dim = 2 if not flash else 1 b, s, n, d = 2, 4096, 16, 32 if bmt.config["sp_rank"] == 0: - qkv = torch.randn(b, n*3, s, d, dtype=dtype).cuda() + qkv_whole = torch.randn(b, n*3, s, d, dtype=dtype).cuda() grad_output = torch.randn(b, n, s, d, dtype=dtype).cuda() - torch.save(qkv, "./qkv.pt") + torch.save(qkv_whole, "./qkv.pt") torch.save(grad_output, "./grad.pt") bmt.synchronize() - qkv = torch.load("qkv.pt", map_location="cuda") + qkv_whole = torch.load("qkv.pt", map_location="cuda") grad_output = torch.load("grad.pt", map_location="cuda") - qkv1 = [t.clone().detach().requires_grad_() for t in qkv.chunk(3, dim=1)] + qkv = [t.clone().detach().requires_grad_() for t in qkv_whole.chunk(3, dim=1)] + + o_ref = ref_attn(qkv[0], qkv[1], qkv[2]) + g_ref = torch.autograd.grad(o_ref, qkv, grad_output) - o_ref, g_ref = test_func(qkv1[0], qkv1[1], qkv1[2], ref_attn, grad_output) for i in range(3): if flash is not None: - qkv1[i] = qkv1[i].transpose(1, 2) - qkv1[i] = qkv1[i].chunk(bmt.world_size(), dim=seq_dim)[bmt.rank()] - qkv1[i] = qkv1[i].clone().detach().requires_grad_() + qkv[i] = qkv[i].transpose(1, 2) + qkv[i] = qkv[i].chunk(bmt.world_size(), dim=seq_dim)[bmt.rank()] + qkv[i] = qkv[i].clone().detach().requires_grad_() + if flash is not None: grad_output = grad_output.transpose(1, 2) @@ -63,10 +65,11 @@ def get_chunk(t, dim): .detach() .contiguous() ) - o1, grad_qkv1 = test_func(qkv1[0], qkv1[1], qkv1[2], burst, grad_output) + o1 = burst(qkv[0], qkv[1], qkv[2], flash, causal, softmax_scale=None) + grad_qkv = torch.autograd.grad(o1, qkv, grad_output) if flash: o1 = o1.transpose(1, 2).contiguous() - grad_qkv1 = [g.transpose(1, 2).contiguous() for g in grad_qkv1] + grad_qkv = [g.transpose(1, 2).contiguous() for g in grad_qkv] o_ref = get_chunk(o_ref, dim=2) g_ref = [get_chunk(g, dim=2) for g in g_ref] np.testing.assert_allclose( @@ -77,9 +80,14 @@ def get_chunk(t, dim): ) for i in range(3): falsh_g_rank = g_ref[i].detach().cpu().numpy() - burst_g_rank = grad_qkv1[i].detach().cpu().numpy() - np.testing.assert_allclose(falsh_g_rank, burst_g_rank, atol=1e-2, rtol=0) + burst_g_rank = grad_qkv[i].detach().cpu().numpy() + try: + np.testing.assert_allclose(falsh_g_rank, burst_g_rank, atol=1e-2, rtol=0) + except Exception as e: + bmt.print_rank(e) + bmt.print_rank(f"passed {i}") + bmt.print_rank(f"dtype = {dtype}, flash = {flash}, causal = {causal} setting passed ") if __name__ == "__main__": - test_burst() + test_main() From 1614ad1f54d1b8be1d33311900b5f69504cca507 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 11 Jun 2024 15:26:13 +0800 Subject: [PATCH 10/11] clean unused import --- tests/test_burst.py | 1 - tests/test_lse.py | 61 --------------------------------------------- 2 files changed, 62 deletions(-) delete mode 100644 tests/test_lse.py diff --git a/tests/test_burst.py b/tests/test_burst.py index 6876f806..67ffe9a3 100644 --- a/tests/test_burst.py +++ b/tests/test_burst.py @@ -1,6 +1,5 @@ import torch import bmtrain as bmt -from flash_attn.flash_attn_interface import flash_attn_func as flash_cuda import numpy as np OpBurstAttn = bmt.nn.OpBurstAttn diff --git a/tests/test_lse.py b/tests/test_lse.py deleted file mode 100644 index f0de3955..00000000 --- a/tests/test_lse.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -import bmtrain as bmt -from flash_attn.flash_attn_interface import flash_attn_func as flash_cuda -import numpy as np - -OpBurstAttn = bmt.nn.OpBurstAttn - -def flash(q, k, v): - return flash_cuda(q, k, v, causal=False, softmax_scale=None, return_attn_probs=True) - - -def burst(q, k, v): - res_burst = OpBurstAttn.apply(q, k, v, None, None, False, True) - return res_burst - -def test_func(q, k, v, func, grad_output): - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - grad_output = grad_output.contiguous() - o = func(q, k, v) - gq, gk, gv = torch.autograd.grad(o, (q, k, v), grad_output) - return o, (gq, gk, gv) - -def test_burst(): - dtype = torch.float16 - bmt.init_distributed(sp_size=4) - def get_chunk(t, dim): - return t.chunk(bmt.config['sp_size'], dim=dim)[bmt.config['sp_rank']].contiguous() - - b, s, n, d = 2, 4096, 16, 32 - if bmt.config["sp_rank"] == 0: - qkv = torch.randn(b, n*3, s, d, dtype=dtype).cuda() - torch.save(qkv, "./qkv.pt") - bmt.synchronize() - qkv = torch.load("qkv.pt", map_location="cuda") - qkv1 = [t.clone().detach().requires_grad_().transpose(1, 2).contiguous() for t in qkv.chunk(3, dim=1)] - qkv_burst_normal = [get_chunk(t, dim=2).clone().detach().requires_grad_() for t in qkv.chunk(3, dim=1)] - output, lse, _ = flash(qkv1[0], qkv1[1], qkv1[2]) - output_burst, lse_burst = burst(qkv_burst_normal[0], qkv_burst_normal[1], qkv_burst_normal[2]) - def test_allclose(t1, t2, atol, rtol): - t1 = t1.detach().cpu().numpy() - t2 = t2.detach().cpu().numpy() - assert np.testing.assert_allclose(t1, t2, atol=atol, rtol=rtol) - try: - output = get_chunk(output.transpose(1, 2), 2).contiguous() - lse = get_chunk(lse, 2).unsqueeze(dim=-1) - print(torch.allclose(output, output_burst)) - print(torch.allclose(lse, lse_burst)) - raise Exception - except Exception: - if bmt.rank() == 0: - from IPython import embed;embed() - bmt.synchronize() - - - - -if __name__ == "__main__": - test_burst() - From 4e1bc988a60e882ad7bfdd61dffae8d54f646b2d Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 11 Jun 2024 15:29:02 +0800 Subject: [PATCH 11/11] move import statement --- bmtrain/nn/burst_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bmtrain/nn/burst_utils.py b/bmtrain/nn/burst_utils.py index 413100fe..ebd9099a 100644 --- a/bmtrain/nn/burst_utils.py +++ b/bmtrain/nn/burst_utils.py @@ -1,11 +1,5 @@ import bmtrain as bmt import torch -from flash_attn.flash_attn_interface import ( - _flash_attn_forward as _flash_attn_forward_cuda, -) -from flash_attn.flash_attn_interface import ( - _flash_attn_backward as _flash_attn_backward_cuda, -) import inspect class ops_wrapper: @@ -228,6 +222,9 @@ def inter_flash_attn_backward_triton( def inter_flash_cuda_fwd(q, k, v, o, lse, softmax_scale=1.0, causal=False): + from flash_attn.flash_attn_interface import ( + _flash_attn_forward as _flash_attn_forward_cuda, + ) o_i, _, _, _, _, lse_i, _, _ = _flash_attn_forward_cuda( q, k, @@ -250,6 +247,9 @@ def inter_flash_cuda_fwd(q, k, v, o, lse, softmax_scale=1.0, causal=False): def inter_flash_cuda_bwd(do, q, k, v, o, lse, dq, dk, dv, softmax_scale, causal): dk_ = torch.empty_like(q) dv_ = torch.empty_like(q) + from flash_attn.flash_attn_interface import ( + _flash_attn_backward as _flash_attn_backward_cuda, + ) if len(o.shape) == 3: # use sum(o_i * gradoutput) as delta and pass a empty out to flash backward # this feature requires a build of this PR: https://github.com/Dao-AILab/flash-attention/pull/905