Skip to content

Commit 3a86afd

Browse files
committed
Improvements
1 parent 1290c64 commit 3a86afd

File tree

4 files changed

+135
-92
lines changed

4 files changed

+135
-92
lines changed

loopy/schedule/tools.py

Lines changed: 96 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,63 @@ def supporting_temporary_names(
178178
return frozenset(result)
179179

180180

181+
def _get_temporaries_accessed_in_schedule(
182+
kernel: LoopKernel,
183+
sched_idx_lower_bound: int,
184+
sched_idx_upper_bound: int
185+
) -> frozenset[str]:
186+
from loopy.schedule import CallKernel, EnterLoop, LeaveLoop
187+
188+
linearization = kernel.linearization
189+
assert linearization is not None
190+
191+
temporaries: frozenset[str] = frozenset()
192+
for sched_index in range(sched_idx_lower_bound, sched_idx_upper_bound):
193+
sched_item = linearization[sched_index]
194+
if isinstance(sched_item, CallKernel):
195+
temporaries = (
196+
temporaries_written_in_subkernel(kernel, sched_item.kernel_name)
197+
| temporaries_read_in_subkernel(
198+
kernel, sched_item.kernel_name
199+
)
200+
| (temporaries)
201+
)
202+
elif isinstance(sched_item, (EnterLoop, LeaveLoop)):
203+
# ignore further outside-kernel loops
204+
pass
205+
206+
else:
207+
raise NotImplementedError("kernel with non-CallKernel outermost")
208+
209+
return temporaries
210+
211+
212+
def _map_to_base_storage(kernel: LoopKernel, tv_names: Set[str]) -> Set[str]:
213+
result: set[str] = set()
214+
for tv_name in tv_names:
215+
while True:
216+
tv = kernel.temporary_variables[tv_name]
217+
if tv.base_storage is not None:
218+
tv_name = tv.base_storage
219+
else:
220+
break
221+
222+
result.add(tv_name)
223+
224+
return result
225+
226+
227+
@memoize_on_first_arg
181228
def get_sched_index_to_first_and_last_used(
182229
kernel: LoopKernel
183-
) -> tuple[dict[int, frozenset[str]], dict[int, frozenset[str]]]:
230+
) -> tuple[Mapping[int, Set[str]], Mapping[int, Set[str]]]:
184231
"""
185232
Returns the tuple (first_used, last_used), where first_used is
186-
a dict such that first_used[sched_index] is the set of all temporary
233+
a dict such that first_used[sched_index] is the set of all global temporary
187234
variable names first used at sched_index.
188235
189-
Likewise, last_used[sched_index] is the set of all temporary variable names
190-
last used at sched_index.
236+
Likewise, last_used[sched_index] is the set of all global temporary
237+
variable names last used at sched_index.
191238
"""
192239
from loopy.kernel.data import AddressSpace
193240
from loopy.schedule import CallKernel, EnterLoop
@@ -200,94 +247,65 @@ def get_sched_index_to_first_and_last_used(
200247
)
201248

202249
# Collapse into blocks
203-
def get_temporaries_in_bounds(
204-
linearization: Sequence[ScheduleItem],
205-
lower_bound: int,
206-
upper_bound: int
207-
) -> frozenset[str]:
208-
temporaries: frozenset[str] = frozenset()
209-
for sched_index in range(lower_bound, upper_bound+1):
210-
sched_item = linearization[sched_index]
211-
if isinstance(sched_item, CallKernel):
212-
temporaries = (
213-
temporaries_written_in_subkernel(kernel, sched_item.kernel_name)
214-
| temporaries_read_in_subkernel(
215-
kernel, sched_item.kernel_name
216-
)
217-
| (temporaries)
218-
)
219-
return temporaries & global_temporaries
220-
221250
block_boundaries = get_block_boundaries(kernel.linearization)
222251

223-
bounds: dict[int, frozenset[str]] = {}
252+
tvs_accessed_at: dict[int, frozenset[str]] = {}
224253
sched_index = 0
225254
while sched_index < len(kernel.linearization):
226255
sched_item = kernel.linearization[sched_index]
227-
if isinstance(sched_item, (EnterLoop, CallKernel)):
228-
if isinstance(sched_item, CallKernel):
229-
block_end = block_boundaries[sched_index]
230-
accessed_temporaries = (
231-
temporaries_written_in_subkernel(kernel, sched_item.kernel_name)
232-
| temporaries_read_in_subkernel(
233-
kernel, sched_item.kernel_name
234-
)
235-
)
236-
else:
237-
block_end = block_boundaries[sched_index]
238-
accessed_temporaries = get_temporaries_in_bounds(
239-
kernel.linearization, sched_index, block_end
256+
if isinstance(sched_item, CallKernel):
257+
block_end = block_boundaries[sched_index]
258+
tvs_accessed_at[sched_index] = (
259+
temporaries_written_in_subkernel(kernel, sched_item.kernel_name)
260+
| temporaries_read_in_subkernel(
261+
kernel, sched_item.kernel_name
240262
)
241-
bounds[sched_index] = accessed_temporaries
263+
) & global_temporaries
264+
242265
sched_index = block_end + 1
243-
else:
244-
sched_index += 1
245266

246-
def update_seen_storage_vars(
247-
seen_sv: set[str],
248-
new_temp_variables: frozenset[str]
249-
) -> frozenset[str]:
250-
new_storage_variables: set[str] = set()
251-
past_sv = frozenset(seen_sv)
252-
for new_tv_name in new_temp_variables:
253-
new_tv = kernel.temporary_variables[new_tv_name]
254-
if new_tv.base_storage is None:
255-
storage_var = new_tv_name
256-
else:
257-
storage_var = new_tv.base_storage
258-
new_storage_variables.add(storage_var)
259-
seen_sv.add(storage_var)
260-
new_sv = frozenset(new_storage_variables)
261-
return new_sv - past_sv
262-
# forward pass for first accesses
263-
first_accesses: dict[int, frozenset[str]] = {}
264-
seen_storage_variables: set[str] = set()
267+
elif isinstance(sched_item, EnterLoop):
268+
block_end = block_boundaries[sched_index]
269+
tvs_accessed_at[sched_index] = _get_temporaries_accessed_in_schedule(
270+
kernel, sched_index, block_end+1
271+
) & global_temporaries
272+
273+
sched_index = block_end + 1
274+
275+
else:
276+
raise ValueError(
277+
f"unexpected schedule item at outermost level: {type(sched_item)}")
278+
279+
storage_vars_accessed_at = {
280+
sched_index: _map_to_base_storage(kernel, accessed)
281+
for sched_index, accessed in tvs_accessed_at.items()
282+
}
283+
del tvs_accessed_at
284+
285+
# forward pass for first_accesses
286+
first_accesses: dict[int, Set[str]] = {}
287+
seen_storage_vars: set[str] = set()
265288
for sched_index in range(0, len(kernel.linearization)):
266-
if (sched_index not in bounds):
267-
continue
268-
sched_item = kernel.linearization[sched_index]
269-
new_temporary_variables = bounds[sched_index]
270-
new_storage_variables = update_seen_storage_vars(
271-
seen_storage_variables, new_temporary_variables
272-
)
289+
accessed = storage_vars_accessed_at.get(sched_index, set())
290+
new_storage_vars = accessed - seen_storage_vars
291+
seen_storage_vars.update(accessed)
273292

274-
if (len(new_storage_variables) > 0):
275-
first_accesses[sched_index] = new_storage_variables
293+
if new_storage_vars:
294+
first_accesses[sched_index] = new_storage_vars
276295

277-
last_accesses: dict[int, frozenset[str]] = {}
278-
seen_storage_variables.clear()
296+
# backward pass for last_accesses
297+
last_accesses: dict[int, Set[str]] = {}
298+
seen_storage_vars = set()
279299
for sched_index in range(len(kernel.linearization)-1, -1, -1):
280-
if (sched_index not in bounds):
281-
continue
282-
sched_item = kernel.linearization[sched_index]
283-
new_temporary_variables = bounds[sched_index]
284-
new_storage_variables = update_seen_storage_vars(
285-
seen_storage_variables, new_temporary_variables
286-
)
300+
accessed = storage_vars_accessed_at.get(sched_index, set())
301+
new_storage_vars = accessed - seen_storage_vars
302+
seen_storage_vars.update(accessed)
303+
304+
if new_storage_vars:
305+
last_accesses[sched_index] = new_storage_vars
287306

288-
if new_storage_variables:
289-
last_accesses[sched_index] = new_storage_variables
290307
return (first_accesses, last_accesses)
308+
291309
# }}}
292310

293311

loopy/target/__init__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -252,20 +252,26 @@ def get_temporary_decls(self, codegen_state: CodeGenerationState,
252252
schedule_index: int) -> ASTType:
253253
raise NotImplementedError
254254

255+
@abstractmethod
255256
def get_temporary_var_declarator(self,
256-
codegen_state: CodeGenerationState,
257-
temp_var: TemporaryVariable) -> ASTType:
258-
raise NotImplementedError()
257+
codegen_state: CodeGenerationState,
258+
temp_var: TemporaryVariable
259+
) -> ASTType | None:
260+
...
259261

262+
@abstractmethod
260263
def get_temporary_var_deallocator(self,
261-
codegen_state: CodeGenerationState,
262-
temp_var: TemporaryVariable) -> ASTType:
263-
raise NotImplementedError()
264+
codegen_state: CodeGenerationState,
265+
temp_var: TemporaryVariable
266+
) -> ASTType | None:
267+
...
264268

269+
@abstractmethod
265270
def get_temporary_decl_at_index(
266-
self, codegen_state: CodeGenerationState,
267-
sched_index: int) -> tuple[ASTType | None, ASTType | None]:
268-
raise NotImplementedError()
271+
self, codegen_state: CodeGenerationState,
272+
sched_index: int
273+
) -> tuple[ASTType | None, ASTType | None]:
274+
...
269275

270276
def get_kernel_call(self, codegen_state: CodeGenerationState,
271277
subkernel_name: str,

loopy/target/pyopencl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@
8989

9090
# {{{ pyopencl function scopers
9191

92-
9392
class PyOpenCLCallable(ScalarCallable):
9493
"""
9594
Records information about the callables which are not covered by
@@ -912,14 +911,14 @@ def get_kernel_call(
912911
gsize: tuple[Expression, ...], lsize: tuple[Expression, ...]
913912
) -> genpy.Suite:
914913
from genpy import Assert, Assign, Comment, Line, Suite
915-
from pymbolic.mapper.stringifier import PREC_NONE
916914

917915
kernel = codegen_state.kernel
918-
ecm = self.get_expression_to_code_mapper(codegen_state)
919916

920917
from loopy.schedule.tools import get_subkernel_arg_info
921918
skai = get_subkernel_arg_info(kernel, subkernel_name)
922919

920+
ecm = self.get_expression_to_code_mapper(codegen_state)
921+
923922
if not gsize:
924923
gsize = (1,)
925924
if not lsize:
@@ -986,6 +985,7 @@ def get_kernel_call(
986985
overflow_args_code = Suite([])
987986

988987
import pyopencl.version as cl_ver
988+
from pymbolic.mapper.stringifier import PREC_NONE
989989
if cl_ver.VERSION < (2020, 2):
990990
from warnings import warn
991991
warn("Your kernel invocation will likely fail because your "

loopy/target/python.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from pymbolic.mapper.stringifier import PREC_NONE, StringifyMapper
3535

3636
from loopy.diagnostic import LoopyError
37-
from loopy.kernel.data import ValueArg
37+
from loopy.kernel.data import TemporaryVariable, ValueArg
3838
from loopy.kernel.function_interface import ScalarCallable
3939
from loopy.target import ASTBuilderBase
4040
from loopy.type_inference import TypeReader
@@ -339,7 +339,26 @@ def emit_assignment(self, codegen_state: CodeGenerationState, insn: Assignment):
339339
ecm(insn.assignee, prec=PREC_NONE, type_context=None),
340340
ecm(insn.expression, prec=PREC_NONE, type_context=None))
341341

342-
# }}}
342+
@override
343+
def get_temporary_var_declarator(self,
344+
codegen_state: CodeGenerationState,
345+
temp_var: TemporaryVariable
346+
) -> Generable | None:
347+
return None
348+
349+
@override
350+
def get_temporary_var_deallocator(self,
351+
codegen_state: CodeGenerationState,
352+
temp_var: TemporaryVariable
353+
) -> Generable | None:
354+
return None
355+
356+
@override
357+
def get_temporary_decl_at_index(
358+
self, codegen_state: CodeGenerationState,
359+
sched_index: int
360+
) -> tuple[Generable | None, Generable | None]:
361+
return None, None
343362

344363
# }}}
345364

0 commit comments

Comments
 (0)