-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizer.py
238 lines (188 loc) · 7.28 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# -*- coding: utf-8 -*-
from sys import argv
from re import sub
from preprocessor import process_file, connect_lines, split_text
auto_replace_jp_to_jr = True
replaces = {
r'ld\s+a,\s+0': 'sub a',
r'dec\s+b\n\s+j[rp]\s+nz,\s+(\w+)': r'djnz \1',
r'call\s+(\w+)\n\s+ret': r'jp \1',
r'\s+jp\s+(\w+)\n\1:': r'\n\1:',
r'xor\s+255': 'cpl',
r'ld\s+d,\s+(\d+)\s+ld\s+e,\s+(\d+)': r'ld de, 256*\1 + \2',
r'ld\s+e,\s+(\d+)\s+ld\s+d,\s+(\d+)': r'ld de, 256*\2 + \1',
r'ld\s+h,\s+(\d+)\s+ld\s+l,\s+(\d+)': r'ld hl, 256*\1 + \2',
r'ld\s+l,\s+(\d+)\s+ld\s+h,\s+(\d+)': r'ld hl, 256*\2 + \1',
r'ld\s+b,\s+(\d+)\s+ld\s+c,\s+(\d+)': r'ld bc, 256*\1 + \2',
r'ld\s+c,\s+(\d+)\s+ld\s+b,\s+(\d+)': r'ld bc, 256*\2 + \1',
}
def optimize_z80_asm(lines):
lines = strip_comments(lines)
delete_unused_functions(lines)
delete_unused_functions(lines)
auto_inline(lines)
optimize_jumps(lines)
lines = replace_instructions(lines)
return lines
def auto_inline(lines):
label_data = collect_calls_data(lines)
label_data = filter_condition_calls(lines, label_data)
find_functions_to_inline(label_data)
replace_calls(lines, label_data)
def collect_calls_data(lines):
label_data = {}
for line in lines:
update_label_data(label_data, line.strip().split())
count_len_of_functions(lines, label_data)
return label_data
def update_label_data(label_data, splited_line):
if splited_line[0] != 'call':
return
ix = 1 if ',' not in splited_line[1] else 2
if splited_line[ix] in label_data:
label_data[splited_line[ix]][0] += 1
else:
label_data[splited_line[ix]] = [1, 0] # num of occurences, len of func
def find_functions_to_inline(label_data):
labels_to_delete = []
for label in label_data:
if label_data[label][0] > 1 and label_data[label][1] > 3:
labels_to_delete.append(label)
delete_labels(label_data, labels_to_delete)
def delete_labels(label_data, labels_to_delete):
for label in labels_to_delete:
del label_data[label]
def count_len_of_functions(lines, label_data):
for lb in label_data:
state = 0
for i, line in enumerate(lines):
if state == 0 and not line.startswith(lb + ':'):
continue
elif state == 0:
state = 1
ans = 0
elif state == 1 and line.strip() != ('ret'):
ans += 1
continue
elif state == 1:
label_data[lb][1] = ans
state = 0
break
def replace_calls(lines, label_data):
for label in label_data:
if label_data[label][1] > 0: # len of func
foo_lines = copy_func_body(lines, label, label_data[label][1])
insert_func_bodies(lines, label, foo_lines, label_data[label][0])
delete_function(lines, label)
def copy_func_body(lines, label, len_of_func_body):
for i, line in enumerate(lines):
if line.startswith(label + ':'):
return lines[i + 1: i + len_of_func_body + 1]
def insert_func_bodies(lines, label, foo_lines, num_of_occurencies):
for n in range(num_of_occurencies):
for n_line, line in enumerate(lines):
if insert_body(lines, line, n_line, label, foo_lines):
break
def insert_body(lines, line, n_line, label, foo_lines):
if line.split()[0] == 'call' and line.split()[1] == label:
del(lines[n_line])
lines[n_line:n_line] = foo_lines
return True
return False
def delete_function(lines, label):
state = 0
lines_to_delete = []
for i, line in enumerate(lines):
if state == 0 and line.startswith(label + ':'):
state = 1
lines_to_delete.append(i)
elif state == 1 and line.strip() != 'ret':
lines_to_delete.append(i)
elif state == 1:
lines_to_delete.append(i)
break
for i in sorted(lines_to_delete, reverse=True):
del(lines[i])
def strip_comments(lines):
new_lines = []
for line in lines:
if not line.strip().startswith(';'):
new_lines.append(line)
return new_lines
def delete_unused_functions(lines):
labels = collect_labels(lines)
called_labels = collect_calls_data(lines)
unused_labels = find_unused(labels, called_labels)
delete_functions(lines, unused_labels)
def find_unused(labels, called_labels):
unused_labels = []
for label in labels:
if label not in called_labels:
unused_labels.append(label)
return unused_labels
def delete_functions(lines, unused_labels):
for label in unused_labels:
delete_function(lines, label)
def collect_labels(lines):
labels = []
for i, line in enumerate(lines):
if line[0] != ' ' and line[-2] == ':' and not is_auto_label(line) \
and not is_data_label(lines, i):
labels.append(line.strip()[:-1])
return labels
def is_auto_label(line):
return (line[0] == 'l' and line[1:-2].isnumeric()) or line == 'main:\n'
def is_data_label(lines, i):
return lines[i + 1].lstrip().split()[0] in ('db', 'dw', 'dd', 'dh', 'ds')
def filter_condition_calls(lines, label_data):
new_label_data = {}
for label in label_data:
for line in lines:
splt = line.split()
if len(splt) > 1 and splt[0] == 'call' and splt[1] == label:
new_label_data[label] = label_data[label]
break
if len(label_data) != len(new_label_data):
print("Optimizer warning: some functions were not inlined")
return new_label_data
def optimize_jumps(lines):
repl = {'z,': 'nz,', 'nz,': 'z,', 'c,': 'nc,', 'nc,': 'c,'}
lines_to_delete = []
for i, line in enumerate(lines[:-2]):
cur = line.split()
if cur[0] == 'jp' and len(cur) == 3 and cur[1] in repl:
nxt = lines[i + 1].split()
if nxt[0] == 'jp' and len(nxt) == 2 and \
lines[i + 2].startswith(cur[2]):
lines_to_delete.append(i + 1)
cur[1] = repl[cur[1]]
cur[2] = nxt[1]
lines[i] = ' ' + ' '.join(cur) + '\n'
for i in sorted(lines_to_delete, reverse=True):
del(lines[i])
replace_jr_to_jp(lines, repl)
def replace_jr_to_jp(lines, repl):
if not auto_replace_jp_to_jr:
return
for i, line in enumerate(lines):
splt = line.split()
if splt[0] != 'jp':
continue
if len(splt) == 3 and splt[1] not in repl:
continue
for j in range(1, 30):
if (i + j < len(lines) and lines[i + j].startswith(splt[-1])) or \
(j < i and lines[i - j].startswith(splt[-1])):
lines[i] = ' jr ' + ' '.join(splt[1:]) + '\n'
def replace_instructions(lines):
if auto_replace_jp_to_jr:
replaces[r'jp\s+(l\d+)'] = r'IF \1 - $ < 127\n jr \1\n' + \
r' ELSE\n jp \1\n ENDIF'
replaces[r'jp\s+([n]?[zc],\s+)(l\d+)'] = r'IF \2 - $ < 127\n' + \
r' jr \1\2\n ELSE\n jp \1\2\n ENDIF'
text = connect_lines(lines)
for replace in replaces:
text = sub(replace, replaces[replace], text)
return split_text(text)
if __name__ == '__main__':
process_file(argv[1], argv[2], optimize_z80_asm)