You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #2963
This diff enables pipeline cache prefetching for SSD-TBE. This allows
prefetch for the next iteration's batch to be carried out while the
computation of the current batch is going on.
We have done the following to guarantee cache consistency when
pipeline prefetching is enabled:
(1) Enable cache line locking (implemented in D46172802, D47638502,
D60812956) to ensure that cache lines are not prematurely evicted by
the prefetch when the previous iteration's computation is not
complete.
(2) Lookup L1 cache, the previous iteration's scratch pad (let's call
it SP(i-1)), and SSD/L2 cache. Move rows from SSD/L2 and/or SP(i-1) to
either L1 or the current iteration's scratch pad (let's call it
SP(i)). Then we update the row pointers of the previous iteration's
indices based on the new locations, i.e., L1 or SP(i). The detailed
explaination of the process is shown in the figure below:
{F1802341461}
https://internalfb.com/excalidraw/EX264315
(3) Ensure proper synchronizations between streams and events
- Ensure that prefetch of iteration i is complete before backward TBE
of iteration i-1
- Ensure that prefetch of iteration i+1 starts after the backward TBE
of iteration i is complete
The following is how prefetch operators run on GPU streams/CPU:
{F1802798301}
**Usage:**
```
# Initialize the module with prefetch_pipeline=True
emb = SSDTableBatchedEmbeddingBags(
embedding_specs=...,
prefetch_pipeline=True,
).cuda()
# When calling prefetch, make sure to pass the forward stream if using
# prefetch_stream so that TBE records tensors on streams properly
with torch.cuda.stream(prefetch_stream):
emb.prefetch(
indices,
offsets,
forward_stream=forward_stream
)
```
Differential Revision: D60727327
0 commit comments