Skip to content

Commit 0184cb1

Browse files
sryapfacebook-github-bot
authored andcommitted
Add ssd_update_row_addrs
Summary: When pipeline prefetching is enabled, data in a scratch pad of the current iteration can be moved to L1 or a scratch pad of the next iteration during the prefetch step. `ssd_update_row_addrs` updates the memory addresses of data that is relocated to the correct location. Differential Revision: D60983150
1 parent bbaead1 commit 0184cb1

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,111 @@ std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(
538538
539539
return {ssd_row_addrs, post_bwd_evicted_indices};
540540
}
541+
542+
__global__ __launch_bounds__(kMaxThreads) void ssd_update_row_addrs_kernel(
543+
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
544+
ssd_row_addrs_curr,
545+
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
546+
ssd_curr_next_map,
547+
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
548+
lxu_cache_locations_curr,
549+
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
550+
linear_index_inverse_indices_curr,
551+
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
552+
unique_indices_count_cumsum_curr,
553+
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
554+
cache_set_inverse_indices_curr,
555+
const uint64_t lxu_cache_weights_addr,
556+
const uint64_t inserted_ssd_weights_addr_next,
557+
const int* N_unique_curr,
558+
const uint64_t cache_row_bytes // has to be 64 bits to current overflow
559+
) {
560+
const auto n_curr = blockDim.y * blockIdx.x + threadIdx.y;
561+
if (n_curr >= *N_unique_curr) {
562+
return;
563+
}
564+
565+
// Find mapping between n_curr and n_next
566+
const auto n_next = ssd_curr_next_map[n_curr];
567+
// Return if the row is not used in both currious and nextent iterations
568+
if (n_next < 0) {
569+
return;
570+
}
571+
572+
// Find out if the row gets moved to the nextent iteration's scratch pad or
573+
// L1 by checking the lxu_cache_locations_next
574+
const auto cache_set_id_curr = cache_set_inverse_indices_curr[n_curr];
575+
const auto segment_start_curr =
576+
unique_indices_count_cumsum_curr[cache_set_id_curr];
577+
const auto segment_end_curr =
578+
unique_indices_count_cumsum_curr[cache_set_id_curr + 1];
579+
const auto cache_loc_curr = lxu_cache_locations_curr
580+
[linear_index_inverse_indices_curr[segment_start_curr]];
581+
582+
const uint64_t ptr_addr = (cache_loc_curr == -1)
583+
// The row is moved from the currious iteration's scratch pad to the
584+
// nextent iteration's scratch pad
585+
? (inserted_ssd_weights_addr_next + (n_next * cache_row_bytes))
586+
// The row is moved from the currious iteration's scratch pad to L1 cache
587+
: (lxu_cache_weights_addr + (cache_loc_curr * cache_row_bytes));
588+
589+
// Set pointer address
590+
for (auto l = segment_start_curr + threadIdx.x; l < segment_end_curr;
591+
l += blockDim.x) {
592+
auto dst = linear_index_inverse_indices_curr[l];
593+
*reinterpret_cast<uint64_t*>(&ssd_row_addrs_curr[dst]) = ptr_addr;
594+
}
595+
}
596+
597+
void ssd_update_row_addrs_cuda(
598+
const Tensor& ssd_row_addrs_curr,
599+
const Tensor& ssd_curr_next_map,
600+
const Tensor& lxu_cache_locations_curr,
601+
const Tensor& linear_index_inverse_indices_curr,
602+
const Tensor& unique_indices_count_cumsum_curr,
603+
const Tensor& cache_set_inverse_indices_curr,
604+
const Tensor& lxu_cache_weights,
605+
const Tensor& inserted_ssd_weights_next,
606+
const Tensor& unique_indices_length_curr) {
607+
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
608+
ssd_row_addrs_curr,
609+
ssd_curr_next_map,
610+
lxu_cache_locations_curr,
611+
linear_index_inverse_indices_curr,
612+
unique_indices_count_cumsum_curr,
613+
cache_set_inverse_indices_curr,
614+
lxu_cache_weights,
615+
inserted_ssd_weights_next,
616+
unique_indices_length_curr);
617+
618+
CUDA_DEVICE_GUARD(ssd_row_addrs_curr);
619+
620+
const auto lxu_cache_weights_addr =
621+
reinterpret_cast<uint64_t>(lxu_cache_weights.data_ptr());
622+
const auto inserted_ssd_weights_addr_next =
623+
reinterpret_cast<uint64_t>(inserted_ssd_weights_next.data_ptr());
624+
const auto cache_row_bytes =
625+
lxu_cache_weights.size(1) * lxu_cache_weights.element_size();
626+
constexpr auto kNumWarps = kMaxThreads / kWarpSize;
627+
628+
ssd_update_row_addrs_kernel<<<
629+
div_round_up(ssd_row_addrs_curr.numel(), kNumWarps),
630+
dim3(kWarpSize, kNumWarps),
631+
0,
632+
at::cuda::getCurrentCUDAStream()>>>(
633+
ssd_row_addrs_curr.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
634+
ssd_curr_next_map.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
635+
lxu_cache_locations_curr
636+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
637+
linear_index_inverse_indices_curr
638+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
639+
unique_indices_count_cumsum_curr
640+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
641+
cache_set_inverse_indices_curr
642+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
643+
lxu_cache_weights_addr,
644+
inserted_ssd_weights_addr_next,
645+
unique_indices_length_curr.data_ptr<int32_t>(),
646+
cache_row_bytes);
647+
C10_CUDA_KERNEL_LAUNCH_CHECK();
648+
}

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,56 @@ std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(
148148
const Tensor& unique_indices_length,
149149
const Tensor& cache_set_sorted_unique_indices);
150150

151+
/// @ingroup embedding-ssd
152+
///
153+
/// @brief Update memory addresses for SSD TBE data
154+
///
155+
/// When pipeline prefetching is enabled, data in a scratch pad of the
156+
/// current iteration can be moved to L1 or a scratch pad of the next
157+
/// iteration during the prefetch step. This operator updates the
158+
/// memory addresses of data that is relocated to the correct
159+
/// location.
160+
///
161+
/// @param ssd_row_addrs_curr The tensor that contains the row address
162+
/// of the current iteration
163+
/// @param inserted_ssd_weights_curr_next_map The tensor that contains
164+
/// mapping between the location of each index in the
165+
/// current iteration in the scratch pad of the next
166+
/// iteration. (-1 = the data has not been moved).
167+
/// inserted_ssd_weights_curr_next_map[i] is the location
168+
// of index i in the next iteration's scratch pad.
169+
/// @param lxu_cache_locations_curr The tensor that contains cache
170+
/// slots where data is stored for the *full* list of
171+
/// indices for the current iteration. -1 is a sentinel
172+
/// value that indicates that data is not in cache.
173+
/// @param linear_index_inverse_indices_curr The tensor that contains
174+
/// the original position of linear indices before being
175+
/// sorted for the current iteration
176+
/// @param unique_indices_count_cumsum_curr The tensor that contains
177+
/// the the exclusive prefix sum results of the counts of
178+
/// unique indices for the current iteration
179+
/// @param cache_set_inverse_indices_curr The tensor that contains the
180+
/// original positions of cache sets before being sorted
181+
/// for the current iteration
182+
/// @param lxu_cache_weights The LXU cache tensor
183+
/// @param inserted_ssd_weights_next The scratch pad tensor for the
184+
/// next iteration
185+
/// @param unique_indices_length_curr The tensor that contains the
186+
/// number of unique indices (GPU tensor) for the current
187+
/// iteration
188+
///
189+
/// @return None
190+
void ssd_update_row_addrs_cuda(
191+
const Tensor& ssd_row_addrs_curr,
192+
const Tensor& inserted_ssd_weights_curr_next_map,
193+
const Tensor& lxu_cache_locations_curr,
194+
const Tensor& linear_index_inverse_indices_curr,
195+
const Tensor& unique_indices_count_cumsum_curr,
196+
const Tensor& cache_set_inverse_indices_curr,
197+
const Tensor& lxu_cache_weights,
198+
const Tensor& inserted_ssd_weights_next,
199+
const Tensor& unique_indices_length_curr);
200+
151201
namespace {
152202
class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
153203
public:
@@ -315,5 +365,18 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
315365
" Tensor cache_set_sorted_unique_indices"
316366
") -> (Tensor, Tensor)");
317367
DISPATCH_TO_CUDA("ssd_generate_row_addrs", ssd_generate_row_addrs_cuda);
368+
m.def(
369+
"ssd_update_row_addrs("
370+
" Tensor ssd_row_addrs_curr, "
371+
" Tensor inserted_ssd_weights_curr_next_map, "
372+
" Tensor lxu_cache_locations_curr, "
373+
" Tensor linear_index_inverse_indices_curr, "
374+
" Tensor unique_indices_count_cumsum_curr, "
375+
" Tensor cache_set_inverse_indices_curr, "
376+
" Tensor lxu_cache_weights, "
377+
" Tensor inserted_ssd_weights_next, "
378+
" Tensor unique_indices_length_curr"
379+
") -> ()");
380+
DISPATCH_TO_CUDA("ssd_update_row_addrs", ssd_update_row_addrs_cuda);
318381
}
319382
} // namespace

0 commit comments

Comments
 (0)