Skip to content

Commit

Permalink
Merge pull request #8 from ShivamKumar2002/master
Browse files Browse the repository at this point in the history
Fix All Race Conditions Detected by Tests
  • Loading branch information
zekroTJA authored Nov 29, 2023
2 parents 78c8426 + 422970f commit 5a95c12
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
27 changes: 19 additions & 8 deletions timedmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package timedmap
import (
"reflect"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -19,7 +20,7 @@ type TimedMap struct {
cleanupTickTime time.Duration
cleanerTicker *time.Ticker
cleanerStopChan chan bool
cleanerRunning bool
cleanerRunning *uint32
}

type keyWrap struct {
Expand Down Expand Up @@ -121,6 +122,8 @@ func (tm *TimedMap) GetValue(key interface{}) interface{} {
if v == nil {
return nil
}
tm.mtx.RLock()
defer tm.mtx.RUnlock()
return v.value
}

Expand Down Expand Up @@ -190,7 +193,7 @@ func (tm *TimedMap) Size() int {
// If the cleanup loop is already running, it will be
// stopped and restarted using the new specification.
func (tm *TimedMap) StartCleanerInternal(interval time.Duration) {
if tm.cleanerRunning {
if atomic.LoadUint32(tm.cleanerRunning) != 0 {
tm.StopCleaner()
}
tm.cleanerTicker = time.NewTicker(interval)
Expand All @@ -205,7 +208,7 @@ func (tm *TimedMap) StartCleanerInternal(interval time.Duration) {
// If the cleanup loop is already running, it will be
// stopped and restarted using the new specification.
func (tm *TimedMap) StartCleanerExternal(initiator <-chan time.Time) {
if tm.cleanerRunning {
if atomic.LoadUint32(tm.cleanerRunning) != 0 {
tm.StopCleaner()
}
go tm.cleanupLoop(initiator)
Expand All @@ -216,7 +219,7 @@ func (tm *TimedMap) StartCleanerExternal(initiator <-chan time.Time) {
// where TimedMap is used that the data can be cleaned
// up correctly.
func (tm *TimedMap) StopCleaner() {
if !tm.cleanerRunning {
if atomic.LoadUint32(tm.cleanerRunning) == 0 {
return
}
tm.cleanerStopChan <- true
Expand All @@ -234,9 +237,9 @@ func (tm *TimedMap) Snapshot() map[interface{}]interface{} {
// cleanupLoop holds the loop executing the cleanup
// when initiated by tc.
func (tm *TimedMap) cleanupLoop(tc <-chan time.Time) {
tm.cleanerRunning = true
atomic.StoreUint32(tm.cleanerRunning, 1)
defer func() {
tm.cleanerRunning = false
atomic.StoreUint32(tm.cleanerRunning, 0)
}()

for {
Expand Down Expand Up @@ -285,6 +288,8 @@ func (tm *TimedMap) cleanUp() {
func (tm *TimedMap) set(key interface{}, sec int, val interface{}, expiresAfter time.Duration, cb ...callback) {
// re-use element when existent on this key
if v := tm.getRaw(key, sec); v != nil {
tm.mtx.Lock()
defer tm.mtx.Unlock()
v.value = val
v.expires = time.Now().Add(expiresAfter)
v.cbs = cb
Expand Down Expand Up @@ -315,9 +320,10 @@ func (tm *TimedMap) get(key interface{}, sec int) *element {
return nil
}

tm.mtx.Lock()
defer tm.mtx.Unlock()

if time.Now().After(v.expires) {
tm.mtx.Lock()
defer tm.mtx.Unlock()
tm.expireElement(key, sec, v)
return nil
}
Expand Down Expand Up @@ -371,7 +377,9 @@ func (tm *TimedMap) refresh(key interface{}, sec int, d time.Duration) error {
if v == nil {
return ErrKeyNotFound
}
tm.mtx.Lock()
v.expires = v.expires.Add(d)
tm.mtx.Unlock()
return nil
}

Expand All @@ -382,7 +390,9 @@ func (tm *TimedMap) setExpires(key interface{}, sec int, d time.Duration) error
if v == nil {
return ErrKeyNotFound
}
tm.mtx.Lock()
v.expires = time.Now().Add(d)
tm.mtx.Unlock()
return nil
}

Expand All @@ -408,6 +418,7 @@ func newTimedMap(
) *TimedMap {
tm := &TimedMap{
container: container,
cleanerRunning: new(uint32),
cleanerStopChan: make(chan bool),
elementPool: &sync.Pool{
New: func() interface{} {
Expand Down
15 changes: 8 additions & 7 deletions timedmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package timedmap

import (
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -19,7 +20,7 @@ func TestNew(t *testing.T) {
assert.NotNil(t, tm)
assert.EqualValues(t, 0, len(tm.container))
time.Sleep(10 * time.Millisecond)
assert.True(t, tm.cleanerRunning)
assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0)
}

func TestFromMap(t *testing.T) {
Expand Down Expand Up @@ -246,7 +247,7 @@ func TestStopCleaner(t *testing.T) {
time.Sleep(10 * time.Millisecond)
tm.StopCleaner()
time.Sleep(10 * time.Millisecond)
assert.False(t, tm.cleanerRunning)
assert.False(t, atomic.LoadUint32(tm.cleanerRunning) != 0)

assert.NotPanics(t, func() {
tm.StopCleaner()
Expand All @@ -259,7 +260,7 @@ func TestStartCleanerInternal(t *testing.T) {
tm := New(0)
time.Sleep(10 * time.Millisecond)

assert.False(t, tm.cleanerRunning)
assert.False(t, atomic.LoadUint32(tm.cleanerRunning) != 0)

// Ensure cleanup timer is not running
tm.set(1, 0, 1, 0)
Expand All @@ -268,7 +269,7 @@ func TestStartCleanerInternal(t *testing.T) {

tm.StartCleanerInternal(dCleanupTick)
time.Sleep(10 * time.Millisecond)
assert.True(t, tm.cleanerRunning)
assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0)

// Ensure cleanup timer is running
tm.set(1, 0, 1, 0)
Expand All @@ -294,7 +295,7 @@ func TestStartCleanerExternal(t *testing.T) {
tm := New(0)
time.Sleep(10 * time.Millisecond)

assert.False(t, tm.cleanerRunning)
assert.False(t, atomic.LoadUint32(tm.cleanerRunning) != 0)

// Ensure cleanup timer is not running
tm.set(1, 0, 1, 0)
Expand All @@ -305,7 +306,7 @@ func TestStartCleanerExternal(t *testing.T) {

tm.StartCleanerExternal(c)
time.Sleep(10 * time.Millisecond)
assert.True(t, tm.cleanerRunning)
assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0)

// Ensure cleanup is controlled by c
tm.set(1, 0, 1, 0)
Expand All @@ -323,7 +324,7 @@ func TestStartCleanerExternal(t *testing.T) {
tm := New(dCleanupTick)
time.Sleep(10 * time.Millisecond)

assert.True(t, tm.cleanerRunning)
assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0)
assert.NotNil(t, tm.cleanerTicker)

c := make(chan time.Time)
Expand Down

0 comments on commit 5a95c12

Please sign in to comment.