This work reduces the quadratic time complexity of self-attention.
It reduces per-token time complexity from
Our work act as memory modules. You can create a memory modules with a set of past keys, for example the keys from previous turns in a multiturn conversation. These keys will then be indexed and you can compute the dot product attention against those using the index and combine with rest of the keys using the attention output and the log-sum-exp of the attention scores from the memory module.
Consider a single attention head with query
This results in a time complexity of
Empirical observations indicate that performance remains invariant when attention is restricted to the
We approximate this top-$M$ selection via an ocality-sensitive hash (LSH) index, reducing query-time cost to
For a given head, keys K are indexed as follows:
-
Projection Computation: Compute
$Y = A K^T \in \mathbb{ℝ}^{C × N}$ , where$A \in \mathbb{R}^{C × d}$ is the LSH projection matrix (detailed later). -
Bucket Assignment: For each bucket
$i \in [1, C]$ , identify the top-$Z$ key indices $b_{i,1:Z} = \text{argtop}Z(Y{i,:})$, where $\text{argtop}Z$ denotes the indices of the $Z$ largest values. Assign key $k{b_{i,j}}$ to bucket$i$ . This differs from standard cross-polytope LSH, which assigns to a single maximizer. -
Computing Attention: For query
$q$ , compute$y_q = A q \in \mathbb{R}^C$ and hash $h_q = \text{argmax}i(y{q,i})$. Retrieve keys$k_j | j \in b_{h_q}$ , apply attention with these keys.
Time complexity:
To enhance collision sensitivity tailored to model-specific token distributions, the projection matrix
The construction process comprises the following steps:
-
Embedding Projection: Let
$X = [x_1, \dots, x_N] \in \mathbb{R}^{N \times d_m}$ represent the embeddings of$N$ prior tokens, where$d_m$ denotes the model embedding dimension. Each embedding$x_j$ is projected onto the unit sphere in query space using a function$f: \mathbb{R}^{d_m} \to S^{d-1}$ , yielding projected vectors$f(m_j)$ . -
Clustering: Perform K-means clustering on the projected embeddings
${f(m_j)}_{j=1}^N$ to obtain$C$ centroids$c_1, \dots, c_C \in S^{d-1}$ . The clustering employs a dot product-based distance metric, specifically minimizing$1 - c \cdot f(m_j)$ for centroid$c$ , which corresponds to maximizing cosine similarity on the unit sphere. -
Matrix Formation: Form the projection matrix
$A \in \mathbb{R}^{C \times d}$ by stacking the transpose of the centroids:$A = [c_1^T; \dots; c_C^T]$ .
The computational complexity of K-means is approximately
The function
- A direct projection using the model's query weights:
$f(x) = \frac{W_q x}{|W_q x|_2}$ , where$W_q$ are the query projection weights from the attention head. - A learned linear transformation.
- A multi-layer perceptron (MLP).
Empirical results indicate that a learned
For a batch of
For flash the HBM access is