Skip to content

Commit 53db8ff

Browse files
authored
[Frontend] Fix tutorial PyTorch test (#307)
1 parent bc051cc commit 53db8ff

File tree

14 files changed

+213
-64
lines changed

14 files changed

+213
-64
lines changed

.github/workflows/config.yml

+21-2
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,31 @@ jobs:
3939
run: |
4040
source activate allo
4141
bash scripts/lint/task_lint.sh
42-
- name: Testing
42+
- name: Unit tests
4343
shell: bash
4444
run: |
4545
source activate allo
4646
export PATH=/root/llvm-project/build/bin:${PATH}
4747
export LLVM_BUILD_DIR=/root/llvm-project/build
4848
python3 -m pytest tests -v
49+
- name: Tutorial
50+
shell: bash
51+
run: |
52+
source activate allo
53+
export LLVM_BUILD_DIR=/root/llvm-project/build
4954
python3 -m pytest tutorials -v
50-
python3 -m pytest examples/polybench -v
55+
- name: Benchmark
56+
shell: bash
57+
run: |
58+
source activate allo
59+
export LLVM_BUILD_DIR=/root/llvm-project/build
60+
python3 -m pytest examples/polybench -v
61+
# no left space!
62+
# - name: PyTorch
63+
# shell: bash
64+
# run: |
65+
# source activate allo
66+
# export LLVM_BUILD_DIR=/root/llvm-project/build
67+
# python3 -m pip install torch==2.5.1
68+
# python3 examples/torch/toy.py
69+
# python3 examples/torch/mlp.py

allo/backend/llvm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright Allo authors. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
# pylint: disable=no-name-in-module, inconsistent-return-statements, too-many-function-args
3+
# pylint: disable=no-name-in-module, inconsistent-return-statements
44

55
import os
66
import ctypes

allo/frontend/pytorch.py

+54-32
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Copyright Allo authors. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
# pylint: disable=too-many-public-methods
3+
# pylint: disable=too-many-public-methods, too-many-instance-attributes
44

5-
import re
65
import operator
76
import inspect
87
import math
@@ -23,11 +22,6 @@
2322
from ..customize import customize
2423
from ..ir.types import float32
2524

26-
compose_mapping = {
27-
"linear": nn.linear,
28-
"relu": nn.relu,
29-
}
30-
3125

3226
def from_pytorch(
3327
model,
@@ -77,8 +71,7 @@ def from_pytorch(
7771
s = customize(code, global_vars=global_vars, enable_tensor=enable_tensor)
7872
# composition
7973
for func, idx, inst in builder.composition:
80-
if func in compose_mapping:
81-
s.compose(compose_mapping[func], id=idx, instantiate=inst)
74+
s.compose(getattr(nn, func), id=idx, instantiate=inst)
8275
if verbose:
8376
print(s.module)
8477
if target == "mlir":
@@ -104,6 +97,7 @@ def __init__(self, gm, example_inputs, leaf_modules=None):
10497
self.subfunctions = []
10598
self.output = []
10699
self.composition = []
100+
self.unique_id = {}
107101

108102
def build(self):
109103
for node in self.gm.graph.nodes:
@@ -155,6 +149,13 @@ def __call__(self, node):
155149
self.code.append(ret)
156150
return ret
157151

152+
def get_unique_id(self, name):
153+
if name not in self.unique_id:
154+
self.unique_id[name] = 0
155+
return 0
156+
self.unique_id[name] += 1
157+
return self.unique_id[name]
158+
158159
def get_module(self, name):
159160
return dict(self.gm.named_modules())[name]
160161

@@ -180,8 +181,13 @@ def build_call_module(self, node):
180181
raise NotImplementedError("Unsupported module")
181182
if op == "linear":
182183
bias = True if module.bias is not None else None
183-
return getattr(self, "build_linear")(node, bias)
184-
return getattr(self, f"build_{op}")(node)
184+
res = getattr(self, "build_linear")(node, bias)
185+
else:
186+
res = getattr(self, f"build_{op}")(node)
187+
# append shape after the operation
188+
if "tensor_meta" in node.meta:
189+
res += f' # shape: {str(tuple(node.meta["tensor_meta"].shape))}'
190+
return res
185191

186192
def build_call_function(self, node):
187193
opcls = {
@@ -203,11 +209,12 @@ def build_call_function(self, node):
203209
torch.cat: "concat",
204210
}.get(node.target)
205211
# Only nodes with shape need to be built.
206-
return (
207-
getattr(self, f"build_{opcls}")(node)
208-
if "tensor_meta" in node.meta
209-
else None
210-
)
212+
if "tensor_meta" in node.meta:
213+
res = getattr(self, f"build_{opcls}")(node)
214+
# append shape after the operation
215+
res += f' # shape: {str(tuple(node.meta["tensor_meta"].shape))}'
216+
return res
217+
return None
211218

212219
def build_call_method(self, node):
213220
if node.target == "contiguous":
@@ -298,29 +305,44 @@ def build_softmax(self, node):
298305

299306
def build_relu(self, node):
300307
inp = get_var_name(node.args[0])
301-
bs, n = tuple(node.meta["tensor_meta"].shape)
302-
match = re.search(r"\d+$", str(node.target).replace(".", "_"))
303-
self.composition.append(
304-
("relu", match.group() if match else None, [float32, bs, n])
305-
)
306-
return f"{node.name} = nn.relu[float32, {bs}, {n}]({inp})"
308+
shape = tuple(node.meta["tensor_meta"].shape)
309+
name_id = self.get_unique_id("relu")
310+
if len(shape) == 2:
311+
n, d = shape
312+
self.composition.append(("relu2d", name_id, [float32, n, d]))
313+
return f'{node.name} = nn.relu2d[float32, {n}, {d}, "{name_id}"]({inp})'
314+
if len(shape) == 4:
315+
n, c, h, w = shape
316+
self.composition.append(("relu4d", name_id, [float32, n, c, h, w]))
317+
return f'{node.name} = nn.relu4d[float32, {n}, {c}, {h}, {w}, "{name_id}"]({inp})'
318+
raise NotImplementedError("Unsupported shape for relu")
307319

308320
def build_linear(self, node, bias):
309321
target_name = node.target.replace(".", "_")
310322
inp = get_var_name(node.args[0])
311323
weight = get_var_name(target_name + "_weight")
312324
if bias:
313325
bias = get_var_name(target_name + "_bias")
314-
# output shape: bs * n
315-
bs, n = tuple(node.meta["tensor_meta"].shape)
316-
_, m = self.named_params[f"{str(node.target)}.weight"].shape
317-
match = re.search(r"\d+$", target_name)
318-
name = f', "{match.group()}"' if match else ""
319-
# bs*m x (n*m)^T + (n*1) = bs*n
320-
self.composition.append(
321-
("linear", match.group() if match else None, [float32, bs, n, m])
322-
)
323-
return f"{node.name} = nn.linear[float32, {bs}, {n}, {m}{name}]({inp}, {weight}, {bias})"
326+
shape = tuple(node.meta["tensor_meta"].shape)
327+
name_id = self.get_unique_id("linear")
328+
if len(shape) == 2:
329+
n, d = shape
330+
_, m = self.named_params[f"{str(node.target)}.weight"].shape
331+
# n*m x (m*d)^T + (n*1) = n*d
332+
self.composition.append(("linear2d", name_id, [float32, n, d, m]))
333+
return f'{node.name} = nn.linear2d[float32, {n}, {d}, {m}, "{name_id}"]({inp}, {weight}, {bias})'
334+
if len(shape) == 3:
335+
bs, l, m = shape
336+
_, d = self.named_params[f"{str(node.target)}.weight"].shape
337+
self.composition.append(
338+
(
339+
"linear3d",
340+
name_id,
341+
[float32, bs, l, d, m],
342+
)
343+
)
344+
return f'{node.name} = nn.linear3d[float32, {bs}, {l}, {d}, {m}, "{name_id}"]({inp}, {weight}, {bias})'
345+
raise NotImplementedError("Unsupported shape for linear")
324346
return f"{node.name} = dsl.linear({inp}, {weight})"
325347

326348
def build_gelu(self, node):

allo/library/__init__.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@
1616
)
1717

1818
from .nn import (
19-
linear,
20-
schedule_linear,
21-
relu,
22-
schedule_relu,
19+
linear2d,
20+
linear3d,
21+
schedule_linear2d,
22+
schedule_linear3d,
23+
relu2d,
24+
relu4d,
25+
schedule_relu2d,
26+
schedule_relu4d,
2327
softmax,
2428
schedule_softmax,
2529
layer_norm,
@@ -42,8 +46,10 @@
4246

4347
KERNEL2SCHEDULE.update(
4448
{
45-
linear: schedule_linear,
46-
relu: schedule_relu,
49+
linear2d: schedule_linear2d,
50+
linear3d: schedule_linear3d,
51+
relu2d: schedule_relu2d,
52+
relu4d: schedule_relu4d,
4753
softmax: schedule_softmax,
4854
layer_norm: schedule_layernorm,
4955
GeLU: schedule_gelu,

allo/library/nn.py

+50-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .systolic import systolic
77

88

9-
def linear[Ty, M, N, K](X: "Ty[M, K]", W: "Ty[N, K]", b: "Ty[N]") -> "Ty[M, N]":
9+
def linear2d[Ty, M, N, K](X: "Ty[M, K]", W: "Ty[N, K]", b: "Ty[N]") -> "Ty[M, N]":
1010
# https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
1111
Z: Ty[M, N]
1212
buf: Ty[N]
@@ -23,22 +23,61 @@ def linear[Ty, M, N, K](X: "Ty[M, K]", W: "Ty[N, K]", b: "Ty[N]") -> "Ty[M, N]":
2323
return Z
2424

2525

26-
def schedule_linear(s):
27-
s.pipeline("linear:j")
28-
s.pipeline("linear:j_init")
29-
s.pipeline("linear:j_back")
26+
def schedule_linear2d(s):
27+
s.pipeline("linear2d:j")
28+
s.pipeline("linear2d:j_init")
29+
s.pipeline("linear2d:j_back")
3030
return s
3131

3232

33-
def relu[Ty, L, D](X: "Ty[L, D]") -> "Ty[L, D]":
34-
Z: Ty[L, D]
35-
for i, j in dsl.grid(L, D):
36-
Z[i, j] = max(0.0, X[i, j])
33+
def linear3d[
34+
Ty, B, L, D, M
35+
](X: "Ty[B, L, D]", W: "Ty[M, D]", bias: "Ty[M]") -> "Ty[B, L, M]":
36+
# https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
37+
Z: Ty[B, L, M]
38+
buf: Ty[M]
39+
for b in range(B):
40+
for i in range(L):
41+
for j_init in range(M):
42+
buf[j_init] = 0
43+
for k in range(D):
44+
# reorder reduction loop outside, and pipeline
45+
x: Ty = X[b, i, k]
46+
for j in range(M):
47+
buf[j] += x * W[j, k]
48+
for j_back in range(M):
49+
Z[b, i, j_back] = buf[j_back] + bias[j_back]
50+
return Z
51+
52+
53+
def schedule_linear3d(s):
54+
s.pipeline("linear3d:j")
55+
s.pipeline("linear3d:j_init")
56+
s.pipeline("linear3d:j_back")
57+
return s
58+
59+
60+
def relu2d[Ty, H, W](X: "Ty[H, W]") -> "Ty[H, W]":
61+
Z: Ty[H, W]
62+
for h, w in dsl.grid(H, W):
63+
Z[h, w] = max(0.0, X[h, w])
64+
return Z
65+
66+
67+
def schedule_relu2d(s):
68+
s.pipeline("relu2d:w")
69+
return s
70+
71+
72+
def relu4d[Ty, N, C, H, W](X: "Ty[N, C, H, W]") -> "Ty[N, C, H, W]":
73+
Z: Ty[N, C, H, W]
74+
for n, c, h, w in dsl.grid(N, C, H, W):
75+
Z[n, c, h, w] = max(0.0, X[n, c, h, w])
3776
return Z
3877

3978

40-
def schedule_relu(s):
41-
s.pipeline("relu:j")
79+
def schedule_relu4d(s):
80+
s.pipeline("relu4d:w")
4281
return s
4382

4483

File renamed without changes.
File renamed without changes.

examples/torch/gpt2.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,17 @@ def forward(self, x):
108108
return output
109109

110110

111-
vocab_size = 50257
112-
n_embd = 768
113-
n_head = 12
114-
n_layers = 12
115-
n_seq = 1024
111+
# Large size
112+
# vocab_size = 50257
113+
# n_embd = 768
114+
# n_head = 12
115+
# n_layers = 12
116+
# n_seq = 1024
117+
vocab_size = 2
118+
n_embd = 4
119+
n_head = 2
120+
n_layers = 1
121+
n_seq = 4
116122
batch_size = 2
117123

118124
module = GPT2(vocab_size, n_embd, n_head, n_layers).eval()
@@ -121,10 +127,15 @@ def forward(self, x):
121127
llvm_mod = allo.frontend.from_pytorch(
122128
module,
123129
example_inputs=example_inputs,
124-
verbose=False,
130+
verbose=True,
125131
)
126132

127133
golden = module(*example_inputs)
128134
np_inputs = [x.detach().numpy() for x in example_inputs]
129135
res = llvm_mod(*np_inputs)
130136
np.testing.assert_allclose(res, golden.detach().numpy(), atol=1e-3)
137+
print("Test passed!")
138+
139+
# generate HLS module
140+
mod = allo.frontend.from_pytorch(module, example_inputs=example_inputs, target="vhls")
141+
print(mod.hls_code)

mlir/include/allo/Dialect/Visitor.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class HLSCppVisitorBase {
4343
// Memref-related statements.
4444
memref::AllocOp, memref::AllocaOp, memref::LoadOp, memref::StoreOp,
4545
memref::GetGlobalOp, allo::GetGlobalFixedOp, memref::GlobalOp,
46-
memref::DeallocOp, memref::DmaStartOp, memref::DmaWaitOp,
46+
memref::DeallocOp, memref::DmaStartOp, memref::DmaWaitOp, memref::ReshapeOp,
4747
memref::ViewOp, memref::SubViewOp, memref::ReinterpretCastOp,
4848
memref::AtomicRMWOp,
4949
// Tensor-related statements.
@@ -144,6 +144,7 @@ class HLSCppVisitorBase {
144144
HANDLE(memref::ViewOp);
145145
HANDLE(memref::SubViewOp);
146146
HANDLE(memref::ReinterpretCastOp);
147+
HANDLE(memref::ReshapeOp);
147148

148149
// Tensor-related statements.
149150
HANDLE(tensor::ExtractOp);

0 commit comments

Comments
 (0)