Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cbdb757
add localized allocation and deallocation
dsding2 Jun 2, 2025
2fee158
delete commented out code
dsding2 Jun 2, 2025
8ace895
deal with base storage
dsding2 Jun 4, 2025
c4e635c
ruff check fixes
dsding2 Jun 5, 2025
24b1a47
rework to push allocations outside of loops
dsding2 Jun 8, 2025
be78797
add types, fix ruff
dsding2 Jun 9, 2025
461558d
Merge remote-tracking branch 'upstream/main' into opencl_allocation
dsding2 Jun 13, 2025
0bcf4df
Merge branch 'main' into opencl_allocation
dsding2 Jun 17, 2025
0b6abdd
refactor to make more target-generic
dsding2 Jun 17, 2025
4f95a6b
resolve lingering merge issues
dsding2 Jun 17, 2025
bd98636
fix to only allocate global temporaries
dsding2 Jun 17, 2025
47dda68
move temp declarations to ASTBuilder
dsding2 Jun 19, 2025
e494a3b
Merge branch 'main' into opencl_allocation
dsding2 Jun 19, 2025
88c436f
fix typing
dsding2 Jun 19, 2025
dae91e2
fix typing hopefully
dsding2 Jun 23, 2025
1cfe83a
add basic test
dsding2 Jun 23, 2025
3c3bb78
Merge branch 'main' into opencl_allocation
dsding2 Jun 30, 2025
3ef324c
more typing/ruff fixes
dsding2 Jun 30, 2025
f708b66
fix tutorial.rst and add to baseline
dsding2 Jun 30, 2025
a0a8365
Merge branch 'main' into opencl_allocation
inducer Jul 5, 2025
f12ce9f
Merge branch 'main' into opencl_allocation
inducer Jul 10, 2025
452be6b
Merge branch 'main' into opencl_allocation
inducer Jul 10, 2025
3985576
Update loopy/schedule/tools.py
dsding2 Jul 11, 2025
95e119e
Apply suggested test changes
dsding2 Jul 11, 2025
5cbfbf1
implement rename and documentation suggestions
dsding2 Jul 11, 2025
612b238
ruff fixes, revert broken change
dsding2 Jul 12, 2025
4b0d754
Merge branch 'main' into opencl_allocation
inducer Jul 28, 2025
e591ae6
Merge branch 'main' into opencl_allocation
inducer Jul 31, 2025
1290c64
Merge branch 'main' into opencl_allocation
inducer Aug 28, 2025
b24fe99
Improvements
inducer Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -8261,8 +8261,8 @@
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 56,
"endColumn": 17,
"lineCount": 8
"endColumn": 63,
"lineCount": 1
}
},
{
Expand All @@ -8273,6 +8273,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 52,
"endColumn": 59,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down
26 changes: 21 additions & 5 deletions loopy/codegen/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,23 @@ def generate_code_for_sched_index(
glob_grid, loc_grid = kernel.get_grid_sizes_for_insn_ids_as_exprs(
get_insn_ids_for_block_at(kernel.linearization, sched_index),
codegen_state.callables_table)
return merge_codegen_results(codegen_state, [
codegen_result,

prefixes, suffixes = (
codegen_state.ast_builder.get_temporary_decl_at_index(
codegen_state, sched_index
)
)
results = [
prefixes,
codegen_result,
codegen_state.ast_builder.get_kernel_call(
codegen_state,
sched_item.kernel_name,
glob_grid, loc_grid)
])
glob_grid, loc_grid),
suffixes
]
results = [r for r in results if r is not None]
return merge_codegen_results(codegen_state, results)
else:
# do not generate host code for non-entrypoint kernels
return codegen_result
Expand Down Expand Up @@ -136,7 +145,14 @@ def generate_code_for_sched_index(
"for '%s', tagged '%s'"
% (sched_item.iname, ", ".join(str(tag) for tag in tags)))

return func(codegen_state, sched_index)
prefixes, suffixes = (
codegen_state.ast_builder.get_temporary_decl_at_index(
codegen_state, sched_index
)
)
results = [prefixes, func(codegen_state, sched_index), suffixes]
results = [r for r in results if r is not None]
return merge_codegen_results(codegen_state, results)

elif isinstance(sched_item, Barrier):
# {{{ emit barrier code
Expand Down
131 changes: 131 additions & 0 deletions loopy/schedule/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,137 @@ def supporting_temporary_names(

return frozenset(result)


def _get_temporaries_accessed_in_schedule(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this handle RunInstruction outermost?

kernel: LoopKernel,
sched_idx_lower_bound: int,
sched_idx_upper_bound: int
) -> frozenset[str]:
from loopy.schedule import CallKernel, EnterLoop, LeaveLoop

linearization = kernel.linearization
assert linearization is not None

temporaries: frozenset[str] = frozenset()
for sched_index in range(sched_idx_lower_bound, sched_idx_upper_bound):
sched_item = linearization[sched_index]
if isinstance(sched_item, CallKernel):
temporaries = (
temporaries_written_in_subkernel(kernel, sched_item.kernel_name)
| temporaries_read_in_subkernel(
kernel, sched_item.kernel_name
)
| (temporaries)
)
elif isinstance(sched_item, (EnterLoop, LeaveLoop)):
# ignore further outside-kernel loops
pass

else:
raise NotImplementedError("kernel with non-CallKernel outermost")

return temporaries


def _map_to_base_storage(kernel: LoopKernel, tv_names: Set[str]) -> Set[str]:
result: set[str] = set()
for tv_name in tv_names:
while True:
tv = kernel.temporary_variables[tv_name]
if tv.base_storage is not None:
tv_name = tv.base_storage
else:
break

result.add(tv_name)

return result


@memoize_on_first_arg
def get_sched_index_to_first_and_last_used(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than "implicitly" attaching allocations to schedule items, maybe realize them explicitly as a new kind of schedule item?

kernel: LoopKernel
) -> tuple[Mapping[int, Set[str]], Mapping[int, Set[str]]]:
"""
Returns the tuple (first_used, last_used), where first_used is
a dict such that first_used[sched_index] is the set of all global temporary
variable names first used at sched_index.

Likewise, last_used[sched_index] is the set of all global temporary
variable names last used at sched_index.
"""
from loopy.kernel.data import AddressSpace
from loopy.schedule import CallKernel, EnterLoop, Barrier

assert kernel.linearization is not None

global_temporaries = frozenset(
tv.name for tv in kernel.temporary_variables.values()
if tv.address_space == AddressSpace.GLOBAL
)

# Collapse into blocks
block_boundaries = get_block_boundaries(kernel.linearization)

tvs_accessed_at: dict[int, frozenset[str]] = {}
sched_index = 0
while sched_index < len(kernel.linearization):
sched_item = kernel.linearization[sched_index]
if isinstance(sched_item, CallKernel):
block_end = block_boundaries[sched_index]
tvs_accessed_at[sched_index] = (
temporaries_written_in_subkernel(kernel, sched_item.kernel_name)
| temporaries_read_in_subkernel(
kernel, sched_item.kernel_name
)
) & global_temporaries

sched_index = block_end + 1

elif isinstance(sched_item, EnterLoop):
block_end = block_boundaries[sched_index]
tvs_accessed_at[sched_index] = _get_temporaries_accessed_in_schedule(
kernel, sched_index, block_end+1
) & global_temporaries

sched_index = block_end + 1

elif isinstance(sched_item, Barrier):
sched_index += 1
else:
raise ValueError(
f"unexpected schedule item at outermost level: {type(sched_item)}")

storage_vars_accessed_at = {
sched_index: _map_to_base_storage(kernel, accessed)
for sched_index, accessed in tvs_accessed_at.items()
}
del tvs_accessed_at

# forward pass for first_accesses
first_accesses: dict[int, Set[str]] = {}
seen_storage_vars: set[str] = set()
for sched_index in range(0, len(kernel.linearization)):
accessed = storage_vars_accessed_at.get(sched_index, set())
new_storage_vars = accessed - seen_storage_vars
seen_storage_vars.update(accessed)

if new_storage_vars:
first_accesses[sched_index] = new_storage_vars

# backward pass for last_accesses
last_accesses: dict[int, Set[str]] = {}
seen_storage_vars = set()
for sched_index in range(len(kernel.linearization)-1, -1, -1):
accessed = storage_vars_accessed_at.get(sched_index, set())
new_storage_vars = accessed - seen_storage_vars
seen_storage_vars.update(accessed)

if new_storage_vars:
last_accesses[sched_index] = new_storage_vars

return (first_accesses, last_accesses)

# }}}


Expand Down
43 changes: 43 additions & 0 deletions loopy/target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from loopy.codegen import CodeGenerationState, PreambleInfo
from loopy.codegen.result import CodeGenerationResult
from loopy.kernel import LoopKernel
from loopy.kernel.data import TemporaryVariable
from loopy.target.c import DTypeRegistry
from loopy.target.execution import ExecutorBase
from loopy.translation_unit import CallableId, CallablesTable, TranslationUnit
Expand Down Expand Up @@ -251,6 +252,27 @@ def get_temporary_decls(self, codegen_state: CodeGenerationState,
schedule_index: int) -> ASTType:
raise NotImplementedError

@abstractmethod
def get_temporary_var_declarator(self,
codegen_state: CodeGenerationState,
temp_var: TemporaryVariable
) -> ASTType | None:
...

@abstractmethod
def get_temporary_var_deallocator(self,
codegen_state: CodeGenerationState,
temp_var: TemporaryVariable
) -> ASTType | None:
...

@abstractmethod
def get_temporary_decl_at_index(
self, codegen_state: CodeGenerationState,
sched_index: int
) -> tuple[ASTType | None, ASTType | None]:
...

def get_kernel_call(self, codegen_state: CodeGenerationState,
subkernel_name: str,
gsize: tuple[Expression, ...],
Expand Down Expand Up @@ -365,6 +387,27 @@ def get_expression_to_code_mapper(self, codegen_state):
def get_kernel_call(self, codegen_state, name, gsize, lsize):
return None

@override
def get_temporary_var_declarator(
self, codegen_state: CodeGenerationState,
temp_var: TemporaryVariable
) -> None:
return None

@override
def get_temporary_var_deallocator(
self, codegen_state: CodeGenerationState,
temp_var: TemporaryVariable
) -> None:
return None

@override
def get_temporary_decl_at_index(
self, codegen_state: CodeGenerationState,
sched_index: int
) -> tuple[None, None]:
return (None, None)

@property
def ast_block_class(self):
return _DummyASTBlock
Expand Down
14 changes: 14 additions & 0 deletions loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from cgen import (
Block,
Collection,
Comment,
Const,
Declarator,
Generable,
Expand Down Expand Up @@ -1109,6 +1110,12 @@ def get_temporary_decls(self, codegen_state, schedule_index):

return result

@override
def get_temporary_decl_at_index(
self, codegen_state: CodeGenerationState, sched_index: int
) -> tuple[Generable | None, Generable | None]:
return (None, None)

@property
@override
def ast_block_class(self):
Expand Down Expand Up @@ -1242,6 +1249,7 @@ def arg_to_cgen_declarator(
raise ValueError(f"unexpected type of argument '{passed_name}': "
f"'{type(var_descr)}'")

@override
def get_temporary_var_declarator(self,
codegen_state: CodeGenerationState,
temp_var: TemporaryVariable) -> Declarator:
Expand Down Expand Up @@ -1274,6 +1282,12 @@ def get_temporary_var_declarator(self,
return self.wrap_decl_for_address_space(temp_var_decl,
temp_var.address_space)

@override
def get_temporary_var_deallocator(self,
codegen_state: CodeGenerationState,
temp_var: TemporaryVariable
) -> Generable:
return Comment("Dynamic freeing of temp vars not supported")
# }}}

@override
Expand Down
Loading
Loading