diff --git a/timedmap.go b/timedmap.go index 5aca089..b74a860 100644 --- a/timedmap.go +++ b/timedmap.go @@ -3,6 +3,7 @@ package timedmap import ( "reflect" "sync" + "sync/atomic" "time" ) @@ -19,7 +20,7 @@ type TimedMap struct { cleanupTickTime time.Duration cleanerTicker *time.Ticker cleanerStopChan chan bool - cleanerRunning bool + cleanerRunning *uint32 } type keyWrap struct { @@ -121,6 +122,8 @@ func (tm *TimedMap) GetValue(key interface{}) interface{} { if v == nil { return nil } + tm.mtx.RLock() + defer tm.mtx.RUnlock() return v.value } @@ -190,7 +193,7 @@ func (tm *TimedMap) Size() int { // If the cleanup loop is already running, it will be // stopped and restarted using the new specification. func (tm *TimedMap) StartCleanerInternal(interval time.Duration) { - if tm.cleanerRunning { + if atomic.LoadUint32(tm.cleanerRunning) != 0 { tm.StopCleaner() } tm.cleanerTicker = time.NewTicker(interval) @@ -205,7 +208,7 @@ func (tm *TimedMap) StartCleanerInternal(interval time.Duration) { // If the cleanup loop is already running, it will be // stopped and restarted using the new specification. func (tm *TimedMap) StartCleanerExternal(initiator <-chan time.Time) { - if tm.cleanerRunning { + if atomic.LoadUint32(tm.cleanerRunning) != 0 { tm.StopCleaner() } go tm.cleanupLoop(initiator) @@ -216,7 +219,7 @@ func (tm *TimedMap) StartCleanerExternal(initiator <-chan time.Time) { // where TimedMap is used that the data can be cleaned // up correctly. func (tm *TimedMap) StopCleaner() { - if !tm.cleanerRunning { + if atomic.LoadUint32(tm.cleanerRunning) == 0 { return } tm.cleanerStopChan <- true @@ -234,9 +237,9 @@ func (tm *TimedMap) Snapshot() map[interface{}]interface{} { // cleanupLoop holds the loop executing the cleanup // when initiated by tc. func (tm *TimedMap) cleanupLoop(tc <-chan time.Time) { - tm.cleanerRunning = true + atomic.StoreUint32(tm.cleanerRunning, 1) defer func() { - tm.cleanerRunning = false + atomic.StoreUint32(tm.cleanerRunning, 0) }() for { @@ -285,6 +288,8 @@ func (tm *TimedMap) cleanUp() { func (tm *TimedMap) set(key interface{}, sec int, val interface{}, expiresAfter time.Duration, cb ...callback) { // re-use element when existent on this key if v := tm.getRaw(key, sec); v != nil { + tm.mtx.Lock() + defer tm.mtx.Unlock() v.value = val v.expires = time.Now().Add(expiresAfter) v.cbs = cb @@ -315,9 +320,10 @@ func (tm *TimedMap) get(key interface{}, sec int) *element { return nil } + tm.mtx.Lock() + defer tm.mtx.Unlock() + if time.Now().After(v.expires) { - tm.mtx.Lock() - defer tm.mtx.Unlock() tm.expireElement(key, sec, v) return nil } @@ -371,7 +377,9 @@ func (tm *TimedMap) refresh(key interface{}, sec int, d time.Duration) error { if v == nil { return ErrKeyNotFound } + tm.mtx.Lock() v.expires = v.expires.Add(d) + tm.mtx.Unlock() return nil } @@ -382,7 +390,9 @@ func (tm *TimedMap) setExpires(key interface{}, sec int, d time.Duration) error if v == nil { return ErrKeyNotFound } + tm.mtx.Lock() v.expires = time.Now().Add(d) + tm.mtx.Unlock() return nil } @@ -408,6 +418,7 @@ func newTimedMap( ) *TimedMap { tm := &TimedMap{ container: container, + cleanerRunning: new(uint32), cleanerStopChan: make(chan bool), elementPool: &sync.Pool{ New: func() interface{} { diff --git a/timedmap_test.go b/timedmap_test.go index 6aa3b10..2c63d6e 100644 --- a/timedmap_test.go +++ b/timedmap_test.go @@ -2,6 +2,7 @@ package timedmap import ( "sync" + "sync/atomic" "testing" "time" @@ -19,7 +20,7 @@ func TestNew(t *testing.T) { assert.NotNil(t, tm) assert.EqualValues(t, 0, len(tm.container)) time.Sleep(10 * time.Millisecond) - assert.True(t, tm.cleanerRunning) + assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0) } func TestFromMap(t *testing.T) { @@ -246,7 +247,7 @@ func TestStopCleaner(t *testing.T) { time.Sleep(10 * time.Millisecond) tm.StopCleaner() time.Sleep(10 * time.Millisecond) - assert.False(t, tm.cleanerRunning) + assert.False(t, atomic.LoadUint32(tm.cleanerRunning) != 0) assert.NotPanics(t, func() { tm.StopCleaner() @@ -259,7 +260,7 @@ func TestStartCleanerInternal(t *testing.T) { tm := New(0) time.Sleep(10 * time.Millisecond) - assert.False(t, tm.cleanerRunning) + assert.False(t, atomic.LoadUint32(tm.cleanerRunning) != 0) // Ensure cleanup timer is not running tm.set(1, 0, 1, 0) @@ -268,7 +269,7 @@ func TestStartCleanerInternal(t *testing.T) { tm.StartCleanerInternal(dCleanupTick) time.Sleep(10 * time.Millisecond) - assert.True(t, tm.cleanerRunning) + assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0) // Ensure cleanup timer is running tm.set(1, 0, 1, 0) @@ -294,7 +295,7 @@ func TestStartCleanerExternal(t *testing.T) { tm := New(0) time.Sleep(10 * time.Millisecond) - assert.False(t, tm.cleanerRunning) + assert.False(t, atomic.LoadUint32(tm.cleanerRunning) != 0) // Ensure cleanup timer is not running tm.set(1, 0, 1, 0) @@ -305,7 +306,7 @@ func TestStartCleanerExternal(t *testing.T) { tm.StartCleanerExternal(c) time.Sleep(10 * time.Millisecond) - assert.True(t, tm.cleanerRunning) + assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0) // Ensure cleanup is controlled by c tm.set(1, 0, 1, 0) @@ -323,7 +324,7 @@ func TestStartCleanerExternal(t *testing.T) { tm := New(dCleanupTick) time.Sleep(10 * time.Millisecond) - assert.True(t, tm.cleanerRunning) + assert.True(t, atomic.LoadUint32(tm.cleanerRunning) != 0) assert.NotNil(t, tm.cleanerTicker) c := make(chan time.Time)