Skip to content

Add ssd_update_row_addrs #2953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 107 additions & 20 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
@@ -341,17 +341,22 @@ def __init__(
self.ssd_event_evict = torch.cuda.Event()
# SSD backward completion event
self.ssd_event_backward = torch.cuda.Event()
# SSD scratch pad eviction completion event
self.ssd_event_evict_sp = torch.cuda.Event()
# SSD get's input copy completion event
self.ssd_event_get_inputs_cpy = torch.cuda.Event()
# SSD scratch pad index queue insert completion event
self.ssd_event_sp_idxq_insert = torch.cuda.Event()

self.timesteps_prefetched: List[int] = []
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
# TODO: add type annotation
# pyre-fixme[4]: Attribute must be annotated.
self.ssd_prefetch_data = []

# Scratch pad value queue
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
# pyre-ignore[4]
# Scratch pad index queue
self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue(-1)

if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization:
raise AssertionError(
"weight_decay_mode = WeightDecayMode.COUNTER is not supported for SSD TBE."
@@ -427,10 +432,6 @@ def __init__(
torch.zeros(0, device=self.current_device, dtype=torch.float)
)

# Register backward hook for evicting rows from a scratch pad to SSD
# post backward
self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad)

assert optimizer in (
OptimType.EXACT_ROWWISE_ADAGRAD,
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
@@ -578,8 +579,8 @@ def evict(
indices_cpu: Tensor,
actions_count_cpu: Tensor,
stream: torch.cuda.Stream,
pre_event: torch.cuda.Event,
post_event: torch.cuda.Event,
pre_event: Optional[torch.cuda.Event],
post_event: Optional[torch.cuda.Event],
is_rows_uvm: bool,
name: Optional[str] = "",
) -> None:
@@ -607,7 +608,8 @@ def evict(
"""
with record_function(f"## ssd_evict_{name} ##"):
with torch.cuda.stream(stream):
stream.wait_event(pre_event)
if pre_event is not None:
stream.wait_event(pre_event)

rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)

@@ -622,13 +624,17 @@ def evict(
self.timestep,
)

# TODO: is this needed?
# Need a way to synchronize
# actions_count_cpu.record_stream(self.ssd_eviction_stream)
stream.record_event(post_event)
if post_event is not None:
stream.record_event(post_event)

def _evict_from_scratch_pad(self, return_on_empty: bool) -> None:
scratch_pad_len = len(self.ssd_scratch_pads)

if not return_on_empty:
assert scratch_pad_len > 0, "There must be at least one scratch pad"
elif scratch_pad_len == 0:
return

def _evict_from_scratch_pad(self, grad: Tensor) -> None:
assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad"
(inserted_rows, post_bwd_evicted_indices_cpu, actions_count_cpu, do_evict) = (
self.ssd_scratch_pads.pop(0)
)
@@ -640,7 +646,7 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None:
actions_count_cpu=actions_count_cpu,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_backward,
post_event=self.ssd_event_evict_sp,
post_event=None,
is_rows_uvm=True,
name="scratch_pad",
)
@@ -683,6 +689,20 @@ def _compute_cache_ptrs(
)
)

# Insert conflict miss indices in the index queue for future lookup
# post_bwd_evicted_indices_cpu is transferred on the ssd_eviction_stream stream
# actions_count_cpu is transferred on the ssd_memcpy_stream stream
with torch.cuda.stream(self.ssd_eviction_stream):
# Ensure that actions_count_cpu transfer is done
self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy)
self.record_function_via_dummy_profile(
"## ssd_scratch_pad_idx_queue_insert ##",
self.scratch_pad_idx_queue.insert_cuda,
post_bwd_evicted_indices_cpu,
actions_count_cpu,
)
self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert)

with record_function("## ssd_scratch_pads ##"):
# Store scratch pad info for post backward eviction
self.ssd_scratch_pads.append(
@@ -778,12 +798,76 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
)

current_stream = torch.cuda.current_stream()

inserted_indices_cpu = self.to_pinned_cpu(inserted_indices)
if len(self.ssd_scratch_pads) > 0:
with record_function("## ssd_lookup_scratch_pad ##"):
current_stream.wait_event(self.ssd_event_sp_idxq_insert)
current_stream.wait_event(self.ssd_event_get_inputs_cpy)

(
inserted_rows_prev,
post_bwd_evicted_indices_cpu_prev,
actions_count_cpu_prev,
do_evict_prev,
) = self.ssd_scratch_pads.pop(0)

# Inserted indices that are found in the scratch pad
# from the previous iteration
sp_locations_cpu = torch.empty(
inserted_indices_cpu.shape,
dtype=inserted_indices_cpu.dtype,
pin_memory=True,
)

# Before entering this function: inserted_indices_cpu
# contains all linear indices that are missed from the
# L1 cache
#
# After this function: inserted indices that are found
# in the scratch pad from the previous iteration are
# stored in sp_locations_cpu, while the rests are
# stored in inserted_indices_cpu
#
# An invalid index is -1 or its position >
# actions_count_cpu
self.record_function_via_dummy_profile(
"## ssd_lookup_mask_and_pop_front ##",
self.scratch_pad_idx_queue.lookup_mask_and_pop_front_cuda,
sp_locations_cpu,
post_bwd_evicted_indices_cpu_prev,
inserted_indices_cpu,
actions_count_cpu,
)

# Transfer sp_locations_cpu to GPU
sp_locations_gpu = sp_locations_cpu.cuda(non_blocking=True)

# Copy data from the previous iteration's scratch pad to
# the current iteration's scratch pad
torch.ops.fbgemm.masked_index_select(
inserted_rows,
sp_locations_gpu,
inserted_rows_prev,
actions_count_gpu,
)

# Evict from scratch pad
if do_evict_prev:
torch.cuda.current_stream().record_event(
self.ssd_event_backward
)
self.evict(
rows=inserted_rows_prev,
indices_cpu=post_bwd_evicted_indices_cpu_prev,
actions_count_cpu=actions_count_cpu_prev,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_backward,
post_event=None,
is_rows_uvm=True,
name="scratch_pad",
)

# Ensure the previous iterations l3_db.set(..) has completed.
current_stream.wait_event(self.ssd_event_evict)
current_stream.wait_event(self.ssd_event_evict_sp)
current_stream.wait_event(self.ssd_event_get_inputs_cpy)

if linear_cache_indices.numel() > 0:
@@ -1030,6 +1114,9 @@ def flush(self) -> None:
active_slots_mask_cpu.view(-1)
)

# Evict data from scratch pad if there is scratch pad in the queue
self._evict_from_scratch_pad(return_on_empty=True)

torch.cuda.current_stream().wait_stream(self.ssd_eviction_stream)

self.ssd_db.set_cuda(
Loading