diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 03f8eb941c..3f54868280 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -14,7 +14,7 @@ import os import tempfile from math import log2 -from typing import Any, Callable, List, Optional, Tuple, Type +from typing import Any, Callable, List, Optional, Tuple, Type, Union import torch # usort:skip import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers @@ -35,11 +35,25 @@ ) from torch import distributed as dist, nn, Tensor # usort:skip +from dataclasses import dataclass + from torch.autograd.profiler import record_function +from ..cache import get_unique_indices_v2 + from .common import ASSOC +@dataclass +class IterData: + indices: Tensor + offsets: Tensor + lxu_cache_locations: Tensor + lxu_cache_ptrs: Tensor + actions_count_gpu: Tensor + cache_set_inverse_indices: Tensor + + class SSDTableBatchedEmbeddingBags(nn.Module): D_offsets: Tensor lxu_cache_weights: Tensor @@ -112,6 +126,8 @@ def __init__( gather_ssd_cache_stats: Optional[bool] = False, stats_reporter_config: Optional[TBEStatsReporterConfig] = None, l2_cache_size: int = 0, + # Set to True to enable pipeline prefetching + prefetch_pipeline: bool = False, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -182,14 +198,42 @@ def __init__( ) self.register_buffer( "lxu_cache_state", - torch.zeros(cache_sets, ASSOC, dtype=torch.int64).fill_(-1), + torch.zeros( + cache_sets, ASSOC, device=self.current_device, dtype=torch.int64 + ).fill_(-1), ) self.register_buffer( - "lru_state", torch.zeros(cache_sets, ASSOC, dtype=torch.int64) + "lru_state", + torch.zeros( + cache_sets, ASSOC, device=self.current_device, dtype=torch.int64 + ), ) self.step = 0 + # Set prefetch pipeline + self.prefetch_pipeline: bool = prefetch_pipeline + self.prefetch_stream: Optional[torch.cuda.Stream] = None + + # Cache locking counter for pipeline prefetching + if self.prefetch_pipeline: + self.register_buffer( + "lxu_cache_locking_counter", + torch.zeros( + cache_sets, + ASSOC, + device=self.current_device, + dtype=torch.int32, + ), + persistent=True, + ) + else: + self.register_buffer( + "lxu_cache_locking_counter", + torch.zeros([0, 0], dtype=torch.int32, device=self.current_device), + persistent=False, + ) + assert ssd_cache_location in ( EmbeddingLocation.MANAGED, EmbeddingLocation.DEVICE, @@ -342,8 +386,11 @@ def __init__( self.ssd_event_backward = torch.cuda.Event() # SSD get's input copy completion event self.ssd_event_get_inputs_cpy = torch.cuda.Event() - # SSD scratch pad index queue insert completion event - self.ssd_event_sp_idxq_insert = torch.cuda.Event() + if self.prefetch_pipeline: + # SSD scratch pad index queue insert completion event + self.ssd_event_sp_idxq_insert: torch.cuda.streams.Event = torch.cuda.Event() + # SSD scratch pad index queue lookup completion event + self.ssd_event_sp_idxq_lookup: torch.cuda.streams.Event = torch.cuda.Event() self.timesteps_prefetched: List[int] = [] # TODO: add type annotation @@ -351,10 +398,19 @@ def __init__( self.ssd_prefetch_data = [] # Scratch pad value queue - self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = [] - # pyre-ignore[4] - # Scratch pad index queue - self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue(-1) + self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = [] + # Scratch pad eviction data queue + self.ssd_scratch_pad_eviction_data: List[ + Tuple[Tensor, Tensor, Tensor, bool] + ] = [] + self.ssd_location_update_data: List[Tuple[Tensor, Tensor]] = [] + + if self.prefetch_pipeline: + # pyre-ignore[4] + # Scratch pad index queue + self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue( + -1 + ) if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization: raise AssertionError( @@ -425,12 +481,24 @@ def __init__( dtype=torch.float32, ) + # For storing current iteration data + self.current_iter_data: Optional[IterData] = None + # add placeholder require_grad param to enable autograd without nn.parameter # this is needed to enable int8 embedding weights for SplitTableBatchedEmbedding self.placeholder_autograd_tensor = nn.Parameter( torch.zeros(0, device=self.current_device, dtype=torch.float) ) + # Register backward hook for evicting rows from a scratch pad to SSD + # post backward + self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad) + + if self.prefetch_pipeline: + self.register_full_backward_pre_hook( + self._update_cache_counter_and_pointers + ) + assert optimizer in ( OptimType.EXACT_ROWWISE_ADAGRAD, ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags" @@ -542,28 +610,31 @@ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor: t_cpu.copy_(t, non_blocking=True) return t_cpu - def to_pinned_cpu_on_stream_wait_on_current_stream( + def to_pinned_cpu_on_stream_wait_on_another_stream( self, tensors: List[Tensor], stream: torch.cuda.Stream, + stream_to_wait_on: torch.cuda.Stream, post_event: Optional[torch.cuda.Event] = None, ) -> List[Tensor]: """ - Transfer input tensors from GPU to CPU using a pinned host buffer. - The transfer is carried out on the given stream (`stream`) after all - the kernels in the default stream (`current_stream`) are complete. + Transfer input tensors from GPU to CPU using a pinned host + buffer. The transfer is carried out on the given stream + (`stream`) after all the kernels in the other stream + (`stream_to_wait_on`) are complete. Args: - tensors (List[Tensor]): The list of tensors to be transferred + tensors (List[Tensor]): The list of tensors to be + transferred stream (Stream): The stream to run memory copy + stream_to_wait_on (Stream): The stream to wait on post_event (Event): The post completion event Returns: The list of pinned CPU tensors """ - current_stream = torch.cuda.current_stream() with torch.cuda.stream(stream): - stream.wait_stream(current_stream) + stream.wait_stream(stream_to_wait_on) cpu_tensors = [] for t in tensors: t.record_stream(stream) @@ -626,115 +697,202 @@ def evict( if post_event is not None: stream.record_event(post_event) - def _evict_from_scratch_pad(self, return_on_empty: bool) -> None: - scratch_pad_len = len(self.ssd_scratch_pads) + def _evict_from_scratch_pad(self, grad: Tensor) -> None: + """ + Evict conflict missed rows from a scratch pad + (`inserted_rows`) on the `ssd_eviction_stream`. This is a hook + that is invoked right after TBE backward. - if not return_on_empty: - assert scratch_pad_len > 0, "There must be at least one scratch pad" - elif scratch_pad_len == 0: - return + Conflict missed indices are specified in + `post_bwd_evicted_indices_cpu`. Indices that are not -1 and + their positions < `actions_count_cpu` (i.e., rows + `post_bwd_evicted_indices_cpu[:actions_count_cpu] != -1` in + post_bwd_evicted_indices_cpu) will be evicted. + + Args: + grad (Tensor): Unused gradient tensor + + Returns: + None + """ + with record_function("## ssd_evict_from_scratch_pad_pipeline ##"): + current_stream = torch.cuda.current_stream() + current_stream.record_event(self.ssd_event_backward) + + assert ( + len(self.ssd_scratch_pad_eviction_data) > 0 + ), "There must be at least one scratch pad" + + ( + inserted_rows, + post_bwd_evicted_indices_cpu, + actions_count_cpu, + do_evict, + ) = self.ssd_scratch_pad_eviction_data.pop(0) + + if not do_evict: + return - (inserted_rows, post_bwd_evicted_indices_cpu, actions_count_cpu, do_evict) = ( - self.ssd_scratch_pads.pop(0) - ) - if do_evict: - torch.cuda.current_stream().record_event(self.ssd_event_backward) self.evict( rows=inserted_rows, indices_cpu=post_bwd_evicted_indices_cpu, actions_count_cpu=actions_count_cpu, stream=self.ssd_eviction_stream, pre_event=self.ssd_event_backward, - post_event=None, + post_event=self.ssd_event_evict, is_rows_uvm=True, name="scratch_pad", ) - def _compute_cache_ptrs( + if self.prefetch_stream: + self.prefetch_stream.wait_stream(current_stream) + + def _update_cache_counter_and_pointers( self, - linear_cache_indices: torch.Tensor, - assigned_cache_slots: torch.Tensor, - linear_index_inverse_indices: torch.Tensor, - unique_indices_count_cumsum: torch.Tensor, - cache_set_inverse_indices: torch.Tensor, - inserted_rows: torch.Tensor, - unique_indices_length: torch.Tensor, - inserted_indices: torch.Tensor, - actions_count_cpu: torch.Tensor, - lxu_cache_locations: torch.Tensor, - ) -> torch.Tensor: - with record_function("## ssd_generate_row_addrs ##"): - lxu_cache_ptrs, post_bwd_evicted_indices = ( - torch.ops.fbgemm.ssd_generate_row_addrs( - lxu_cache_locations, - assigned_cache_slots, - linear_index_inverse_indices, - unique_indices_count_cumsum, - cache_set_inverse_indices, - self.lxu_cache_weights, - inserted_rows, - unique_indices_length, - inserted_indices, - ) + module: nn.Module, + grad_input: Union[Tuple[Tensor, ...], Tensor], + ) -> None: + """ + Update cache line locking counter and pointers before backward + TBE. This is a hook that is called before the backward of TBE + + Update cache line counter: + + We ensure that cache prefetching does not execute concurrently + with the backward TBE. Therefore, it is safe to unlock the + cache lines used in current iteration before backward TBE. + + Update pointers: + + Now some rows that are used in both the current iteration and + the next iteration are moved (1) from the current iteration's + scratch pad into the next iteration's scratch pad or (2) from + the current iteration's scratch pad into the L1 cache + + To ensure that the TBE backward kernel accesses valid data, + here we update the pointers of these rows in the current + iteration's `lxu_cache_ptrs` to point to either L1 cache or + the next iteration scratch pad + + Args: + module (nn.Module): Unused + grad_input (Union[Tuple[Tensor, ...], Tensor]): Unused + + Returns: + None + """ + if self.prefetch_stream: + # Ensure that prefetch is done + torch.cuda.current_stream().wait_stream(self.prefetch_stream) + + assert self.current_iter_data is not None, "current_iter_data must be set" + + curr_data: IterData = self.current_iter_data + + if curr_data.lxu_cache_locations.numel() == 0: + return + + with record_function("## ssd_update_cache_counter_and_pointers ##"): + # Unlock the cache lines + torch.ops.fbgemm.lxu_cache_locking_counter_decrement( + self.lxu_cache_locking_counter, + curr_data.lxu_cache_locations, ) - # Transfer post_bwd_evicted_indices from GPU to CPU right away to - # increase a chance of overlapping with compute in the default stream - (post_bwd_evicted_indices_cpu,) = ( - self.to_pinned_cpu_on_stream_wait_on_current_stream( - tensors=[post_bwd_evicted_indices], - stream=self.ssd_eviction_stream, - post_event=None, + # Recompute linear_cache_indices to save memory + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + self.hash_size_cumsum, + curr_data.indices, + curr_data.offsets, + ) + ( + linear_unique_indices, + linear_unique_indices_length, + unique_indices_count, + linear_index_inverse_indices, + ) = get_unique_indices_v2( + linear_cache_indices, + self.total_hash_size, + compute_count=True, + compute_inverse_indices=True, + ) + unique_indices_count_cumsum = torch.ops.fbgemm.asynchronous_complete_cumsum( + unique_indices_count ) - ) - # Insert conflict miss indices in the index queue for future lookup - # post_bwd_evicted_indices_cpu is transferred on the ssd_eviction_stream stream - # actions_count_cpu is transferred on the ssd_memcpy_stream stream - with torch.cuda.stream(self.ssd_eviction_stream): - # Ensure that actions_count_cpu transfer is done - self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy) - self.record_function_via_dummy_profile( - "## ssd_scratch_pad_idx_queue_insert ##", - self.scratch_pad_idx_queue.insert_cuda, - post_bwd_evicted_indices_cpu, - actions_count_cpu, + # Look up the cache to check which indices in the scratch + # pad are moved to L1 + torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + self.lxu_cache_state, + self.total_hash_size, + gather_cache_stats=False, # not collecting cache stats + lxu_cache_locations_output=curr_data.lxu_cache_locations, ) - self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert) - with record_function("## ssd_scratch_pads ##"): - # Store scratch pad info for post backward eviction - self.ssd_scratch_pads.append( - ( - inserted_rows, - post_bwd_evicted_indices_cpu, - actions_count_cpu, - linear_cache_indices.numel() > 0, - ) + if len(self.ssd_location_update_data) == 0: + return + + (sp_curr_next_map, inserted_rows_next) = self.ssd_location_update_data.pop( + 0 ) - # pyre-fixme[7]: Expected `Tensor` but got `Tuple[typing.Any, Tensor, - # typing.Any, Tensor]`. - return ( - lxu_cache_ptrs, - inserted_rows, - post_bwd_evicted_indices_cpu, - actions_count_cpu, - ) + # Update poitners + torch.ops.fbgemm.ssd_update_row_addrs( + ssd_row_addrs_curr=curr_data.lxu_cache_ptrs, + inserted_ssd_weights_curr_next_map=sp_curr_next_map, + lxu_cache_locations_curr=curr_data.lxu_cache_locations, + linear_index_inverse_indices_curr=linear_index_inverse_indices, + unique_indices_count_cumsum_curr=unique_indices_count_cumsum, + cache_set_inverse_indices_curr=curr_data.cache_set_inverse_indices, + lxu_cache_weights=self.lxu_cache_weights, + inserted_ssd_weights_next=inserted_rows_next, + unique_indices_length_curr=curr_data.actions_count_gpu, + ) + + def prefetch( # noqa C901 + self, + indices: Tensor, + offsets: Tensor, + forward_stream: Optional[torch.cuda.Stream] = None, + ) -> Optional[Tensor]: + if self.prefetch_stream is None and forward_stream is not None: + # Set the prefetch stream to the current stream + self.prefetch_stream = torch.cuda.current_stream() + assert ( + self.prefetch_stream != forward_stream + ), "prefetch_stream and forward_stream should not be the same stream" + + # TODO: Refactor prefetch + current_stream = torch.cuda.current_stream() + + # If a prefetch stream is used, then record tensors on the stream + indices.record_stream(current_stream) + offsets.record_stream(current_stream) - def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: - with record_function("## ssd_prefetch ##"): + with record_function("## ssd_prefetch {} ##".format(self.timestep)): if self.gather_ssd_cache_stats: self.local_ssd_cache_stats.zero_() (indices, offsets) = indices.long(), offsets.long() + # Linearize indices linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( self.hash_size_cumsum, indices, offsets, ) + self.timestep += 1 self.timesteps_prefetched.append(self.timestep) + + # Lookup and virtually insert indices into L1. After this operator, + # we know: + # (1) which cache lines can be evicted + # (2) which rows are already in cache (hit) + # (3) which rows are missed and can be inserted later (missed, but + # not conflict missed) + # (4) which rows are missed but CANNOT be inserted later (conflict + # missed) ( inserted_indices, evicted_indices, @@ -753,33 +911,58 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: self.lru_state, self.gather_ssd_cache_stats, self.local_ssd_cache_stats, + lock_cache_line=self.prefetch_pipeline, + lxu_cache_locking_counter=self.lxu_cache_locking_counter, ) - # Transfer evicted indices from GPU to CPU right away to increase a - # chance of overlapping with compute on the default stream - (evicted_indices_cpu,) = ( - self.to_pinned_cpu_on_stream_wait_on_current_stream( - tensors=[evicted_indices], - stream=self.ssd_eviction_stream, - post_event=None, + + # Compute cache locations (rows that are hit are missed but can be + # inserted will have cache locations != -1) + with record_function("## ssd_tbe_lxu_cache_lookup ##"): + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + self.lxu_cache_state, + self.total_hash_size, + self.gather_ssd_cache_stats, + self.local_ssd_cache_stats, ) - ) - actions_count_cpu, inserted_indices_cpu = ( - self.to_pinned_cpu_on_stream_wait_on_current_stream( - tensors=[ - actions_count_gpu, - inserted_indices, - ], - stream=self.ssd_memcpy_stream, - post_event=self.ssd_event_get_inputs_cpy, + with record_function("## ssd_d2h_inserted_indices ##"): + # Transfer actions_count and insert_indices right away to + # incrase an overlap opportunity + actions_count_cpu, inserted_indices_cpu = ( + self.to_pinned_cpu_on_stream_wait_on_another_stream( + tensors=[ + actions_count_gpu, + inserted_indices, + ], + stream=self.ssd_memcpy_stream, + stream_to_wait_on=current_stream, + post_event=self.ssd_event_get_inputs_cpy, + ) ) - ) - assigned_cache_slots = assigned_cache_slots.long() - evicted_rows = self.lxu_cache_weights[ - assigned_cache_slots.clamp(min=0).long(), : - ] + with record_function("## ssd_d2h_evicted_indices ##"): + # Transfer evicted indices from GPU to CPU right away to increase a + # chance of overlapping with compute on the default stream + (evicted_indices_cpu,) = ( + self.to_pinned_cpu_on_stream_wait_on_another_stream( + tensors=[evicted_indices], + stream=self.ssd_eviction_stream, + stream_to_wait_on=current_stream, + post_event=None, + ) + ) + + # Copy rows to be evicted into a separate buffer (will be evicted + # later in the prefetch step) + with record_function("## ssd_compute_evicted_rows ##"): + assigned_cache_slots = assigned_cache_slots.long() + evicted_rows = self.lxu_cache_weights[ + assigned_cache_slots.clamp(min=0).long(), : + ] + # Allocation a scratch pad for the current iteration. The scratch + # pad is a UVA tensor if linear_cache_indices.numel() > 0: inserted_rows = torch.ops.fbgemm.new_managed_tensor( torch.zeros( @@ -796,34 +979,47 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: device=self.current_device, ) - current_stream = torch.cuda.current_stream() - if len(self.ssd_scratch_pads) > 0: + if self.prefetch_pipeline and len(self.ssd_scratch_pads) > 0: + # Look up all missed indices from the previous iteration's + # scratch pad (do this only if pipeline prefetching is being + # used) with record_function("## ssd_lookup_scratch_pad ##"): - current_stream.wait_event(self.ssd_event_sp_idxq_insert) - current_stream.wait_event(self.ssd_event_get_inputs_cpy) - + # Get the previous scratch pad ( inserted_rows_prev, post_bwd_evicted_indices_cpu_prev, actions_count_cpu_prev, - do_evict_prev, ) = self.ssd_scratch_pads.pop(0) # Inserted indices that are found in the scratch pad # from the previous iteration - sp_locations_cpu = torch.empty( + sp_prev_curr_map_cpu = torch.empty( inserted_indices_cpu.shape, dtype=inserted_indices_cpu.dtype, pin_memory=True, ) + # Conflict missed indices from the previous iteration that + # overlap with the current iterations's inserted indices + sp_curr_prev_map_cpu = torch.empty( + post_bwd_evicted_indices_cpu_prev.shape, + dtype=torch.int, + pin_memory=True, + ).fill_(-1) + + # Ensure that the necessary D2H transfers are done + current_stream.wait_event(self.ssd_event_get_inputs_cpy) + # Ensure that the previous iteration's scratch pad indices + # insertion is complete + current_stream.wait_event(self.ssd_event_sp_idxq_insert) + # Before entering this function: inserted_indices_cpu # contains all linear indices that are missed from the # L1 cache # # After this function: inserted indices that are found # in the scratch pad from the previous iteration are - # stored in sp_locations_cpu, while the rests are + # stored in sp_prev_curr_map_cpu, while the rests are # stored in inserted_indices_cpu # # An invalid index is -1 or its position > @@ -831,44 +1027,55 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: self.record_function_via_dummy_profile( "## ssd_lookup_mask_and_pop_front ##", self.scratch_pad_idx_queue.lookup_mask_and_pop_front_cuda, - sp_locations_cpu, - post_bwd_evicted_indices_cpu_prev, - inserted_indices_cpu, - actions_count_cpu, + sp_prev_curr_map_cpu, # scratch_pad_prev_curr_map + sp_curr_prev_map_cpu, # scratch_pad_curr_prev_map + post_bwd_evicted_indices_cpu_prev, # scratch_pad_indices_prev + inserted_indices_cpu, # inserted_indices_curr + actions_count_cpu, # count_curr ) - # Transfer sp_locations_cpu to GPU - sp_locations_gpu = sp_locations_cpu.cuda(non_blocking=True) + # Mark scratch pad index queue lookup completion + current_stream.record_event(self.ssd_event_sp_idxq_lookup) + + # Transfer sp_prev_curr_map_cpu to GPU + sp_prev_curr_map_gpu = sp_prev_curr_map_cpu.cuda(non_blocking=True) + # Transfer sp_curr_prev_map_cpu to GPU + sp_curr_prev_map_gpu = sp_curr_prev_map_cpu.cuda(non_blocking=True) + + # Previously actions_count_gpu was recorded on another + # stream. Thus, we need to record it on this stream + actions_count_gpu.record_stream(current_stream) # Copy data from the previous iteration's scratch pad to # the current iteration's scratch pad torch.ops.fbgemm.masked_index_select( inserted_rows, - sp_locations_gpu, + sp_prev_curr_map_gpu, inserted_rows_prev, actions_count_gpu, ) - # Evict from scratch pad - if do_evict_prev: - torch.cuda.current_stream().record_event( - self.ssd_event_backward - ) - self.evict( - rows=inserted_rows_prev, - indices_cpu=post_bwd_evicted_indices_cpu_prev, - actions_count_cpu=actions_count_cpu_prev, - stream=self.ssd_eviction_stream, - pre_event=self.ssd_event_backward, - post_event=None, - is_rows_uvm=True, - name="scratch_pad", + # Record the tensors that will be pushed into a queue + # on the forward stream + if forward_stream: + sp_curr_prev_map_gpu.record_stream(forward_stream) + + # Store info for evicting the previous iteration's + # scratch pad after the corresponding backward pass is + # done + self.ssd_location_update_data.append( + ( + sp_curr_prev_map_gpu, + inserted_rows, ) + ) - # Ensure the previous iterations l3_db.set(..) has completed. + # Ensure the previous iterations eviction is complete current_stream.wait_event(self.ssd_event_evict) + # Ensure that D2H is done current_stream.wait_event(self.ssd_event_get_inputs_cpy) + # Fetch data from SSD if linear_cache_indices.numel() > 0: self.record_function_via_dummy_profile( "## ssd_get ##", @@ -877,8 +1084,11 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: inserted_rows, actions_count_cpu, ) + + # Record an event to mark the completion of `get_cuda` current_stream.record_event(self.ssd_event_get) + # Copy rows from the current iteration's scratch pad to L1 torch.ops.fbgemm.masked_index_put( self.lxu_cache_weights, assigned_cache_slots, @@ -894,36 +1104,105 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: actions_count_cpu=actions_count_cpu, stream=self.ssd_eviction_stream, pre_event=self.ssd_event_get, - post_event=self.ssd_event_evict, + # Record completion event after scratch pad eviction + # instead since that happens after L1 eviction + post_event=None, is_rows_uvm=False, name="cache", ) - with record_function("## ssd_tbe_lxu_cache_lookup ##"): - lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( - linear_cache_indices, - self.lxu_cache_state, - self.total_hash_size, - self.gather_ssd_cache_stats, - self.local_ssd_cache_stats, + # Generate row addresses (pointing to either L1 or the current + # iteration's scratch pad) + with record_function("## ssd_generate_row_addrs ##"): + lxu_cache_ptrs, post_bwd_evicted_indices = ( + torch.ops.fbgemm.ssd_generate_row_addrs( + lxu_cache_locations, + assigned_cache_slots, + linear_index_inverse_indices, + unique_indices_count_cumsum, + cache_set_inverse_indices, + self.lxu_cache_weights, + inserted_rows, + unique_indices_length, + inserted_indices, + ) + ) + + with record_function("## ssd_d2h_post_bwd_evicted_indices ##"): + # Transfer post_bwd_evicted_indices from GPU to CPU right away to + # increase a chance of overlapping with compute in the default stream + (post_bwd_evicted_indices_cpu,) = ( + self.to_pinned_cpu_on_stream_wait_on_another_stream( + tensors=[post_bwd_evicted_indices], + stream=self.ssd_eviction_stream, + stream_to_wait_on=current_stream, + post_event=None, + ) + ) + + if self.prefetch_pipeline: + # Insert the current iteration's conflict miss indices in the index + # queue for future lookup. + # + # post_bwd_evicted_indices_cpu is transferred on the + # ssd_eviction_stream stream so it does not need stream + # synchronization + # + # actions_count_cpu is transferred on the ssd_memcpy_stream stream. + # Thus, we have to explicitly sync the stream + with torch.cuda.stream(self.ssd_eviction_stream): + # Ensure that actions_count_cpu transfer is done + self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy) + # Ensure that the scratch pad index queue look up is complete + self.ssd_eviction_stream.wait_event(self.ssd_event_sp_idxq_lookup) + self.record_function_via_dummy_profile( + "## ssd_scratch_pad_idx_queue_insert ##", + self.scratch_pad_idx_queue.insert_cuda, + post_bwd_evicted_indices_cpu, + actions_count_cpu, + ) + # Mark the completion of scratch pad index insertion + self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert) + + prefetch_data = ( + lxu_cache_ptrs, + inserted_rows, + post_bwd_evicted_indices_cpu, + actions_count_cpu, + actions_count_gpu, + lxu_cache_locations, + cache_set_inverse_indices, + ) + + # Record tensors on the forward stream + if forward_stream is not None: + for t in prefetch_data: + if t.is_cuda: + t.record_stream(forward_stream) + + # Store scratch pad info for the lookup in the next iteration + # prefetch + self.ssd_scratch_pads.append( + ( + inserted_rows, + post_bwd_evicted_indices_cpu, + actions_count_cpu, ) + ) - # TODO: keep only necessary tensors - self.ssd_prefetch_data.append( + # Store scratch pad info for post backward eviction + self.ssd_scratch_pad_eviction_data.append( ( - linear_cache_indices, - assigned_cache_slots, - linear_index_inverse_indices, - unique_indices_count_cumsum, - cache_set_inverse_indices, inserted_rows, - unique_indices_length, - inserted_indices, + post_bwd_evicted_indices_cpu, actions_count_cpu, - lxu_cache_locations, + linear_cache_indices.numel() > 0, ) ) + # Store data for forward + self.ssd_prefetch_data.append(prefetch_data) + if self.gather_ssd_cache_stats: self.ssd_cache_stats = torch.add( self.ssd_cache_stats, self.local_ssd_cache_stats @@ -945,13 +1224,25 @@ def forward( self.prefetch(indices, offsets) assert len(self.ssd_prefetch_data) > 0 - prefetch_data = self.ssd_prefetch_data.pop(0) ( lxu_cache_ptrs, inserted_rows, post_bwd_evicted_indices_cpu, actions_count_cpu, - ) = self._compute_cache_ptrs(*prefetch_data) + actions_count_gpu, + lxu_cache_locations, + cache_set_inverse_indices, + ) = self.ssd_prefetch_data.pop(0) + + # Storing current iteration data for future use + self.current_iter_data = IterData( + indices, + offsets, + lxu_cache_locations, + lxu_cache_ptrs, + actions_count_gpu, + cache_set_inverse_indices, + ) common_args = invokers.lookup_args_ssd.CommonArgs( placeholder_autograd_tensor=self.placeholder_autograd_tensor, @@ -1113,9 +1404,6 @@ def flush(self) -> None: active_slots_mask_cpu.view(-1) ) - # Evict data from scratch pad if there is scratch pad in the queue - self._evict_from_scratch_pad(return_on_empty=True) - torch.cuda.current_stream().wait_stream(self.ssd_eviction_stream) self.ssd_db.set_cuda( diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_scratch_pad_indices_queue.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_scratch_pad_indices_queue.cpp index f96cd5424d..c174f2c196 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_scratch_pad_indices_queue.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_scratch_pad_indices_queue.cpp @@ -53,15 +53,20 @@ class SSDScratchPadIndicesQueueImpl } void lookup_mask_and_pop_front_cuda( - const Tensor& scratch_pad_locations, - const Tensor& scratch_pad_indices, - const Tensor& ssd_indices, - const Tensor& count) { + const Tensor& scratch_pad_prev_curr_map, + const Tensor& scratch_pad_curr_prev_map, + const Tensor& scratch_pad_indices_prev, + const Tensor& inserted_indices_curr, + const Tensor& count_curr) { // take reference to self to avoid lifetime issues. auto self = shared_from_this(); std::function* functor = new std::function([=]() { self->lookup_mask_and_pop_front( - scratch_pad_locations, scratch_pad_indices, ssd_indices, count); + scratch_pad_prev_curr_map, + scratch_pad_curr_prev_map, + scratch_pad_indices_prev, + inserted_indices_curr, + count_curr); }); AT_CUDA_CHECK(cudaStreamAddCallback( at::cuda::getCurrentCUDAStream(), @@ -101,91 +106,103 @@ class SSDScratchPadIndicesQueueImpl index_loc_map_queue.push(std::move(map)); } - /// Looks up `ssd_indices` in the front hash map in the queue. This - /// is equivalent to looking up indices in the front scratch pad. + /// Looks up `inserted_indices_curr` in the front hash map in the + /// queue. This is equivalent to looking up indices in the front + /// scratch pad (i.e., the previous scratch pad). /// /// If an index is found: /// - /// - Sets the corresponding `scratch_pad_locations` to the location - /// of the index in the scratch pad (the value in the hash map) + /// - Sets the corresponding `scratch_pad_prev_curr_map` to the + /// location of the index in the front scratch pad (the value in the + /// hash map). scratch_pad_prev_curr_map[i] is the location in the + /// previous scratch pad of the the current scratch pad's index i /// - /// - Sets the corresponding `ssd_indices` to the sentinel value - /// (this is to prevent looking up this index from SSD) + /// - Sets the corresponding `inserted_indices_curr` to the sentinel + /// value (this is to prevent looking up this index from SSD) /// - /// - Sets the `scratch_pad_indices` to the sentinel value (this is - /// to prevent evicting the corresponding row from the scratch pad). + /// - Sets the `scratch_pad_indices_prev` to the sentinel value + /// (this is to prevent evicting the corresponding row from the + /// previous scratch pad). /// - /// Else: Sets the corresponding `scratch_pad_locations` to the + /// Else: Sets the corresponding `scratch_pad_prev_curr_map` to the /// sentinel value (to indicate that the index is not found in the - /// scratch pad) + /// previous scratch pad). /// /// Once the process above is done, pop the hash map from the queue. /// - /// @param scratch_pad_locations The 1D output tensor that has the - /// same size as `ssd_indices`. It - /// contains locations of the - /// corresponding indices in the - /// scratch pad if they are found or - /// sentinel values - /// @param scratch_pad_indices The 1D tensor that contains scratch - /// pad indices, i.e., conflict missed - /// indices from the previous iteration. - /// The indices and their locations must - /// match the keys and values in the - /// front hash map. After this function, - /// the indices that are found will be - /// set to sentinel values to prevent - /// them from getting evicted - /// @param ssd_indices The 1D tensor that contains indices that are - /// missed from the L1 cache, i.e., all missed - /// indices (including conflict misses). After - /// this function, the indices that are found - /// will be set to sentinel values to prevent - /// them from being looked up in from SSD - /// @param count The tensor that contains the number of indices to - /// be processed + /// @param scratch_pad_prev_curr_map The 1D output tensor that has + /// the same size as `inserted_indices_curr`. It contains + /// locations of the corresponding indices (in the + /// current scratch pad) in the previous scratch pad if + /// they are found or sentinel values + /// @param scratch_pad_curr_prev_map The 1D output tensor that has + /// the same size as `scratch_pad_indices_prev`. It + /// contains locations of the corresponding indices (in + /// the previous scratch pad) in the current scratch pad + /// if they are found or sentinel values + /// @param scratch_pad_indices_prev The 1D tensor that contains + /// scratch pad indices, i.e., conflict missed indices + /// from the previous iteration. The indices and their + /// locations must match the keys and values in the front + /// hash map. After this function, the indices that are + /// found will be set to sentinel values to prevent them + /// from getting evicted + /// @param inserted_indices_curr The 1D tensor that contains indices + /// that are missed from the L1 cache, i.e., all missed + /// indices (including conflict misses) from the current + /// iteration. After this function, the indices that are + /// found will be set to sentinel values to prevent them + /// from being looked up in from SSD + /// @param count_curr The tensor that contains the number of indices + /// to be processed. /// /// @return Outputs are passed by reference. void lookup_mask_and_pop_front( - const Tensor& scratch_pad_locations, - const Tensor& scratch_pad_indices, - const Tensor& ssd_indices, - const Tensor& count) { + const Tensor& scratch_pad_prev_curr_map, + const Tensor& scratch_pad_curr_prev_map, + const Tensor& scratch_pad_indices_prev, + const Tensor& inserted_indices_curr, + const Tensor& count_curr) { TORCH_CHECK( index_loc_map_queue.size() > 0, "index_loc_map_queue must not be empty"); - const auto count_ = count.item(); - TORCH_CHECK(ssd_indices.numel() >= count_); - TORCH_CHECK(scratch_pad_locations.numel() == ssd_indices.numel()); + const auto count_ = count_curr.item(); + TORCH_CHECK(inserted_indices_curr.numel() >= count_); + TORCH_CHECK( + scratch_pad_prev_curr_map.numel() == inserted_indices_curr.numel()); auto& map = index_loc_map_queue.front(); - auto loc_acc = scratch_pad_locations.accessor(); - auto sp_acc = scratch_pad_indices.accessor(); - auto ssd_acc = ssd_indices.accessor(); + auto sp_prev_curr_map_acc = + scratch_pad_prev_curr_map.accessor(); + auto sp_indices_acc = scratch_pad_indices_prev.accessor(); + auto inserted_indices_acc = inserted_indices_curr.accessor(); + auto sp_curr_prev_map_acc = scratch_pad_curr_prev_map.accessor(); // Concurrent lookup is OK since it is read-only at::parallel_for(0, count_, 1, [&](int64_t start, int64_t end) { for (const auto i : c10::irange(start, end)) { - const auto val = ssd_acc[i]; + const auto val = inserted_indices_acc[i]; const auto val_loc = map->find(val); // If index is found in the map if (val_loc != map->end()) { const auto loc = val_loc->second; - // Store the location in the scratch pad - loc_acc[i] = loc; - // Set the scratch pad index as the sentinel value to - // prevent it from being evicted - sp_acc[loc] = sentinel_val_; + // Store the previous scratch pad location + sp_prev_curr_map_acc[i] = loc; + // Store the current scratch pad location + sp_curr_prev_map_acc[loc] = i; // Set the SSD index as the sentinel value to prevent it // from being looked up in SSD - ssd_acc[i] = sentinel_val_; + inserted_indices_acc[i] = sentinel_val_; + // Set the scratch pad index as the sentinel value to + // prevent it from being evicted + sp_indices_acc[loc] = sentinel_val_; } else { // Set the location to the sentinel value to indicate that // the index is not found in the scratch pad - loc_acc[i] = sentinel_val_; + sp_prev_curr_map_acc[i] = sentinel_val_; } } }); @@ -208,12 +225,17 @@ class SSDScratchPadIndicesQueue : public torch::jit::CustomClassHolder { } void lookup_mask_and_pop_front_cuda( - const Tensor& scratch_pad_locations, - const Tensor& scratch_pad_indices, - const Tensor& ssd_indices, - const Tensor& count) { + const Tensor& scratch_pad_prev_curr_map, + const Tensor& scratch_pad_curr_prev_map, + const Tensor& scratch_pad_indices_prev, + const Tensor& inserted_indices_curr, + const Tensor& count_curr) { impl_->lookup_mask_and_pop_front_cuda( - scratch_pad_locations, scratch_pad_indices, ssd_indices, count); + scratch_pad_prev_curr_map, + scratch_pad_curr_prev_map, + scratch_pad_indices_prev, + inserted_indices_curr, + count_curr); } int64_t size() { diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 7f37896ae8..441f9fb034 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -352,6 +352,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { (2 * (count_ + dbs_.size() - 1) / dbs_.size()) * (sizeof(int64_t) + sizeof(scalar_t) * D)); for (auto i = 0; i < count_; ++i) { + // TODO: Check whether this is OK + if (indices_acc[i] == -1) { + continue; + } if (db_shard(indices_acc[i], dbs_.size()) != shard) { continue; } diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 025063d837..78970150c9 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -8,7 +8,8 @@ # pyre-ignore-all-errors[3,6,56] import unittest -from typing import List, Optional, Tuple + +from typing import Any, Dict, List, Optional, Tuple import hypothesis.strategies as st import numpy as np @@ -22,15 +23,45 @@ from .. import common # noqa E402 from ..common import open_source - if open_source: # pyre-ignore[21] from test_utils import gpu_unavailable, running_on_github else: from fbgemm_gpu.test.test_utils import gpu_unavailable, running_on_github +from enum import Enum + MAX_EXAMPLES = 40 +MAX_PIPELINE_EXAMPLES = 10 + +default_st: Dict["str", Any] = { + "T": st.integers(min_value=1, max_value=10), + "D": st.integers(min_value=2, max_value=128), + "B": st.integers(min_value=1, max_value=128), + "log_E": st.integers(min_value=3, max_value=5), + "L": st.integers(min_value=0, max_value=20), + "weighted": st.booleans(), + "cache_set_scale": st.sampled_from([0.0, 0.005, 1]), + "pooling_mode": st.sampled_from( + [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] + ), + "weights_precision": st.sampled_from([SparseType.FP32, SparseType.FP16]), + "output_dtype": st.sampled_from([SparseType.FP32, SparseType.FP16]), + "share_table": st.booleans(), +} + + +class PrefetchLocation(Enum): + BEFORE_FWD = 1 + BETWEEN_FWD_BWD = 2 + + +class FlushLocation(Enum): + AFTER_FWD = 1 + AFTER_BWD = 2 + BEFORE_TRAINING = 3 + ALL = 4 @unittest.skipIf(*running_on_github) @@ -152,6 +183,7 @@ def generate_ssd_tbes( output_dtype: SparseType = SparseType.FP32, stochastic_rounding: bool = True, share_table: bool = False, + prefetch_pipeline: bool = False, ) -> Tuple[SSDTableBatchedEmbeddingBags, List[torch.nn.EmbeddingBag]]: """ Generate embedding modules (i,e., SSDTableBatchedEmbeddingBags and @@ -225,6 +257,7 @@ def generate_ssd_tbes( weights_precision=weights_precision, output_dtype=output_dtype, stochastic_rounding=stochastic_rounding, + prefetch_pipeline=prefetch_pipeline, ).cuda() # A list to keep the CPU tensor alive until `set` (called inside @@ -350,21 +383,7 @@ def execute_ssd_forward_( ) return output_ref_list, output - @given( - T=st.integers(min_value=1, max_value=10), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - weighted=st.booleans(), - cache_set_scale=st.sampled_from([0.0, 0.005, 1]), - pooling_mode=st.sampled_from( - [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] - ), - weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - share_table=st.booleans(), - ) + @given(**default_st) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_forward( self, @@ -426,21 +445,7 @@ def test_ssd_forward( weighted, ) - @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - weighted=st.booleans(), - cache_set_scale=st.sampled_from([0.0, 0.005, 1]), - pooling_mode=st.sampled_from( - [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] - ), - weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - share_table=st.booleans(), - ) + @given(**default_st) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_backward_adagrad( self, @@ -576,23 +581,7 @@ def test_ssd_backward_adagrad( rtol=tolerance, ) - @given( - T=st.integers(min_value=1, max_value=10), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - weighted=st.booleans(), - cache_set_scale=st.sampled_from([0.0, 0.005, 1]), - pooling_mode=st.sampled_from( - [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] - ), - weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - share_table=st.booleans(), - ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_ssd_cache( + def execute_ssd_cache_pipeline_( # noqa C901 self, T: int, D: int, @@ -605,8 +594,15 @@ def test_ssd_cache( weights_precision: SparseType, output_dtype: SparseType, share_table: bool, + prefetch_pipeline: bool, + # If True, prefetch will be invoked by the user. + explicit_prefetch: bool, + prefetch_location: Optional[PrefetchLocation], + use_prefetch_stream: bool, + flush_location: Optional[FlushLocation], ) -> None: - assume(not weighted or pooling_mode == PoolingMode.SUM) + # If using pipeline prefetching, explicit prefetching must be True + assert not prefetch_pipeline or explicit_prefetch lr = 0.5 eps = 0.2 @@ -636,6 +632,7 @@ def test_ssd_cache( # functionality of the cache stochastic_rounding=False, share_table=share_table, + prefetch_pipeline=prefetch_pipeline, ) optimizer_states_ref = [ @@ -650,48 +647,73 @@ def test_ssd_cache( else 1.0e-2 ) + batches = [] for it in range(10): - ( - indices_list, - per_sample_weights_list, - indices, - offsets, - per_sample_weights, - ) = self.generate_inputs_( - B, - L, - Es, - emb.feature_table_map, - weights_precision=weights_precision, + batches.append( + self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + ) ) - assert emb.timestep == it - emb.prefetch(indices, offsets) + prefetch_stream = ( + torch.cuda.Stream() if use_prefetch_stream else torch.cuda.current_stream() + ) + forward_stream = torch.cuda.current_stream() if use_prefetch_stream else None + + iters = 10 - linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( - emb.hash_size_cumsum, - indices, - offsets, - ) + force_flush = flush_location == FlushLocation.ALL + + if force_flush or flush_location == FlushLocation.BEFORE_TRAINING: + emb.flush() + + # pyre-ignore[53] + def _prefetch(b_it: int) -> int: + if not explicit_prefetch or b_it >= iters: + return b_it + 1 - # Verify that prefetching twice avoids any actions. ( _, _, + indices, + offsets, _, - actions_count_gpu, - _, - _, - _, - _, - ) = torch.ops.fbgemm.ssd_cache_populate_actions( # noqa - linear_cache_indices, - emb.total_hash_size, - emb.lxu_cache_state, - emb.timestep, - 0, # prefetch_dist - emb.lru_state, - ) + ) = batches[b_it] + with torch.cuda.stream(prefetch_stream): + emb.prefetch(indices, offsets, forward_stream=forward_stream) + return b_it + 1 + + if prefetch_pipeline: + # Prefetch the first iteration + _prefetch(0) + b_it = 1 + else: + b_it = 0 + + for it in range(iters): + ( + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + ) = batches[it] + + # Ensure that prefetch i is done before forward i + if use_prefetch_stream: + assert forward_stream is not None + forward_stream.wait_stream(prefetch_stream) + + # Prefetch before forward + if ( + not prefetch_pipeline + or prefetch_location == PrefetchLocation.BEFORE_FWD + ): + b_it = _prefetch(b_it) # Execute forward output_ref_list, output = self.execute_ssd_forward_( @@ -709,6 +731,9 @@ def test_ssd_cache( it=it, ) + if force_flush or flush_location == FlushLocation.AFTER_FWD: + emb.flush() + # Generate output gradient output_grad_list = [torch.randn_like(out) for out in output_ref_list] @@ -728,9 +753,19 @@ def test_ssd_cache( D * 4, ) + # Prefetch between forward and backward + if ( + prefetch_pipeline + and prefetch_location == PrefetchLocation.BETWEEN_FWD_BWD + ): + b_it = _prefetch(b_it) + # Execute TBE SSD backward output.backward(grad_test) + if force_flush or flush_location == FlushLocation.AFTER_BWD: + emb.flush() + # Compare optimizer states split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): @@ -779,3 +814,136 @@ def test_ssd_cache( atol=tolerance, rtol=tolerance, ) + + @given( + flush_location=st.sampled_from(FlushLocation), + prefetch_pipeline=st.booleans(), + explicit_prefetch=st.booleans(), + prefetch_location=st.sampled_from(PrefetchLocation), + use_prefetch_stream=st.booleans(), + **default_st, + ) + @settings( + verbosity=Verbosity.verbose, max_examples=MAX_PIPELINE_EXAMPLES, deadline=None + ) + def test_ssd_cache_flush(self, **kwargs: Any): + """ + Test the correctness of the SSD cache prefetch workflow with + excessive flushing + """ + assume(not kwargs["weighted"] or kwargs["pooling_mode"] == PoolingMode.SUM) + assume(kwargs["prefetch_pipeline"] and kwargs["explicit_prefetch"]) + assume(not kwargs["use_prefetch_stream"] or kwargs["prefetch_pipeline"]) + self.execute_ssd_cache_pipeline_( + **kwargs, + ) + + @given(**default_st) + @settings( + verbosity=Verbosity.verbose, max_examples=MAX_PIPELINE_EXAMPLES, deadline=None + ) + def test_ssd_cache_implicit_prefetch(self, **kwargs: Any): + """ + Test the correctness of the SSD cache prefetch workflow + without pipeline prefetching and with implicit prefetching. + Implicit prefetching relies on TBE forward to invoke prefetch. + """ + assume(not kwargs["weighted"] or kwargs["pooling_mode"] == PoolingMode.SUM) + self.execute_ssd_cache_pipeline_( + prefetch_pipeline=False, + explicit_prefetch=False, + prefetch_location=None, + use_prefetch_stream=False, + flush_location=None, + **kwargs, + ) + + @given(**default_st) + @settings( + verbosity=Verbosity.verbose, max_examples=MAX_PIPELINE_EXAMPLES, deadline=None + ) + def test_ssd_cache_explicit_prefetch(self, **kwargs: Any): + """ + Test the correctness of the SSD cache prefetch workflow + without pipeline prefetching and with explicit prefetching + (the user explicitly invokes prefetch). Each prefetch invoked + before a forward TBE fetches data for that specific iteration. + + For example: + + ------------------------- Timeline ------------------------> + pf(i) -> fwd(i) -> ... -> pf(i+1) -> fwd(i+1) -> ... + + Note: + - pf(i) = prefetch of iteration i + - fwd(i) = forward TBE of iteration i + """ + assume(not kwargs["weighted"] or kwargs["pooling_mode"] == PoolingMode.SUM) + self.execute_ssd_cache_pipeline_( + prefetch_pipeline=False, + explicit_prefetch=True, + prefetch_location=None, + use_prefetch_stream=False, + flush_location=None, + **kwargs, + ) + + @given(use_prefetch_stream=st.booleans(), **default_st) + @settings( + verbosity=Verbosity.verbose, max_examples=MAX_PIPELINE_EXAMPLES, deadline=None + ) + def test_ssd_cache_pipeline_before_fwd(self, **kwargs: Any): + """ + Test the correctness of the SSD cache prefetch workflow with + pipeline prefetching when cache prefetching of the next + iteration is invoked before the forward TBE of the current + iteration. + + For example: + + ------------------------- Timeline ------------------------> + pf(i+1) -> fwd(i) -> ... -> pf(i+2) -> fwd(i+1) -> ... + + Note: + - pf(i) = prefetch of iteration i + - fwd(i) = forward TBE of iteration i + """ + assume(not kwargs["weighted"] or kwargs["pooling_mode"] == PoolingMode.SUM) + self.execute_ssd_cache_pipeline_( + prefetch_pipeline=True, + explicit_prefetch=True, + prefetch_location=PrefetchLocation.BEFORE_FWD, + flush_location=None, + **kwargs, + ) + + @given(use_prefetch_stream=st.booleans(), **default_st) + @settings( + verbosity=Verbosity.verbose, max_examples=MAX_PIPELINE_EXAMPLES, deadline=None + ) + def test_ssd_cache_pipeline_between_fwd_bwd(self, **kwargs: Any): + """ + Test the correctness of the SSD cache prefetch workflow with + pipeline prefetching when cache prefetching of the next + iteration is invoked after the forward TBE and before the + backward TBE of the current iteration. + + For example: + + ------------------------- Timeline ------------------------> + fwd(i) -> pf(i+1) -> bwd(i) -> ... -> fwd(i+1) -> pf(i+2) -> bwd(i+1) -> ... + + Note: + - pf(i) = prefetch of iteration i + - fwd(i) = forward TBE of iteration i + - bwd(i) = backward TBE of iteration i + """ + + assume(not kwargs["weighted"] or kwargs["pooling_mode"] == PoolingMode.SUM) + self.execute_ssd_cache_pipeline_( + prefetch_pipeline=True, + explicit_prefetch=True, + prefetch_location=PrefetchLocation.BETWEEN_FWD_BWD, + flush_location=None, + **kwargs, + ) diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py index 72948ba528..ef434dd8e1 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py @@ -259,7 +259,8 @@ def test_scratch_pad_indices_queue( # Run reference # Prepare inputs for the reference run - sp_locations_ref = torch.zeros_like(lookup_indices) + sp_prev_curr_map_ref = torch.zeros_like(lookup_indices) + sp_curr_prev_map_ref = torch.empty_like(indices, dtype=torch.int).fill_(-1) sp_indices = indices.clone().tolist() ssd_indices = lookup_indices.clone().tolist() @@ -273,30 +274,33 @@ def test_scratch_pad_indices_queue( ssd_idx = ssd_indices[i] if ssd_idx in sp_map: loc = sp_map[ssd_idx] - sp_locations_ref[i] = loc + sp_prev_curr_map_ref[i] = loc + sp_curr_prev_map_ref[loc] = i sp_indices[loc] = sentinel_value ssd_indices[i] = sentinel_value else: - sp_locations_ref[i] = sentinel_value + sp_prev_curr_map_ref[i] = sentinel_value all_lookup_outputs_ref.append( ( - sp_locations_ref, + sp_prev_curr_map_ref, torch.as_tensor(sp_indices), torch.as_tensor(ssd_indices), ) ) # Run test - sp_locations = torch.zeros_like(lookup_indices) + sp_prev_curr_map = torch.zeros_like(lookup_indices) + sp_curr_prev_map = torch.empty_like(indices, dtype=torch.int).fill_(-1) sp_idx_queue.lookup_mask_and_pop_front_cuda( - sp_locations, + sp_prev_curr_map, + sp_curr_prev_map, indices, lookup_indices, lookup_count, ) - all_lookup_outputs.append((sp_locations, indices, lookup_indices)) + all_lookup_outputs.append((sp_prev_curr_map, indices, lookup_indices)) # Ensure that the lookups are done torch.cuda.synchronize() @@ -304,7 +308,12 @@ def test_scratch_pad_indices_queue( # Compare results for test, ref in zip(all_lookup_outputs, all_lookup_outputs_ref): for name, test_, ref_ in zip( - ["scratch_pad_locations", "scratch_pad_indices", "ssd_indices"], + [ + "scratch_pad_prev_curr_map", + "scratch_pad_curr_prev_map", + "scratch_pad_indices", + "ssd_indices", + ], test, ref, ):