feat: optimize skiplist_mask with O(1) lookup table approach #146
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
--- TL; DR ---
--- Full Detail ---
Optimize skiplist_mask with O(1) lookup table approach - 77.7x speedup on GPU
Summary
This PR optimizes the
skiplist_mask
method in ColBERT by replacing the O(n×m) algorithm with an O(1) lookup table approach, resulting in 77.7x speedup on GPU and 33.2x speedup on CPU.Problem
The current
skiplist_mask
implementation has O(n×m) time complexity where:For each token in the skiplist, it performs a
torch.where
operation across all input tokens, creating m intermediate tensors and executing m comparison operations.Solution
This PR introduces a lookup table (LUT) based approach with O(1) complexity per token:
vocab_size
where skiplist tokens are marked as Falselut[input_ids]
for maskingKey Implementation Details
Performance Results
Benchmarked on NVIDIA L4 GPU with typical ColBERT workloads:
Speed Improvements
Cache Effectiveness
Visualization
The PR includes comprehensive benchmarks showing:
Correctness
All outputs are identical to the original implementation, verified through:
Code Quality
Testing
A comprehensive benchmark script is included (
tests/test_skiplist_optimization.py
) that:To run the benchmark:
Impact
This optimization significantly improves ColBERT training and inference performance:
Backward Compatibility
The optimization is fully backward compatible:
Memory Considerations
The LUT cache uses minimal memory:
References
Note to reviewers: The dramatic speedup comes from eliminating the nested loop structure. Instead of m passes over n tokens (O(n×m)), we now do a single indexing operation (O(n)) with a pre-computed lookup table.