diff --git a/api.go b/api.go index 9997ec0..cb8130c 100644 --- a/api.go +++ b/api.go @@ -3,13 +3,10 @@ package vshard_router //nolint:revive import ( "context" "fmt" - "sync" - "sync/atomic" "time" "github.com/google/uuid" "github.com/mitchellh/mapstructure" - "golang.org/x/sync/errgroup" "github.com/tarantool/go-tarantool/v2" "github.com/tarantool/go-tarantool/v2/pool" @@ -221,24 +218,6 @@ func (r *Router) RouterCallImpl(ctx context.Context, } } -// call function "storage_unref" if map_callrw is failed or successed -func (r *Router) callStorageUnref(idToReplicasetRef map[uuid.UUID]*Replicaset, refID int64) { - req := tarantool.NewCallRequest("vshard.storage._call") - req = req.Args([]interface{}{"storage_unref", refID}) - - for _, replicaset := range idToReplicasetRef { - conn := replicaset.conn - - future := conn.Do(req, pool.RW) - future.SetError(nil) - } -} - -type replicasetFuture struct { - id uuid.UUID - future *tarantool.Future -} - // RouterMapCallRWImpl perform call function on all masters in the cluster // with a guarantee that in case of success it was executed with all // buckets being accessible for reads and writes. @@ -248,221 +227,144 @@ func (r *Router) RouterMapCallRWImpl( args interface{}, opts CallOpts, ) (map[uuid.UUID]interface{}, error) { - if opts.Timeout == 0 { - opts.Timeout = CallTimeoutMin + const vshardStorageServiceCall = "vshard.storage._call" + + timeout := CallTimeoutMin + if opts.Timeout > 0 { + timeout = opts.Timeout } - timeout := opts.Timeout timeStart := time.Now() - refID := r.refID.Add(1) r.idToReplicasetMutex.RLock() idToReplicasetRef := r.idToReplicaset r.idToReplicasetMutex.RUnlock() - defer r.callStorageUnref(idToReplicasetRef, refID) - - mapCallCtx, cancel := context.WithTimeout(ctx, timeout) - - req := tarantool.NewCallRequest("vshard.storage._call") - req = req.Context(ctx) - - // ref stage: send - - req = req.Args([]interface{}{ - "storage_ref", - refID, - timeout, - }) - - g, gctx := errgroup.WithContext(mapCallCtx) - rsFutures := make(chan replicasetFuture) + defer func() { + // call function "storage_unref" if map_callrw is failed or successed + storageUnrefReq := tarantool.NewCallRequest(vshardStorageServiceCall) + storageUnrefReq = storageUnrefReq.Args([]interface{}{"storage_unref", refID}) - g.Go(func() error { - defer close(rsFutures) - - for id, replicaset := range idToReplicasetRef { - conn := replicaset.conn - - future := conn.Do(req, pool.RW) - if _, err := future.Get(); err != nil { - cancel() - - return fmt.Errorf("rs {%s} storage_ref err: %s", id.String(), err.Error()) - } - - select { - case <-gctx.Done(): - return gctx.Err() - case rsFutures <- replicasetFuture{ - id: id, - future: future, - }: - } + for _, rs := range idToReplicasetRef { + future := rs.conn.Do(storageUnrefReq, pool.RW) + future.SetError(nil) } + }() - return nil - }) - - // ref stage collect - - totalBucketCount := int32(0) - - for i := 0; i < int(r.nWorkers); i++ { - g.Go(func() error { - for rsFuture := range rsFutures { - future := rsFuture.future + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() - respData, err := future.Get() - if err != nil { - cancel() + // ref stage - return err - } + storageRefReq := tarantool.NewCallRequest(vshardStorageServiceCall) + storageRefReq = storageRefReq.Context(ctx) + storageRefReq = storageRefReq.Args([]interface{}{"storage_ref", refID, timeout}) - if respData[0] == nil { - vshardErr := &StorageCallAssertError{} - - err = mapstructure.Decode(respData[1], vshardErr) - if err != nil { - cancel() - - return err - } - - cancel() - - return vshardErr - } - - var bucketCount int32 - err = future.GetTyped(&[]interface{}{&bucketCount}) - if err != nil { - cancel() - - return err - } - - atomic.AddInt32(&totalBucketCount, bucketCount) - } - - return nil - }) + type replicasetFuture struct { + uuid uuid.UUID + future *tarantool.Future } - if err := g.Wait(); err != nil { - return nil, err - } + var rsFutures = make([]replicasetFuture, 0, len(idToReplicasetRef)) - if totalBucketCount != int32(r.cfg.TotalBucketCount) { - return nil, fmt.Errorf("unknown bucket counts %d", totalBucketCount) + // ref stage: send concurrent ref requests + for uuid, rs := range idToReplicasetRef { + rsFutures = append(rsFutures, replicasetFuture{ + uuid: uuid, + future: rs.conn.Do(storageRefReq, pool.RW), + }) } - // map stage: send - - g, gctx = errgroup.WithContext(mapCallCtx) - rsFutures = make(chan replicasetFuture) - req = req.Args([]interface{}{"storage_map", refID, fnc, args}) - - g.Go(func() error { - defer close(rsFutures) - - for id, replicaset := range idToReplicasetRef { - conn := replicaset.conn + // ref stage: get their responses + var totalBucketCount uint64 + for _, rsFuture := range rsFutures { + respData, err := rsFuture.future.Get() + if err != nil { + return nil, fmt.Errorf("rs {%s} storage_ref err: %v", rsFuture.uuid, err) + } - future := conn.Do(req, pool.RW) - if _, err := future.Get(); err != nil { - cancel() + if respData[0] == nil { + vshardErr := &StorageCallAssertError{} - return fmt.Errorf("rs {%s} storage_map err: %s", id.String(), err.Error()) + err = mapstructure.Decode(respData[1], vshardErr) + if err != nil { + return nil, err } - select { - case <-gctx.Done(): - return gctx.Err() - case rsFutures <- replicasetFuture{ - id: id, - future: future, - }: - } + return nil, vshardErr } - return nil - }) - - // map stage: collect - - var idToResultMutex sync.Mutex - idToResult := make(map[uuid.UUID]interface{}) - - for i := 0; i < int(r.nWorkers); i++ { - g.Go(func() error { - for rsFuture := range rsFutures { - future := rsFuture.future - - respData, err := future.Get() - if err != nil { - cancel() - - return err - } - - if len(respData) != 2 { - err = fmt.Errorf("invalid length of response data: must be = 2, current: %d", len(respData)) - cancel() + var bucketCount uint64 + err = rsFuture.future.GetTyped(&[]interface{}{&bucketCount}) + if err != nil { + return nil, err + } - return err - } + totalBucketCount += bucketCount + } - if respData[0] == nil { - vshardErr := &StorageCallAssertError{} + if totalBucketCount != r.cfg.TotalBucketCount { + return nil, fmt.Errorf("total bucket count got %d, expected %d", totalBucketCount, r.cfg.TotalBucketCount) + } - err = mapstructure.Decode(respData[1], vshardErr) - if err != nil { - cancel() + // map stage - return err - } + storageMapReq := tarantool.NewCallRequest(vshardStorageServiceCall) + storageMapReq = storageMapReq.Context(ctx) + storageMapReq = storageMapReq.Args([]interface{}{"storage_map", refID, fnc, args}) - cancel() + rsFutures = rsFutures[0:0] - return vshardErr - } + // map stage: send concurrent map requests + for uuid, rs := range idToReplicasetRef { + rsFutures = append(rsFutures, replicasetFuture{ + uuid: uuid, + future: rs.conn.Do(storageMapReq, pool.RW), + }) + } - isVShardRespOk := false + // map stage: get their responses + idToResult := make(map[uuid.UUID]interface{}) + for _, rsFuture := range rsFutures { + respData, err := rsFuture.future.Get() + if err != nil { + return nil, fmt.Errorf("rs {%s} storage_map err: %v", rsFuture.uuid, err) + } - err = future.GetTyped(&[]interface{}{&isVShardRespOk}) - if err != nil { - cancel() + if len(respData) != 2 { + return nil, fmt.Errorf("invalid length of response data: must be = 2, current: %d", len(respData)) + } - return err - } + if respData[0] == nil { + vshardErr := &StorageCallAssertError{} - if !isVShardRespOk { // error - errorResp := &StorageCallAssertError{} + err = mapstructure.Decode(respData[1], vshardErr) + if err != nil { + return nil, err + } - err = future.GetTyped(&[]interface{}{&isVShardRespOk, errorResp}) - if err != nil { - err = fmt.Errorf("cant get typed vshard err with err: %s", err) - } + return nil, vshardErr + } - cancel() + var isVShardRespOk bool + err = rsFuture.future.GetTyped(&[]interface{}{&isVShardRespOk}) + if err != nil { + return nil, err + } - return err - } + if !isVShardRespOk { // error + errorResp := &StorageCallAssertError{} - idToResultMutex.Lock() - idToResult[rsFuture.id] = respData[1] - idToResultMutex.Unlock() + err = rsFuture.future.GetTyped(&[]interface{}{&isVShardRespOk, errorResp}) + if err != nil { + return nil, fmt.Errorf("cant get typed vshard err with err: %v", err) } - return nil - }) - } + return nil, errorResp + } - if err := g.Wait(); err != nil { - return nil, err + idToResult[rsFuture.uuid] = respData[1] } r.metrics().RequestDuration(time.Since(timeStart), true, true) diff --git a/vshard.go b/vshard.go index 298831b..de3f399 100644 --- a/vshard.go +++ b/vshard.go @@ -58,9 +58,6 @@ type Router struct { // and therefore is global and monotonically growing. refID atomic.Int64 - // worker's count to proceed channel of replicaset's futures - nWorkers int32 - cancelDiscovery func() } @@ -156,13 +153,6 @@ func NewRouter(ctx context.Context, cfg Config) (*Router, error) { router.cancelDiscovery = cancelFunc } - nWorkers := int32(2) - if cfg.NWorkers > 0 { - nWorkers = cfg.NWorkers - } - - router.nWorkers = nWorkers - return router, err }