-
Notifications
You must be signed in to change notification settings - Fork 108
Open
Labels
Description
🚀 Feature
Create a comprehensive integration test that demonstrates all symbolic value features working together in a realistic transformer KV-cache management scenario.
Depends On
- Support symbolic integer scalars #2735
- Support symbolic arithmetic operations #2736
- Support torch.sym_max/sym_min #2737
- Support symbolic values in torch.arange and other tensor operations #2738
- Support symbolic tensor shapes and dynamic dimensions #2741
Test Case
import torch
import thunder
@thunder.jit
def kv_cache_update(
input_ids: torch.Tensor, # [1, seq_len] - seq_len is symbolic
cumulative_length: int, # symbolic scalar
kv_cache: torch.Tensor, # [1, 8, 1056, 128]
sliding_window: int = 1056,
) -> tuple[torch.Tensor, int, torch.Tensor]:
"""
Realistic KV-cache update demonstrating all symbolic patterns.
This mirrors the pattern from Llama-4 model.
"""
batch_size, seq_len = input_ids.shape
# Pattern 1: Symbolic arithmetic
new_cumulative_length = cumulative_length + seq_len
# Pattern 2: torch.arange with symbolic bounds
cache_position = torch.arange(
cumulative_length,
new_cumulative_length,
device=input_ids.device
)
# Pattern 3: sym_max for sliding window
kv_offset = max(cumulative_length - sliding_window + 1, 0)
# Pattern 4: Symbolic tensor shapes
# Compute query states with symbolic seq_len dimension
hidden_dim = 5120
embedding_weight = torch.randn(
50000, hidden_dim,
device='cuda',
dtype=torch.bfloat16
)
hidden_states = torch.nn.functional.embedding(input_ids, embedding_weight)
# hidden_states.shape = [1, s50, 5120]
# Compute keys for cache update
key_proj = torch.randn(1024, hidden_dim, device='cuda', dtype=torch.bfloat16)
keys = torch.nn.functional.linear(hidden_states, key_proj)
# keys.shape = [1, s50, 1024]
keys = keys.view(batch_size, seq_len, 8, 128).transpose(1, 2)
# keys.shape = [1, 8, s50, 128]
# Update KV-cache using symbolic cache_position
kv_cache = kv_cache.index_copy_(2, cache_position, keys)
return kv_cache, new_cumulative_length, cache_position
kv_cache = torch.zeros(1, 8, 1056, 128, device='cuda', dtype=torch.bfloat16)
cumulative_length = 0
# Prefill phase - long sequence
input_ids_prefill = torch.randint(0, 50000, (1, 1024), device='cuda')
kv_cache, cumulative_length, cache_pos = kv_cache_update(
input_ids_prefill, cumulative_length, kv_cache
)
assert cumulative_length == 1024
assert cache_pos.shape == (1024,)
# Decode phase - single tokens
for _ in range(32):
input_ids_decode = torch.randint(0, 50000, (1, 1), device='cuda')
kv_cache, cumulative_length, cache_pos = kv_cache_update(
input_ids_decode, cumulative_length, kv_cache
)
assert cache_pos.shape == (1,)
assert cumulative_length == 1056Validation Criteria
[ ] Test compiles successfully with @thunder.jit
[ ] Works with variable sequence lengths without recompilation
[ ] Performance comparable to torch.compile
[ ] Symbolic values tracked correctly throughout
[ ] Generated code is efficient without recompilation
Performance Target
Should match or exceed torch.compile performance for this pattern.
Documentation
- Once passing, add to Thunder documentation as an example of:
- Dynamic shape handling
- KV-cache management patterns
Reactions are currently unavailable