Skip to content

kv cache inference logic with SWA#6

Open
sachin7695 wants to merge 1 commit intopriyammaz:mainfrom
sachin7695:kv-cache-swa
Open

kv cache inference logic with SWA#6
sachin7695 wants to merge 1 commit intopriyammaz:mainfrom
sachin7695:kv-cache-swa

Conversation

@sachin7695
Copy link

I tried to implement Key-Value cache with sliding window attention for the inference stage for token by token generation
Added a testing code to test the inference with KV cache !! This is the simplified approach to understand how kv-cache can be implemented to SWA (while the original being Rolling buffer) . This is a "Naive implementation"

@priyammaz
Copy link
Owner

I will look at this during the weekend! I am excited to add it in! At first glance it looks good, we will have to make a test showing inference time gains from this!

@priyammaz priyammaz self-assigned this Oct 9, 2025
@sachin7695
Copy link
Author

I will look at this during the weekend! I am excited to add it in! At first glance it looks good, we will have to make a test showing inference time gains from this!

Sure thing!! Have a look and lets discuss if there is anything that does not make sense, i will fix it !! Also i will try to add a test script for this espicially which will tell in a quantitative way how faster is this from vanilla attention without kv-cache with SWA.

@priyammaz
Copy link
Owner

As far as I can tell this is working great! But lets do one thing (as this is for learning purposes it will make the code more readable)

Here is my Cache class that I wrote for the Llama4 implementation (that closely follows the huggingface cache)

class Cache:
    """
    KV Cache Method that is close to the Huggingface DynamicCache
    https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
    """

    def __init__(self, config):

        ### Counter for Number of Tokens in Cache ###
        self._seen_tokens = 0

        ### Key/Value Cache (List of Tensor, where list is over model layers) ###
        self.key_cache = [torch.tensor([]) for _ in range(config.num_hidden_layers)]
        self.value_cache = [torch.tensor([]) for _ in range(config.num_hidden_layers)]

    def __repr__(self):        
        return f"DyanmicCache(Num_Layers: {len(self.key_cache)} | Cached Tokens: {self.key_cache[0].shape[2]})"
        
    def update(self, key_states, value_states, layer_idx):
        
        ### Only iterate num tokens seen on the first layer ###
        ### key_states (B x H x L x E)
        ### value_states (B x H x L x E)

        if layer_idx == 0:  
            self._seen_tokens += key_states.shape[-2]

        ### Append New key/Value states to key/value cache ###
        self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
        self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]
    
    def get_seq_len(self, layer_idx=0):
        return self.key_cache[layer_idx].shape[-2] if self.key_cache[layer_idx].numel() != 0 else 0

Lets port this code to use somethign like this instead. Basically a container that holds the KVs for ALL layers outside of the model! I am working on Multihead latent attention, we can easily add that in here too! Sliding window + latent attention caching!

@priyammaz
Copy link
Owner

I will start this in 3 days, but feel free to do it earlier if you have the time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments