diff --git a/pkg/plugins/gateway/scheduler/sessioninfo/cache.go b/pkg/plugins/gateway/scheduler/sessioninfo/cache.go new file mode 100644 index 000000000..650914901 --- /dev/null +++ b/pkg/plugins/gateway/scheduler/sessioninfo/cache.go @@ -0,0 +1,87 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessioninfo + +import "time" + +// SessionCache is the interface for session state management. +// It provides thread-safe operations for tracking session states +// across multiple concurrent requests. +// +// Implementations: +// - MutexSessionCache: Simple mutex-based implementation for low to medium concurrency +// - ShardedSessionCache: High-performance sharded implementation for high concurrency +type SessionCache interface { + // GetOrCreateForScheduler returns the inherited CST and total wait time + // for a new job in the given session. If the session doesn't exist, + // it will be created with zero values. + // + // This is the primary method used by the scheduler to get scheduling + // information for a new request. + // + // Returns: + // - cst: The critical path service time of the session + // - waitTime: The total accumulated wait time of the session + GetOrCreateForScheduler(sessionID string) (cst, waitTime time.Duration) + + // UpdateState atomically updates the session state after a request completes. + // + // The update logic follows the ATLAS algorithm: + // - TotalWaitTime is accumulated: totalWaitTime += waitTime + // - CriticalPathServiceTime is updated to max(current, inheritedCST + executionTime) + // + // Parameters: + // - sessionID: The session identifier + // - inheritedCST: The CST value inherited when the request started + // - executionTime: The actual execution time of this request + // - waitTime: The time this request spent waiting in the queue + UpdateState(sessionID string, inheritedCST, executionTime, waitTime time.Duration) + + // UpdateAffinity updates the pod affinity hint for a session. + // This can be used to optimize cache hits by routing subsequent + // requests from the same session to the same pod. + // + // This method also updates the LastActivityTimestamp to prevent + // the session from being cleaned up by the cleanup routine. + // + // Parameters: + // - sessionID: The session identifier + // - podName: The name of the pod to set as affinity hint + UpdateAffinity(sessionID, podName string) + + // GetState retrieves a copy of the full session state. + // This method is primarily used for testing and debugging. + // + // Returns: + // - state: A copy of the session state + // - exists: false if the session doesn't exist + GetState(sessionID string) (state SessionState, exists bool) + + // StartCleanupRoutine starts a background goroutine that periodically + // cleans up stale sessions that have been inactive for longer than timeout. + // + // The cleanup runs at the specified interval. Calling this method multiple + // times will start multiple cleanup routines (caller should avoid this). + // + // Parameters: + // - interval: How often to run the cleanup + // - timeout: Sessions inactive for longer than this will be removed + // + // Returns: + // - stop: A function that stops the cleanup routine when called + StartCleanupRoutine(interval, timeout time.Duration) (stop func()) +} diff --git a/pkg/plugins/gateway/scheduler/sessioninfo/cache_mutex.go b/pkg/plugins/gateway/scheduler/sessioninfo/cache_mutex.go new file mode 100644 index 000000000..08f19a21e --- /dev/null +++ b/pkg/plugins/gateway/scheduler/sessioninfo/cache_mutex.go @@ -0,0 +1,158 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessioninfo + +import ( + "sync" + "time" +) + +// SessionState holds all the scheduling-relevant information for a single session +type SessionState struct { + SessionID string // The session ID + CriticalPathServiceTime time.Duration // The critical path service time + TotalWaitTime time.Duration // The total wait time (anti-starvation) + PodAffinity string // The pod affinity (later may needed) + LastActivityTimestamp time.Time // The last activity timestamp +} + +// MutexSessionCache is a thread-safe, in-memory store for session states +// using a sync.RWMutex. +type MutexSessionCache struct { + mu sync.RWMutex // Protects the sessions map + sessions map[string]*SessionState // sessionID -> *SessionState +} + +// NewMutexSessionCache creates a new in-memory session cache protected by a mutex. +func NewMutexSessionCache() *MutexSessionCache { + return &MutexSessionCache{ + sessions: make(map[string]*SessionState), + } +} + +// getState is a private helper that assumes a write lock is already held. +// It ensures a session state exists before any operation. +func (sc *MutexSessionCache) getState(sessionID string) *SessionState { + state, exists := sc.sessions[sessionID] + if !exists { + state = &SessionState{ + SessionID: sessionID, + LastActivityTimestamp: time.Now(), + } + sc.sessions[sessionID] = state + } + return state +} + +// GetState retrieves a copy of the state for a given sessionID +// for read-only purposes. +// It returns false if the session does not exist. +func (sc *MutexSessionCache) GetState(sessionID string) (SessionState, bool) { + sc.mu.RLock() + defer sc.mu.RUnlock() + + state, exists := sc.sessions[sessionID] + if !exists { + return SessionState{}, false + } + + // Return a copy to ensure + // the caller cannot modify the internal state without a lock, + // which would cause a data race. + return *state, true +} + +// GetOrCreateForScheduler is the primary method for the scheduler +// to get the necessary info. +// It returns the inherited CST and total wait time for a new job. +func (sc *MutexSessionCache) GetOrCreateForScheduler(sessionID string) ( + time.Duration, time.Duration) { + sc.mu.Lock() // Use a write lock because we might create a session. + defer sc.mu.Unlock() + + state := sc.getState(sessionID) + return state.CriticalPathServiceTime, state.TotalWaitTime +} + +// UpdateState atomically updates the session state after a request is finished. +func (sc *MutexSessionCache) UpdateState(sessionID string, inheritedCST, + executionTime, waitTime time.Duration) { + sc.mu.Lock() + defer sc.mu.Unlock() + + state := sc.getState(sessionID) + + // Atomically update total wait time + state.TotalWaitTime += waitTime + + // Atomically update CriticalPathServiceTime (ATLAS logic) + newPathLength := inheritedCST + executionTime + if newPathLength > state.CriticalPathServiceTime { + state.CriticalPathServiceTime = newPathLength + } + + state.LastActivityTimestamp = time.Now() +} + +// UpdateAffinity updates the pod affinity for a session. +func (sc *MutexSessionCache) UpdateAffinity(sessionID, podName string) { + sc.mu.Lock() + defer sc.mu.Unlock() + + state := sc.getState(sessionID) + state.PodAffinity = podName + state.LastActivityTimestamp = time.Now() +} + +// StartCleanupRoutine starts a background goroutine that periodically +// cleans up stale sessions. +// It returns a function that can be called to stop the routine. +func (sc *MutexSessionCache) StartCleanupRoutine(interval, + timeout time.Duration) (stop func()) { + ticker := time.NewTicker(interval) + done := make(chan struct{}) + + go func() { + for { + select { + case <-ticker.C: + sc.cleanup(timeout) + case <-done: + ticker.Stop() + return + } + } + }() + + return func() { + close(done) + } +} + +// cleanup removes sessions that have been inactive for longer than the timeout. +// This is a private method that assumes the caller handles locking. +func (sc *MutexSessionCache) cleanup(timeout time.Duration) { + sc.mu.Lock() + defer sc.mu.Unlock() + + now := time.Now() + for sessionID, state := range sc.sessions { + if now.Sub(state.LastActivityTimestamp) > timeout { + delete(sc.sessions, sessionID) + } + } +} diff --git a/pkg/plugins/gateway/scheduler/sessioninfo/cache_mutex_test.go b/pkg/plugins/gateway/scheduler/sessioninfo/cache_mutex_test.go new file mode 100644 index 000000000..1e0a65a65 --- /dev/null +++ b/pkg/plugins/gateway/scheduler/sessioninfo/cache_mutex_test.go @@ -0,0 +1,134 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessioninfo + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestMutexCache_GetOrCreateForScheduler_NewSession tests the GetOrCreateForScheduler method. +func TestMutexCache_GetOrCreateForScheduler_NewSession(t *testing.T) { + cache := NewMutexSessionCache() + cst, waitTime := cache.GetOrCreateForScheduler("session1") + + assert.Equal(t, time.Duration(0), cst) + assert.Equal(t, time.Duration(0), waitTime) +} + +// TestMutexCache_UpdateState_Single tests the UpdateState method. +func TestMutexCache_UpdateState_Single(t *testing.T) { + cache := NewMutexSessionCache() + + // First update (like a serial request) + cache.UpdateState("session1", 0, 5*time.Second, 2*time.Second) + state, _ := cache.GetState("session1") + assert.Equal(t, 5*time.Second, state.CriticalPathServiceTime) + assert.Equal(t, 2*time.Second, state.TotalWaitTime) + + // Second update (another serial request) + // InheritedCST should be the CST from the previous state (5s) + cache.UpdateState("session1", 5*time.Second, 3*time.Second, 1*time.Second) + state, _ = cache.GetState("session1") + assert.Equal(t, 8*time.Second, state.CriticalPathServiceTime) // 5s + 3s + assert.Equal(t, 3*time.Second, state.TotalWaitTime) // 2s + 1s +} + +// TestMutexCache_UpdateState_Concurrent tests the UpdateState method. +func TestMutexCache_UpdateState_Concurrent(t *testing.T) { + cache := NewMutexSessionCache() + concurrency := 1000 + var wg sync.WaitGroup + wg.Add(concurrency) + + // Simulate 100 parallel requests for the same session finishing. + // All inherited CST=0, as they started when the session's CST was 0. + for i := 0; i < concurrency; i++ { + go func(execTimeMs int) { + defer wg.Done() + cache.UpdateState("session1", 0, + time.Duration(execTimeMs)*time.Millisecond, + 10*time.Millisecond) + }(i + 1) + } + wg.Wait() + + state, exists := cache.GetState("session1") + assert.True(t, exists) + + // Final CST should be the max of all new path lengths, + // which is max(0+1ms, 0+2ms, ... 0+100ms, ..., 0+1000ms) = 1000ms + assert.Equal(t, 1000*time.Millisecond, state.CriticalPathServiceTime) + + // Total wait time should be the sum of all wait times: + // 1000 * 10ms = 10000ms + assert.Equal(t, 10000*time.Millisecond, state.TotalWaitTime) +} + +// TestMutexCache_UpdateAffinity_Concurrent tests the UpdateAffinity method. +func TestMutexCache_UpdateAffinity_Concurrent(t *testing.T) { + cache := NewMutexSessionCache() + concurrency := 10 + var wg sync.WaitGroup + wg.Add(concurrency) + + for i := 0; i < concurrency; i++ { + go func(podNum int) { + defer wg.Done() + cache.UpdateAffinity("session1", + fmt.Sprintf("pod%d", podNum)) + }(i) + } + wg.Wait() + + state, exists := cache.GetState("session1") + assert.True(t, exists) + // Due to the race, we can't know the final value, + // but it must be one of the values we set. + assert.Contains(t, []string{"pod0", "pod1", "pod2", "pod3", + "pod4", "pod5", "pod6", "pod7", "pod8", "pod9"}, + state.PodAffinity) +} + +// TestMutexCache_Cleanup tests the Cleanup method. +func TestMutexCache_Cleanup(t *testing.T) { + cache := NewMutexSessionCache() + + // Create session1 + cache.UpdateState("session1", 0, 1*time.Second, 0) + + // Wait for 2 seconds, making session1 stale relative to a 1.5s timeout + time.Sleep(2 * time.Second) + + // Create/update session2, making it fresh + cache.UpdateState("session2", 0, 1*time.Second, 0) + + // Now, cleanup sessions older than 1.5 seconds + cache.cleanup(1500 * time.Millisecond) + + // session1 should be gone because it's ~2 seconds old + _, exists := cache.GetState("session1") + assert.False(t, exists, "session1 should be stale and cleaned up") + + // session2 should still exist because it's very fresh + _, exists = cache.GetState("session2") + assert.True(t, exists, "session2 should be fresh and remain") +} diff --git a/pkg/plugins/gateway/scheduler/sessioninfo/cache_sharded.go b/pkg/plugins/gateway/scheduler/sessioninfo/cache_sharded.go new file mode 100644 index 000000000..07f7be3ce --- /dev/null +++ b/pkg/plugins/gateway/scheduler/sessioninfo/cache_sharded.go @@ -0,0 +1,316 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessioninfo + +import ( + "hash" + "hash/fnv" + "sync" + "time" +) + +var hasherPool = sync.Pool{ + New: func() interface{} { + return fnv.New64a() + }, +} + +// --- Internal channel communication structs --- +type cacheOp int // operation code for cacheRequest + +const ( + opGetForScheduler cacheOp = iota + opGetFullState + opUpdateState + opUpdateAffinity + opCleanup +) + +// cacheRequest is the message format for all shard channels. +type cacheRequest struct { + op cacheOp + sessionID string + updatePayload updatePayload + affinityPayload string + cleanupPayload cleanupPayload + schedulerInfoRespChan chan schedulerInfoResponse + fullStateResponseChan chan fullStateResponse +} + +// updatePayload is the payload for opUpdateState +type updatePayload struct { + inheritedCST time.Duration + executionTime time.Duration + waitTime time.Duration +} + +// schedulerInfoResponse is the response for opGetForScheduler +type schedulerInfoResponse struct { + cst time.Duration + waitTime time.Duration +} + +// cleanupPayload is the payload for opCleanup +type cleanupPayload struct { + timeout time.Duration +} + +// fullStateResponse is the response for opGetFullState +type fullStateResponse struct { + state *SessionState +} + +// --- Shard and ShardedCache implementation --- + +// shardCount is the number of independent shards used to reduce lock contention +// in high-concurrency scenarios. +// +// Purpose: +// - Reduces lock contention probability to ~1/shardCount +// - Enables parallel processing: multiple sessions can be accessed simultaneously +// - Each shard has its own goroutine and processes requests via a channel (Actor Model) +// - No locks needed within each shard (single-threaded access to its map) +// +// Why 256? +// - Power of 2: Enables fast bitwise AND for hash-to-shard mapping (hash & 255) +// - Balanced: Not too small (still contention) or too large (goroutine overhead) +// - Multi-core friendly: Modern servers have 32-128 cores; 256 shards can fully utilize them +// +// Performance Trade-offs: +// - Higher values: Better concurrency, more goroutines/memory overhead +// - Lower values: Less overhead, more lock contention +// - Recommended range: 64-512 (must be power of 2) +// +// When to adjust: +// - Increase (e.g., 512) for extremely high concurrency (10K+ QPS) +// - Decrease (e.g., 64, 128) for lower concurrency or memory-constrained environments +// +// Note: This is for single-process concurrency optimization, NOT distributed partitioning. +const shardCount = 256 // Must be a power of 2 for bitwise AND optimization + +// cacheShard is a single shard of the sharded cache. +type cacheShard struct { + sessions map[string]*SessionState // sessionID -> *SessionState + requests chan cacheRequest // Channel for requests to this shard + done chan struct{} // Channel for shutdown +} + +// run is the main loop for each shard goroutine. +func (s *cacheShard) run(wg *sync.WaitGroup) { + defer wg.Done() + for req := range s.requests { + switch req.op { + case opGetForScheduler: + // Get or create session for scheduler + state, exists := s.sessions[req.sessionID] + if !exists { + state = &SessionState{ + SessionID: req.sessionID, + LastActivityTimestamp: time.Now(), + } + s.sessions[req.sessionID] = state + } + req.schedulerInfoRespChan <- schedulerInfoResponse{ + cst: state.CriticalPathServiceTime, + waitTime: state.TotalWaitTime, + } + case opGetFullState: + // Only return existing state, don't create + state, exists := s.sessions[req.sessionID] + if !exists { + req.fullStateResponseChan <- fullStateResponse{state: nil} + continue + } + stateCopy := *state + req.fullStateResponseChan <- fullStateResponse{state: &stateCopy} + case opUpdateState: + // Get or create session for update + state, exists := s.sessions[req.sessionID] + if !exists { + state = &SessionState{ + SessionID: req.sessionID, + LastActivityTimestamp: time.Now(), + } + s.sessions[req.sessionID] = state + } + payload := req.updatePayload + state.TotalWaitTime += payload.waitTime + newPathLength := payload.inheritedCST + payload.executionTime + if newPathLength > state.CriticalPathServiceTime { + state.CriticalPathServiceTime = newPathLength + } + state.LastActivityTimestamp = time.Now() + case opUpdateAffinity: + // Get or create session for affinity update + state, exists := s.sessions[req.sessionID] + if !exists { + state = &SessionState{ + SessionID: req.sessionID, + LastActivityTimestamp: time.Now(), + } + s.sessions[req.sessionID] = state + } + state.PodAffinity = req.affinityPayload + state.LastActivityTimestamp = time.Now() + case opCleanup: + payload := req.cleanupPayload + now := time.Now() + for sessionID, state := range s.sessions { + if now.Sub(state.LastActivityTimestamp) > payload.timeout { + delete(s.sessions, sessionID) + } + } + } + } +} + +// ShardedSessionCache is a highly concurrent, channel-based session cache. +type ShardedSessionCache struct { + shards []*cacheShard + wg sync.WaitGroup +} + +// NewShardedSessionCache creates and starts all shard goroutines. +func NewShardedSessionCache() *ShardedSessionCache { + sc := &ShardedSessionCache{ + shards: make([]*cacheShard, shardCount), + } + for i := 0; i < shardCount; i++ { + shard := &cacheShard{ + sessions: make(map[string]*SessionState), + requests: make(chan cacheRequest, 128), // Buffered channel per shard + done: make(chan struct{}), + } + sc.shards[i] = shard + sc.wg.Add(1) + go shard.run(&sc.wg) + } + return sc +} + +// getShard returns the shard for a given sessionID. +func (sc *ShardedSessionCache) getShard(sessionID string) *cacheShard { + hasher := hasherPool.Get().(hash.Hash64) + defer hasherPool.Put(hasher) + hasher.Reset() + hasher.Write([]byte(sessionID)) + return sc.shards[hasher.Sum64()&uint64(shardCount-1)] +} + +// --- Public API --- + +// GetOrCreateForScheduler is the primary method for the scheduler +func (sc *ShardedSessionCache) GetOrCreateForScheduler(sessionID string) (time.Duration, time.Duration) { + shard := sc.getShard(sessionID) + respChan := make(chan schedulerInfoResponse, 1) + shard.requests <- cacheRequest{ + op: opGetForScheduler, + sessionID: sessionID, + schedulerInfoRespChan: respChan, + } + info := <-respChan + return info.cst, info.waitTime +} + +// UpdateState is the primary method for the executor +func (sc *ShardedSessionCache) UpdateState(sessionID string, inheritedCST, executionTime, waitTime time.Duration) { + shard := sc.getShard(sessionID) + shard.requests <- cacheRequest{ + op: opUpdateState, + sessionID: sessionID, + updatePayload: updatePayload{ + inheritedCST: inheritedCST, + executionTime: executionTime, + waitTime: waitTime, + }, + } +} + +// UpdateAffinity is the primary method for the executor +func (sc *ShardedSessionCache) UpdateAffinity(sessionID, podName string) { + shard := sc.getShard(sessionID) + shard.requests <- cacheRequest{ + op: opUpdateAffinity, + sessionID: sessionID, + affinityPayload: podName, + } +} + +// GetState is provided for testing and debugging. +// Returns a copy of the session state to ensure thread safety. +func (sc *ShardedSessionCache) GetState(sessionID string) (SessionState, bool) { + shard := sc.getShard(sessionID) + respChan := make(chan fullStateResponse, 1) + shard.requests <- cacheRequest{ + op: opGetFullState, + sessionID: sessionID, + fullStateResponseChan: respChan, + } + info := <-respChan + if info.state == nil { + return SessionState{}, false + } + // Return a copy to match the interface signature + return *info.state, true +} + +// StartCleanupRoutine starts a background goroutine that periodically +// cleans up stale sessions across all shards. +// Returns a stop function that can be called to halt the cleanup routine. +func (sc *ShardedSessionCache) StartCleanupRoutine(interval, timeout time.Duration) (stop func()) { + ticker := time.NewTicker(interval) + done := make(chan struct{}) + + go func() { + defer ticker.Stop() + for { + select { + case <-ticker.C: + // Send cleanup request to all shards + req := cacheRequest{ + op: opCleanup, + cleanupPayload: cleanupPayload{ + timeout: timeout, + }, + } + for _, shard := range sc.shards { + // Use select to avoid panic if channel is closed + select { + case shard.requests <- req: + case <-done: + // Stop signal received while sending + return + } + } + case <-done: + return + } + } + }() + + return func() { + close(done) + } +} + +// Close shuts down all shard goroutines, not elegantly yet. +func (sc *ShardedSessionCache) Close() { + for _, shard := range sc.shards { + close(shard.requests) + } + sc.wg.Wait() +} diff --git a/pkg/plugins/gateway/scheduler/sessioninfo/cache_sharded_test.go b/pkg/plugins/gateway/scheduler/sessioninfo/cache_sharded_test.go new file mode 100644 index 000000000..07e9447aa --- /dev/null +++ b/pkg/plugins/gateway/scheduler/sessioninfo/cache_sharded_test.go @@ -0,0 +1,150 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessioninfo + +import ( + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// The sharded tests are structurally similar to the mutex tests, +// as they are testing the same public API contract. + +func TestShardedCache_GetOrCreateForScheduler_NewSession(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + cst, waitTime := cache.GetOrCreateForScheduler("session1") + + assert.Equal(t, time.Duration(0), cst) + assert.Equal(t, time.Duration(0), waitTime) +} + +func TestShardedCache_UpdateState_Single(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + cache.UpdateState("session1", 0, 5*time.Second, 2*time.Second) + assert.Eventually(t, func() bool { + cst, wait := cache.GetOrCreateForScheduler("session1") + return cst == 5*time.Second && wait == 2*time.Second + }, 250*time.Millisecond, 5*time.Millisecond) + + cache.UpdateState("session1", 5*time.Second, 3*time.Second, 1*time.Second) + assert.Eventually(t, func() bool { + cst, wait := cache.GetOrCreateForScheduler("session1") + return cst == 8*time.Second && wait == 3*time.Second + }, 250*time.Millisecond, 5*time.Millisecond) +} + +func TestShardedCache_UpdateState_Concurrent(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + concurrency := 100 + var done atomic.Int32 + for i := 0; i < concurrency; i++ { + go func(execTimeMs int) { + cache.UpdateState("session1", 0, time.Duration(execTimeMs)*time.Millisecond, 10*time.Millisecond) + done.Add(1) + }(i + 1) + } + assert.Eventually(t, func() bool { + return done.Load() == int32(concurrency) + }, time.Second, 10*time.Millisecond) + assert.Eventually(t, func() bool { + cst, wait := cache.GetOrCreateForScheduler("session1") + return cst == 100*time.Millisecond && wait == 1000*time.Millisecond + }, 250*time.Millisecond, 5*time.Millisecond) +} + +func TestShardedCache_UpdateAffinity_Concurrent(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + // This test is harder to write for the channel version without a proper GetState that returns affinity. + // We'll skip the detailed check for now as the main purpose is to test the update mechanism. + concurrency := 10 + var wg sync.WaitGroup + wg.Add(concurrency) + + for i := 0; i < concurrency; i++ { + go func(podNum int) { + defer wg.Done() + cache.UpdateAffinity("session1", fmt.Sprintf("pod%d", podNum)) + }(i) + } + wg.Wait() + time.Sleep(50 * time.Millisecond) + + // We can't easily verify the result without a full GetState op. + // This test mainly serves to ensure no deadlocks occur. + t.Log("Concurrent affinity update test completed without deadlock.") +} + +func TestShardedCache_GetFullState(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + cache.UpdateState("session1", 0, 5*time.Second, 2*time.Second) + state, exists := cache.GetState("session1") + assert.True(t, exists) + assert.Equal(t, "session1", state.SessionID) + assert.Equal(t, 5*time.Second, state.CriticalPathServiceTime) + assert.Equal(t, 2*time.Second, state.TotalWaitTime) +} + +func TestShardedCache_Cleanup(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + // Ensure session1 and session2 hash to different shards if possible, or test with one + // For simplicity, we assume they might hash to the same shard, which is a valid test case. + + // Create session1 + cache.UpdateState("session1", 0, 1*time.Second, 0) + time.Sleep(10 * time.Millisecond) // wait for channel to process + + // Wait to make session1 stale + time.Sleep(2 * time.Second) + + // Create session2, making it fresh + cache.UpdateState("session2", 0, 1*time.Second, 0) + time.Sleep(10 * time.Millisecond) + + // Start cleanup routine with short interval + stop := cache.StartCleanupRoutine(100*time.Millisecond, 1500*time.Millisecond) + defer stop() + + // Wait for cleanup to run + time.Sleep(200 * time.Millisecond) + + // Check that session1 was cleaned up (should not exist) + state1, exists1 := cache.GetState("session1") + assert.False(t, exists1, "session1 should have been cleaned up") + assert.Equal(t, SessionState{}, state1, "Cleaned session should return zero value") + + // Check that session2 still exists (fresh) + state2, exists2 := cache.GetState("session2") + assert.True(t, exists2, "session2 should still exist") + assert.NotEqual(t, time.Duration(0), state2.CriticalPathServiceTime, "session2 should have non-zero CST") +} diff --git a/pkg/plugins/gateway/scheduler/sessioninfo/cache_test.go b/pkg/plugins/gateway/scheduler/sessioninfo/cache_test.go new file mode 100644 index 000000000..0419c4d24 --- /dev/null +++ b/pkg/plugins/gateway/scheduler/sessioninfo/cache_test.go @@ -0,0 +1,278 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessioninfo + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestInterfaceCompliance verifies that both implementations satisfy the SessionCache interface +func TestInterfaceCompliance(t *testing.T) { + var _ SessionCache = (*MutexSessionCache)(nil) + var _ SessionCache = (*ShardedSessionCache)(nil) + t.Log("Both MutexSessionCache and ShardedSessionCache implement SessionCache interface") +} + +// cacheTestSuite runs a comprehensive test suite against any SessionCache implementation +func cacheTestSuite(t *testing.T, name string, factory func() SessionCache, needsClose bool) { + t.Run(name, func(t *testing.T) { + t.Run("GetOrCreateForScheduler_NewSession", func(t *testing.T) { + cache := factory() + if needsClose { + defer cache.(*ShardedSessionCache).Close() + } + + cst, waitTime := cache.GetOrCreateForScheduler("session1") + assert.Equal(t, time.Duration(0), cst, "New session should have zero CST") + assert.Equal(t, time.Duration(0), waitTime, "New session should have zero wait time") + }) + + t.Run("UpdateState_SingleRequest", func(t *testing.T) { + cache := factory() + if needsClose { + defer cache.(*ShardedSessionCache).Close() + } + + // First request + cache.UpdateState("session1", 0, 5*time.Second, 2*time.Second) + + // For sharded cache, need to wait for async processing + if needsClose { + time.Sleep(50 * time.Millisecond) + } + + state, exists := cache.GetState("session1") + require.True(t, exists, "Session should exist after update") + assert.Equal(t, 5*time.Second, state.CriticalPathServiceTime, "CST should be 5s") + assert.Equal(t, 2*time.Second, state.TotalWaitTime, "Wait time should be 2s") + + // Second request (serial) + cache.UpdateState("session1", 5*time.Second, 3*time.Second, 1*time.Second) + + if needsClose { + time.Sleep(50 * time.Millisecond) + } + + state, exists = cache.GetState("session1") + require.True(t, exists) + assert.Equal(t, 8*time.Second, state.CriticalPathServiceTime, "CST should be 5s + 3s = 8s") + assert.Equal(t, 3*time.Second, state.TotalWaitTime, "Wait time should be 2s + 1s = 3s") + }) + + t.Run("UpdateState_ConcurrentRequests", func(t *testing.T) { + cache := factory() + if needsClose { + defer cache.(*ShardedSessionCache).Close() + } + + concurrency := 100 + var wg sync.WaitGroup + wg.Add(concurrency) + + // Simulate 100 parallel requests with different execution times + for i := 0; i < concurrency; i++ { + go func(execTimeMs int) { + defer wg.Done() + cache.UpdateState("session1", 0, + time.Duration(execTimeMs)*time.Millisecond, + 10*time.Millisecond) + }(i + 1) + } + wg.Wait() + + // Wait for async processing if needed + if needsClose { + time.Sleep(100 * time.Millisecond) + } + + state, exists := cache.GetState("session1") + require.True(t, exists) + + // CST should be max of all execution times: max(1ms, 2ms, ..., 100ms) = 100ms + assert.Equal(t, 100*time.Millisecond, state.CriticalPathServiceTime, + "CST should be the maximum execution time") + + // Total wait time should be sum: 100 * 10ms = 1000ms + assert.Equal(t, 1000*time.Millisecond, state.TotalWaitTime, + "Total wait time should be sum of all wait times") + }) + + t.Run("UpdateAffinity", func(t *testing.T) { + cache := factory() + if needsClose { + defer cache.(*ShardedSessionCache).Close() + } + + cache.UpdateAffinity("session1", "pod-1") + + if needsClose { + time.Sleep(50 * time.Millisecond) + } + + state, exists := cache.GetState("session1") + require.True(t, exists) + assert.Equal(t, "pod-1", state.PodAffinity, "Pod affinity should be set") + + // Update to different pod + cache.UpdateAffinity("session1", "pod-2") + + if needsClose { + time.Sleep(50 * time.Millisecond) + } + + state, exists = cache.GetState("session1") + require.True(t, exists) + assert.Equal(t, "pod-2", state.PodAffinity, "Pod affinity should be updated") + }) + + t.Run("GetState_NonExistentSession", func(t *testing.T) { + cache := factory() + if needsClose { + defer cache.(*ShardedSessionCache).Close() + } + + state, exists := cache.GetState("non-existent") + assert.False(t, exists, "Non-existent session should return false") + assert.Equal(t, SessionState{}, state, "Should return zero value for non-existent session") + }) + + t.Run("MultipleSessionsIsolation", func(t *testing.T) { + cache := factory() + if needsClose { + defer cache.(*ShardedSessionCache).Close() + } + + // Update multiple sessions + cache.UpdateState("session1", 0, 5*time.Second, 1*time.Second) + cache.UpdateState("session2", 0, 10*time.Second, 2*time.Second) + cache.UpdateState("session3", 0, 3*time.Second, 1*time.Second) + + if needsClose { + time.Sleep(50 * time.Millisecond) + } + + // Verify each session has independent state + state1, _ := cache.GetState("session1") + state2, _ := cache.GetState("session2") + state3, _ := cache.GetState("session3") + + assert.Equal(t, 5*time.Second, state1.CriticalPathServiceTime) + assert.Equal(t, 10*time.Second, state2.CriticalPathServiceTime) + assert.Equal(t, 3*time.Second, state3.CriticalPathServiceTime) + }) + + t.Run("StartCleanupRoutine", func(t *testing.T) { + cache := factory() + if needsClose { + defer cache.(*ShardedSessionCache).Close() + } + + // Create a session + cache.UpdateState("session1", 0, 1*time.Second, 0) + + if needsClose { + time.Sleep(50 * time.Millisecond) + } + + // Verify it exists + _, exists := cache.GetState("session1") + require.True(t, exists, "Session should exist before cleanup") + + // Start cleanup with very short timeout + stop := cache.StartCleanupRoutine(100*time.Millisecond, 50*time.Millisecond) + defer stop() + + // Wait for session to become stale and cleanup to run + time.Sleep(200 * time.Millisecond) + + // Session should be cleaned up + _, exists = cache.GetState("session1") + assert.False(t, exists, "Stale session should be cleaned up") + + // Create a new session after cleanup started + cache.UpdateState("session2", 0, 1*time.Second, 0) + + if needsClose { + time.Sleep(50 * time.Millisecond) + } + + // It should exist (fresh) + _, exists = cache.GetState("session2") + assert.True(t, exists, "Fresh session should not be cleaned up") + }) + }) +} + +// TestAllImplementations runs the test suite against all implementations +func TestAllImplementations(t *testing.T) { + cacheTestSuite(t, "MutexSessionCache", func() SessionCache { + return NewMutexSessionCache() + }, false) + + cacheTestSuite(t, "ShardedSessionCache", func() SessionCache { + return NewShardedSessionCache() + }, true) +} + +// BenchmarkImplementations compares performance of different implementations +func BenchmarkImplementations(b *testing.B) { + benchmarkCache := func(b *testing.B, name string, factory func() SessionCache, needsClose bool) { + b.Run(name, func(b *testing.B) { + cache := factory() + if needsClose { + defer cache.(*ShardedSessionCache).Close() + } + + b.Run("GetOrCreateForScheduler", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + sessionID := fmt.Sprintf("session-%d", i%1000) + cache.GetOrCreateForScheduler(sessionID) + i++ + } + }) + }) + + b.Run("UpdateState", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + sessionID := fmt.Sprintf("session-%d", i%1000) + cache.UpdateState(sessionID, 0, time.Millisecond, time.Millisecond) + i++ + } + }) + }) + }) + } + + benchmarkCache(b, "MutexSessionCache", func() SessionCache { + return NewMutexSessionCache() + }, false) + + benchmarkCache(b, "ShardedSessionCache", func() SessionCache { + return NewShardedSessionCache() + }, true) +} +