Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix All Race Conditions Detected by Tests #8

Merged
merged 3 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions timedmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package timedmap
import (
"reflect"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -19,7 +20,7 @@ type TimedMap struct {
cleanupTickTime time.Duration
cleanerTicker *time.Ticker
cleanerStopChan chan bool
cleanerRunning bool
cleanerRunning *uint32
}

type keyWrap struct {
Expand Down Expand Up @@ -121,6 +122,8 @@ func (tm *TimedMap) GetValue(key interface{}) interface{} {
if v == nil {
return nil
}
tm.mtx.RLock()
defer tm.mtx.RUnlock()
return v.value
}

Expand Down Expand Up @@ -190,7 +193,7 @@ func (tm *TimedMap) Size() int {
// If the cleanup loop is already running, it will be
// stopped and restarted using the new specification.
func (tm *TimedMap) StartCleanerInternal(interval time.Duration) {
if tm.cleanerRunning {
if atomic.LoadUint32(tm.cleanerRunning) != 0 {
tm.StopCleaner()
}
tm.cleanerTicker = time.NewTicker(interval)
Expand All @@ -205,7 +208,7 @@ func (tm *TimedMap) StartCleanerInternal(interval time.Duration) {
// If the cleanup loop is already running, it will be
// stopped and restarted using the new specification.
func (tm *TimedMap) StartCleanerExternal(initiator <-chan time.Time) {
if tm.cleanerRunning {
if atomic.LoadUint32(tm.cleanerRunning) != 0 {
tm.StopCleaner()
}
go tm.cleanupLoop(initiator)
Expand All @@ -216,7 +219,7 @@ func (tm *TimedMap) StartCleanerExternal(initiator <-chan time.Time) {
// where TimedMap is used that the data can be cleaned
// up correctly.
func (tm *TimedMap) StopCleaner() {
if !tm.cleanerRunning {
if atomic.LoadUint32(tm.cleanerRunning) == 0 {
return
}
tm.cleanerStopChan <- true
Expand All @@ -234,9 +237,9 @@ func (tm *TimedMap) Snapshot() map[interface{}]interface{} {
// cleanupLoop holds the loop executing the cleanup
// when initiated by tc.
func (tm *TimedMap) cleanupLoop(tc <-chan time.Time) {
tm.cleanerRunning = true
atomic.StoreUint32(tm.cleanerRunning, 1)
defer func() {
tm.cleanerRunning = false
atomic.StoreUint32(tm.cleanerRunning, 0)
}()

for {
Expand Down Expand Up @@ -285,6 +288,8 @@ func (tm *TimedMap) cleanUp() {
func (tm *TimedMap) set(key interface{}, sec int, val interface{}, expiresAfter time.Duration, cb ...callback) {
// re-use element when existent on this key
if v := tm.getRaw(key, sec); v != nil {
tm.mtx.Lock()
defer tm.mtx.Unlock()
v.value = val
v.expires = time.Now().Add(expiresAfter)
v.cbs = cb
Expand Down Expand Up @@ -315,9 +320,10 @@ func (tm *TimedMap) get(key interface{}, sec int) *element {
return nil
}

tm.mtx.Lock()
defer tm.mtx.Unlock()

if time.Now().After(v.expires) {
tm.mtx.Lock()
defer tm.mtx.Unlock()
tm.expireElement(key, sec, v)
return nil
}
Expand Down Expand Up @@ -371,7 +377,9 @@ func (tm *TimedMap) refresh(key interface{}, sec int, d time.Duration) error {
if v == nil {
return ErrKeyNotFound
}
tm.mtx.Lock()
v.expires = v.expires.Add(d)
tm.mtx.Unlock()
return nil
}

Expand All @@ -382,7 +390,9 @@ func (tm *TimedMap) setExpires(key interface{}, sec int, d time.Duration) error
if v == nil {
return ErrKeyNotFound
}
tm.mtx.Lock()
v.expires = time.Now().Add(d)
tm.mtx.Unlock()
return nil
}

Expand All @@ -408,6 +418,7 @@ func newTimedMap(
) *TimedMap {
tm := &TimedMap{
container: container,
cleanerRunning: new(uint32),
cleanerStopChan: make(chan bool),
elementPool: &sync.Pool{
New: func() interface{} {
Expand Down
15 changes: 8 additions & 7 deletions timedmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package timedmap

import (
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -19,7 +20,7 @@ func TestNew(t *testing.T) {
assert.NotNil(t, tm)
assert.EqualValues(t, 0, len(tm.container))
time.Sleep(10 * time.Millisecond)
assert.True(t, tm.cleanerRunning)
assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0)
}

func TestFromMap(t *testing.T) {
Expand Down Expand Up @@ -246,7 +247,7 @@ func TestStopCleaner(t *testing.T) {
time.Sleep(10 * time.Millisecond)
tm.StopCleaner()
time.Sleep(10 * time.Millisecond)
assert.False(t, tm.cleanerRunning)
assert.False(t, atomic.LoadUint32(tm.cleanerRunning) != 0)

assert.NotPanics(t, func() {
tm.StopCleaner()
Expand All @@ -259,7 +260,7 @@ func TestStartCleanerInternal(t *testing.T) {
tm := New(0)
time.Sleep(10 * time.Millisecond)

assert.False(t, tm.cleanerRunning)
assert.False(t, atomic.LoadUint32(tm.cleanerRunning) != 0)

// Ensure cleanup timer is not running
tm.set(1, 0, 1, 0)
Expand All @@ -268,7 +269,7 @@ func TestStartCleanerInternal(t *testing.T) {

tm.StartCleanerInternal(dCleanupTick)
time.Sleep(10 * time.Millisecond)
assert.True(t, tm.cleanerRunning)
assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0)

// Ensure cleanup timer is running
tm.set(1, 0, 1, 0)
Expand All @@ -294,7 +295,7 @@ func TestStartCleanerExternal(t *testing.T) {
tm := New(0)
time.Sleep(10 * time.Millisecond)

assert.False(t, tm.cleanerRunning)
assert.False(t, atomic.LoadUint32(tm.cleanerRunning) != 0)

// Ensure cleanup timer is not running
tm.set(1, 0, 1, 0)
Expand All @@ -305,7 +306,7 @@ func TestStartCleanerExternal(t *testing.T) {

tm.StartCleanerExternal(c)
time.Sleep(10 * time.Millisecond)
assert.True(t, tm.cleanerRunning)
assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0)

// Ensure cleanup is controlled by c
tm.set(1, 0, 1, 0)
Expand All @@ -323,7 +324,7 @@ func TestStartCleanerExternal(t *testing.T) {
tm := New(dCleanupTick)
time.Sleep(10 * time.Millisecond)

assert.True(t, tm.cleanerRunning)
assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0)
assert.NotNil(t, tm.cleanerTicker)

c := make(chan time.Time)
Expand Down
Loading