Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ssd_update_row_addrs #2953

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 107 additions & 20 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,17 +341,22 @@ def __init__(
self.ssd_event_evict = torch.cuda.Event()
# SSD backward completion event
self.ssd_event_backward = torch.cuda.Event()
# SSD scratch pad eviction completion event
self.ssd_event_evict_sp = torch.cuda.Event()
# SSD get's input copy completion event
self.ssd_event_get_inputs_cpy = torch.cuda.Event()
# SSD scratch pad index queue insert completion event
self.ssd_event_sp_idxq_insert = torch.cuda.Event()

self.timesteps_prefetched: List[int] = []
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
# TODO: add type annotation
# pyre-fixme[4]: Attribute must be annotated.
self.ssd_prefetch_data = []

# Scratch pad value queue
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
# pyre-ignore[4]
# Scratch pad index queue
self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue(-1)

if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization:
raise AssertionError(
"weight_decay_mode = WeightDecayMode.COUNTER is not supported for SSD TBE."
Expand Down Expand Up @@ -427,10 +432,6 @@ def __init__(
torch.zeros(0, device=self.current_device, dtype=torch.float)
)

# Register backward hook for evicting rows from a scratch pad to SSD
# post backward
self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad)

assert optimizer in (
OptimType.EXACT_ROWWISE_ADAGRAD,
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
Expand Down Expand Up @@ -578,8 +579,8 @@ def evict(
indices_cpu: Tensor,
actions_count_cpu: Tensor,
stream: torch.cuda.Stream,
pre_event: torch.cuda.Event,
post_event: torch.cuda.Event,
pre_event: Optional[torch.cuda.Event],
post_event: Optional[torch.cuda.Event],
is_rows_uvm: bool,
name: Optional[str] = "",
) -> None:
Expand Down Expand Up @@ -607,7 +608,8 @@ def evict(
"""
with record_function(f"## ssd_evict_{name} ##"):
with torch.cuda.stream(stream):
stream.wait_event(pre_event)
if pre_event is not None:
stream.wait_event(pre_event)

rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)

Expand All @@ -622,13 +624,17 @@ def evict(
self.timestep,
)

# TODO: is this needed?
# Need a way to synchronize
# actions_count_cpu.record_stream(self.ssd_eviction_stream)
stream.record_event(post_event)
if post_event is not None:
stream.record_event(post_event)

def _evict_from_scratch_pad(self, return_on_empty: bool) -> None:
scratch_pad_len = len(self.ssd_scratch_pads)

if not return_on_empty:
assert scratch_pad_len > 0, "There must be at least one scratch pad"
elif scratch_pad_len == 0:
return

def _evict_from_scratch_pad(self, grad: Tensor) -> None:
assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad"
(inserted_rows, post_bwd_evicted_indices_cpu, actions_count_cpu, do_evict) = (
self.ssd_scratch_pads.pop(0)
)
Expand All @@ -640,7 +646,7 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None:
actions_count_cpu=actions_count_cpu,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_backward,
post_event=self.ssd_event_evict_sp,
post_event=None,
is_rows_uvm=True,
name="scratch_pad",
)
Expand Down Expand Up @@ -683,6 +689,20 @@ def _compute_cache_ptrs(
)
)

# Insert conflict miss indices in the index queue for future lookup
# post_bwd_evicted_indices_cpu is transferred on the ssd_eviction_stream stream
# actions_count_cpu is transferred on the ssd_memcpy_stream stream
with torch.cuda.stream(self.ssd_eviction_stream):
# Ensure that actions_count_cpu transfer is done
self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy)
self.record_function_via_dummy_profile(
"## ssd_scratch_pad_idx_queue_insert ##",
self.scratch_pad_idx_queue.insert_cuda,
post_bwd_evicted_indices_cpu,
actions_count_cpu,
)
self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert)

with record_function("## ssd_scratch_pads ##"):
# Store scratch pad info for post backward eviction
self.ssd_scratch_pads.append(
Expand Down Expand Up @@ -778,12 +798,76 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
)

current_stream = torch.cuda.current_stream()

inserted_indices_cpu = self.to_pinned_cpu(inserted_indices)
if len(self.ssd_scratch_pads) > 0:
with record_function("## ssd_lookup_scratch_pad ##"):
current_stream.wait_event(self.ssd_event_sp_idxq_insert)
current_stream.wait_event(self.ssd_event_get_inputs_cpy)

(
inserted_rows_prev,
post_bwd_evicted_indices_cpu_prev,
actions_count_cpu_prev,
do_evict_prev,
) = self.ssd_scratch_pads.pop(0)

# Inserted indices that are found in the scratch pad
# from the previous iteration
sp_locations_cpu = torch.empty(
inserted_indices_cpu.shape,
dtype=inserted_indices_cpu.dtype,
pin_memory=True,
)

# Before entering this function: inserted_indices_cpu
# contains all linear indices that are missed from the
# L1 cache
#
# After this function: inserted indices that are found
# in the scratch pad from the previous iteration are
# stored in sp_locations_cpu, while the rests are
# stored in inserted_indices_cpu
#
# An invalid index is -1 or its position >
# actions_count_cpu
self.record_function_via_dummy_profile(
"## ssd_lookup_mask_and_pop_front ##",
self.scratch_pad_idx_queue.lookup_mask_and_pop_front_cuda,
sp_locations_cpu,
post_bwd_evicted_indices_cpu_prev,
inserted_indices_cpu,
actions_count_cpu,
)

# Transfer sp_locations_cpu to GPU
sp_locations_gpu = sp_locations_cpu.cuda(non_blocking=True)

# Copy data from the previous iteration's scratch pad to
# the current iteration's scratch pad
torch.ops.fbgemm.masked_index_select(
inserted_rows,
sp_locations_gpu,
inserted_rows_prev,
actions_count_gpu,
)

# Evict from scratch pad
if do_evict_prev:
torch.cuda.current_stream().record_event(
self.ssd_event_backward
)
self.evict(
rows=inserted_rows_prev,
indices_cpu=post_bwd_evicted_indices_cpu_prev,
actions_count_cpu=actions_count_cpu_prev,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_backward,
post_event=None,
is_rows_uvm=True,
name="scratch_pad",
)

# Ensure the previous iterations l3_db.set(..) has completed.
current_stream.wait_event(self.ssd_event_evict)
current_stream.wait_event(self.ssd_event_evict_sp)
current_stream.wait_event(self.ssd_event_get_inputs_cpy)

if linear_cache_indices.numel() > 0:
Expand Down Expand Up @@ -1030,6 +1114,9 @@ def flush(self) -> None:
active_slots_mask_cpu.view(-1)
)

# Evict data from scratch pad if there is scratch pad in the queue
self._evict_from_scratch_pad(return_on_empty=True)

torch.cuda.current_stream().wait_stream(self.ssd_eviction_stream)

self.ssd_db.set_cuda(
Expand Down
Loading
Loading