@@ -338,17 +338,21 @@ def __init__(
338
338
self .ssd_event_evict = torch .cuda .Event ()
339
339
# SSD backward completion event
340
340
self .ssd_event_backward = torch .cuda .Event ()
341
- # SSD scratch pad eviction completion event
342
- self .ssd_event_evict_sp = torch .cuda .Event ()
343
341
# SSD get's input copy completion event
344
342
self .ssd_event_get_inputs_cpy = torch .cuda .Event ()
343
+ # SSD scratch pad index queue insert completion event
344
+ self .ssd_event_sp_idxq_insert = torch .cuda .Event ()
345
345
346
346
self .timesteps_prefetched : List [int ] = []
347
- self .ssd_scratch_pads : List [Tuple [Tensor , Tensor , Tensor , bool ]] = []
348
347
# TODO: add type annotation
349
348
# pyre-fixme[4]: Attribute must be annotated.
350
349
self .ssd_prefetch_data = []
351
350
351
+ # Scratch pad value queue
352
+ self .ssd_scratch_pads : List [Tuple [Tensor , Tensor , Tensor , bool ]] = []
353
+ # Scratch pad index queue
354
+ self .scratch_pad_idx_queue = torch .classes .fbgemm .SSDScratchPadIndicesQueue (- 1 )
355
+
352
356
if weight_decay_mode == WeightDecayMode .COUNTER or counter_based_regularization :
353
357
raise AssertionError (
354
358
"weight_decay_mode = WeightDecayMode.COUNTER is not supported for SSD TBE."
@@ -424,10 +428,6 @@ def __init__(
424
428
torch .zeros (0 , device = self .current_device , dtype = torch .float )
425
429
)
426
430
427
- # Register backward hook for evicting rows from a scratch pad to SSD
428
- # post backward
429
- self .placeholder_autograd_tensor .register_hook (self ._evict_from_scratch_pad )
430
-
431
431
assert optimizer in (
432
432
OptimType .EXACT_ROWWISE_ADAGRAD ,
433
433
), f"Optimizer { optimizer } is not supported by SSDTableBatchedEmbeddingBags"
@@ -624,8 +624,14 @@ def evict(
624
624
# actions_count_cpu.record_stream(self.ssd_eviction_stream)
625
625
stream .record_event (post_event )
626
626
627
- def _evict_from_scratch_pad (self , grad : Tensor ) -> None :
628
- assert len (self .ssd_scratch_pads ) > 0 , "There must be at least one scratch pad"
627
+ def _evict_from_scratch_pad (self , return_on_empty : bool ) -> None :
628
+ scratch_pad_len = len (self .ssd_scratch_pads )
629
+
630
+ if not return_on_empty :
631
+ assert scratch_pad_len > 0 , "There must be at least one scratch pad"
632
+ elif scratch_pad_len == 0 :
633
+ return
634
+
629
635
(inserted_rows , post_bwd_evicted_indices_cpu , actions_count_cpu , do_evict ) = (
630
636
self .ssd_scratch_pads .pop (0 )
631
637
)
@@ -637,7 +643,7 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None:
637
643
actions_count_cpu = actions_count_cpu ,
638
644
stream = self .ssd_eviction_stream ,
639
645
pre_event = self .ssd_event_backward ,
640
- post_event = self . ssd_event_evict_sp ,
646
+ post_event = None ,
641
647
is_rows_uvm = True ,
642
648
name = "scratch_pad" ,
643
649
)
@@ -680,6 +686,20 @@ def _compute_cache_ptrs(
680
686
)
681
687
)
682
688
689
+ # Insert conflict miss indices in the index queue for future lookup
690
+ # post_bwd_evicted_indices_cpu is transferred on the ssd_eviction_stream stream
691
+ # actions_count_cpu is transferred on the ssd_memcpy_stream stream
692
+ with torch .cuda .stream (self .ssd_eviction_stream ):
693
+ # Ensure that actions_count_cpu transfer is done
694
+ self .ssd_eviction_stream .wait_event (self .ssd_event_get_inputs_cpy )
695
+ self .record_function_via_dummy_profile (
696
+ "## ssd_scratch_pad_idx_queue_insert ##" ,
697
+ self .scratch_pad_idx_queue .insert_cuda ,
698
+ post_bwd_evicted_indices_cpu ,
699
+ actions_count_cpu ,
700
+ )
701
+ self .ssd_eviction_stream .record_event (self .ssd_event_sp_idxq_insert )
702
+
683
703
with record_function ("## ssd_scratch_pads ##" ):
684
704
# Store scratch pad info for post backward eviction
685
705
self .ssd_scratch_pads .append (
@@ -775,12 +795,76 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
775
795
)
776
796
777
797
current_stream = torch .cuda .current_stream ()
778
-
779
- inserted_indices_cpu = self .to_pinned_cpu (inserted_indices )
798
+ if len (self .ssd_scratch_pads ) > 0 :
799
+ with record_function ("## ssd_lookup_scratch_pad ##" ):
800
+ current_stream .wait_event (self .ssd_event_sp_idxq_insert )
801
+ current_stream .wait_event (self .ssd_event_get_inputs_cpy )
802
+
803
+ (
804
+ inserted_rows_prev ,
805
+ post_bwd_evicted_indices_cpu_prev ,
806
+ actions_count_cpu_prev ,
807
+ do_evict_prev ,
808
+ ) = self .ssd_scratch_pads .pop (0 )
809
+
810
+ # Inserted indices that are found in the scratch pad
811
+ # from the previous iteration
812
+ sp_locations_cpu = torch .empty (
813
+ inserted_indices_cpu .shape ,
814
+ dtype = inserted_indices_cpu .dtype ,
815
+ pin_memory = True ,
816
+ )
817
+
818
+ # Before entering this function: inserted_indices_cpu
819
+ # contains all linear indices that are missed from the
820
+ # L1 cache
821
+ #
822
+ # After this function: inserted indices that are found
823
+ # in the scratch pad from the previous iteration are
824
+ # stored in sp_locations_cpu, while the rests are
825
+ # stored in inserted_indices_cpu
826
+ #
827
+ # An invalid index is -1 or its position >
828
+ # actions_count_cpu
829
+ self .record_function_via_dummy_profile (
830
+ "## ssd_lookup_mask_and_pop_front ##" ,
831
+ self .scratch_pad_idx_queue .lookup_mask_and_pop_front_cuda ,
832
+ sp_locations_cpu ,
833
+ post_bwd_evicted_indices_cpu_prev ,
834
+ inserted_indices_cpu ,
835
+ actions_count_cpu ,
836
+ )
837
+
838
+ # Transfer sp_locations_cpu to GPU
839
+ sp_locations_gpu = sp_locations_cpu .cuda (non_blocking = True )
840
+
841
+ # Copy data from the previous iteration's scratch pad to
842
+ # the current iteration's scratch pad
843
+ torch .ops .fbgemm .masked_index_select (
844
+ inserted_rows ,
845
+ sp_locations_gpu ,
846
+ inserted_rows_prev ,
847
+ actions_count_gpu ,
848
+ )
849
+
850
+ # Evict from scratch pad
851
+ if do_evict_prev :
852
+ torch .cuda .current_stream ().record_event (
853
+ self .ssd_event_backward
854
+ )
855
+ self .evict (
856
+ rows = inserted_rows_prev ,
857
+ indices_cpu = post_bwd_evicted_indices_cpu_prev ,
858
+ actions_count_cpu = actions_count_cpu_prev ,
859
+ stream = self .ssd_eviction_stream ,
860
+ pre_event = self .ssd_event_backward ,
861
+ post_event = None ,
862
+ is_rows_uvm = True ,
863
+ name = "scratch_pad" ,
864
+ )
780
865
781
866
# Ensure the previous iterations l3_db.set(..) has completed.
782
867
current_stream .wait_event (self .ssd_event_evict )
783
- current_stream .wait_event (self .ssd_event_evict_sp )
784
868
current_stream .wait_event (self .ssd_event_get_inputs_cpy )
785
869
786
870
if linear_cache_indices .numel () > 0 :
@@ -1027,6 +1111,9 @@ def flush(self) -> None:
1027
1111
active_slots_mask_cpu .view (- 1 )
1028
1112
)
1029
1113
1114
+ # Evict data from scratch pad if there is scratch pad in the queue
1115
+ self ._evict_from_scratch_pad (return_on_empty = True )
1116
+
1030
1117
torch .cuda .current_stream ().wait_stream (self .ssd_eviction_stream )
1031
1118
1032
1119
self .ssd_db .set_cuda (
0 commit comments