@@ -87,11 +87,8 @@ function transform_gpu!(def, constargs, force_inbounds, unsafe_indicies)
8787end
8888
8989struct WorkgroupLoop
90- indicies:: Vector{Any}
9190 stmts:: Vector{Any}
9291 allocations:: Vector{Any}
93- private_allocations:: Vector{Any}
94- private:: Set{Symbol}
9592 terminated_in_sync:: Bool
9693end
9794
@@ -112,26 +109,18 @@ function find_sync(stmt)
112109end
113110
114111# TODO proper handling of LineInfo
115- function split (
116- stmts,
117- indicies = Any[], private = Set {Symbol} (),
118- )
112+ function split (stmts)
119113 # 1. Split the code into blocks separated by `@synchronize`
120- # 2. Aggregate `@index` expressions
121- # 3. Hoist allocations
122- # 4. Hoist uniforms
123114
124115 current = Any[]
125116 allocations = Any[]
126- private_allocations = Any[]
127117 new_stmts = Any[]
128118 for stmt in stmts
129119 has_sync = find_sync (stmt)
130120 if has_sync
131- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private) , is_sync (stmt))
121+ loop = WorkgroupLoop (current, allocations, is_sync (stmt))
132122 push! (new_stmts, emit (loop))
133123 allocations = Any[]
134- private_allocations = Any[]
135124 current = Any[]
136125 is_sync (stmt) && continue
137126
@@ -143,7 +132,7 @@ function split(
143132 function recurse (expr:: Expr )
144133 expr = unblock (expr)
145134 if is_scope_construct (expr) && any (find_sync, expr. args)
146- new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private) ))
135+ new_args = unblock (split (expr. args))
147136 return Expr (expr. head, new_args... )
148137 else
149138 return Expr (expr. head, map (recurse, expr. args)... )
@@ -157,14 +146,10 @@ function split(
157146 push! (allocations, stmt)
158147 continue
159148 elseif @capture (stmt, @private lhs_ = rhs_)
160- push! (private, lhs)
161- push! (private_allocations, :($ lhs = $ rhs))
149+ push! (allocations, :($ lhs = $ rhs))
162150 continue
163151 elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
164- if @capture (rhs, @index (args__))
165- push! (indicies, stmt)
166- continue
167- elseif @capture (rhs, @localmem (args__) | @uniform (args__))
152+ if @capture (rhs, @localmem (args__) | @uniform (args__))
168153 push! (allocations, stmt)
169154 continue
170155 elseif @capture (rhs, @private (T_, dims_))
@@ -176,7 +161,6 @@ function split(
176161 end
177162 alloc = :($ Scratchpad (__ctx__, $ T, Val ($ dims)))
178163 push! (allocations, :($ lhs = $ alloc))
179- push! (private, lhs)
180164 continue
181165 end
182166 end
@@ -186,7 +170,7 @@ function split(
186170
187171 # everything since the last `@synchronize`
188172 if ! isempty (current)
189- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private) , false )
173+ loop = WorkgroupLoop (current, allocations, false )
190174 push! (new_stmts, emit (loop))
191175 end
192176 return new_stmts
@@ -198,9 +182,7 @@ function emit(loop)
198182 body = Expr (:block , loop. stmts... )
199183 loopexpr = quote
200184 $ (loop. allocations... )
201- $ (loop. private_allocations... )
202185 if __active_lane__
203- $ (loop. indicies... )
204186 $ (unblock (body))
205187 end
206188 end
0 commit comments