@@ -166,6 +166,9 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel(
166
166
evicted_indices,
167
167
pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits>
168
168
actions_count,
169
+ const bool lock_cache_line,
170
+ pta::PackedTensorAccessor32<int32_t , 2 , at::RestrictPtrTraits>
171
+ lxu_cache_locking_counter,
169
172
TORCH_DSA_KERNEL_ARGS) {
170
173
// Number of cache sets
171
174
const int32_t C = lxu_cache_state.size (0 );
@@ -216,51 +219,65 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel(
216
219
SL += 1 ;
217
220
}
218
221
219
- // now , we need to insert the (unique!) values in indices[n:n + SL] into
222
+ // Now , we need to insert the (unique!) values in indices[n:n + SL] into
220
223
// our slots.
221
224
const int32_t slot = threadIdx .x ;
222
225
const int64_t slot_time = lru_state[cache_set][slot];
223
- int64_t costs[1 ] = {slot_time};
226
+
227
+ // Check if the slot is locked
228
+ const bool is_slot_locked =
229
+ lock_cache_line && (lxu_cache_locking_counter[cache_set][slot] > 0 );
230
+ // Check if the slot has the inserted row that was a cache hit.
231
+ const int64_t slot_idx = lxu_cache_state[cache_set][slot];
232
+ const bool slot_has_idx = slot_idx != -1 && slot_time == time_stamp;
233
+ // Check if the slot is unavailable: either it is locked or contains
234
+ // a cache hit inserted row
235
+ const bool is_slot_unavailable = is_slot_locked || slot_has_idx;
236
+
237
+ // Set the slot cost: if the slot is not available, set it to the
238
+ // maximum timestamp which is the current timestamp. After sorting,
239
+ // the unavailable slots will be in the bottom, while the available
240
+ // slots will be bubbled to the top
241
+ const int64_t slot_cost = is_slot_unavailable ? time_stamp : slot_time;
242
+
243
+ // Prepare key-value pair for sorting
244
+ int64_t costs[1 ] = {slot_cost};
224
245
int32_t slots[1 ] = {slot};
225
246
247
+ // Sort the slots based on their costs
226
248
BitonicSort<int64_t , int32_t , 1 , Comparator<int64_t >>::sort (costs, slots);
227
- const int32_t sorted_slot = slots[0 ];
228
- const int64_t sorted_time = costs[0 ];
249
+
250
+ // Get the sorted results
251
+ const int32_t insert_slot = slots[0 ];
252
+ const int64_t insert_cost = costs[0 ];
229
253
230
254
auto l = threadIdx .x ;
231
255
256
+ // Get the current index
257
+ const int64_t current_idx = shfl_sync (slot_idx, insert_slot);
258
+
232
259
// Insert rows
233
260
if (l < SL) {
234
261
// Insert indices
235
- const int32_t insert_slot = sorted_slot;
236
- const int64_t insert_time = sorted_time;
237
-
238
262
const int64_t insert_idx = cache_set_sorted_indices[n + l];
239
- const int64_t current_idx = lxu_cache_state[cache_set][insert_slot];
240
-
241
- #if 0
242
- // TODO: Check whether to uncomment this
243
- // Only check insert_time if tag is for valid entry
244
- if (current_idx != -1) {
245
- // We need to ensure if prefetching (prefetch_dist) batches ahead
246
- // No entries that are younger than (time_stamp - prefetch_dist) are
247
- // evicted from the cache. This will break the guarantees required
248
- // for the SSD embedding.
249
- // If you hit this assert, increase the cache size.
250
- CUDA_KERNEL_ASSERT2(insert_time < (time_stamp - prefetch_dist));
251
- }
252
- #endif
253
263
254
- if (current_idx != -1 && insert_time == time_stamp) {
255
- // Skip this slot as the inserted row was a cache hit
256
- // This is conflict miss
264
+ if (insert_cost == time_stamp) {
265
+ // Skip this slot as it is not available
257
266
evicted_indices[n + l] = -1 ;
258
267
assigned_cache_slots[n + l] = -1 ;
259
268
} else {
260
269
evicted_indices[n + l] = current_idx; // -1 if not set, >= 0 if valid.
261
270
assigned_cache_slots[n + l] = cache_set * kWarpSize + insert_slot;
271
+
272
+ // TODO: Check if we can do contiguous writes here.
273
+ // Update cache states
262
274
lxu_cache_state[cache_set][insert_slot] = insert_idx;
263
275
lru_state[cache_set][insert_slot] = time_stamp;
276
+
277
+ // Lock cache line
278
+ if (lock_cache_line) {
279
+ lxu_cache_locking_counter[cache_set][insert_slot] += 1 ;
280
+ }
264
281
}
265
282
}
266
283
@@ -280,9 +297,11 @@ ssd_cache_populate_actions_cuda(
280
297
int64_t prefetch_dist,
281
298
Tensor lru_state,
282
299
bool gather_cache_stats,
283
- std::optional<Tensor> ssd_cache_stats) {
300
+ std::optional<Tensor> ssd_cache_stats,
301
+ const bool lock_cache_line,
302
+ const c10::optional<Tensor>& lxu_cache_locking_counter) {
284
303
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL (
285
- linear_indices, lxu_cache_state, lru_state);
304
+ linear_indices, lxu_cache_state, lru_state, lxu_cache_locking_counter );
286
305
287
306
CUDA_DEVICE_GUARD (linear_indices);
288
307
@@ -332,9 +351,17 @@ ssd_cache_populate_actions_cuda(
332
351
/* cache_set_inverse_indices=*/ at::empty ({0 }, int_options));
333
352
}
334
353
354
+ Tensor lxu_cache_locking_counter_;
355
+ if (lock_cache_line) {
356
+ TORCH_CHECK (lxu_cache_locking_counter.has_value ());
357
+ lxu_cache_locking_counter_ = lxu_cache_locking_counter.value ();
358
+ } else {
359
+ lxu_cache_locking_counter_ =
360
+ at::empty ({0 , 0 }, lxu_cache_state.options ().dtype (at::kInt ));
361
+ }
362
+
335
363
auto actions_count = at::empty ({1 }, int_options);
336
364
// Find uncached indices
337
- Tensor lxu_cache_locking_counter = at::empty ({0 , 0 }, int_options);
338
365
auto
339
366
[sorted_cache_sets,
340
367
cache_set_sorted_unique_indices,
@@ -348,8 +375,8 @@ ssd_cache_populate_actions_cuda(
348
375
lru_state,
349
376
gather_cache_stats,
350
377
ssd_cache_stats_,
351
- /* lock_cache_line= */ false ,
352
- lxu_cache_locking_counter ,
378
+ lock_cache_line,
379
+ lxu_cache_locking_counter_ ,
353
380
/* compute_inverse_indices=*/ true );
354
381
355
382
TORCH_CHECK (cache_set_inverse_indices.has_value ());
@@ -373,7 +400,10 @@ ssd_cache_populate_actions_cuda(
373
400
MAKE_PTA_WITH_NAME (func_name, lru_state, int64_t , 2 , 32 ),
374
401
MAKE_PTA_WITH_NAME (func_name, assigned_cache_slots, int32_t , 1 , 32 ),
375
402
MAKE_PTA_WITH_NAME (func_name, evicted_indices, int64_t , 1 , 32 ),
376
- MAKE_PTA_WITH_NAME (func_name, actions_count, int32_t , 1 , 32 ));
403
+ MAKE_PTA_WITH_NAME (func_name, actions_count, int32_t , 1 , 32 ),
404
+ lock_cache_line,
405
+ MAKE_PTA_WITH_NAME (
406
+ func_name, lxu_cache_locking_counter_, int32_t , 2 , 32 ));
377
407
378
408
return std::make_tuple (
379
409
cache_set_sorted_unique_indices,
0 commit comments