From 6f761f8f2d149c5a7504bf85e328f103de6fcab5 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Mon, 12 Aug 2024 18:19:09 -0700 Subject: [PATCH] Enable pipeline prefetching (#2963) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2963 This diff enables pipeline cache prefetching for SSD-TBE. This allows prefetch for the next iteration's batch to be carried out while the computation of the current batch is going on. We have done the following to guarantee cache consistency when pipeline prefetching is enabled: (1) Enable cache line locking (implemented in D46172802, D47638502, D60812956) to ensure that cache lines are not prematurely evicted by the prefetch when the previous iteration's computation is not complete. (2) Lookup L1 cache, the previous iteration's scratch pad (let's call it SP(i-1)), and SSD/L2 cache. Move rows from SSD/L2 and/or SP(i-1) to either L1 or the current iteration's scratch pad (let's call it SP(i)). Then we update the row pointers of the previous iteration's indices based on the new locations, i.e., L1 or SP(i). The detailed explaination of the process is shown in the figure below: {F1802341461} https://internalfb.com/excalidraw/EX264315 (3) Ensure proper synchronizations between streams and events - Ensure that prefetch of iteration i is complete before backward TBE of iteration i-1 - Ensure that prefetch of iteration i+1 starts after the backward TBE of iteration i is complete The following is how prefetch operators run on GPU streams/CPU: {F1802798301} **Usage:** ``` # Initialize the module with prefetch_pipeline=True # prefetch_stream is the CUDA stream for prefetching (optional) emb = SSDTableBatchedEmbeddingBags( embedding_specs=..., prefetch_pipeline=True, # It is recommended to set the stream priority to low prefetch_stream=torch.cuda.Stream(), ).cuda() # When calling prefetch, make sure to pass the forward stream if using prefetch_stream so that TBE records tensors on streams properly with torch.cuda.stream(prefetch_stream): emb.prefetch( indices, offsets, forward_stream=forward_stream ) ``` Differential Revision: D60727327 --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 612 +++++++++++++----- .../ssd_scratch_pad_indices_queue.cpp | 144 +++-- .../ssd_table_batched_embeddings.h | 4 + .../tbe/ssd/ssd_split_tbe_training_test.py | 336 +++++++--- fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py | 25 +- 5 files changed, 803 insertions(+), 318 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 41c99c9675..16d6a61d69 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -14,7 +14,7 @@ import os import tempfile from math import log2 -from typing import Any, Callable, List, Optional, Tuple, Type +from typing import Any, Callable, List, Optional, Tuple, Type, Union import torch # usort:skip @@ -36,11 +36,23 @@ ) from torch import distributed as dist, nn, Tensor # usort:skip +from dataclasses import dataclass + from torch.autograd.profiler import record_function from .common import ASSOC +@dataclass +class IterData: + indices: Tensor + offsets: Tensor + lxu_cache_locations: Tensor + lxu_cache_ptrs: Tensor + actions_count_gpu: Tensor + cache_set_inverse_indices: Tensor + + class SSDTableBatchedEmbeddingBags(nn.Module): D_offsets: Tensor lxu_cache_weights: Tensor @@ -113,6 +125,9 @@ def __init__( gather_ssd_cache_stats: Optional[bool] = False, stats_reporter_config: Optional[TBEStatsReporterConfig] = None, l2_cache_size: int = 0, + # Set to True to enable pipeline prefetching + prefetch_pipeline: bool = False, + prefetch_stream: Optional[torch.cuda.Stream] = None, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -183,14 +198,42 @@ def __init__( ) self.register_buffer( "lxu_cache_state", - torch.zeros(cache_sets, ASSOC, dtype=torch.int64).fill_(-1), + torch.zeros( + cache_sets, ASSOC, device=self.current_device, dtype=torch.int64 + ).fill_(-1), ) self.register_buffer( - "lru_state", torch.zeros(cache_sets, ASSOC, dtype=torch.int64) + "lru_state", + torch.zeros( + cache_sets, ASSOC, device=self.current_device, dtype=torch.int64 + ), ) self.step = 0 + # Set prefetch pipeline + self.prefetch_pipeline: bool = prefetch_pipeline + self.prefetch_stream: Optional[torch.cuda.Stream] = prefetch_stream + + # Cache locking counter for pipeline prefetching + if self.prefetch_pipeline: + self.register_buffer( + "lxu_cache_locking_counter", + torch.zeros( + cache_sets, + ASSOC, + device=self.current_device, + dtype=torch.int32, + ), + persistent=True, + ) + else: + self.register_buffer( + "lxu_cache_locking_counter", + torch.zeros([0, 0], dtype=torch.int32, device=self.current_device), + persistent=False, + ) + assert ssd_cache_location in ( EmbeddingLocation.MANAGED, EmbeddingLocation.DEVICE, @@ -343,8 +386,11 @@ def __init__( self.ssd_event_backward = torch.cuda.Event() # SSD get's input copy completion event self.ssd_event_get_inputs_cpy = torch.cuda.Event() - # SSD scratch pad index queue insert completion event - self.ssd_event_sp_idxq_insert = torch.cuda.Event() + if self.prefetch_pipeline: + # SSD scratch pad index queue insert completion event + self.ssd_event_sp_idxq_insert: torch.cuda.streams.Event = torch.cuda.Event() + # SSD scratch pad index queue lookup completion event + self.ssd_event_sp_idxq_lookup: torch.cuda.streams.Event = torch.cuda.Event() self.timesteps_prefetched: List[int] = [] # TODO: add type annotation @@ -352,10 +398,19 @@ def __init__( self.ssd_prefetch_data = [] # Scratch pad value queue - self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = [] - # pyre-ignore[4] - # Scratch pad index queue - self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue(-1) + self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = [] + # Scratch pad eviction data queue + self.ssd_scratch_pad_eviction_data: List[ + Tuple[Tensor, Tensor, Tensor, bool] + ] = [] + self.ssd_location_update_data: List[Tuple[Tensor, Tensor]] = [] + + if self.prefetch_pipeline: + # pyre-ignore[4] + # Scratch pad index queue + self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue( + -1 + ) if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization: raise AssertionError( @@ -426,12 +481,24 @@ def __init__( dtype=torch.float32, ) + # For storing current iteration data + self.current_iter_data: Optional[IterData] = None + # add placeholder require_grad param to enable autograd without nn.parameter # this is needed to enable int8 embedding weights for SplitTableBatchedEmbedding self.placeholder_autograd_tensor = nn.Parameter( torch.zeros(0, device=self.current_device, dtype=torch.float) ) + # Register backward hook for evicting rows from a scratch pad to SSD + # post backward + self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad) + + if self.prefetch_pipeline: + self.register_full_backward_pre_hook( + self._update_cache_counter_and_pointers + ) + assert optimizer in ( OptimType.EXACT_ROWWISE_ADAGRAD, ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags" @@ -543,28 +610,31 @@ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor: t_cpu.copy_(t, non_blocking=True) return t_cpu - def to_pinned_cpu_on_stream_wait_on_current_stream( + def to_pinned_cpu_on_stream_wait_on_another_stream( self, tensors: List[Tensor], stream: torch.cuda.Stream, + stream_to_wait_on: torch.cuda.Stream, post_event: Optional[torch.cuda.Event] = None, ) -> List[Tensor]: """ - Transfer input tensors from GPU to CPU using a pinned host buffer. - The transfer is carried out on the given stream (`stream`) after all - the kernels in the default stream (`current_stream`) are complete. + Transfer input tensors from GPU to CPU using a pinned host + buffer. The transfer is carried out on the given stream + (`stream`) after all the kernels in the other stream + (`stream_to_wait_on`) are complete. Args: - tensors (List[Tensor]): The list of tensors to be transferred + tensors (List[Tensor]): The list of tensors to be + transferred stream (Stream): The stream to run memory copy + stream_to_wait_on (Stream): The stream to wait on post_event (Event): The post completion event Returns: The list of pinned CPU tensors """ - current_stream = torch.cuda.current_stream() with torch.cuda.stream(stream): - stream.wait_stream(current_stream) + stream.wait_stream(stream_to_wait_on) cpu_tensors = [] for t in tensors: t.record_stream(stream) @@ -627,115 +697,195 @@ def evict( if post_event is not None: stream.record_event(post_event) - def _evict_from_scratch_pad(self, return_on_empty: bool) -> None: - scratch_pad_len = len(self.ssd_scratch_pads) + def _evict_from_scratch_pad(self, grad: Tensor) -> None: + """ + Evict conflict missed rows from a scratch pad + (`inserted_rows`) on the `ssd_eviction_stream`. This is a hook + that is invoked right after TBE backward. - if not return_on_empty: - assert scratch_pad_len > 0, "There must be at least one scratch pad" - elif scratch_pad_len == 0: - return + Conflict missed indices are specified in + `post_bwd_evicted_indices_cpu`. Indices that are not -1 and + their positions < `actions_count_cpu` (i.e., rows + `post_bwd_evicted_indices_cpu[:actions_count_cpu] != -1` in + post_bwd_evicted_indices_cpu) will be evicted. + + Args: + grad (Tensor): Unused gradient tensor + + Returns: + None + """ + with record_function("## ssd_evict_from_scratch_pad_pipeline ##"): + current_stream = torch.cuda.current_stream() + current_stream.record_event(self.ssd_event_backward) + + assert ( + len(self.ssd_scratch_pad_eviction_data) > 0 + ), "There must be at least one scratch pad" + + ( + inserted_rows, + post_bwd_evicted_indices_cpu, + actions_count_cpu, + do_evict, + ) = self.ssd_scratch_pad_eviction_data.pop(0) + + if not do_evict: + return - (inserted_rows, post_bwd_evicted_indices_cpu, actions_count_cpu, do_evict) = ( - self.ssd_scratch_pads.pop(0) - ) - if do_evict: - torch.cuda.current_stream().record_event(self.ssd_event_backward) self.evict( rows=inserted_rows, indices_cpu=post_bwd_evicted_indices_cpu, actions_count_cpu=actions_count_cpu, stream=self.ssd_eviction_stream, pre_event=self.ssd_event_backward, - post_event=None, + post_event=self.ssd_event_evict, is_rows_uvm=True, name="scratch_pad", ) - def _compute_cache_ptrs( + if self.prefetch_stream: + self.prefetch_stream.wait_stream(current_stream) + + def _update_cache_counter_and_pointers( self, - linear_cache_indices: torch.Tensor, - assigned_cache_slots: torch.Tensor, - linear_index_inverse_indices: torch.Tensor, - unique_indices_count_cumsum: torch.Tensor, - cache_set_inverse_indices: torch.Tensor, - inserted_rows: torch.Tensor, - unique_indices_length: torch.Tensor, - inserted_indices: torch.Tensor, - actions_count_cpu: torch.Tensor, - lxu_cache_locations: torch.Tensor, - ) -> torch.Tensor: - with record_function("## ssd_generate_row_addrs ##"): - lxu_cache_ptrs, post_bwd_evicted_indices = ( - torch.ops.fbgemm.ssd_generate_row_addrs( - lxu_cache_locations, - assigned_cache_slots, - linear_index_inverse_indices, - unique_indices_count_cumsum, - cache_set_inverse_indices, - self.lxu_cache_weights, - inserted_rows, - unique_indices_length, - inserted_indices, - ) + module: nn.Module, + grad_input: Union[Tuple[Tensor, ...], Tensor], + ) -> None: + """ + Update cache line locking counter and pointers before backward + TBE. This is a hook that is called before the backward of TBE + + Update cache line counter: + + We ensure that cache prefetching does not execute concurrently + with the backward TBE. Therefore, it is safe to unlock the + cache lines used in current iteration before backward TBE. + + Update pointers: + + Now some rows that are used in both the current iteration and + the next iteration are moved (1) from the current iteration's + scratch pad into the next iteration's scratch pad or (2) from + the current iteration's scratch pad into the L1 cache + + To ensure that the TBE backward kernel accesses valid data, + here we update the pointers of these rows in the current + iteration's `lxu_cache_ptrs` to point to either L1 cache or + the next iteration scratch pad + + Args: + module (nn.Module): Unused + grad_input (Union[Tuple[Tensor, ...], Tensor]): Unused + + Returns: + None + """ + if self.prefetch_stream: + # Ensure that prefetch is done + torch.cuda.current_stream().wait_stream(self.prefetch_stream) + + assert self.current_iter_data is not None, "current_iter_data must be set" + + curr_data: IterData = self.current_iter_data + + if curr_data.lxu_cache_locations.numel() == 0: + return + + with record_function("## ssd_update_cache_counter_and_pointers ##"): + # Unlock the cache lines + torch.ops.fbgemm.lxu_cache_locking_counter_decrement( + self.lxu_cache_locking_counter, + curr_data.lxu_cache_locations, ) - # Transfer post_bwd_evicted_indices from GPU to CPU right away to - # increase a chance of overlapping with compute in the default stream - (post_bwd_evicted_indices_cpu,) = ( - self.to_pinned_cpu_on_stream_wait_on_current_stream( - tensors=[post_bwd_evicted_indices], - stream=self.ssd_eviction_stream, - post_event=None, + # Recompute linear_cache_indices to save memory + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + self.hash_size_cumsum, + curr_data.indices, + curr_data.offsets, + ) + ( + linear_unique_indices, + linear_unique_indices_length, + unique_indices_count, + linear_index_inverse_indices, + ) = torch.ops.fbgemm.get_unique_indices_v2( + linear_cache_indices, + self.total_hash_size, + compute_count=True, + compute_inverse_indices=True, + ) + unique_indices_count_cumsum = torch.ops.fbgemm.asynchronous_complete_cumsum( + unique_indices_count ) - ) - # Insert conflict miss indices in the index queue for future lookup - # post_bwd_evicted_indices_cpu is transferred on the ssd_eviction_stream stream - # actions_count_cpu is transferred on the ssd_memcpy_stream stream - with torch.cuda.stream(self.ssd_eviction_stream): - # Ensure that actions_count_cpu transfer is done - self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy) - self.record_function_via_dummy_profile( - "## ssd_scratch_pad_idx_queue_insert ##", - self.scratch_pad_idx_queue.insert_cuda, - post_bwd_evicted_indices_cpu, - actions_count_cpu, + # Look up the cache to check which indices in the scratch + # pad are moved to L1 + torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + self.lxu_cache_state, + self.total_hash_size, + gather_cache_stats=False, # not collecting cache stats + lxu_cache_locations_output=curr_data.lxu_cache_locations, ) - self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert) - with record_function("## ssd_scratch_pads ##"): - # Store scratch pad info for post backward eviction - self.ssd_scratch_pads.append( - ( - inserted_rows, - post_bwd_evicted_indices_cpu, - actions_count_cpu, - linear_cache_indices.numel() > 0, - ) + if len(self.ssd_location_update_data) == 0: + return + + (sp_curr_next_map, inserted_rows_next) = self.ssd_location_update_data.pop( + 0 ) - # pyre-fixme[7]: Expected `Tensor` but got `Tuple[typing.Any, Tensor, - # typing.Any, Tensor]`. - return ( - lxu_cache_ptrs, - inserted_rows, - post_bwd_evicted_indices_cpu, - actions_count_cpu, - ) + # Update poitners + torch.ops.fbgemm.ssd_update_row_addrs( + ssd_row_addrs_curr=curr_data.lxu_cache_ptrs, + inserted_ssd_weights_curr_next_map=sp_curr_next_map, + lxu_cache_locations_curr=curr_data.lxu_cache_locations, + linear_index_inverse_indices_curr=linear_index_inverse_indices, + unique_indices_count_cumsum_curr=unique_indices_count_cumsum, + cache_set_inverse_indices_curr=curr_data.cache_set_inverse_indices, + lxu_cache_weights=self.lxu_cache_weights, + inserted_ssd_weights_next=inserted_rows_next, + unique_indices_length_curr=curr_data.actions_count_gpu, + ) - def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: - with record_function("## ssd_prefetch ##"): + def prefetch( + self, + indices: Tensor, + offsets: Tensor, + forward_stream: Optional[torch.cuda.Stream] = None, + ) -> Optional[Tensor]: + # TODO: Refactor prefetch + current_stream = torch.cuda.current_stream() + + # If a prefetch stream is used, then record tensors on the stream + indices.record_stream(current_stream) + offsets.record_stream(current_stream) + + with record_function("## ssd_prefetch {} ##".format(self.timestep)): if self.gather_ssd_cache_stats: self.local_ssd_cache_stats.zero_() (indices, offsets) = indices.long(), offsets.long() + # Linearize indices linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( self.hash_size_cumsum, indices, offsets, ) + self.timestep += 1 self.timesteps_prefetched.append(self.timestep) + + # Lookup and virtually insert indices into L1. After this operator, + # we know: + # (1) which cache lines can be evicted + # (2) which rows are already in cache (hit) + # (3) which rows are missed and can be inserted later (missed, but + # not conflict missed) + # (4) which rows are missed but CANNOT be inserted later (conflict + # missed) ( inserted_indices, evicted_indices, @@ -754,33 +904,58 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: self.lru_state, self.gather_ssd_cache_stats, self.local_ssd_cache_stats, + lock_cache_line=self.prefetch_pipeline, + lxu_cache_locking_counter=self.lxu_cache_locking_counter, ) - # Transfer evicted indices from GPU to CPU right away to increase a - # chance of overlapping with compute on the default stream - (evicted_indices_cpu,) = ( - self.to_pinned_cpu_on_stream_wait_on_current_stream( - tensors=[evicted_indices], - stream=self.ssd_eviction_stream, - post_event=None, + + # Compute cache locations (rows that are hit are missed but can be + # inserted will have cache locations != -1) + with record_function("## ssd_tbe_lxu_cache_lookup ##"): + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + self.lxu_cache_state, + self.total_hash_size, + self.gather_ssd_cache_stats, + self.local_ssd_cache_stats, ) - ) - actions_count_cpu, inserted_indices_cpu = ( - self.to_pinned_cpu_on_stream_wait_on_current_stream( - tensors=[ - actions_count_gpu, - inserted_indices, - ], - stream=self.ssd_memcpy_stream, - post_event=self.ssd_event_get_inputs_cpy, + with record_function("## ssd_d2h_inserted_indices ##"): + # Transfer actions_count and insert_indices right away to + # incrase an overlap opportunity + actions_count_cpu, inserted_indices_cpu = ( + self.to_pinned_cpu_on_stream_wait_on_another_stream( + tensors=[ + actions_count_gpu, + inserted_indices, + ], + stream=self.ssd_memcpy_stream, + stream_to_wait_on=current_stream, + post_event=self.ssd_event_get_inputs_cpy, + ) ) - ) - assigned_cache_slots = assigned_cache_slots.long() - evicted_rows = self.lxu_cache_weights[ - assigned_cache_slots.clamp(min=0).long(), : - ] + with record_function("## ssd_d2h_evicted_indices ##"): + # Transfer evicted indices from GPU to CPU right away to increase a + # chance of overlapping with compute on the default stream + (evicted_indices_cpu,) = ( + self.to_pinned_cpu_on_stream_wait_on_another_stream( + tensors=[evicted_indices], + stream=self.ssd_eviction_stream, + stream_to_wait_on=current_stream, + post_event=None, + ) + ) + + # Copy rows to be evicted into a separate buffer (will be evicted + # later in the prefetch step) + with record_function("## ssd_compute_evicted_rows ##"): + assigned_cache_slots = assigned_cache_slots.long() + evicted_rows = self.lxu_cache_weights[ + assigned_cache_slots.clamp(min=0).long(), : + ] + # Allocation a scratch pad for the current iteration. The scratch + # pad is a UVA tensor if linear_cache_indices.numel() > 0: inserted_rows = torch.ops.fbgemm.new_managed_tensor( torch.zeros( @@ -797,34 +972,47 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: device=self.current_device, ) - current_stream = torch.cuda.current_stream() - if len(self.ssd_scratch_pads) > 0: + if self.prefetch_pipeline and len(self.ssd_scratch_pads) > 0: + # Look up all missed indices from the previous iteration's + # scratch pad (do this only if pipeline prefetching is being + # used) with record_function("## ssd_lookup_scratch_pad ##"): - current_stream.wait_event(self.ssd_event_sp_idxq_insert) - current_stream.wait_event(self.ssd_event_get_inputs_cpy) - + # Get the previous scratch pad ( inserted_rows_prev, post_bwd_evicted_indices_cpu_prev, actions_count_cpu_prev, - do_evict_prev, ) = self.ssd_scratch_pads.pop(0) # Inserted indices that are found in the scratch pad # from the previous iteration - sp_locations_cpu = torch.empty( + sp_prev_curr_map_cpu = torch.empty( inserted_indices_cpu.shape, dtype=inserted_indices_cpu.dtype, pin_memory=True, ) + # Conflict missed indices from the previous iteration that + # overlap with the current iterations's inserted indices + sp_curr_prev_map_cpu = torch.empty( + post_bwd_evicted_indices_cpu_prev.shape, + dtype=torch.int, + pin_memory=True, + ).fill_(-1) + + # Ensure that the necessary D2H transfers are done + current_stream.wait_event(self.ssd_event_get_inputs_cpy) + # Ensure that the previous iteration's scratch pad indices + # insertion is complete + current_stream.wait_event(self.ssd_event_sp_idxq_insert) + # Before entering this function: inserted_indices_cpu # contains all linear indices that are missed from the # L1 cache # # After this function: inserted indices that are found # in the scratch pad from the previous iteration are - # stored in sp_locations_cpu, while the rests are + # stored in sp_prev_curr_map_cpu, while the rests are # stored in inserted_indices_cpu # # An invalid index is -1 or its position > @@ -832,44 +1020,55 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: self.record_function_via_dummy_profile( "## ssd_lookup_mask_and_pop_front ##", self.scratch_pad_idx_queue.lookup_mask_and_pop_front_cuda, - sp_locations_cpu, - post_bwd_evicted_indices_cpu_prev, - inserted_indices_cpu, - actions_count_cpu, + sp_prev_curr_map_cpu, # scratch_pad_prev_curr_map + sp_curr_prev_map_cpu, # scratch_pad_curr_prev_map + post_bwd_evicted_indices_cpu_prev, # scratch_pad_indices_prev + inserted_indices_cpu, # inserted_indices_curr + actions_count_cpu, # count_curr ) - # Transfer sp_locations_cpu to GPU - sp_locations_gpu = sp_locations_cpu.cuda(non_blocking=True) + # Mark scratch pad index queue lookup completion + current_stream.record_event(self.ssd_event_sp_idxq_lookup) + + # Transfer sp_prev_curr_map_cpu to GPU + sp_prev_curr_map_gpu = sp_prev_curr_map_cpu.cuda(non_blocking=True) + # Transfer sp_curr_prev_map_cpu to GPU + sp_curr_prev_map_gpu = sp_curr_prev_map_cpu.cuda(non_blocking=True) + + # Previously actions_count_gpu was recorded on another + # stream. Thus, we need to record it on this stream + actions_count_gpu.record_stream(current_stream) # Copy data from the previous iteration's scratch pad to # the current iteration's scratch pad torch.ops.fbgemm.masked_index_select( inserted_rows, - sp_locations_gpu, + sp_prev_curr_map_gpu, inserted_rows_prev, actions_count_gpu, ) - # Evict from scratch pad - if do_evict_prev: - torch.cuda.current_stream().record_event( - self.ssd_event_backward - ) - self.evict( - rows=inserted_rows_prev, - indices_cpu=post_bwd_evicted_indices_cpu_prev, - actions_count_cpu=actions_count_cpu_prev, - stream=self.ssd_eviction_stream, - pre_event=self.ssd_event_backward, - post_event=None, - is_rows_uvm=True, - name="scratch_pad", + # Record the tensors that will be pushed into a queue + # on the forward stream + if forward_stream: + sp_curr_prev_map_gpu.record_stream(forward_stream) + + # Store info for evicting the previous iteration's + # scratch pad after the corresponding backward pass is + # done + self.ssd_location_update_data.append( + ( + sp_curr_prev_map_gpu, + inserted_rows, ) + ) - # Ensure the previous iterations l3_db.set(..) has completed. + # Ensure the previous iterations eviction is complete current_stream.wait_event(self.ssd_event_evict) + # Ensure that D2H is done current_stream.wait_event(self.ssd_event_get_inputs_cpy) + # Fetch data from SSD if linear_cache_indices.numel() > 0: self.record_function_via_dummy_profile( "## ssd_get ##", @@ -878,8 +1077,11 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: inserted_rows, actions_count_cpu, ) + + # Record an event to mark the completion of `get_cuda` current_stream.record_event(self.ssd_event_get) + # Copy rows from the current iteration's scratch pad to L1 torch.ops.fbgemm.masked_index_put( self.lxu_cache_weights, assigned_cache_slots, @@ -895,36 +1097,105 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: actions_count_cpu=actions_count_cpu, stream=self.ssd_eviction_stream, pre_event=self.ssd_event_get, - post_event=self.ssd_event_evict, + # Record completion event after scratch pad eviction + # instead since that happens after L1 eviction + post_event=None, is_rows_uvm=False, name="cache", ) - with record_function("## ssd_tbe_lxu_cache_lookup ##"): - lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( - linear_cache_indices, - self.lxu_cache_state, - self.total_hash_size, - self.gather_ssd_cache_stats, - self.local_ssd_cache_stats, + # Generate row addresses (pointing to either L1 or the current + # iteration's scratch pad) + with record_function("## ssd_generate_row_addrs ##"): + lxu_cache_ptrs, post_bwd_evicted_indices = ( + torch.ops.fbgemm.ssd_generate_row_addrs( + lxu_cache_locations, + assigned_cache_slots, + linear_index_inverse_indices, + unique_indices_count_cumsum, + cache_set_inverse_indices, + self.lxu_cache_weights, + inserted_rows, + unique_indices_length, + inserted_indices, + ) + ) + + with record_function("## ssd_d2h_post_bwd_evicted_indices ##"): + # Transfer post_bwd_evicted_indices from GPU to CPU right away to + # increase a chance of overlapping with compute in the default stream + (post_bwd_evicted_indices_cpu,) = ( + self.to_pinned_cpu_on_stream_wait_on_another_stream( + tensors=[post_bwd_evicted_indices], + stream=self.ssd_eviction_stream, + stream_to_wait_on=current_stream, + post_event=None, + ) + ) + + if self.prefetch_pipeline: + # Insert the current iteration's conflict miss indices in the index + # queue for future lookup. + # + # post_bwd_evicted_indices_cpu is transferred on the + # ssd_eviction_stream stream so it does not need stream + # synchronization + # + # actions_count_cpu is transferred on the ssd_memcpy_stream stream. + # Thus, we have to explicitly sync the stream + with torch.cuda.stream(self.ssd_eviction_stream): + # Ensure that actions_count_cpu transfer is done + self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy) + # Ensure that the scratch pad index queue look up is complete + self.ssd_eviction_stream.wait_event(self.ssd_event_sp_idxq_lookup) + self.record_function_via_dummy_profile( + "## ssd_scratch_pad_idx_queue_insert ##", + self.scratch_pad_idx_queue.insert_cuda, + post_bwd_evicted_indices_cpu, + actions_count_cpu, + ) + # Mark the completion of scratch pad index insertion + self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert) + + prefetch_data = ( + lxu_cache_ptrs, + inserted_rows, + post_bwd_evicted_indices_cpu, + actions_count_cpu, + actions_count_gpu, + lxu_cache_locations, + cache_set_inverse_indices, + ) + + # Record tensors on the forward stream + if forward_stream is not None: + for t in prefetch_data: + if t.is_cuda: + t.record_stream(forward_stream) + + # Store scratch pad info for the lookup in the next iteration + # prefetch + self.ssd_scratch_pads.append( + ( + inserted_rows, + post_bwd_evicted_indices_cpu, + actions_count_cpu, ) + ) - # TODO: keep only necessary tensors - self.ssd_prefetch_data.append( + # Store scratch pad info for post backward eviction + self.ssd_scratch_pad_eviction_data.append( ( - linear_cache_indices, - assigned_cache_slots, - linear_index_inverse_indices, - unique_indices_count_cumsum, - cache_set_inverse_indices, inserted_rows, - unique_indices_length, - inserted_indices, + post_bwd_evicted_indices_cpu, actions_count_cpu, - lxu_cache_locations, + linear_cache_indices.numel() > 0, ) ) + # Store data for forward + self.ssd_prefetch_data.append(prefetch_data) + if self.gather_ssd_cache_stats: self.ssd_cache_stats = torch.add( self.ssd_cache_stats, self.local_ssd_cache_stats @@ -946,13 +1217,25 @@ def forward( self.prefetch(indices, offsets) assert len(self.ssd_prefetch_data) > 0 - prefetch_data = self.ssd_prefetch_data.pop(0) ( lxu_cache_ptrs, inserted_rows, post_bwd_evicted_indices_cpu, actions_count_cpu, - ) = self._compute_cache_ptrs(*prefetch_data) + actions_count_gpu, + lxu_cache_locations, + cache_set_inverse_indices, + ) = self.ssd_prefetch_data.pop(0) + + # Storing current iteration data for future use + self.current_iter_data = IterData( + indices, + offsets, + lxu_cache_locations, + lxu_cache_ptrs, + actions_count_gpu, + cache_set_inverse_indices, + ) common_args = invokers.lookup_args_ssd.CommonArgs( placeholder_autograd_tensor=self.placeholder_autograd_tensor, @@ -1114,9 +1397,6 @@ def flush(self) -> None: active_slots_mask_cpu.view(-1) ) - # Evict data from scratch pad if there is scratch pad in the queue - self._evict_from_scratch_pad(return_on_empty=True) - torch.cuda.current_stream().wait_stream(self.ssd_eviction_stream) self.ssd_db.set_cuda( diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_scratch_pad_indices_queue.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_scratch_pad_indices_queue.cpp index f96cd5424d..c174f2c196 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_scratch_pad_indices_queue.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_scratch_pad_indices_queue.cpp @@ -53,15 +53,20 @@ class SSDScratchPadIndicesQueueImpl } void lookup_mask_and_pop_front_cuda( - const Tensor& scratch_pad_locations, - const Tensor& scratch_pad_indices, - const Tensor& ssd_indices, - const Tensor& count) { + const Tensor& scratch_pad_prev_curr_map, + const Tensor& scratch_pad_curr_prev_map, + const Tensor& scratch_pad_indices_prev, + const Tensor& inserted_indices_curr, + const Tensor& count_curr) { // take reference to self to avoid lifetime issues. auto self = shared_from_this(); std::function* functor = new std::function([=]() { self->lookup_mask_and_pop_front( - scratch_pad_locations, scratch_pad_indices, ssd_indices, count); + scratch_pad_prev_curr_map, + scratch_pad_curr_prev_map, + scratch_pad_indices_prev, + inserted_indices_curr, + count_curr); }); AT_CUDA_CHECK(cudaStreamAddCallback( at::cuda::getCurrentCUDAStream(), @@ -101,91 +106,103 @@ class SSDScratchPadIndicesQueueImpl index_loc_map_queue.push(std::move(map)); } - /// Looks up `ssd_indices` in the front hash map in the queue. This - /// is equivalent to looking up indices in the front scratch pad. + /// Looks up `inserted_indices_curr` in the front hash map in the + /// queue. This is equivalent to looking up indices in the front + /// scratch pad (i.e., the previous scratch pad). /// /// If an index is found: /// - /// - Sets the corresponding `scratch_pad_locations` to the location - /// of the index in the scratch pad (the value in the hash map) + /// - Sets the corresponding `scratch_pad_prev_curr_map` to the + /// location of the index in the front scratch pad (the value in the + /// hash map). scratch_pad_prev_curr_map[i] is the location in the + /// previous scratch pad of the the current scratch pad's index i /// - /// - Sets the corresponding `ssd_indices` to the sentinel value - /// (this is to prevent looking up this index from SSD) + /// - Sets the corresponding `inserted_indices_curr` to the sentinel + /// value (this is to prevent looking up this index from SSD) /// - /// - Sets the `scratch_pad_indices` to the sentinel value (this is - /// to prevent evicting the corresponding row from the scratch pad). + /// - Sets the `scratch_pad_indices_prev` to the sentinel value + /// (this is to prevent evicting the corresponding row from the + /// previous scratch pad). /// - /// Else: Sets the corresponding `scratch_pad_locations` to the + /// Else: Sets the corresponding `scratch_pad_prev_curr_map` to the /// sentinel value (to indicate that the index is not found in the - /// scratch pad) + /// previous scratch pad). /// /// Once the process above is done, pop the hash map from the queue. /// - /// @param scratch_pad_locations The 1D output tensor that has the - /// same size as `ssd_indices`. It - /// contains locations of the - /// corresponding indices in the - /// scratch pad if they are found or - /// sentinel values - /// @param scratch_pad_indices The 1D tensor that contains scratch - /// pad indices, i.e., conflict missed - /// indices from the previous iteration. - /// The indices and their locations must - /// match the keys and values in the - /// front hash map. After this function, - /// the indices that are found will be - /// set to sentinel values to prevent - /// them from getting evicted - /// @param ssd_indices The 1D tensor that contains indices that are - /// missed from the L1 cache, i.e., all missed - /// indices (including conflict misses). After - /// this function, the indices that are found - /// will be set to sentinel values to prevent - /// them from being looked up in from SSD - /// @param count The tensor that contains the number of indices to - /// be processed + /// @param scratch_pad_prev_curr_map The 1D output tensor that has + /// the same size as `inserted_indices_curr`. It contains + /// locations of the corresponding indices (in the + /// current scratch pad) in the previous scratch pad if + /// they are found or sentinel values + /// @param scratch_pad_curr_prev_map The 1D output tensor that has + /// the same size as `scratch_pad_indices_prev`. It + /// contains locations of the corresponding indices (in + /// the previous scratch pad) in the current scratch pad + /// if they are found or sentinel values + /// @param scratch_pad_indices_prev The 1D tensor that contains + /// scratch pad indices, i.e., conflict missed indices + /// from the previous iteration. The indices and their + /// locations must match the keys and values in the front + /// hash map. After this function, the indices that are + /// found will be set to sentinel values to prevent them + /// from getting evicted + /// @param inserted_indices_curr The 1D tensor that contains indices + /// that are missed from the L1 cache, i.e., all missed + /// indices (including conflict misses) from the current + /// iteration. After this function, the indices that are + /// found will be set to sentinel values to prevent them + /// from being looked up in from SSD + /// @param count_curr The tensor that contains the number of indices + /// to be processed. /// /// @return Outputs are passed by reference. void lookup_mask_and_pop_front( - const Tensor& scratch_pad_locations, - const Tensor& scratch_pad_indices, - const Tensor& ssd_indices, - const Tensor& count) { + const Tensor& scratch_pad_prev_curr_map, + const Tensor& scratch_pad_curr_prev_map, + const Tensor& scratch_pad_indices_prev, + const Tensor& inserted_indices_curr, + const Tensor& count_curr) { TORCH_CHECK( index_loc_map_queue.size() > 0, "index_loc_map_queue must not be empty"); - const auto count_ = count.item(); - TORCH_CHECK(ssd_indices.numel() >= count_); - TORCH_CHECK(scratch_pad_locations.numel() == ssd_indices.numel()); + const auto count_ = count_curr.item(); + TORCH_CHECK(inserted_indices_curr.numel() >= count_); + TORCH_CHECK( + scratch_pad_prev_curr_map.numel() == inserted_indices_curr.numel()); auto& map = index_loc_map_queue.front(); - auto loc_acc = scratch_pad_locations.accessor(); - auto sp_acc = scratch_pad_indices.accessor(); - auto ssd_acc = ssd_indices.accessor(); + auto sp_prev_curr_map_acc = + scratch_pad_prev_curr_map.accessor(); + auto sp_indices_acc = scratch_pad_indices_prev.accessor(); + auto inserted_indices_acc = inserted_indices_curr.accessor(); + auto sp_curr_prev_map_acc = scratch_pad_curr_prev_map.accessor(); // Concurrent lookup is OK since it is read-only at::parallel_for(0, count_, 1, [&](int64_t start, int64_t end) { for (const auto i : c10::irange(start, end)) { - const auto val = ssd_acc[i]; + const auto val = inserted_indices_acc[i]; const auto val_loc = map->find(val); // If index is found in the map if (val_loc != map->end()) { const auto loc = val_loc->second; - // Store the location in the scratch pad - loc_acc[i] = loc; - // Set the scratch pad index as the sentinel value to - // prevent it from being evicted - sp_acc[loc] = sentinel_val_; + // Store the previous scratch pad location + sp_prev_curr_map_acc[i] = loc; + // Store the current scratch pad location + sp_curr_prev_map_acc[loc] = i; // Set the SSD index as the sentinel value to prevent it // from being looked up in SSD - ssd_acc[i] = sentinel_val_; + inserted_indices_acc[i] = sentinel_val_; + // Set the scratch pad index as the sentinel value to + // prevent it from being evicted + sp_indices_acc[loc] = sentinel_val_; } else { // Set the location to the sentinel value to indicate that // the index is not found in the scratch pad - loc_acc[i] = sentinel_val_; + sp_prev_curr_map_acc[i] = sentinel_val_; } } }); @@ -208,12 +225,17 @@ class SSDScratchPadIndicesQueue : public torch::jit::CustomClassHolder { } void lookup_mask_and_pop_front_cuda( - const Tensor& scratch_pad_locations, - const Tensor& scratch_pad_indices, - const Tensor& ssd_indices, - const Tensor& count) { + const Tensor& scratch_pad_prev_curr_map, + const Tensor& scratch_pad_curr_prev_map, + const Tensor& scratch_pad_indices_prev, + const Tensor& inserted_indices_curr, + const Tensor& count_curr) { impl_->lookup_mask_and_pop_front_cuda( - scratch_pad_locations, scratch_pad_indices, ssd_indices, count); + scratch_pad_prev_curr_map, + scratch_pad_curr_prev_map, + scratch_pad_indices_prev, + inserted_indices_curr, + count_curr); } int64_t size() { diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 7f37896ae8..441f9fb034 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -352,6 +352,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { (2 * (count_ + dbs_.size() - 1) / dbs_.size()) * (sizeof(int64_t) + sizeof(scalar_t) * D)); for (auto i = 0; i < count_; ++i) { + // TODO: Check whether this is OK + if (indices_acc[i] == -1) { + continue; + } if (db_shard(indices_acc[i], dbs_.size()) != shard) { continue; } diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 025063d837..b0c3fc060c 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -8,7 +8,8 @@ # pyre-ignore-all-errors[3,6,56] import unittest -from typing import List, Optional, Tuple + +from typing import Any, Dict, List, Optional, Tuple import hypothesis.strategies as st import numpy as np @@ -22,15 +23,45 @@ from .. import common # noqa E402 from ..common import open_source - if open_source: # pyre-ignore[21] from test_utils import gpu_unavailable, running_on_github else: from fbgemm_gpu.test.test_utils import gpu_unavailable, running_on_github +from enum import Enum + MAX_EXAMPLES = 40 +MAX_PIPELINE_EXAMPLES = 10 + +default_st: Dict["str", Any] = { + "T": st.integers(min_value=1, max_value=10), + "D": st.integers(min_value=2, max_value=128), + "B": st.integers(min_value=1, max_value=128), + "log_E": st.integers(min_value=3, max_value=5), + "L": st.integers(min_value=0, max_value=20), + "weighted": st.booleans(), + "cache_set_scale": st.sampled_from([0.0, 0.005, 1]), + "pooling_mode": st.sampled_from( + [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] + ), + "weights_precision": st.sampled_from([SparseType.FP32, SparseType.FP16]), + "output_dtype": st.sampled_from([SparseType.FP32, SparseType.FP16]), + "share_table": st.booleans(), +} + + +class PrefetchLocation(Enum): + BEFORE_FWD = 1 + BETWEEN_FWD_BWD = 2 + + +class FlushLocation(Enum): + AFTER_FWD = 1 + AFTER_BWD = 2 + BEFORE_TRAINING = 3 + ALL = 4 @unittest.skipIf(*running_on_github) @@ -152,6 +183,8 @@ def generate_ssd_tbes( output_dtype: SparseType = SparseType.FP32, stochastic_rounding: bool = True, share_table: bool = False, + prefetch_pipeline: bool = False, + use_prefetch_stream: bool = False, ) -> Tuple[SSDTableBatchedEmbeddingBags, List[torch.nn.EmbeddingBag]]: """ Generate embedding modules (i,e., SSDTableBatchedEmbeddingBags and @@ -225,6 +258,8 @@ def generate_ssd_tbes( weights_precision=weights_precision, output_dtype=output_dtype, stochastic_rounding=stochastic_rounding, + prefetch_pipeline=prefetch_pipeline, + prefetch_stream=torch.cuda.Stream() if use_prefetch_stream else None, ).cuda() # A list to keep the CPU tensor alive until `set` (called inside @@ -350,21 +385,7 @@ def execute_ssd_forward_( ) return output_ref_list, output - @given( - T=st.integers(min_value=1, max_value=10), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - weighted=st.booleans(), - cache_set_scale=st.sampled_from([0.0, 0.005, 1]), - pooling_mode=st.sampled_from( - [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] - ), - weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - share_table=st.booleans(), - ) + @given(**default_st) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_forward( self, @@ -426,21 +447,7 @@ def test_ssd_forward( weighted, ) - @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - weighted=st.booleans(), - cache_set_scale=st.sampled_from([0.0, 0.005, 1]), - pooling_mode=st.sampled_from( - [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] - ), - weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - share_table=st.booleans(), - ) + @given(**default_st) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_backward_adagrad( self, @@ -576,23 +583,7 @@ def test_ssd_backward_adagrad( rtol=tolerance, ) - @given( - T=st.integers(min_value=1, max_value=10), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - weighted=st.booleans(), - cache_set_scale=st.sampled_from([0.0, 0.005, 1]), - pooling_mode=st.sampled_from( - [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] - ), - weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - share_table=st.booleans(), - ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_ssd_cache( + def execute_ssd_cache_pipeline_( # noqa C901 self, T: int, D: int, @@ -605,8 +596,15 @@ def test_ssd_cache( weights_precision: SparseType, output_dtype: SparseType, share_table: bool, + prefetch_pipeline: bool, + # If True, prefetch will be invoked by the user. + explicit_prefetch: bool, + prefetch_location: Optional[PrefetchLocation], + use_prefetch_stream: bool, + flush_location: Optional[FlushLocation], ) -> None: - assume(not weighted or pooling_mode == PoolingMode.SUM) + # If using pipeline prefetching, explicit prefetching must be True + assert not prefetch_pipeline or explicit_prefetch lr = 0.5 eps = 0.2 @@ -636,6 +634,8 @@ def test_ssd_cache( # functionality of the cache stochastic_rounding=False, share_table=share_table, + prefetch_pipeline=prefetch_pipeline, + use_prefetch_stream=use_prefetch_stream, ) optimizer_states_ref = [ @@ -650,48 +650,72 @@ def test_ssd_cache( else 1.0e-2 ) + batches = [] for it in range(10): - ( - indices_list, - per_sample_weights_list, - indices, - offsets, - per_sample_weights, - ) = self.generate_inputs_( - B, - L, - Es, - emb.feature_table_map, - weights_precision=weights_precision, + batches.append( + self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + ) ) - assert emb.timestep == it - emb.prefetch(indices, offsets) + prefetch_stream = emb.prefetch_stream if use_prefetch_stream else None + forward_stream = torch.cuda.current_stream() if use_prefetch_stream else None - linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( - emb.hash_size_cumsum, - indices, - offsets, - ) + iters = 10 + + force_flush = flush_location == FlushLocation.ALL + + if force_flush or flush_location == FlushLocation.BEFORE_TRAINING: + emb.flush() + + # pyre-ignore[53] + def _prefetch(b_it: int) -> int: + if not explicit_prefetch or b_it >= iters: + return b_it + 1 - # Verify that prefetching twice avoids any actions. ( _, _, + indices, + offsets, _, - actions_count_gpu, - _, - _, - _, - _, - ) = torch.ops.fbgemm.ssd_cache_populate_actions( # noqa - linear_cache_indices, - emb.total_hash_size, - emb.lxu_cache_state, - emb.timestep, - 0, # prefetch_dist - emb.lru_state, - ) + ) = batches[b_it] + # print("prefetch {} indices {}".format(b_it, indices.unique())) + with torch.cuda.stream(prefetch_stream): + emb.prefetch(indices, offsets, forward_stream=forward_stream) + return b_it + 1 + + if prefetch_pipeline: + # Prefetch the first iteration + _prefetch(0) + b_it = 1 + else: + b_it = 0 + + for it in range(iters): + ( + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + ) = batches[it] + + # Ensure that prefetch i is done before forward i + if prefetch_stream: + assert forward_stream is not None + forward_stream.wait_stream(prefetch_stream) + + # Prefetch before forward + if ( + not prefetch_pipeline + or prefetch_location == PrefetchLocation.BEFORE_FWD + ): + b_it = _prefetch(b_it) # Execute forward output_ref_list, output = self.execute_ssd_forward_( @@ -709,6 +733,9 @@ def test_ssd_cache( it=it, ) + if force_flush or flush_location == FlushLocation.AFTER_FWD: + emb.flush() + # Generate output gradient output_grad_list = [torch.randn_like(out) for out in output_ref_list] @@ -728,9 +755,19 @@ def test_ssd_cache( D * 4, ) + # Prefetch between forward and backward + if ( + prefetch_pipeline + and prefetch_location == PrefetchLocation.BETWEEN_FWD_BWD + ): + b_it = _prefetch(b_it) + # Execute TBE SSD backward output.backward(grad_test) + if force_flush or flush_location == FlushLocation.AFTER_BWD: + emb.flush() + # Compare optimizer states split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): @@ -779,3 +816,136 @@ def test_ssd_cache( atol=tolerance, rtol=tolerance, ) + + @given( + flush_location=st.sampled_from(FlushLocation), + prefetch_pipeline=st.booleans(), + explicit_prefetch=st.booleans(), + prefetch_location=st.sampled_from(PrefetchLocation), + use_prefetch_stream=st.booleans(), + **default_st, + ) + @settings( + verbosity=Verbosity.verbose, max_examples=MAX_PIPELINE_EXAMPLES, deadline=None + ) + def test_ssd_cache_flush(self, **kwargs: Any): + """ + Test the correctness of the SSD cache prefetch workflow with + excessive flushing + """ + assume(not kwargs["weighted"] or kwargs["pooling_mode"] == PoolingMode.SUM) + assume(kwargs["prefetch_pipeline"] and kwargs["explicit_prefetch"]) + assume(not kwargs["use_prefetch_stream"] or kwargs["prefetch_pipeline"]) + self.execute_ssd_cache_pipeline_( + **kwargs, + ) + + @given(**default_st) + @settings( + verbosity=Verbosity.verbose, max_examples=MAX_PIPELINE_EXAMPLES, deadline=None + ) + def test_ssd_cache_implicit_prefetch(self, **kwargs: Any): + """ + Test the correctness of the SSD cache prefetch workflow + without pipeline prefetching and with implicit prefetching. + Implicit prefetching relies on TBE forward to invoke prefetch. + """ + assume(not kwargs["weighted"] or kwargs["pooling_mode"] == PoolingMode.SUM) + self.execute_ssd_cache_pipeline_( + prefetch_pipeline=False, + explicit_prefetch=False, + prefetch_location=None, + use_prefetch_stream=False, + flush_location=None, + **kwargs, + ) + + @given(**default_st) + @settings( + verbosity=Verbosity.verbose, max_examples=MAX_PIPELINE_EXAMPLES, deadline=None + ) + def test_ssd_cache_explicit_prefetch(self, **kwargs: Any): + """ + Test the correctness of the SSD cache prefetch workflow + without pipeline prefetching and with explicit prefetching + (the user explicitly invokes prefetch). Each prefetch invoked + before a forward TBE fetches data for that specific iteration. + + For example: + + ------------------------- Timeline ------------------------> + pf(i) -> fwd(i) -> ... -> pf(i+1) -> fwd(i+1) -> ... + + Note: + - pf(i) = prefetch of iteration i + - fwd(i) = forward TBE of iteration i + """ + assume(not kwargs["weighted"] or kwargs["pooling_mode"] == PoolingMode.SUM) + self.execute_ssd_cache_pipeline_( + prefetch_pipeline=False, + explicit_prefetch=True, + prefetch_location=None, + use_prefetch_stream=False, + flush_location=None, + **kwargs, + ) + + @given(use_prefetch_stream=st.booleans(), **default_st) + @settings( + verbosity=Verbosity.verbose, max_examples=MAX_PIPELINE_EXAMPLES, deadline=None + ) + def test_ssd_cache_pipeline_before_fwd(self, **kwargs: Any): + """ + Test the correctness of the SSD cache prefetch workflow with + pipeline prefetching when cache prefetching of the next + iteration is invoked before the forward TBE of the current + iteration. + + For example: + + ------------------------- Timeline ------------------------> + pf(i+1) -> fwd(i) -> ... -> pf(i+2) -> fwd(i+1) -> ... + + Note: + - pf(i) = prefetch of iteration i + - fwd(i) = forward TBE of iteration i + """ + assume(not kwargs["weighted"] or kwargs["pooling_mode"] == PoolingMode.SUM) + self.execute_ssd_cache_pipeline_( + prefetch_pipeline=True, + explicit_prefetch=True, + prefetch_location=PrefetchLocation.BEFORE_FWD, + flush_location=None, + **kwargs, + ) + + @given(use_prefetch_stream=st.booleans(), **default_st) + @settings( + verbosity=Verbosity.verbose, max_examples=MAX_PIPELINE_EXAMPLES, deadline=None + ) + def test_ssd_cache_pipeline_between_fwd_bwd(self, **kwargs: Any): + """ + Test the correctness of the SSD cache prefetch workflow with + pipeline prefetching when cache prefetching of the next + iteration is invoked after the forward TBE and before the + backward TBE of the current iteration. + + For example: + + ------------------------- Timeline ------------------------> + fwd(i) -> pf(i+1) -> bwd(i) -> ... -> fwd(i+1) -> pf(i+2) -> bwd(i+1) -> ... + + Note: + - pf(i) = prefetch of iteration i + - fwd(i) = forward TBE of iteration i + - bwd(i) = backward TBE of iteration i + """ + + assume(not kwargs["weighted"] or kwargs["pooling_mode"] == PoolingMode.SUM) + self.execute_ssd_cache_pipeline_( + prefetch_pipeline=True, + explicit_prefetch=True, + prefetch_location=PrefetchLocation.BETWEEN_FWD_BWD, + flush_location=None, + **kwargs, + ) diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py index 72948ba528..ef434dd8e1 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py @@ -259,7 +259,8 @@ def test_scratch_pad_indices_queue( # Run reference # Prepare inputs for the reference run - sp_locations_ref = torch.zeros_like(lookup_indices) + sp_prev_curr_map_ref = torch.zeros_like(lookup_indices) + sp_curr_prev_map_ref = torch.empty_like(indices, dtype=torch.int).fill_(-1) sp_indices = indices.clone().tolist() ssd_indices = lookup_indices.clone().tolist() @@ -273,30 +274,33 @@ def test_scratch_pad_indices_queue( ssd_idx = ssd_indices[i] if ssd_idx in sp_map: loc = sp_map[ssd_idx] - sp_locations_ref[i] = loc + sp_prev_curr_map_ref[i] = loc + sp_curr_prev_map_ref[loc] = i sp_indices[loc] = sentinel_value ssd_indices[i] = sentinel_value else: - sp_locations_ref[i] = sentinel_value + sp_prev_curr_map_ref[i] = sentinel_value all_lookup_outputs_ref.append( ( - sp_locations_ref, + sp_prev_curr_map_ref, torch.as_tensor(sp_indices), torch.as_tensor(ssd_indices), ) ) # Run test - sp_locations = torch.zeros_like(lookup_indices) + sp_prev_curr_map = torch.zeros_like(lookup_indices) + sp_curr_prev_map = torch.empty_like(indices, dtype=torch.int).fill_(-1) sp_idx_queue.lookup_mask_and_pop_front_cuda( - sp_locations, + sp_prev_curr_map, + sp_curr_prev_map, indices, lookup_indices, lookup_count, ) - all_lookup_outputs.append((sp_locations, indices, lookup_indices)) + all_lookup_outputs.append((sp_prev_curr_map, indices, lookup_indices)) # Ensure that the lookups are done torch.cuda.synchronize() @@ -304,7 +308,12 @@ def test_scratch_pad_indices_queue( # Compare results for test, ref in zip(all_lookup_outputs, all_lookup_outputs_ref): for name, test_, ref_ in zip( - ["scratch_pad_locations", "scratch_pad_indices", "ssd_indices"], + [ + "scratch_pad_prev_curr_map", + "scratch_pad_curr_prev_map", + "scratch_pad_indices", + "ssd_indices", + ], test, ref, ):