Skip to content

Commit 19990d8

Browse files
sryapfacebook-github-bot
authored andcommitted
Enable cache line locking support in SSD kernel (#2949)
Summary: X-link: facebookresearch/FBGEMM#51 Pull Request resolved: #2949 This diff enables cache line locking in `ssd_cache_actions_insert_kernel`. We updated the cache slot cost computation (within a cache set) to facilitate the cache line locking logic. Before this diff, the cache slots are ranked based on the their costs (i.e., their timestamps which can be retrieved from the cache state). The cache slots that contain the lowest costs (lowest timestamps) will be evicted first. However, when cache line locking is enabled, the cache slot that has the lowest timestamp cannot be used if it is locked. Therefore, we assign unavailable cache slots (i.e., being locked or being occupied by another index in the same batch) with the highest cost (that is the current timestamp). The slot will not be evicted if its cost is the same as the current timestamp. Reviewed By: ehsanardestani Differential Revision: D60812956 fbshipit-source-id: 77e959ecdb260a76fe6ee0457508c241fad1376b
1 parent 440cbb0 commit 19990d8

File tree

2 files changed

+66
-32
lines changed

2 files changed

+66
-32
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel(
166166
evicted_indices,
167167
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
168168
actions_count,
169+
const bool lock_cache_line,
170+
pta::PackedTensorAccessor32<int32_t, 2, at::RestrictPtrTraits>
171+
lxu_cache_locking_counter,
169172
TORCH_DSA_KERNEL_ARGS) {
170173
// Number of cache sets
171174
const int32_t C = lxu_cache_state.size(0);
@@ -216,51 +219,65 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel(
216219
SL += 1;
217220
}
218221
219-
// now, we need to insert the (unique!) values in indices[n:n + SL] into
222+
// Now, we need to insert the (unique!) values in indices[n:n + SL] into
220223
// our slots.
221224
const int32_t slot = threadIdx.x;
222225
const int64_t slot_time = lru_state[cache_set][slot];
223-
int64_t costs[1] = {slot_time};
226+
227+
// Check if the slot is locked
228+
const bool is_slot_locked =
229+
lock_cache_line && (lxu_cache_locking_counter[cache_set][slot] > 0);
230+
// Check if the slot has the inserted row that was a cache hit.
231+
const int64_t slot_idx = lxu_cache_state[cache_set][slot];
232+
const bool slot_has_idx = slot_idx != -1 && slot_time == time_stamp;
233+
// Check if the slot is unavailable: either it is locked or contains
234+
// a cache hit inserted row
235+
const bool is_slot_unavailable = is_slot_locked || slot_has_idx;
236+
237+
// Set the slot cost: if the slot is not available, set it to the
238+
// maximum timestamp which is the current timestamp. After sorting,
239+
// the unavailable slots will be in the bottom, while the available
240+
// slots will be bubbled to the top
241+
const int64_t slot_cost = is_slot_unavailable ? time_stamp : slot_time;
242+
243+
// Prepare key-value pair for sorting
244+
int64_t costs[1] = {slot_cost};
224245
int32_t slots[1] = {slot};
225246
247+
// Sort the slots based on their costs
226248
BitonicSort<int64_t, int32_t, 1, Comparator<int64_t>>::sort(costs, slots);
227-
const int32_t sorted_slot = slots[0];
228-
const int64_t sorted_time = costs[0];
249+
250+
// Get the sorted results
251+
const int32_t insert_slot = slots[0];
252+
const int64_t insert_cost = costs[0];
229253
230254
auto l = threadIdx.x;
231255
256+
// Get the current index
257+
const int64_t current_idx = shfl_sync(slot_idx, insert_slot);
258+
232259
// Insert rows
233260
if (l < SL) {
234261
// Insert indices
235-
const int32_t insert_slot = sorted_slot;
236-
const int64_t insert_time = sorted_time;
237-
238262
const int64_t insert_idx = cache_set_sorted_indices[n + l];
239-
const int64_t current_idx = lxu_cache_state[cache_set][insert_slot];
240-
241-
#if 0
242-
// TODO: Check whether to uncomment this
243-
// Only check insert_time if tag is for valid entry
244-
if (current_idx != -1) {
245-
// We need to ensure if prefetching (prefetch_dist) batches ahead
246-
// No entries that are younger than (time_stamp - prefetch_dist) are
247-
// evicted from the cache. This will break the guarantees required
248-
// for the SSD embedding.
249-
// If you hit this assert, increase the cache size.
250-
CUDA_KERNEL_ASSERT2(insert_time < (time_stamp - prefetch_dist));
251-
}
252-
#endif
253263
254-
if (current_idx != -1 && insert_time == time_stamp) {
255-
// Skip this slot as the inserted row was a cache hit
256-
// This is conflict miss
264+
if (insert_cost == time_stamp) {
265+
// Skip this slot as it is not available
257266
evicted_indices[n + l] = -1;
258267
assigned_cache_slots[n + l] = -1;
259268
} else {
260269
evicted_indices[n + l] = current_idx; // -1 if not set, >= 0 if valid.
261270
assigned_cache_slots[n + l] = cache_set * kWarpSize + insert_slot;
271+
272+
// TODO: Check if we can do contiguous writes here.
273+
// Update cache states
262274
lxu_cache_state[cache_set][insert_slot] = insert_idx;
263275
lru_state[cache_set][insert_slot] = time_stamp;
276+
277+
// Lock cache line
278+
if (lock_cache_line) {
279+
lxu_cache_locking_counter[cache_set][insert_slot] += 1;
280+
}
264281
}
265282
}
266283
@@ -280,9 +297,11 @@ ssd_cache_populate_actions_cuda(
280297
int64_t prefetch_dist,
281298
Tensor lru_state,
282299
bool gather_cache_stats,
283-
std::optional<Tensor> ssd_cache_stats) {
300+
std::optional<Tensor> ssd_cache_stats,
301+
const bool lock_cache_line,
302+
const c10::optional<Tensor>& lxu_cache_locking_counter) {
284303
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
285-
linear_indices, lxu_cache_state, lru_state);
304+
linear_indices, lxu_cache_state, lru_state, lxu_cache_locking_counter);
286305
287306
CUDA_DEVICE_GUARD(linear_indices);
288307
@@ -332,9 +351,17 @@ ssd_cache_populate_actions_cuda(
332351
/*cache_set_inverse_indices=*/at::empty({0}, int_options));
333352
}
334353
354+
Tensor lxu_cache_locking_counter_;
355+
if (lock_cache_line) {
356+
TORCH_CHECK(lxu_cache_locking_counter.has_value());
357+
lxu_cache_locking_counter_ = lxu_cache_locking_counter.value();
358+
} else {
359+
lxu_cache_locking_counter_ =
360+
at::empty({0, 0}, lxu_cache_state.options().dtype(at::kInt));
361+
}
362+
335363
auto actions_count = at::empty({1}, int_options);
336364
// Find uncached indices
337-
Tensor lxu_cache_locking_counter = at::empty({0, 0}, int_options);
338365
auto
339366
[sorted_cache_sets,
340367
cache_set_sorted_unique_indices,
@@ -348,8 +375,8 @@ ssd_cache_populate_actions_cuda(
348375
lru_state,
349376
gather_cache_stats,
350377
ssd_cache_stats_,
351-
/*lock_cache_line=*/false,
352-
lxu_cache_locking_counter,
378+
lock_cache_line,
379+
lxu_cache_locking_counter_,
353380
/*compute_inverse_indices=*/true);
354381
355382
TORCH_CHECK(cache_set_inverse_indices.has_value());
@@ -373,7 +400,10 @@ ssd_cache_populate_actions_cuda(
373400
MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32),
374401
MAKE_PTA_WITH_NAME(func_name, assigned_cache_slots, int32_t, 1, 32),
375402
MAKE_PTA_WITH_NAME(func_name, evicted_indices, int64_t, 1, 32),
376-
MAKE_PTA_WITH_NAME(func_name, actions_count, int32_t, 1, 32));
403+
MAKE_PTA_WITH_NAME(func_name, actions_count, int32_t, 1, 32),
404+
lock_cache_line,
405+
MAKE_PTA_WITH_NAME(
406+
func_name, lxu_cache_locking_counter_, int32_t, 2, 32));
377407
378408
return std::make_tuple(
379409
cache_set_sorted_unique_indices,

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ ssd_cache_populate_actions_cuda(
2626
int64_t prefetch_dist,
2727
Tensor lru_state,
2828
bool gather_cache_stats,
29-
std::optional<Tensor> ssd_cache_stats);
29+
std::optional<Tensor> ssd_cache_stats,
30+
const bool lock_cache_line,
31+
const c10::optional<Tensor>& lxu_cache_locking_counter);
3032

3133
/// @ingroup embedding-ssd
3234
///
@@ -298,7 +300,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
298300
" int prefetch_dist, "
299301
" Tensor lru_state, "
300302
" bool gather_cache_stats=False, "
301-
" Tensor? ssd_cache_stats=None"
303+
" Tensor? ssd_cache_stats=None, "
304+
" bool lock_cache_line=False, "
305+
" Tensor? lxu_cache_locking_counter=None"
302306
") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
303307
DISPATCH_TO_CUDA(
304308
"ssd_cache_populate_actions", ssd_cache_populate_actions_cuda);

0 commit comments

Comments
 (0)