Skip to content
This repository has been archived by the owner on Dec 23, 2024. It is now read-only.

Commit

Permalink
resolve issue #46
Browse files Browse the repository at this point in the history
* simplify RouterMapCallRWImpl: don't use errgroup, channels and workers
	just send a bunch of requests and wait for their results (similar to the lua implementation)
  • Loading branch information
nurzhan-saktaganov committed Aug 26, 2024
1 parent 989d44a commit 863742d
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 203 deletions.
288 changes: 95 additions & 193 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@ package vshard_router //nolint:revive
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"

"github.com/google/uuid"
"github.com/mitchellh/mapstructure"
"golang.org/x/sync/errgroup"

"github.com/tarantool/go-tarantool/v2"
"github.com/tarantool/go-tarantool/v2/pool"
Expand Down Expand Up @@ -221,24 +218,6 @@ func (r *Router) RouterCallImpl(ctx context.Context,
}
}

// call function "storage_unref" if map_callrw is failed or successed
func (r *Router) callStorageUnref(idToReplicasetRef map[uuid.UUID]*Replicaset, refID int64) {
req := tarantool.NewCallRequest("vshard.storage._call")
req = req.Args([]interface{}{"storage_unref", refID})

for _, replicaset := range idToReplicasetRef {
conn := replicaset.conn

future := conn.Do(req, pool.RW)
future.SetError(nil)
}
}

type replicasetFuture struct {
id uuid.UUID
future *tarantool.Future
}

// RouterMapCallRWImpl perform call function on all masters in the cluster
// with a guarantee that in case of success it was executed with all
// buckets being accessible for reads and writes.
Expand All @@ -248,221 +227,144 @@ func (r *Router) RouterMapCallRWImpl(
args interface{},
opts CallOpts,
) (map[uuid.UUID]interface{}, error) {
if opts.Timeout == 0 {
opts.Timeout = CallTimeoutMin
const vshardStorageServiceCall = "vshard.storage._call"

timeout := CallTimeoutMin
if opts.Timeout > 0 {
timeout = opts.Timeout
}

timeout := opts.Timeout
timeStart := time.Now()

refID := r.refID.Add(1)

r.idToReplicasetMutex.RLock()
idToReplicasetRef := r.idToReplicaset
r.idToReplicasetMutex.RUnlock()

defer r.callStorageUnref(idToReplicasetRef, refID)

mapCallCtx, cancel := context.WithTimeout(ctx, timeout)

req := tarantool.NewCallRequest("vshard.storage._call")
req = req.Context(ctx)

// ref stage: send

req = req.Args([]interface{}{
"storage_ref",
refID,
timeout,
})

g, gctx := errgroup.WithContext(mapCallCtx)
rsFutures := make(chan replicasetFuture)
defer func() {
// call function "storage_unref" if map_callrw is failed or successed
storageUnrefReq := tarantool.NewCallRequest(vshardStorageServiceCall)
storageUnrefReq = storageUnrefReq.Args([]interface{}{"storage_unref", refID})

g.Go(func() error {
defer close(rsFutures)

for id, replicaset := range idToReplicasetRef {
conn := replicaset.conn

future := conn.Do(req, pool.RW)
if _, err := future.Get(); err != nil {
cancel()

return fmt.Errorf("rs {%s} storage_ref err: %s", id.String(), err.Error())
}

select {
case <-gctx.Done():
return gctx.Err()
case rsFutures <- replicasetFuture{
id: id,
future: future,
}:
}
for _, rs := range idToReplicasetRef {
future := rs.conn.Do(storageUnrefReq, pool.RW)
future.SetError(nil)
}
}()

return nil
})

// ref stage collect

totalBucketCount := int32(0)

for i := 0; i < int(r.nWorkers); i++ {
g.Go(func() error {
for rsFuture := range rsFutures {
future := rsFuture.future
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

respData, err := future.Get()
if err != nil {
cancel()
// ref stage

return err
}
storageRefReq := tarantool.NewCallRequest(vshardStorageServiceCall)
storageRefReq = storageRefReq.Context(ctx)
storageRefReq = storageRefReq.Args([]interface{}{"storage_ref", refID, timeout})

if respData[0] == nil {
vshardErr := &StorageCallAssertError{}

err = mapstructure.Decode(respData[1], vshardErr)
if err != nil {
cancel()

return err
}

cancel()

return vshardErr
}

var bucketCount int32
err = future.GetTyped(&[]interface{}{&bucketCount})
if err != nil {
cancel()

return err
}

atomic.AddInt32(&totalBucketCount, bucketCount)
}

return nil
})
type replicasetFuture struct {
uuid uuid.UUID
future *tarantool.Future
}

if err := g.Wait(); err != nil {
return nil, err
}
var rsFutures = make([]replicasetFuture, 0, len(idToReplicasetRef))

if totalBucketCount != int32(r.cfg.TotalBucketCount) {
return nil, fmt.Errorf("unknown bucket counts %d", totalBucketCount)
// ref stage: send concurrent ref requests
for uuid, rs := range idToReplicasetRef {
rsFutures = append(rsFutures, replicasetFuture{
uuid: uuid,
future: rs.conn.Do(storageRefReq, pool.RW),
})
}

// map stage: send

g, gctx = errgroup.WithContext(mapCallCtx)
rsFutures = make(chan replicasetFuture)
req = req.Args([]interface{}{"storage_map", refID, fnc, args})

g.Go(func() error {
defer close(rsFutures)

for id, replicaset := range idToReplicasetRef {
conn := replicaset.conn
// ref stage: get their responses
var totalBucketCount uint64
for _, rsFuture := range rsFutures {
respData, err := rsFuture.future.Get()
if err != nil {
return nil, fmt.Errorf("rs {%s} storage_ref err: %v", rsFuture.uuid, err)
}

future := conn.Do(req, pool.RW)
if _, err := future.Get(); err != nil {
cancel()
if respData[0] == nil {
vshardErr := &StorageCallAssertError{}

return fmt.Errorf("rs {%s} storage_map err: %s", id.String(), err.Error())
err = mapstructure.Decode(respData[1], vshardErr)
if err != nil {
return nil, err
}

select {
case <-gctx.Done():
return gctx.Err()
case rsFutures <- replicasetFuture{
id: id,
future: future,
}:
}
return nil, vshardErr
}

return nil
})

// map stage: collect

var idToResultMutex sync.Mutex
idToResult := make(map[uuid.UUID]interface{})

for i := 0; i < int(r.nWorkers); i++ {
g.Go(func() error {
for rsFuture := range rsFutures {
future := rsFuture.future

respData, err := future.Get()
if err != nil {
cancel()

return err
}

if len(respData) != 2 {
err = fmt.Errorf("invalid length of response data: must be = 2, current: %d", len(respData))
cancel()
var bucketCount uint64
err = rsFuture.future.GetTyped(&[]interface{}{&bucketCount})
if err != nil {
return nil, err
}

return err
}
totalBucketCount += bucketCount
}

if respData[0] == nil {
vshardErr := &StorageCallAssertError{}
if totalBucketCount != r.cfg.TotalBucketCount {
return nil, fmt.Errorf("total bucket count got %d, expected %d", totalBucketCount, r.cfg.TotalBucketCount)
}

err = mapstructure.Decode(respData[1], vshardErr)
if err != nil {
cancel()
// map stage

return err
}
storageMapReq := tarantool.NewCallRequest(vshardStorageServiceCall)
storageMapReq = storageMapReq.Context(ctx)
storageMapReq = storageMapReq.Args([]interface{}{"storage_map", refID, fnc, args})

cancel()
rsFutures = rsFutures[0:0]

return vshardErr
}
// map stage: send concurrent map requests
for uuid, rs := range idToReplicasetRef {
rsFutures = append(rsFutures, replicasetFuture{
uuid: uuid,
future: rs.conn.Do(storageMapReq, pool.RW),
})
}

isVShardRespOk := false
// map stage: get their responses
idToResult := make(map[uuid.UUID]interface{})
for _, rsFuture := range rsFutures {
respData, err := rsFuture.future.Get()
if err != nil {
return nil, fmt.Errorf("rs {%s} storage_map err: %v", rsFuture.uuid, err)
}

err = future.GetTyped(&[]interface{}{&isVShardRespOk})
if err != nil {
cancel()
if len(respData) != 2 {
return nil, fmt.Errorf("invalid length of response data: must be = 2, current: %d", len(respData))
}

return err
}
if respData[0] == nil {
vshardErr := &StorageCallAssertError{}

if !isVShardRespOk { // error
errorResp := &StorageCallAssertError{}
err = mapstructure.Decode(respData[1], vshardErr)
if err != nil {
return nil, err
}

err = future.GetTyped(&[]interface{}{&isVShardRespOk, errorResp})
if err != nil {
err = fmt.Errorf("cant get typed vshard err with err: %s", err)
}
return nil, vshardErr
}

cancel()
var isVShardRespOk bool
err = rsFuture.future.GetTyped(&[]interface{}{&isVShardRespOk})
if err != nil {
return nil, err
}

return err
}
if !isVShardRespOk { // error
errorResp := &StorageCallAssertError{}

idToResultMutex.Lock()
idToResult[rsFuture.id] = respData[1]
idToResultMutex.Unlock()
err = rsFuture.future.GetTyped(&[]interface{}{&isVShardRespOk, errorResp})
if err != nil {
return nil, fmt.Errorf("cant get typed vshard err with err: %v", err)
}

return nil
})
}
return nil, errorResp
}

if err := g.Wait(); err != nil {
return nil, err
idToResult[rsFuture.uuid] = respData[1]
}

r.metrics().RequestDuration(time.Since(timeStart), true, true)
Expand Down
10 changes: 0 additions & 10 deletions vshard.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ type Router struct {
// and therefore is global and monotonically growing.
refID atomic.Int64

// worker's count to proceed channel of replicaset's futures
nWorkers int32

cancelDiscovery func()
}

Expand Down Expand Up @@ -156,13 +153,6 @@ func NewRouter(ctx context.Context, cfg Config) (*Router, error) {
router.cancelDiscovery = cancelFunc
}

nWorkers := int32(2)
if cfg.NWorkers > 0 {
nWorkers = cfg.NWorkers
}

router.nWorkers = nWorkers

return router, err
}

Expand Down

0 comments on commit 863742d

Please sign in to comment.