@@ -178,16 +178,63 @@ def supporting_temporary_names(
178178 return frozenset (result )
179179
180180
181+ def _get_temporaries_accessed_in_schedule (
182+ kernel : LoopKernel ,
183+ sched_idx_lower_bound : int ,
184+ sched_idx_upper_bound : int
185+ ) -> frozenset [str ]:
186+ from loopy .schedule import CallKernel , EnterLoop , LeaveLoop
187+
188+ linearization = kernel .linearization
189+ assert linearization is not None
190+
191+ temporaries : frozenset [str ] = frozenset ()
192+ for sched_index in range (sched_idx_lower_bound , sched_idx_upper_bound ):
193+ sched_item = linearization [sched_index ]
194+ if isinstance (sched_item , CallKernel ):
195+ temporaries = (
196+ temporaries_written_in_subkernel (kernel , sched_item .kernel_name )
197+ | temporaries_read_in_subkernel (
198+ kernel , sched_item .kernel_name
199+ )
200+ | (temporaries )
201+ )
202+ elif isinstance (sched_item , (EnterLoop , LeaveLoop )):
203+ # ignore further outside-kernel loops
204+ pass
205+
206+ else :
207+ raise NotImplementedError ("kernel with non-CallKernel outermost" )
208+
209+ return temporaries
210+
211+
212+ def _map_to_base_storage (kernel : LoopKernel , tv_names : Set [str ]) -> Set [str ]:
213+ result : set [str ] = set ()
214+ for tv_name in tv_names :
215+ while True :
216+ tv = kernel .temporary_variables [tv_name ]
217+ if tv .base_storage is not None :
218+ tv_name = tv .base_storage
219+ else :
220+ break
221+
222+ result .add (tv_name )
223+
224+ return result
225+
226+
227+ @memoize_on_first_arg
181228def get_sched_index_to_first_and_last_used (
182229 kernel : LoopKernel
183- ) -> tuple [dict [int , frozenset [str ]], dict [int , frozenset [str ]]]:
230+ ) -> tuple [Mapping [int , Set [str ]], Mapping [int , Set [str ]]]:
184231 """
185232 Returns the tuple (first_used, last_used), where first_used is
186- a dict such that first_used[sched_index] is the set of all temporary
233+ a dict such that first_used[sched_index] is the set of all global temporary
187234 variable names first used at sched_index.
188235
189- Likewise, last_used[sched_index] is the set of all temporary variable names
190- last used at sched_index.
236+ Likewise, last_used[sched_index] is the set of all global temporary
237+ variable names last used at sched_index.
191238 """
192239 from loopy .kernel .data import AddressSpace
193240 from loopy .schedule import CallKernel , EnterLoop
@@ -200,94 +247,65 @@ def get_sched_index_to_first_and_last_used(
200247 )
201248
202249 # Collapse into blocks
203- def get_temporaries_in_bounds (
204- linearization : Sequence [ScheduleItem ],
205- lower_bound : int ,
206- upper_bound : int
207- ) -> frozenset [str ]:
208- temporaries : frozenset [str ] = frozenset ()
209- for sched_index in range (lower_bound , upper_bound + 1 ):
210- sched_item = linearization [sched_index ]
211- if isinstance (sched_item , CallKernel ):
212- temporaries = (
213- temporaries_written_in_subkernel (kernel , sched_item .kernel_name )
214- | temporaries_read_in_subkernel (
215- kernel , sched_item .kernel_name
216- )
217- | (temporaries )
218- )
219- return temporaries & global_temporaries
220-
221250 block_boundaries = get_block_boundaries (kernel .linearization )
222251
223- bounds : dict [int , frozenset [str ]] = {}
252+ tvs_accessed_at : dict [int , frozenset [str ]] = {}
224253 sched_index = 0
225254 while sched_index < len (kernel .linearization ):
226255 sched_item = kernel .linearization [sched_index ]
227- if isinstance (sched_item , (EnterLoop , CallKernel )):
228- if isinstance (sched_item , CallKernel ):
229- block_end = block_boundaries [sched_index ]
230- accessed_temporaries = (
231- temporaries_written_in_subkernel (kernel , sched_item .kernel_name )
232- | temporaries_read_in_subkernel (
233- kernel , sched_item .kernel_name
234- )
235- )
236- else :
237- block_end = block_boundaries [sched_index ]
238- accessed_temporaries = get_temporaries_in_bounds (
239- kernel .linearization , sched_index , block_end
256+ if isinstance (sched_item , CallKernel ):
257+ block_end = block_boundaries [sched_index ]
258+ tvs_accessed_at [sched_index ] = (
259+ temporaries_written_in_subkernel (kernel , sched_item .kernel_name )
260+ | temporaries_read_in_subkernel (
261+ kernel , sched_item .kernel_name
240262 )
241- bounds [sched_index ] = accessed_temporaries
263+ ) & global_temporaries
264+
242265 sched_index = block_end + 1
243- else :
244- sched_index += 1
245266
246- def update_seen_storage_vars (
247- seen_sv : set [str ],
248- new_temp_variables : frozenset [str ]
249- ) -> frozenset [str ]:
250- new_storage_variables : set [str ] = set ()
251- past_sv = frozenset (seen_sv )
252- for new_tv_name in new_temp_variables :
253- new_tv = kernel .temporary_variables [new_tv_name ]
254- if new_tv .base_storage is None :
255- storage_var = new_tv_name
256- else :
257- storage_var = new_tv .base_storage
258- new_storage_variables .add (storage_var )
259- seen_sv .add (storage_var )
260- new_sv = frozenset (new_storage_variables )
261- return new_sv - past_sv
262- # forward pass for first accesses
263- first_accesses : dict [int , frozenset [str ]] = {}
264- seen_storage_variables : set [str ] = set ()
267+ elif isinstance (sched_item , EnterLoop ):
268+ block_end = block_boundaries [sched_index ]
269+ tvs_accessed_at [sched_index ] = _get_temporaries_accessed_in_schedule (
270+ kernel , sched_index , block_end + 1
271+ ) & global_temporaries
272+
273+ sched_index = block_end + 1
274+
275+ else :
276+ raise ValueError (
277+ f"unexpected schedule item at outermost level: { type (sched_item )} " )
278+
279+ storage_vars_accessed_at = {
280+ sched_index : _map_to_base_storage (kernel , accessed )
281+ for sched_index , accessed in tvs_accessed_at .items ()
282+ }
283+ del tvs_accessed_at
284+
285+ # forward pass for first_accesses
286+ first_accesses : dict [int , Set [str ]] = {}
287+ seen_storage_vars : set [str ] = set ()
265288 for sched_index in range (0 , len (kernel .linearization )):
266- if (sched_index not in bounds ):
267- continue
268- sched_item = kernel .linearization [sched_index ]
269- new_temporary_variables = bounds [sched_index ]
270- new_storage_variables = update_seen_storage_vars (
271- seen_storage_variables , new_temporary_variables
272- )
289+ accessed = storage_vars_accessed_at .get (sched_index , set ())
290+ new_storage_vars = accessed - seen_storage_vars
291+ seen_storage_vars .update (accessed )
273292
274- if ( len ( new_storage_variables ) > 0 ) :
275- first_accesses [sched_index ] = new_storage_variables
293+ if new_storage_vars :
294+ first_accesses [sched_index ] = new_storage_vars
276295
277- last_accesses : dict [int , frozenset [str ]] = {}
278- seen_storage_variables .clear ()
296+ # backward pass for last_accesses
297+ last_accesses : dict [int , Set [str ]] = {}
298+ seen_storage_vars = set ()
279299 for sched_index in range (len (kernel .linearization )- 1 , - 1 , - 1 ):
280- if (sched_index not in bounds ):
281- continue
282- sched_item = kernel .linearization [sched_index ]
283- new_temporary_variables = bounds [sched_index ]
284- new_storage_variables = update_seen_storage_vars (
285- seen_storage_variables , new_temporary_variables
286- )
300+ accessed = storage_vars_accessed_at .get (sched_index , set ())
301+ new_storage_vars = accessed - seen_storage_vars
302+ seen_storage_vars .update (accessed )
303+
304+ if new_storage_vars :
305+ last_accesses [sched_index ] = new_storage_vars
287306
288- if new_storage_variables :
289- last_accesses [sched_index ] = new_storage_variables
290307 return (first_accesses , last_accesses )
308+
291309# }}}
292310
293311
0 commit comments