diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 672a2732ee..41c99c9675 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -341,17 +341,22 @@ def __init__( self.ssd_event_evict = torch.cuda.Event() # SSD backward completion event self.ssd_event_backward = torch.cuda.Event() - # SSD scratch pad eviction completion event - self.ssd_event_evict_sp = torch.cuda.Event() # SSD get's input copy completion event self.ssd_event_get_inputs_cpy = torch.cuda.Event() + # SSD scratch pad index queue insert completion event + self.ssd_event_sp_idxq_insert = torch.cuda.Event() self.timesteps_prefetched: List[int] = [] - self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = [] # TODO: add type annotation # pyre-fixme[4]: Attribute must be annotated. self.ssd_prefetch_data = [] + # Scratch pad value queue + self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = [] + # pyre-ignore[4] + # Scratch pad index queue + self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue(-1) + if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization: raise AssertionError( "weight_decay_mode = WeightDecayMode.COUNTER is not supported for SSD TBE." @@ -427,10 +432,6 @@ def __init__( torch.zeros(0, device=self.current_device, dtype=torch.float) ) - # Register backward hook for evicting rows from a scratch pad to SSD - # post backward - self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad) - assert optimizer in ( OptimType.EXACT_ROWWISE_ADAGRAD, ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags" @@ -578,8 +579,8 @@ def evict( indices_cpu: Tensor, actions_count_cpu: Tensor, stream: torch.cuda.Stream, - pre_event: torch.cuda.Event, - post_event: torch.cuda.Event, + pre_event: Optional[torch.cuda.Event], + post_event: Optional[torch.cuda.Event], is_rows_uvm: bool, name: Optional[str] = "", ) -> None: @@ -607,7 +608,8 @@ def evict( """ with record_function(f"## ssd_evict_{name} ##"): with torch.cuda.stream(stream): - stream.wait_event(pre_event) + if pre_event is not None: + stream.wait_event(pre_event) rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows) @@ -622,13 +624,17 @@ def evict( self.timestep, ) - # TODO: is this needed? - # Need a way to synchronize - # actions_count_cpu.record_stream(self.ssd_eviction_stream) - stream.record_event(post_event) + if post_event is not None: + stream.record_event(post_event) + + def _evict_from_scratch_pad(self, return_on_empty: bool) -> None: + scratch_pad_len = len(self.ssd_scratch_pads) + + if not return_on_empty: + assert scratch_pad_len > 0, "There must be at least one scratch pad" + elif scratch_pad_len == 0: + return - def _evict_from_scratch_pad(self, grad: Tensor) -> None: - assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad" (inserted_rows, post_bwd_evicted_indices_cpu, actions_count_cpu, do_evict) = ( self.ssd_scratch_pads.pop(0) ) @@ -640,7 +646,7 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None: actions_count_cpu=actions_count_cpu, stream=self.ssd_eviction_stream, pre_event=self.ssd_event_backward, - post_event=self.ssd_event_evict_sp, + post_event=None, is_rows_uvm=True, name="scratch_pad", ) @@ -683,6 +689,20 @@ def _compute_cache_ptrs( ) ) + # Insert conflict miss indices in the index queue for future lookup + # post_bwd_evicted_indices_cpu is transferred on the ssd_eviction_stream stream + # actions_count_cpu is transferred on the ssd_memcpy_stream stream + with torch.cuda.stream(self.ssd_eviction_stream): + # Ensure that actions_count_cpu transfer is done + self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy) + self.record_function_via_dummy_profile( + "## ssd_scratch_pad_idx_queue_insert ##", + self.scratch_pad_idx_queue.insert_cuda, + post_bwd_evicted_indices_cpu, + actions_count_cpu, + ) + self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert) + with record_function("## ssd_scratch_pads ##"): # Store scratch pad info for post backward eviction self.ssd_scratch_pads.append( @@ -778,12 +798,76 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: ) current_stream = torch.cuda.current_stream() - - inserted_indices_cpu = self.to_pinned_cpu(inserted_indices) + if len(self.ssd_scratch_pads) > 0: + with record_function("## ssd_lookup_scratch_pad ##"): + current_stream.wait_event(self.ssd_event_sp_idxq_insert) + current_stream.wait_event(self.ssd_event_get_inputs_cpy) + + ( + inserted_rows_prev, + post_bwd_evicted_indices_cpu_prev, + actions_count_cpu_prev, + do_evict_prev, + ) = self.ssd_scratch_pads.pop(0) + + # Inserted indices that are found in the scratch pad + # from the previous iteration + sp_locations_cpu = torch.empty( + inserted_indices_cpu.shape, + dtype=inserted_indices_cpu.dtype, + pin_memory=True, + ) + + # Before entering this function: inserted_indices_cpu + # contains all linear indices that are missed from the + # L1 cache + # + # After this function: inserted indices that are found + # in the scratch pad from the previous iteration are + # stored in sp_locations_cpu, while the rests are + # stored in inserted_indices_cpu + # + # An invalid index is -1 or its position > + # actions_count_cpu + self.record_function_via_dummy_profile( + "## ssd_lookup_mask_and_pop_front ##", + self.scratch_pad_idx_queue.lookup_mask_and_pop_front_cuda, + sp_locations_cpu, + post_bwd_evicted_indices_cpu_prev, + inserted_indices_cpu, + actions_count_cpu, + ) + + # Transfer sp_locations_cpu to GPU + sp_locations_gpu = sp_locations_cpu.cuda(non_blocking=True) + + # Copy data from the previous iteration's scratch pad to + # the current iteration's scratch pad + torch.ops.fbgemm.masked_index_select( + inserted_rows, + sp_locations_gpu, + inserted_rows_prev, + actions_count_gpu, + ) + + # Evict from scratch pad + if do_evict_prev: + torch.cuda.current_stream().record_event( + self.ssd_event_backward + ) + self.evict( + rows=inserted_rows_prev, + indices_cpu=post_bwd_evicted_indices_cpu_prev, + actions_count_cpu=actions_count_cpu_prev, + stream=self.ssd_eviction_stream, + pre_event=self.ssd_event_backward, + post_event=None, + is_rows_uvm=True, + name="scratch_pad", + ) # Ensure the previous iterations l3_db.set(..) has completed. current_stream.wait_event(self.ssd_event_evict) - current_stream.wait_event(self.ssd_event_evict_sp) current_stream.wait_event(self.ssd_event_get_inputs_cpy) if linear_cache_indices.numel() > 0: @@ -1030,6 +1114,9 @@ def flush(self) -> None: active_slots_mask_cpu.view(-1) ) + # Evict data from scratch pad if there is scratch pad in the queue + self._evict_from_scratch_pad(return_on_empty=True) + torch.cuda.current_stream().wait_stream(self.ssd_eviction_stream) self.ssd_db.set_cuda( diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index ec11e86e8b..10189fcf74 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -166,6 +166,9 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( evicted_indices, pta::PackedTensorAccessor32 actions_count, + const bool lock_cache_line, + pta::PackedTensorAccessor32 + lxu_cache_locking_counter, TORCH_DSA_KERNEL_ARGS) { // Number of cache sets const int32_t C = lxu_cache_state.size(0); @@ -216,51 +219,65 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( SL += 1; } - // now, we need to insert the (unique!) values in indices[n:n + SL] into + // Now, we need to insert the (unique!) values in indices[n:n + SL] into // our slots. const int32_t slot = threadIdx.x; const int64_t slot_time = lru_state[cache_set][slot]; - int64_t costs[1] = {slot_time}; + + // Check if the slot is locked + const bool is_slot_locked = + lock_cache_line && (lxu_cache_locking_counter[cache_set][slot] > 0); + // Check if the slot has the inserted row that was a cache hit. + const int64_t slot_idx = lxu_cache_state[cache_set][slot]; + const bool slot_has_idx = slot_idx != -1 && slot_time == time_stamp; + // Check if the slot is unavailable: either it is locked or contains + // a cache hit inserted row + const bool is_slot_unavailable = is_slot_locked || slot_has_idx; + + // Set the slot cost: if the slot is not available, set it to the + // maximum timestamp which is the current timestamp. After sorting, + // the unavailable slots will be in the bottom, while the available + // slots will be bubbled to the top + const int64_t slot_cost = is_slot_unavailable ? time_stamp : slot_time; + + // Prepare key-value pair for sorting + int64_t costs[1] = {slot_cost}; int32_t slots[1] = {slot}; + // Sort the slots based on their costs BitonicSort>::sort(costs, slots); - const int32_t sorted_slot = slots[0]; - const int64_t sorted_time = costs[0]; + + // Get the sorted results + const int32_t insert_slot = slots[0]; + const int64_t insert_cost = costs[0]; auto l = threadIdx.x; + // Get the current index + const int64_t current_idx = shfl_sync(slot_idx, insert_slot); + // Insert rows if (l < SL) { // Insert indices - const int32_t insert_slot = sorted_slot; - const int64_t insert_time = sorted_time; - const int64_t insert_idx = cache_set_sorted_indices[n + l]; - const int64_t current_idx = lxu_cache_state[cache_set][insert_slot]; - -#if 0 - // TODO: Check whether to uncomment this - // Only check insert_time if tag is for valid entry - if (current_idx != -1) { - // We need to ensure if prefetching (prefetch_dist) batches ahead - // No entries that are younger than (time_stamp - prefetch_dist) are - // evicted from the cache. This will break the guarantees required - // for the SSD embedding. - // If you hit this assert, increase the cache size. - CUDA_KERNEL_ASSERT2(insert_time < (time_stamp - prefetch_dist)); - } -#endif - if (current_idx != -1 && insert_time == time_stamp) { - // Skip this slot as the inserted row was a cache hit - // This is conflict miss + if (insert_cost == time_stamp) { + // Skip this slot as it is not available evicted_indices[n + l] = -1; assigned_cache_slots[n + l] = -1; } else { evicted_indices[n + l] = current_idx; // -1 if not set, >= 0 if valid. assigned_cache_slots[n + l] = cache_set * kWarpSize + insert_slot; + + // TODO: Check if we can do contiguous writes here. + // Update cache states lxu_cache_state[cache_set][insert_slot] = insert_idx; lru_state[cache_set][insert_slot] = time_stamp; + + // Lock cache line + if (lock_cache_line) { + lxu_cache_locking_counter[cache_set][insert_slot] += 1; + } } } @@ -280,9 +297,11 @@ ssd_cache_populate_actions_cuda( int64_t prefetch_dist, Tensor lru_state, bool gather_cache_stats, - std::optional ssd_cache_stats) { + std::optional ssd_cache_stats, + const bool lock_cache_line, + const c10::optional& lxu_cache_locking_counter) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( - linear_indices, lxu_cache_state, lru_state); + linear_indices, lxu_cache_state, lru_state, lxu_cache_locking_counter); CUDA_DEVICE_GUARD(linear_indices); @@ -332,9 +351,17 @@ ssd_cache_populate_actions_cuda( /*cache_set_inverse_indices=*/at::empty({0}, int_options)); } + Tensor lxu_cache_locking_counter_; + if (lock_cache_line) { + TORCH_CHECK(lxu_cache_locking_counter.has_value()); + lxu_cache_locking_counter_ = lxu_cache_locking_counter.value(); + } else { + lxu_cache_locking_counter_ = + at::empty({0, 0}, lxu_cache_state.options().dtype(at::kInt)); + } + auto actions_count = at::empty({1}, int_options); // Find uncached indices - Tensor lxu_cache_locking_counter = at::empty({0, 0}, int_options); auto [sorted_cache_sets, cache_set_sorted_unique_indices, @@ -348,8 +375,8 @@ ssd_cache_populate_actions_cuda( lru_state, gather_cache_stats, ssd_cache_stats_, - /*lock_cache_line=*/false, - lxu_cache_locking_counter, + lock_cache_line, + lxu_cache_locking_counter_, /*compute_inverse_indices=*/true); TORCH_CHECK(cache_set_inverse_indices.has_value()); @@ -373,7 +400,10 @@ ssd_cache_populate_actions_cuda( MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32), MAKE_PTA_WITH_NAME(func_name, assigned_cache_slots, int32_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, evicted_indices, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, actions_count, int32_t, 1, 32)); + MAKE_PTA_WITH_NAME(func_name, actions_count, int32_t, 1, 32), + lock_cache_line, + MAKE_PTA_WITH_NAME( + func_name, lxu_cache_locking_counter_, int32_t, 2, 32)); return std::make_tuple( cache_set_sorted_unique_indices, @@ -508,3 +538,112 @@ std::tuple ssd_generate_row_addrs_cuda( return {ssd_row_addrs, post_bwd_evicted_indices}; } + +__global__ __launch_bounds__(kMaxThreads) void ssd_update_row_addrs_kernel( + at::PackedTensorAccessor32 + ssd_row_addrs_curr, + const at::PackedTensorAccessor32 + ssd_curr_next_map, + const at::PackedTensorAccessor32 + lxu_cache_locations_curr, + const at::PackedTensorAccessor32 + linear_index_inverse_indices_curr, + const at::PackedTensorAccessor32 + unique_indices_count_cumsum_curr, + const at::PackedTensorAccessor32 + cache_set_inverse_indices_curr, + const uint64_t lxu_cache_weights_addr, + const uint64_t inserted_ssd_weights_addr_next, + const int* N_unique_curr, + const uint64_t cache_row_bytes // has to be 64 bits to prevent overflow +) { + const auto n_curr = blockDim.y * blockIdx.x + threadIdx.y; + if (n_curr >= *N_unique_curr) { + return; + } + + // Find mapping between n_curr and n_next + const auto n_next = ssd_curr_next_map[n_curr]; + + // Return if the row is not used in both previous and next iterations + if (n_next < 0) { + return; + } + + // Find out if the row gets moved to the nextent iteration's scratch pad or + // L1 by checking the lxu_cache_locations_curr + const auto cache_set_id_curr = cache_set_inverse_indices_curr[n_curr]; + const auto segment_start_curr = + unique_indices_count_cumsum_curr[cache_set_id_curr]; + const auto segment_end_curr = + unique_indices_count_cumsum_curr[cache_set_id_curr + 1]; + const auto cache_loc_curr = lxu_cache_locations_curr + [linear_index_inverse_indices_curr[segment_start_curr]]; + + const uint64_t ptr_addr = (cache_loc_curr == -1) + // The row is moved from the previous iteration's scratch pad to the + // next iteration's scratch pad + ? (inserted_ssd_weights_addr_next + (n_next * cache_row_bytes)) + // The row is moved from the previous iteration's scratch pad to L1 cache + : (lxu_cache_weights_addr + (cache_loc_curr * cache_row_bytes)); + + // Set pointer address + for (auto l = segment_start_curr + threadIdx.x; l < segment_end_curr; + l += blockDim.x) { + auto dst = linear_index_inverse_indices_curr[l]; + *reinterpret_cast(&ssd_row_addrs_curr[dst]) = ptr_addr; + } +} + +void ssd_update_row_addrs_cuda( + const Tensor& ssd_row_addrs_curr, + const Tensor& ssd_curr_next_map, + const Tensor& lxu_cache_locations_curr, + const Tensor& linear_index_inverse_indices_curr, + const Tensor& unique_indices_count_cumsum_curr, + const Tensor& cache_set_inverse_indices_curr, + const Tensor& lxu_cache_weights, + const Tensor& inserted_ssd_weights_next, + const Tensor& unique_indices_length_curr) { + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( + ssd_row_addrs_curr, + ssd_curr_next_map, + lxu_cache_locations_curr, + linear_index_inverse_indices_curr, + unique_indices_count_cumsum_curr, + cache_set_inverse_indices_curr, + lxu_cache_weights, + inserted_ssd_weights_next, + unique_indices_length_curr); + + CUDA_DEVICE_GUARD(ssd_row_addrs_curr); + + const auto lxu_cache_weights_addr = + reinterpret_cast(lxu_cache_weights.data_ptr()); + const auto inserted_ssd_weights_addr_next = + reinterpret_cast(inserted_ssd_weights_next.data_ptr()); + const auto cache_row_bytes = + lxu_cache_weights.size(1) * lxu_cache_weights.element_size(); + constexpr auto kNumWarps = kMaxThreads / kWarpSize; + + ssd_update_row_addrs_kernel<<< + div_round_up(ssd_row_addrs_curr.numel(), kNumWarps), + dim3(kWarpSize, kNumWarps), + 0, + at::cuda::getCurrentCUDAStream()>>>( + ssd_row_addrs_curr.packed_accessor32(), + ssd_curr_next_map.packed_accessor32(), + lxu_cache_locations_curr + .packed_accessor32(), + linear_index_inverse_indices_curr + .packed_accessor32(), + unique_indices_count_cumsum_curr + .packed_accessor32(), + cache_set_inverse_indices_curr + .packed_accessor32(), + lxu_cache_weights_addr, + inserted_ssd_weights_addr_next, + unique_indices_length_curr.data_ptr(), + cache_row_bytes); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index b9abcbbd54..1cc67815a3 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -26,7 +26,9 @@ ssd_cache_populate_actions_cuda( int64_t prefetch_dist, Tensor lru_state, bool gather_cache_stats, - std::optional ssd_cache_stats); + std::optional ssd_cache_stats, + const bool lock_cache_line, + const c10::optional& lxu_cache_locking_counter); /// @ingroup embedding-ssd /// @@ -146,6 +148,56 @@ std::tuple ssd_generate_row_addrs_cuda( const Tensor& unique_indices_length, const Tensor& cache_set_sorted_unique_indices); +/// @ingroup embedding-ssd +/// +/// @brief Update memory addresses for SSD TBE data +/// +/// When pipeline prefetching is enabled, data in a scratch pad of the +/// current iteration can be moved to L1 or a scratch pad of the next +/// iteration during the prefetch step. This operator updates the +/// memory addresses of data that is relocated to the correct +/// location. +/// +/// @param ssd_row_addrs_curr The tensor that contains the row address +/// of the current iteration +/// @param inserted_ssd_weights_curr_next_map The tensor that contains +/// mapping between the location of each index in the +/// current iteration in the scratch pad of the next +/// iteration. (-1 = the data has not been moved). +/// inserted_ssd_weights_curr_next_map[i] is the location +// of index i in the next iteration's scratch pad. +/// @param lxu_cache_locations_curr The tensor that contains cache +/// slots where data is stored for the *full* list of +/// indices for the current iteration. -1 is a sentinel +/// value that indicates that data is not in cache. +/// @param linear_index_inverse_indices_curr The tensor that contains +/// the original position of linear indices before being +/// sorted for the current iteration +/// @param unique_indices_count_cumsum_curr The tensor that contains +/// the the exclusive prefix sum results of the counts of +/// unique indices for the current iteration +/// @param cache_set_inverse_indices_curr The tensor that contains the +/// original positions of cache sets before being sorted +/// for the current iteration +/// @param lxu_cache_weights The LXU cache tensor +/// @param inserted_ssd_weights_next The scratch pad tensor for the +/// next iteration +/// @param unique_indices_length_curr The tensor that contains the +/// number of unique indices (GPU tensor) for the current +/// iteration +/// +/// @return None +void ssd_update_row_addrs_cuda( + const Tensor& ssd_row_addrs_curr, + const Tensor& inserted_ssd_weights_curr_next_map, + const Tensor& lxu_cache_locations_curr, + const Tensor& linear_index_inverse_indices_curr, + const Tensor& unique_indices_count_cumsum_curr, + const Tensor& cache_set_inverse_indices_curr, + const Tensor& lxu_cache_weights, + const Tensor& inserted_ssd_weights_next, + const Tensor& unique_indices_length_curr); + namespace { class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { public: @@ -298,7 +350,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " int prefetch_dist, " " Tensor lru_state, " " bool gather_cache_stats=False, " - " Tensor? ssd_cache_stats=None" + " Tensor? ssd_cache_stats=None, " + " bool lock_cache_line=False, " + " Tensor? lxu_cache_locking_counter=None" ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); DISPATCH_TO_CUDA( "ssd_cache_populate_actions", ssd_cache_populate_actions_cuda); @@ -315,5 +369,18 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor cache_set_sorted_unique_indices" ") -> (Tensor, Tensor)"); DISPATCH_TO_CUDA("ssd_generate_row_addrs", ssd_generate_row_addrs_cuda); + m.def( + "ssd_update_row_addrs(" + " Tensor ssd_row_addrs_curr, " + " Tensor inserted_ssd_weights_curr_next_map, " + " Tensor lxu_cache_locations_curr, " + " Tensor linear_index_inverse_indices_curr, " + " Tensor unique_indices_count_cumsum_curr, " + " Tensor cache_set_inverse_indices_curr, " + " Tensor lxu_cache_weights, " + " Tensor inserted_ssd_weights_next, " + " Tensor unique_indices_length_curr" + ") -> ()"); + DISPATCH_TO_CUDA("ssd_update_row_addrs", ssd_update_row_addrs_cuda); } } // namespace