diff --git a/conv2d.py b/conv2d.py index 93b3142..e6740c1 100644 --- a/conv2d.py +++ b/conv2d.py @@ -243,7 +243,7 @@ def triton_conv2d_kernel( def triton_conv2d(input, filter): - n, _, h, w = input.shape + n, c, h, w = input.shape k, _, r, s = filter.shape p = h - r + 1 q = w - s + 1