Conversation
|
I will look at this during the weekend! I am excited to add it in! At first glance it looks good, we will have to make a test showing inference time gains from this! |
Sure thing!! Have a look and lets discuss if there is anything that does not make sense, i will fix it !! Also i will try to add a test script for this espicially which will tell in a quantitative way how faster is this from vanilla attention without kv-cache with SWA. |
|
As far as I can tell this is working great! But lets do one thing (as this is for learning purposes it will make the code more readable) Here is my Cache class that I wrote for the Llama4 implementation (that closely follows the huggingface cache) class Cache:
"""
KV Cache Method that is close to the Huggingface DynamicCache
https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
"""
def __init__(self, config):
### Counter for Number of Tokens in Cache ###
self._seen_tokens = 0
### Key/Value Cache (List of Tensor, where list is over model layers) ###
self.key_cache = [torch.tensor([]) for _ in range(config.num_hidden_layers)]
self.value_cache = [torch.tensor([]) for _ in range(config.num_hidden_layers)]
def __repr__(self):
return f"DyanmicCache(Num_Layers: {len(self.key_cache)} | Cached Tokens: {self.key_cache[0].shape[2]})"
def update(self, key_states, value_states, layer_idx):
### Only iterate num tokens seen on the first layer ###
### key_states (B x H x L x E)
### value_states (B x H x L x E)
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
### Append New key/Value states to key/value cache ###
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def get_seq_len(self, layer_idx=0):
return self.key_cache[layer_idx].shape[-2] if self.key_cache[layer_idx].numel() != 0 else 0Lets port this code to use somethign like this instead. Basically a container that holds the KVs for ALL layers outside of the model! I am working on Multihead latent attention, we can easily add that in here too! Sliding window + latent attention caching! |
|
I will start this in 3 days, but feel free to do it earlier if you have the time! |
I tried to implement Key-Value cache with sliding window attention for the inference stage for token by token generation
Added a testing code to test the inference with KV cache !! This is the simplified approach to understand how kv-cache can be implemented to SWA (while the original being Rolling buffer) . This is a "Naive implementation"