From cc5160fbaa4bd9b63c85b1888ff17721ed19a1de Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Mon, 15 Jul 2024 15:41:10 +0800 Subject: [PATCH 1/2] use context to manage runner Signed-off-by: Ryan Leung --- pkg/mcs/scheduling/server/cluster.go | 6 +++--- pkg/ratelimit/runner.go | 24 ++++++++++++++++-------- pkg/ratelimit/runner_test.go | 7 ++++--- server/cluster/cluster.go | 6 +++--- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index b18db7c0798..58ec84157f0 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -99,9 +99,9 @@ func NewCluster(parentCtx context.Context, persistConfig *config.PersistConfig, clusterID: clusterID, checkMembershipCh: checkMembershipCh, - heartbeatRunner: ratelimit.NewConcurrentRunner(heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - miscRunner: ratelimit.NewConcurrentRunner(miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - logRunner: ratelimit.NewConcurrentRunner(logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + heartbeatRunner: ratelimit.NewConcurrentRunner(ctx, heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + miscRunner: ratelimit.NewConcurrentRunner(ctx, miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + logRunner: ratelimit.NewConcurrentRunner(ctx, logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), } c.coordinator = schedule.NewCoordinator(ctx, c, hbStreams) err = c.ruleManager.Initialize(persistConfig.GetMaxReplicas(), persistConfig.GetLocationLabels(), persistConfig.GetIsolationLevel()) diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go index 2d88e36106e..1ac7ae899af 100644 --- a/pkg/ratelimit/runner.go +++ b/pkg/ratelimit/runner.go @@ -66,12 +66,13 @@ type taskID struct { } type ConcurrentRunner struct { + ctx context.Context + cancel context.CancelFunc name string limiter *ConcurrencyLimiter maxPendingDuration time.Duration taskChan chan *Task pendingMu sync.Mutex - stopChan chan struct{} wg sync.WaitGroup pendingTaskCount map[string]int pendingTasks []*Task @@ -80,8 +81,11 @@ type ConcurrentRunner struct { } // NewConcurrentRunner creates a new ConcurrentRunner. -func NewConcurrentRunner(name string, limiter *ConcurrencyLimiter, maxPendingDuration time.Duration) *ConcurrentRunner { +func NewConcurrentRunner(ctx context.Context, name string, limiter *ConcurrencyLimiter, maxPendingDuration time.Duration) *ConcurrentRunner { + ctx, cancel := context.WithCancel(ctx) s := &ConcurrentRunner{ + ctx: ctx, + cancel: cancel, name: name, limiter: limiter, maxPendingDuration: maxPendingDuration, @@ -104,7 +108,6 @@ func WithRetained(retained bool) TaskOption { // Start starts the runner. func (cr *ConcurrentRunner) Start() { - cr.stopChan = make(chan struct{}) cr.wg.Add(1) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -118,11 +121,11 @@ func (cr *ConcurrentRunner) Start() { if err != nil { continue } - go cr.run(task, token) + go cr.run(cr.ctx, task, token) } else { - go cr.run(task, nil) + go cr.run(cr.ctx, task, nil) } - case <-cr.stopChan: + case <-cr.ctx.Done(): cr.pendingMu.Lock() cr.pendingTasks = make([]*Task, 0, initialCapacity) cr.pendingMu.Unlock() @@ -144,8 +147,13 @@ func (cr *ConcurrentRunner) Start() { }() } -func (cr *ConcurrentRunner) run(task *Task, token *TaskToken) { +func (cr *ConcurrentRunner) run(ctx context.Context, task *Task, token *TaskToken) { start := time.Now() + select { + case <-ctx.Done(): + return + default: + } task.f() if token != nil { cr.limiter.ReleaseToken(token) @@ -173,7 +181,7 @@ func (cr *ConcurrentRunner) processPendingTasks() { // Stop stops the runner. func (cr *ConcurrentRunner) Stop() { - close(cr.stopChan) + cr.cancel() cr.wg.Wait() } diff --git a/pkg/ratelimit/runner_test.go b/pkg/ratelimit/runner_test.go index 0335a78bcbe..a3eac7f238e 100644 --- a/pkg/ratelimit/runner_test.go +++ b/pkg/ratelimit/runner_test.go @@ -15,6 +15,7 @@ package ratelimit import ( + "context" "sync" "testing" "time" @@ -24,7 +25,7 @@ import ( func TestConcurrentRunner(t *testing.T) { t.Run("RunTask", func(t *testing.T) { - runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), time.Second) + runner := NewConcurrentRunner(context.TODO(), "test", NewConcurrencyLimiter(1), time.Second) runner.Start() defer runner.Stop() @@ -46,7 +47,7 @@ func TestConcurrentRunner(t *testing.T) { }) t.Run("MaxPendingDuration", func(t *testing.T) { - runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), 2*time.Millisecond) + runner := NewConcurrentRunner(context.TODO(), "test", NewConcurrencyLimiter(1), 2*time.Millisecond) runner.Start() defer runner.Stop() var wg sync.WaitGroup @@ -75,7 +76,7 @@ func TestConcurrentRunner(t *testing.T) { }) t.Run("DuplicatedTask", func(t *testing.T) { - runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), time.Minute) + runner := NewConcurrentRunner(context.TODO(), "test", NewConcurrencyLimiter(1), time.Minute) runner.Start() defer runner.Stop() for i := 1; i < 11; i++ { diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 93be9d1c076..812cbb437f0 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -204,9 +204,9 @@ func NewRaftCluster(ctx context.Context, clusterID uint64, basicCluster *core.Ba etcdClient: etcdClient, BasicCluster: basicCluster, storage: storage, - heartbeatRunner: ratelimit.NewConcurrentRunner(heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - miscRunner: ratelimit.NewConcurrentRunner(miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - logRunner: ratelimit.NewConcurrentRunner(logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + heartbeatRunner: ratelimit.NewConcurrentRunner(ctx, heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + miscRunner: ratelimit.NewConcurrentRunner(ctx, miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + logRunner: ratelimit.NewConcurrentRunner(ctx, logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), } } From 5941965e3ffcf694f395671217284e2f2a17730a Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 16 Jul 2024 15:42:51 +0800 Subject: [PATCH 2/2] fix Signed-off-by: Ryan Leung --- pkg/mcs/scheduling/server/cluster.go | 12 ++++++------ pkg/ratelimit/runner.go | 12 +++++------- pkg/ratelimit/runner_test.go | 12 ++++++------ server/cluster/cluster.go | 12 ++++++------ 4 files changed, 23 insertions(+), 25 deletions(-) diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index 58ec84157f0..24a75012331 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -99,9 +99,9 @@ func NewCluster(parentCtx context.Context, persistConfig *config.PersistConfig, clusterID: clusterID, checkMembershipCh: checkMembershipCh, - heartbeatRunner: ratelimit.NewConcurrentRunner(ctx, heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - miscRunner: ratelimit.NewConcurrentRunner(ctx, miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - logRunner: ratelimit.NewConcurrentRunner(ctx, logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + heartbeatRunner: ratelimit.NewConcurrentRunner(heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + miscRunner: ratelimit.NewConcurrentRunner(miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + logRunner: ratelimit.NewConcurrentRunner(logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), } c.coordinator = schedule.NewCoordinator(ctx, c, hbStreams) err = c.ruleManager.Initialize(persistConfig.GetMaxReplicas(), persistConfig.GetLocationLabels(), persistConfig.GetIsolationLevel()) @@ -549,9 +549,9 @@ func (c *Cluster) StartBackgroundJobs() { go c.runUpdateStoreStats() go c.runCoordinator() go c.runMetricsCollectionJob() - c.heartbeatRunner.Start() - c.miscRunner.Start() - c.logRunner.Start() + c.heartbeatRunner.Start(c.ctx) + c.miscRunner.Start(c.ctx) + c.logRunner.Start(c.ctx) c.running.Store(true) } diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go index 1ac7ae899af..57a19e4e682 100644 --- a/pkg/ratelimit/runner.go +++ b/pkg/ratelimit/runner.go @@ -43,7 +43,7 @@ const ( // Runner is the interface for running tasks. type Runner interface { RunTask(id uint64, name string, f func(), opts ...TaskOption) error - Start() + Start(ctx context.Context) Stop() } @@ -81,11 +81,8 @@ type ConcurrentRunner struct { } // NewConcurrentRunner creates a new ConcurrentRunner. -func NewConcurrentRunner(ctx context.Context, name string, limiter *ConcurrencyLimiter, maxPendingDuration time.Duration) *ConcurrentRunner { - ctx, cancel := context.WithCancel(ctx) +func NewConcurrentRunner(name string, limiter *ConcurrencyLimiter, maxPendingDuration time.Duration) *ConcurrentRunner { s := &ConcurrentRunner{ - ctx: ctx, - cancel: cancel, name: name, limiter: limiter, maxPendingDuration: maxPendingDuration, @@ -107,7 +104,8 @@ func WithRetained(retained bool) TaskOption { } // Start starts the runner. -func (cr *ConcurrentRunner) Start() { +func (cr *ConcurrentRunner) Start(ctx context.Context) { + cr.ctx, cr.cancel = context.WithCancel(ctx) cr.wg.Add(1) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -246,7 +244,7 @@ func (*SyncRunner) RunTask(_ uint64, _ string, f func(), _ ...TaskOption) error } // Start starts the runner. -func (*SyncRunner) Start() {} +func (*SyncRunner) Start(context.Context) {} // Stop stops the runner. func (*SyncRunner) Stop() {} diff --git a/pkg/ratelimit/runner_test.go b/pkg/ratelimit/runner_test.go index a3eac7f238e..d4aa0825e83 100644 --- a/pkg/ratelimit/runner_test.go +++ b/pkg/ratelimit/runner_test.go @@ -25,8 +25,8 @@ import ( func TestConcurrentRunner(t *testing.T) { t.Run("RunTask", func(t *testing.T) { - runner := NewConcurrentRunner(context.TODO(), "test", NewConcurrencyLimiter(1), time.Second) - runner.Start() + runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), time.Second) + runner.Start(context.TODO()) defer runner.Stop() var wg sync.WaitGroup @@ -47,8 +47,8 @@ func TestConcurrentRunner(t *testing.T) { }) t.Run("MaxPendingDuration", func(t *testing.T) { - runner := NewConcurrentRunner(context.TODO(), "test", NewConcurrencyLimiter(1), 2*time.Millisecond) - runner.Start() + runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), 2*time.Millisecond) + runner.Start(context.TODO()) defer runner.Stop() var wg sync.WaitGroup for i := 0; i < 10; i++ { @@ -76,8 +76,8 @@ func TestConcurrentRunner(t *testing.T) { }) t.Run("DuplicatedTask", func(t *testing.T) { - runner := NewConcurrentRunner(context.TODO(), "test", NewConcurrencyLimiter(1), time.Minute) - runner.Start() + runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), time.Minute) + runner.Start(context.TODO()) defer runner.Stop() for i := 1; i < 11; i++ { regionID := uint64(i) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 812cbb437f0..ed1080f617a 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -204,9 +204,9 @@ func NewRaftCluster(ctx context.Context, clusterID uint64, basicCluster *core.Ba etcdClient: etcdClient, BasicCluster: basicCluster, storage: storage, - heartbeatRunner: ratelimit.NewConcurrentRunner(ctx, heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - miscRunner: ratelimit.NewConcurrentRunner(ctx, miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - logRunner: ratelimit.NewConcurrentRunner(ctx, logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + heartbeatRunner: ratelimit.NewConcurrentRunner(heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + miscRunner: ratelimit.NewConcurrentRunner(miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + logRunner: ratelimit.NewConcurrentRunner(logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), } } @@ -364,9 +364,9 @@ func (c *RaftCluster) Start(s Server) error { go c.startGCTuner() c.running = true - c.heartbeatRunner.Start() - c.miscRunner.Start() - c.logRunner.Start() + c.heartbeatRunner.Start(c.ctx) + c.miscRunner.Start(c.ctx) + c.logRunner.Start(c.ctx) return nil }