diff --git a/CHANGELOG.md b/CHANGELOG.md index 9345487..05006f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ BUG FIXES: * RouterMapCallRWImpl: decode bucketCount into 32 bit integer instead of 16 bit * RouterMapCallRWImpl: fix concurrent access to idToResult map * BucketDiscovery: fix possible concurrent access to resultRs and err vars +* RouterMapCallRWImpl: compare totalBucketCount against r.cfg.TotalBucketCount +* issue #39: fixed concurrent access to routeMap: use consistent view (immutable object) + atomics FEATURES: diff --git a/api.go b/api.go index 3488816..d9b3bdf 100644 --- a/api.go +++ b/api.go @@ -361,7 +361,7 @@ func (r *Router) RouterMapCallRWImpl( return nil, err } - if totalBucketCount != r.knownBucketCount.Load() { + if totalBucketCount != int32(r.cfg.TotalBucketCount) { return nil, fmt.Errorf("unknown bucket counts %d", totalBucketCount) } diff --git a/api_test.go b/api_test.go index ce41308..7008c3d 100644 --- a/api_test.go +++ b/api_test.go @@ -3,6 +3,7 @@ package vshard_router // nolint: revive import ( "context" "fmt" + "sync/atomic" "testing" "time" @@ -50,7 +51,9 @@ func TestRouter_RouterCallImpl(t *testing.T) { Logger: &EmptyLogger{}, Metrics: &EmptyMetrics{}, }, - routeMap: make([]*Replicaset, 11), + view: &consistentView{ + routeMap: make([]atomic.Pointer[Replicaset], 11), + }, } futureError := fmt.Errorf("testErr") @@ -60,9 +63,9 @@ func TestRouter_RouterCallImpl(t *testing.T) { mPool := mockpool.NewPool(t) mPool.On("Do", mock.Anything, mock.Anything).Return(errFuture) - r.routeMap[5] = &Replicaset{ + r.view.routeMap[5].Store(&Replicaset{ conn: mPool, - } + }) _, _, err := r.RouterCallImpl(ctx, 5, CallOpts{Timeout: time.Second}, "test", []byte("test")) require.ErrorIs(t, futureError, err) diff --git a/discovery.go b/discovery.go index c19df33..da34f69 100644 --- a/discovery.go +++ b/discovery.go @@ -26,44 +26,17 @@ const ( DiscoveryModeOnce ) -type searchLock struct { - mu sync.RWMutex - perBucket []chan struct{} -} - -func (s *searchLock) WaitOnSearch(bucketID uint64) { - ch := s.perBucket[bucketID] - if ch == nil { - return - } - - <-ch -} - -func (s *searchLock) StartSearch(bucketID uint64) chan struct{} { - s.mu.Lock() - defer s.mu.Unlock() - - ch := make(chan struct{}) - s.perBucket[bucketID] = ch - - return ch -} - // BucketDiscovery search bucket in whole cluster func (r *Router) BucketDiscovery(ctx context.Context, bucketID uint64) (*Replicaset, error) { - r.searchLock.WaitOnSearch(bucketID) + view := r.getConsistentView() - rs := r.routeMap[bucketID] + rs := view.routeMap[bucketID].Load() if rs != nil { return rs, nil } // it`s ok if in the same time we have few active searches // mu per bucket is expansive - stopSearchCh := r.searchLock.StartSearch(bucketID) - defer close(stopSearchCh) - r.cfg.Logger.Info(ctx, fmt.Sprintf("Discovering bucket %d", bucketID)) idToReplicasetRef := r.getIDToReplicaset() @@ -85,6 +58,8 @@ func (r *Router) BucketDiscovery(ctx context.Context, bucketID uint64) (*Replica go func(rs *Replicaset, rsID uuid.UUID) { defer wg.Done() if _, err := rs.BucketStat(ctx, bucketID); err == nil { + // It's ok if several replicasets return ok to bucket_stat command for the same bucketID, + // just pick any of them. var res result res.rs, res.err = r.BucketSet(bucketID, rsID) resultAtomic.Store(&res) @@ -118,7 +93,9 @@ func (r *Router) BucketResolve(ctx context.Context, bucketID uint64) (*Replicase return nil, fmt.Errorf("bucket id is out of range: %d (total %d)", bucketID, r.cfg.TotalBucketCount) } - rs := r.routeMap[bucketID] + view := r.getConsistentView() + + rs := view.routeMap[bucketID].Load() if rs != nil { return rs, nil } @@ -134,11 +111,14 @@ func (r *Router) BucketResolve(ctx context.Context, bucketID uint64) (*Replicase // DiscoveryHandleBuckets arrange downloaded buckets to the route map so as they reference a given replicaset. func (r *Router) DiscoveryHandleBuckets(ctx context.Context, rs *Replicaset, buckets []uint64) { + view := r.getConsistentView() + count := rs.bucketCount.Load() + affected := make(map[*Replicaset]int) for _, bucketID := range buckets { - oldRs := r.routeMap[bucketID] + oldRs := view.routeMap[bucketID].Swap(rs) if oldRs != rs { count++ @@ -151,9 +131,8 @@ func (r *Router) DiscoveryHandleBuckets(ctx context.Context, rs *Replicaset, buc oldRs.bucketCount.Add(-1) } else { // router.known_bucket_count = router.known_bucket_count + 1 - r.knownBucketCount.Add(1) + view.knownBucketCount.Add(1) } - r.routeMap[bucketID] = rs } } @@ -177,10 +156,9 @@ func (r *Router) DiscoveryAllBuckets(ctx context.Context) error { r.log().Info(ctx, "start discovery all buckets") - knownBucket := atomic.Int32{} - errGr, ctx := errgroup.WithContext(ctx) + view := r.getConsistentView() idToReplicasetRef := r.getIDToReplicaset() for _, rs := range idToReplicasetRef { @@ -218,8 +196,9 @@ func (r *Router) DiscoveryAllBuckets(ctx context.Context) error { break } - r.routeMap[bucket] = rs - knownBucket.Add(1) + if old := view.routeMap[bucket].Swap(rs); old == nil { + view.knownBucketCount.Add(1) + } } // There are no more buckets @@ -239,8 +218,6 @@ func (r *Router) DiscoveryAllBuckets(ctx context.Context) error { } r.log().Info(ctx, fmt.Sprintf("discovery done since: %s", time.Since(t))) - r.knownBucketCount.Store(knownBucket.Load()) - return nil } diff --git a/discovery_test.go b/discovery_test.go index c93af5e..7b66586 100644 --- a/discovery_test.go +++ b/discovery_test.go @@ -2,40 +2,12 @@ package vshard_router //nolint:revive import ( "context" - "sync" + "sync/atomic" "testing" - "time" "github.com/stretchr/testify/require" ) -func TestSearchLock_WaitOnSearch(t *testing.T) { - lock := searchLock{ - mu: sync.RWMutex{}, - perBucket: make([]chan struct{}, 10), - } - - noLockStart := time.Now() - lock.WaitOnSearch(5) - require.True(t, time.Since(noLockStart) < time.Millisecond) - - lockStart := time.Now() - chStopSearch := lock.StartSearch(3) - - go func() { - time.Sleep(time.Millisecond * 10) - close(chStopSearch) - }() - - noLockStart = time.Now() - lock.WaitOnSearch(5) - require.True(t, time.Since(noLockStart) < time.Millisecond) - - lock.WaitOnSearch(3) - - require.True(t, time.Since(lockStart) < 12*time.Millisecond && time.Since(lockStart) > 9*time.Millisecond) -} - func TestRouter_BucketResolve_InvalidBucketID(t *testing.T) { ctx := context.TODO() @@ -44,7 +16,9 @@ func TestRouter_BucketResolve_InvalidBucketID(t *testing.T) { TotalBucketCount: uint64(10), Logger: &EmptyLogger{}, }, - routeMap: make([]*Replicaset, 11), + view: &consistentView{ + routeMap: make([]atomic.Pointer[Replicaset], 11), + }, } _, err := r.BucketResolve(ctx, 20) diff --git a/vshard.go b/vshard.go index 8ff116e..86b2c1b 100644 --- a/vshard.go +++ b/vshard.go @@ -19,6 +19,25 @@ var ( ErrTopologyProvider = fmt.Errorf("got error from topology provider") ) +// This data struct is instroduced by https://github.com/KaymeKaydex/go-vshard-router/issues/39. +// We use an array of atomics to lock-free handling elements of routeMap. +// knownBucketCount reflects a statistic over routeMap. +// knownBucketCount might be inconsistent for a few mksecs, because at first we change routeMap[bucketID], +// only after that we change knownBucketCount: this is not an atomic change of complex state. +// It it is not a problem at all. +// +// While changing `knownBucketCount` we heavily rely on commutative property of algebraic sum operation ("+"), +// due to this property we don't afraid any amount of concurrent modifications. +// See: https://en.wikipedia.org/wiki/Commutative_property +// +// Since RouteMapClean creates a new routeMap, we have to assign knownBucketCount := 0. +// But assign is not a commutative operation, therefore we have to create a completely new atomic variable, +// that reflects a statistic over newly created routeMap. +type consistentView struct { + routeMap []atomic.Pointer[Replicaset] + knownBucketCount atomic.Int32 +} + type Router struct { cfg Config @@ -32,10 +51,8 @@ type Router struct { idToReplicasetMutex sync.RWMutex idToReplicaset map[uuid.UUID]*Replicaset - routeMap []*Replicaset - searchLock searchLock - - knownBucketCount atomic.Int32 + viewMutex sync.RWMutex + view *consistentView // ----------------------- Map-Reduce ----------------------- // Storage Ref ID. It must be unique for each ref request @@ -55,6 +72,20 @@ func (r *Router) log() LogProvider { return r.cfg.Logger } +func (r *Router) getConsistentView() *consistentView { + r.viewMutex.RLock() + view := r.view + r.viewMutex.RUnlock() + + return view +} + +func (r *Router) setConsistentView(view *consistentView) { + r.viewMutex.Lock() + r.view = view + r.viewMutex.Unlock() +} + type Config struct { // Providers Logger LogProvider // Logger is not required @@ -112,11 +143,11 @@ func NewRouter(ctx context.Context, cfg Config) (*Router, error) { } router := &Router{ - cfg: cfg, - idToReplicaset: make(map[uuid.UUID]*Replicaset), - routeMap: make([]*Replicaset, cfg.TotalBucketCount+1), - searchLock: searchLock{mu: sync.RWMutex{}, perBucket: make([]chan struct{}, cfg.TotalBucketCount+1)}, - knownBucketCount: atomic.Int32{}, + cfg: cfg, + idToReplicaset: make(map[uuid.UUID]*Replicaset), + view: &consistentView{ + routeMap: make([]atomic.Pointer[Replicaset], cfg.TotalBucketCount+1), + }, } err = cfg.TopologyProvider.Init(router.Topology()) @@ -163,37 +194,42 @@ func (r *Router) BucketSet(bucketID uint64, rsID uuid.UUID) (*Replicaset, error) return nil, Errors[9] // NO_ROUTE_TO_BUCKET } - oldReplicaset := r.routeMap[bucketID] + view := r.getConsistentView() + oldReplicaset := view.routeMap[bucketID].Swap(rs) if oldReplicaset != rs { if oldReplicaset != nil { oldReplicaset.bucketCount.Add(-1) } else { - r.knownBucketCount.Add(1) + view.knownBucketCount.Add(1) } rs.bucketCount.Add(1) } - r.routeMap[bucketID] = rs - return rs, nil } func (r *Router) BucketReset(bucketID uint64) { - if bucketID > uint64(len(r.routeMap))+1 { + view := r.getConsistentView() + + if bucketID > uint64(len(view.routeMap))+1 { return } - r.knownBucketCount.Add(-1) - r.routeMap[bucketID] = nil + if old := view.routeMap[bucketID].Swap(nil); old != nil { + view.knownBucketCount.Add(-1) + } } func (r *Router) RouteMapClean() { idToReplicasetRef := r.getIDToReplicaset() - r.routeMap = make([]*Replicaset, r.cfg.TotalBucketCount+1) - r.knownBucketCount.Store(0) + newView := &consistentView{ + routeMap: make([]atomic.Pointer[Replicaset], r.cfg.TotalBucketCount+1), + } + + r.setConsistentView(newView) for _, rs := range idToReplicasetRef { rs.bucketCount.Store(0) diff --git a/vshard_test.go b/vshard_test.go index 3dc5911..6963322 100644 --- a/vshard_test.go +++ b/vshard_test.go @@ -1,6 +1,7 @@ package vshard_router //nolint:revive import ( + "sync/atomic" "testing" "github.com/stretchr/testify/require" @@ -31,8 +32,10 @@ func TestRouter_RouterBucketCount(t *testing.T) { func TestRouter_RouteMapClean(t *testing.T) { r := Router{ - cfg: Config{TotalBucketCount: 10}, - routeMap: make([]*Replicaset, 10), + cfg: Config{TotalBucketCount: 10}, + view: &consistentView{ + routeMap: make([]atomic.Pointer[Replicaset], 10), + }, } require.NotPanics(t, func() {