Skip to content

Commit 2768bfb

Browse files
authored
Fix data race-related test failures (PR #171)
1 parent 885fd25 commit 2768bfb

File tree

6 files changed

+48
-28
lines changed

6 files changed

+48
-28
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ jobs:
77
go-version: [1.x.x]
88
platform: [ubuntu-latest]
99
runs-on: ${{ matrix.platform }}
10+
name: Go ${{ matrix.go-version }} (${{ matrix.platform }})
1011
steps:
12+
- name: Checkout code
13+
uses: actions/checkout@v2
1114
- name: Install Go
12-
uses: actions/setup-go@v1
15+
uses: actions/setup-go@v2
1316
with:
1417
go-version: ${{ matrix.go-version }}
15-
- name: Checkout code
16-
uses: actions/checkout@v1
1718
- name: Test
18-
run: go test --timeout 360s -v ./...
19+
run: go test -race -timeout 360s -v ./...

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@ _testmain.go
2424
*.test
2525
*.prof
2626

27-
# ignore vendor - they are only needed for tests.
27+
# ignore vendor - they are only needed for tests.
2828
vendor/
2929

3030
# ignore bazel things that are local only
3131
bazel-*
32+
33+
# Visual Studio Code
34+
.vscode/*

client.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"strconv"
99
"strings"
1010
"sync"
11+
"sync/atomic"
1112
"time"
1213

1314
"github.com/lytics/grid/v3/codec"
@@ -426,7 +427,7 @@ func (c *Client) broadcast(ctx context.Context, cancel context.CancelFunc, g *Gr
426427
receivers := g.Members()
427428

428429
var broadcastErr error
429-
successes := 0
430+
var successes int32
430431
mu := new(sync.Mutex)
431432
wg := new(sync.WaitGroup)
432433
for _, rec := range receivers {
@@ -442,7 +443,7 @@ func (c *Client) broadcast(ctx context.Context, cancel context.CancelFunc, g *Gr
442443
// if this request was successful and the group is configured to Fastest,
443444
// then cancel the context so other requests are terminated
444445
cancel()
445-
successes++
446+
atomic.AddInt32(&successes, 1)
446447
}
447448

448449
mu.Lock()
@@ -457,7 +458,7 @@ func (c *Client) broadcast(ctx context.Context, cancel context.CancelFunc, g *Gr
457458

458459
// if the group is configured to Fastest, and we had at least one successful
459460
// request, then don't return an error
460-
if g.fastest && broadcastErr != nil && successes > 0 {
461+
if g.fastest && broadcastErr != nil && atomic.LoadInt32(&successes) > 0 {
461462
broadcastErr = nil
462463
}
463464
return res, broadcastErr

context_test.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,32 @@ package grid
33
import (
44
"context"
55
"net"
6+
"sync"
67
"testing"
78
"time"
89

910
"github.com/lytics/grid/v3/testetcd"
1011
)
1112

1213
type contextActor struct {
14+
mu sync.RWMutex
1315
started chan bool
1416
ctx context.Context
1517
}
1618

1719
func (a *contextActor) Act(c context.Context) {
20+
a.mu.Lock()
1821
a.ctx = c
22+
a.mu.Unlock()
1923
a.started <- true
2024
}
2125

26+
func (a *contextActor) Context() context.Context {
27+
a.mu.RLock()
28+
defer a.mu.RUnlock()
29+
return a.ctx
30+
}
31+
2232
func TestContextError(t *testing.T) {
2333
// Create a context that is not valid to use
2434
// with the grid context methods. The context
@@ -88,23 +98,23 @@ func TestValidContext(t *testing.T) {
8898
case <-a.started:
8999
server.Stop()
90100

91-
id, err := ContextActorID(a.ctx)
101+
id, err := ContextActorID(a.Context())
92102
if err != nil {
93103
t.Fatal(err)
94104
}
95105
if id == "" {
96106
t.Fatal("expected non-zero value")
97107
}
98108

99-
name, err := ContextActorName(a.ctx)
109+
name, err := ContextActorName(a.Context())
100110
if err != nil {
101111
t.Fatal(err)
102112
}
103113
if name == "" {
104114
t.Fatal("expected non-zero value")
105115
}
106116

107-
namespace, err := ContextActorNamespace(a.ctx)
117+
namespace, err := ContextActorNamespace(a.Context())
108118
if err != nil {
109119
t.Fatal(err)
110120
}

registry/registry.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (we *WatchEvent) String() string {
9292

9393
// Registry for discovery.
9494
type Registry struct {
95-
mu sync.Mutex
95+
mu sync.RWMutex
9696
done chan bool
9797
exited chan bool
9898
kv etcdv3.KV
@@ -202,6 +202,7 @@ func (rr *Registry) Start(addr net.Addr) error {
202202

203203
// Ensure that we're the owner of the address by taking an etcd lock
204204
tctx, cancel := context.WithTimeout(context.TODO(), rr.LeaseDuration*2) // retry until Lease is up...
205+
defer cancel()
205206
err = rr.waitForAddress(tctx, address)
206207
if err != nil {
207208
return err
@@ -263,17 +264,24 @@ func (rr *Registry) Start(addr net.Addr) error {
263264

264265
// Address of this registry in the format of <ip>:<port>
265266
func (rr *Registry) Address() string {
267+
rr.mu.RLock()
268+
defer rr.mu.RUnlock()
266269
return rr.address
267270
}
268271

269272
// Registry name, which is a human readable all ASCII
270273
// transformation of the network address.
271274
func (rr *Registry) Registry() string {
275+
rr.mu.RLock()
276+
defer rr.mu.RUnlock()
272277
return rr.name
273278
}
274279

275280
// Stop Registry.
276281
func (rr *Registry) Stop() error {
282+
rr.mu.Lock()
283+
defer rr.mu.Unlock()
284+
277285
if rr.leaseID < 0 {
278286
return nil
279287
}
@@ -296,8 +304,8 @@ func (rr *Registry) Stop() error {
296304

297305
// Watch a prefix in the registry.
298306
func (rr *Registry) Watch(c context.Context, prefix string) ([]*Registration, <-chan *WatchEvent, error) {
299-
rr.mu.Lock()
300-
defer rr.mu.Unlock()
307+
rr.mu.RLock()
308+
defer rr.mu.RUnlock()
301309

302310
getRes, err := rr.kv.Get(c, prefix, etcdv3.WithPrefix())
303311
if err != nil {
@@ -390,8 +398,8 @@ func (rr *Registry) Watch(c context.Context, prefix string) ([]*Registration, <-
390398

391399
// FindRegistrations associated with the prefix.
392400
func (rr *Registry) FindRegistrations(c context.Context, prefix string) ([]*Registration, error) {
393-
rr.mu.Lock()
394-
defer rr.mu.Unlock()
401+
rr.mu.RLock()
402+
defer rr.mu.RUnlock()
395403

396404
getRes, err := rr.kv.Get(c, prefix, etcdv3.WithPrefix())
397405
if err != nil {
@@ -411,8 +419,8 @@ func (rr *Registry) FindRegistrations(c context.Context, prefix string) ([]*Regi
411419

412420
// FindRegistration associated with the given key.
413421
func (rr *Registry) FindRegistration(c context.Context, key string) (*Registration, error) {
414-
rr.mu.Lock()
415-
defer rr.mu.Unlock()
422+
rr.mu.RLock()
423+
defer rr.mu.RUnlock()
416424

417425
getRes, err := rr.kv.Get(c, key, etcdv3.WithLimit(1))
418426
if err != nil {
@@ -434,8 +442,8 @@ func (rr *Registry) FindRegistration(c context.Context, key string) (*Registrati
434442
// Hence, registration can be used for mutual-exclusion.
435443
func (rr *Registry) Register(c context.Context, key string, annotations ...string) error {
436444
sort.Strings(annotations)
437-
rr.mu.Lock()
438-
defer rr.mu.Unlock()
445+
rr.mu.RLock()
446+
defer rr.mu.RUnlock()
439447

440448
if rr.leaseID < 0 {
441449
return ErrNotStarted
@@ -475,8 +483,8 @@ func (rr *Registry) Register(c context.Context, key string, annotations ...strin
475483

476484
// Deregister under the given key.
477485
func (rr *Registry) Deregister(c context.Context, key string) error {
478-
rr.mu.Lock()
479-
defer rr.mu.Unlock()
486+
rr.mu.RLock()
487+
defer rr.mu.RUnlock()
480488

481489
if rr.leaseID < 0 {
482490
return ErrNotStarted

registry/registry_test.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,18 @@ func TestWaitForLeaseThatDoesExpires(t *testing.T) {
156156
}
157157
r1.LeaseDuration = 10 * time.Second
158158

159-
_, err = kv.Put(context.Background(), registryLockKey(address), "")
160-
if err != nil {
159+
if _, err := kv.Put(context.Background(), registryLockKey(address), ""); err != nil {
161160
t.Fatal(err)
162161
}
163162
time.AfterFunc(5*time.Second, func() {
164163
// cleanup lock so that the registry can startup.
165-
_, err = kv.Delete(context.Background(), registryLockKey(address))
166-
if err != nil {
164+
if _, err := kv.Delete(context.Background(), registryLockKey(address)); err != nil {
167165
t.Fatal(err)
168166
}
169167
})
170168

171169
st := time.Now()
172-
err = r1.Start(addr)
173-
if err != nil {
170+
if err := r1.Start(addr); err != nil {
174171
t.Fatalf("unexpected error: err: %v", err)
175172
}
176173
// ensure that we waited 10 seconds...

0 commit comments

Comments
 (0)