Skip to content

Commit

Permalink
commit causal_conv1d.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed Nov 10, 2024
1 parent e59e0a4 commit c3af234
Showing 1 changed file with 381 additions and 0 deletions.
381 changes: 381 additions & 0 deletions hawk/causal_conv1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,381 @@
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from torch.autograd import Function


@triton.jit
def silu(x):
return x * tl.sigmoid(x)


@triton.jit
def dsilu(x):
sig = tl.sigmoid(x)
return sig + (x * sig * (1 - sig))


@triton.jit
def causal_conv_forward_kernel(
x_ptr,
w_ptr,
b_ptr,
y_ptr,
x_stride_b,
x_stride_d,
w_stride_d,
seqlen: tl.constexpr,
kernel_t: tl.constexpr,
act: tl.constexpr,
):
b_pid = tl.program_id(0)
d_pid = tl.program_id(1)

# offset to beginning of block
x_ptr += b_pid * x_stride_b + d_pid * x_stride_d
y_ptr += b_pid * x_stride_b + d_pid * x_stride_d
w_ptr += d_pid * w_stride_d

b = tl.load(b_ptr + d_pid)

y = tl.zeros([seqlen], dtype=tl.float32)

for t in range(0, kernel_t):
# Succesively load x in shifts

w_load = (kernel_t - 1) - t

t_off = t

w = tl.load(w_ptr + w_load)

# NOTE: Not vectorized since these are loads from global

x = tl.load(
x_ptr - t_off + tl.arange(0, seqlen),
mask=tl.arange(0, seqlen) >= t_off,
other=0.0,
)
y += x * w
y += b

if act == 1:
tl.store(y_ptr + tl.arange(0, seqlen), silu(y).to(tl.bfloat16))
else:
tl.store(y_ptr + tl.arange(0, seqlen), y.to(tl.bfloat16))


@triton.jit
def causal_conv_backward_kernel(
x_ptr,
w_ptr,
b_ptr,
grad_out_ptr,
scratch_ptr,
dx_ptr,
dw_ptr,
db_ptr,
x_stride_b,
x_stride_d,
w_stride_d,
dw_stride_b,
dw_stride_d,
db_stride_b,
seqlen: tl.constexpr,
kernel_t: tl.constexpr,
act: tl.const,
):
b_pid = tl.program_id(0)
d_pid = tl.program_id(1)

# offset to beginning of block
x_ptr += b_pid * x_stride_b + d_pid * x_stride_d
dx_ptr += b_pid * x_stride_b + d_pid * x_stride_d
w_ptr += d_pid * w_stride_d

grad_out_ptr += b_pid * x_stride_b + d_pid * x_stride_d
scratch_ptr += b_pid * x_stride_b + d_pid * x_stride_d

b = tl.load(b_ptr + d_pid)

y_fwd_no_act = tl.zeros([seqlen], dtype=tl.float32)

x_grad = tl.zeros_like(y_fwd_no_act)

grad_out = tl.load(grad_out_ptr + tl.arange(0, seqlen))

x = tl.load(x_ptr + tl.arange(0, seqlen), cache_modifier=".ca")

for t in range(0, kernel_t):
# Succesively load x in shifts
w_load = (kernel_t - 1) - t

w = tl.load(w_ptr + w_load)

offs_t = tl.arange(0, seqlen) - t

# NOTE: Not vectorized since these are loads from global

x = tl.load(
x_ptr + offs_t,
mask=offs_t >= 0,
other=0.0,
)
y_fwd_no_act += x * w

y_fwd_no_act += b

if act:
y_grad = grad_out * dsilu(y_fwd_no_act.to(tl.float32)) # [seqlen]
else:
y_grad = grad_out.to(tl.float32)

# NOTE: this is the cache-modifier we want for scratchpad, doesn't help much though
tl.store(
scratch_ptr + tl.arange(0, seqlen), y_grad.to(tl.bfloat16), cache_modifier=".wb"
)

x = tl.load(x_ptr + tl.arange(0, seqlen), cache_modifier=".ca").to(tl.float32)

db_ptr += b_pid * db_stride_b + d_pid
db_grad = tl.sum(y_grad)
tl.store(db_ptr, db_grad.to(tl.bfloat16))

dw_ptr += b_pid * dw_stride_b + d_pid * dw_stride_d

for t in range(0, kernel_t):
offs_t = t + tl.arange(0, seqlen)

# NOTE: Not vectorized since these are loads from global
y_grad_rolled = tl.load(scratch_ptr + offs_t, mask=offs_t < seqlen).to(
tl.bfloat16
)

w = tl.load(w_ptr + (kernel_t - 1 - t))

x_grad += y_grad_rolled * w
val = tl.sum(y_grad_rolled * x)

tl.store(dw_ptr + (kernel_t - 1 - t), val.to(tl.bfloat16))

tl.store(dx_ptr + tl.arange(0, seqlen), value=x_grad.to(tl.bfloat16))


# docstring from https://github.com/Dao-AILab/causal-conv1d :)
def causal_conv1d_fn_fwd(x, weight, bias=None, act: int = 1):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
activation: either None or "silu"
out: (batch, dim, seqlen)
"""

y = torch.empty_like(x)

b = x.shape[0]
d = x.shape[1]

seqlen = x.shape[-1]
kernel_t = weight.shape[-1]

grid = (b, d)

if not x.is_contiguous():
x = x.contiguous()

match seqlen:
case _ if (seqlen) <= 2048:
warps = 1

case _ if (seqlen) <= 4096:
warps = 2

case _ if (seqlen) <= 8192:
warps = 4

case _:
warps = 8

causal_conv_forward_kernel[grid](
x,
weight,
bias,
y,
x.stride(0),
x.stride(1),
weight.stride(0),
seqlen,
kernel_t,
act=act,
enable_fp_fusion=True, # type: ignore
num_warps=warps, # type: ignore
)

return y


def causal_conv1d_fn_bwd(x, weight, bias, grad_out, act):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
grad_out: (batch, dim, seqlen)
act: int (act=1 for 'silu' activation)
dx: (batch, dim, seqlen)
dw: (dim, width)
dbias: (dim,)
"""

if not grad_out.is_contiguous():
grad_out = grad_out.contiguous()

assert grad_out.is_contiguous()
assert x.is_contiguous()

b = x.shape[0]
d = x.shape[1]

seqlen = x.shape[-1]
kernel_t = weight.shape[-1]

grid = (b, d)

x_grad = torch.empty_like(x)
bias_grad = torch.empty(
(b, bias.shape[0]),
dtype=bias.dtype,
device=bias.device,
)
weight_grad = torch.zeros(
size=(b, weight.shape[0], kernel_t),
dtype=weight.dtype,
device=weight.device,
)

match seqlen:
case _ if (seqlen) <= 2048:
warps = 1

case _ if (seqlen) <= 4096:
warps = 2

case _ if (seqlen) <= 8192:
warps = 4

case _:
warps = 8

grad_scratchpad = torch.empty_like(grad_out, dtype=torch.float32)

causal_conv_backward_kernel[grid](
x,
weight,
bias,
grad_out,
grad_scratchpad,
x_grad,
weight_grad,
bias_grad,
x.stride(0),
x.stride(1),
weight.stride(0),
weight_grad.stride(0),
weight_grad.stride(1),
bias_grad.stride(0),
seqlen,
kernel_t,
act=act,
enable_fp_fusion=True, # type: ignore
num_warps=warps, # type: ignore
)
return x_grad, weight_grad.sum(dim=0).unsqueeze(1), bias_grad.sum(dim=0)


class CausalConv(Function):
@staticmethod
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.bfloat16) # type: ignore
def forward(ctx, x, weight, bias, act):
ctx.save_for_backward(x, weight, bias)
ctx.act = act

return causal_conv1d_fn_fwd(x, weight, bias, act=act)

@staticmethod
@torch.amp.custom_bwd(device_type="cuda") # type: ignore
def backward(ctx, grad_output): # type: ignore
(x, weight, bias) = ctx.saved_tensors

x_grad, weight_grad, bias_grad = causal_conv1d_fn_bwd(
x, weight, bias, grad_output, act=ctx.act
)

return x_grad, weight_grad, bias_grad, None


def causal_conv(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, act: int):
return CausalConv.apply(x, weight, bias, act)


# if __name__ == "__main__":
# b = 4
# d = 128
# l = 4096
# width = 4

# temporal_width = 4
# conv1d = torch.nn.Conv1d(
# in_channels=d,
# out_channels=d,
# bias=True,
# kernel_size=width,
# groups=d,
# padding=width - 1,
# )
# conv1d.cuda()

# from einops import rearrange

# x = torch.randn((b, l, d), device="cuda", dtype=torch.float32, requires_grad=True)

# with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
# x_ref = rearrange(x, "b l d_in -> b d_in l")
# out_ref = F.silu(conv1d(x_ref))[..., :l]
# out_ref = rearrange(out_ref, "b d_in l -> b l d_in")

# with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
# out_tl = causal_conv(x.mT.contiguous(), conv1d.weight, conv1d.bias, act=1)
# out_tl = out_tl.mT.contiguous()

# dy = torch.ones_like(out_ref)

# print(out_tl.shape, out_ref.shape)
# print(torch.linalg.norm(out_ref - out_tl) / torch.linalg.norm(out_ref))

# # # Backward pass - PyTorch
# out_ref.backward(dy, retain_graph=True)
# dx_torch = x.grad.clone()
# dw_torch = conv1d.weight.grad.clone()
# db_torch = conv1d.bias.grad.clone()

# # Reset grads for manual implementation
# x.grad = None
# conv1d.weight.grad = None
# conv1d.bias.grad = None

# # Backward pass - Manual implementation
# out_tl.backward(dy, retain_graph=True)

# def relative_error(x: torch.Tensor, y: torch.Tensor) -> float:
# return (torch.linalg.norm(x - y) / torch.linalg.norm(y)).item()

# dx_manual = x.grad.clone()
# dw_manual = conv1d.weight.grad.clone()
# db_manual = conv1d.bias.grad.clone()

# print(f"dx: {relative_error(dx_manual, dx_torch)}")
# print(f"dW: {relative_error(dw_manual, dw_torch)}")
# print(f"db: {relative_error(db_manual, db_torch)}")

0 comments on commit c3af234

Please sign in to comment.