Skip to content

Integration Test: KV-Cache management with symbolic values #2742

@IvanYashchuk

Description

@IvanYashchuk

🚀 Feature

Create a comprehensive integration test that demonstrates all symbolic value features working together in a realistic transformer KV-cache management scenario.

Depends On

Test Case

import torch
import thunder

@thunder.jit
def kv_cache_update(
    input_ids: torch.Tensor,  # [1, seq_len] - seq_len is symbolic
    cumulative_length: int,    # symbolic scalar
    kv_cache: torch.Tensor,    # [1, 8, 1056, 128]
    sliding_window: int = 1056,
) -> tuple[torch.Tensor, int, torch.Tensor]:
    """
    Realistic KV-cache update demonstrating all symbolic patterns.
    
    This mirrors the pattern from Llama-4 model.
    """
    batch_size, seq_len = input_ids.shape
    
    # Pattern 1: Symbolic arithmetic
    new_cumulative_length = cumulative_length + seq_len
   
    # Pattern 2: torch.arange with symbolic bounds 
    cache_position = torch.arange(
        cumulative_length,
        new_cumulative_length,
        device=input_ids.device
    )
    
    # Pattern 3: sym_max for sliding window 
    kv_offset = max(cumulative_length - sliding_window + 1, 0)
    
    # Pattern 4: Symbolic tensor shapes 
    # Compute query states with symbolic seq_len dimension
    hidden_dim = 5120
    embedding_weight = torch.randn(
        50000, hidden_dim, 
        device='cuda', 
        dtype=torch.bfloat16
    )
    hidden_states = torch.nn.functional.embedding(input_ids, embedding_weight)
    # hidden_states.shape = [1, s50, 5120]
    
    # Compute keys for cache update
    key_proj = torch.randn(1024, hidden_dim, device='cuda', dtype=torch.bfloat16)
    keys = torch.nn.functional.linear(hidden_states, key_proj)
    # keys.shape = [1, s50, 1024]
    keys = keys.view(batch_size, seq_len, 8, 128).transpose(1, 2)
    # keys.shape = [1, 8, s50, 128]
    
    # Update KV-cache using symbolic cache_position
    kv_cache = kv_cache.index_copy_(2, cache_position, keys)
    
    return kv_cache, new_cumulative_length, cache_position


kv_cache = torch.zeros(1, 8, 1056, 128, device='cuda', dtype=torch.bfloat16)
cumulative_length = 0

# Prefill phase - long sequence
input_ids_prefill = torch.randint(0, 50000, (1, 1024), device='cuda')
kv_cache, cumulative_length, cache_pos = kv_cache_update(
    input_ids_prefill, cumulative_length, kv_cache
)
assert cumulative_length == 1024
assert cache_pos.shape == (1024,)

# Decode phase - single tokens
for _ in range(32):
    input_ids_decode = torch.randint(0, 50000, (1, 1), device='cuda')
    kv_cache, cumulative_length, cache_pos = kv_cache_update(
        input_ids_decode, cumulative_length, kv_cache
    )
    assert cache_pos.shape == (1,)

assert cumulative_length == 1056

Validation Criteria

[ ] Test compiles successfully with @thunder.jit
[ ] Works with variable sequence lengths without recompilation
[ ] Performance comparable to torch.compile
[ ] Symbolic values tracked correctly throughout
[ ] Generated code is efficient without recompilation

Performance Target

Should match or exceed torch.compile performance for this pattern.

Documentation

  • Once passing, add to Thunder documentation as an example of:
    • Dynamic shape handling
    • KV-cache management patterns

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions