Skip to content

Commit c68f69d

Browse files
sarunyafacebook-github-bot
authored andcommitted
Add SSDScratchPadIndicesQueue lookup in frontend
Differential Revision: D60413116
1 parent 28bdcf3 commit c68f69d

File tree

1 file changed

+100
-13
lines changed

1 file changed

+100
-13
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 100 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -338,17 +338,21 @@ def __init__(
338338
self.ssd_event_evict = torch.cuda.Event()
339339
# SSD backward completion event
340340
self.ssd_event_backward = torch.cuda.Event()
341-
# SSD scratch pad eviction completion event
342-
self.ssd_event_evict_sp = torch.cuda.Event()
343341
# SSD get's input copy completion event
344342
self.ssd_event_get_inputs_cpy = torch.cuda.Event()
343+
# SSD scratch pad index queue insert completion event
344+
self.ssd_event_sp_idxq_insert = torch.cuda.Event()
345345

346346
self.timesteps_prefetched: List[int] = []
347-
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
348347
# TODO: add type annotation
349348
# pyre-fixme[4]: Attribute must be annotated.
350349
self.ssd_prefetch_data = []
351350

351+
# Scratch pad value queue
352+
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
353+
# Scratch pad index queue
354+
self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue(-1)
355+
352356
if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization:
353357
raise AssertionError(
354358
"weight_decay_mode = WeightDecayMode.COUNTER is not supported for SSD TBE."
@@ -424,10 +428,6 @@ def __init__(
424428
torch.zeros(0, device=self.current_device, dtype=torch.float)
425429
)
426430

427-
# Register backward hook for evicting rows from a scratch pad to SSD
428-
# post backward
429-
self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad)
430-
431431
assert optimizer in (
432432
OptimType.EXACT_ROWWISE_ADAGRAD,
433433
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
@@ -624,8 +624,14 @@ def evict(
624624
# actions_count_cpu.record_stream(self.ssd_eviction_stream)
625625
stream.record_event(post_event)
626626

627-
def _evict_from_scratch_pad(self, grad: Tensor) -> None:
628-
assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad"
627+
def _evict_from_scratch_pad(self, return_on_empty: bool) -> None:
628+
scratch_pad_len = len(self.ssd_scratch_pads)
629+
630+
if not return_on_empty:
631+
assert scratch_pad_len > 0, "There must be at least one scratch pad"
632+
elif scratch_pad_len == 0:
633+
return
634+
629635
(inserted_rows, post_bwd_evicted_indices_cpu, actions_count_cpu, do_evict) = (
630636
self.ssd_scratch_pads.pop(0)
631637
)
@@ -637,7 +643,7 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None:
637643
actions_count_cpu=actions_count_cpu,
638644
stream=self.ssd_eviction_stream,
639645
pre_event=self.ssd_event_backward,
640-
post_event=self.ssd_event_evict_sp,
646+
post_event=None,
641647
is_rows_uvm=True,
642648
name="scratch_pad",
643649
)
@@ -680,6 +686,20 @@ def _compute_cache_ptrs(
680686
)
681687
)
682688

689+
# Insert conflict miss indices in the index queue for future lookup
690+
# post_bwd_evicted_indices_cpu is transferred on the ssd_eviction_stream stream
691+
# actions_count_cpu is transferred on the ssd_memcpy_stream stream
692+
with torch.cuda.stream(self.ssd_eviction_stream):
693+
# Ensure that actions_count_cpu transfer is done
694+
self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy)
695+
self.record_function_via_dummy_profile(
696+
"## ssd_scratch_pad_idx_queue_insert ##",
697+
self.scratch_pad_idx_queue.insert_cuda,
698+
post_bwd_evicted_indices_cpu,
699+
actions_count_cpu,
700+
)
701+
self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert)
702+
683703
with record_function("## ssd_scratch_pads ##"):
684704
# Store scratch pad info for post backward eviction
685705
self.ssd_scratch_pads.append(
@@ -775,12 +795,76 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
775795
)
776796

777797
current_stream = torch.cuda.current_stream()
778-
779-
inserted_indices_cpu = self.to_pinned_cpu(inserted_indices)
798+
if len(self.ssd_scratch_pads) > 0:
799+
with record_function("## ssd_lookup_scratch_pad ##"):
800+
current_stream.wait_event(self.ssd_event_sp_idxq_insert)
801+
current_stream.wait_event(self.ssd_event_get_inputs_cpy)
802+
803+
(
804+
inserted_rows_prev,
805+
post_bwd_evicted_indices_cpu_prev,
806+
actions_count_cpu_prev,
807+
do_evict_prev,
808+
) = self.ssd_scratch_pads.pop(0)
809+
810+
# Inserted indices that are found in the scratch pad
811+
# from the previous iteration
812+
sp_locations_cpu = torch.empty(
813+
inserted_indices_cpu.shape,
814+
dtype=inserted_indices_cpu.dtype,
815+
pin_memory=True,
816+
)
817+
818+
# Before entering this function: inserted_indices_cpu
819+
# contains all linear indices that are missed from the
820+
# L1 cache
821+
#
822+
# After this function: inserted indices that are found
823+
# in the scratch pad from the previous iteration are
824+
# stored in sp_locations_cpu, while the rests are
825+
# stored in inserted_indices_cpu
826+
#
827+
# An invalid index is -1 or its position >
828+
# actions_count_cpu
829+
self.record_function_via_dummy_profile(
830+
"## ssd_lookup_mask_and_pop_front ##",
831+
self.scratch_pad_idx_queue.lookup_mask_and_pop_front_cuda,
832+
sp_locations_cpu,
833+
post_bwd_evicted_indices_cpu_prev,
834+
inserted_indices_cpu,
835+
actions_count_cpu,
836+
)
837+
838+
# Transfer sp_locations_cpu to GPU
839+
sp_locations_gpu = sp_locations_cpu.cuda(non_blocking=True)
840+
841+
# Copy data from the previous iteration's scratch pad to
842+
# the current iteration's scratch pad
843+
torch.ops.fbgemm.masked_index_select(
844+
inserted_rows,
845+
sp_locations_gpu,
846+
inserted_rows_prev,
847+
actions_count_gpu,
848+
)
849+
850+
# Evict from scratch pad
851+
if do_evict_prev:
852+
torch.cuda.current_stream().record_event(
853+
self.ssd_event_backward
854+
)
855+
self.evict(
856+
rows=inserted_rows_prev,
857+
indices_cpu=post_bwd_evicted_indices_cpu_prev,
858+
actions_count_cpu=actions_count_cpu_prev,
859+
stream=self.ssd_eviction_stream,
860+
pre_event=self.ssd_event_backward,
861+
post_event=None,
862+
is_rows_uvm=True,
863+
name="scratch_pad",
864+
)
780865

781866
# Ensure the previous iterations l3_db.set(..) has completed.
782867
current_stream.wait_event(self.ssd_event_evict)
783-
current_stream.wait_event(self.ssd_event_evict_sp)
784868
current_stream.wait_event(self.ssd_event_get_inputs_cpy)
785869

786870
if linear_cache_indices.numel() > 0:
@@ -1027,6 +1111,9 @@ def flush(self) -> None:
10271111
active_slots_mask_cpu.view(-1)
10281112
)
10291113

1114+
# Evict data from scratch pad if there is scratch pad in the queue
1115+
self._evict_from_scratch_pad(return_on_empty=True)
1116+
10301117
torch.cuda.current_stream().wait_stream(self.ssd_eviction_stream)
10311118

10321119
self.ssd_db.set_cuda(

0 commit comments

Comments
 (0)