diff --git a/bench_test.go b/bench_test.go index 5d4744a..ed98e7e 100644 --- a/bench_test.go +++ b/bench_test.go @@ -180,6 +180,99 @@ func BenchmarkContains100Unsafe(b *testing.B) { benchContains(b, 100, NewThreadUnsafeSet[int]()) } +func benchContainsOne(b *testing.B, n int, s Set[int]) { + nums := nrand(n) + for _, v := range nums { + s.Add(v) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.ContainsOne(-1) + } +} + +func BenchmarkContainsOne1Safe(b *testing.B) { + benchContainsOne(b, 1, NewSet[int]()) +} + +func BenchmarkContainsOne1Unsafe(b *testing.B) { + benchContainsOne(b, 1, NewThreadUnsafeSet[int]()) +} + +func BenchmarkContainsOne10Safe(b *testing.B) { + benchContainsOne(b, 10, NewSet[int]()) +} + +func BenchmarkContainsOne10Unsafe(b *testing.B) { + benchContainsOne(b, 10, NewThreadUnsafeSet[int]()) +} + +func BenchmarkContainsOne100Safe(b *testing.B) { + benchContainsOne(b, 100, NewSet[int]()) +} + +func BenchmarkContainsOne100Unsafe(b *testing.B) { + benchContainsOne(b, 100, NewThreadUnsafeSet[int]()) +} + +// In this scenario, Contains argument escapes to the heap, while ContainsOne does not. +func benchContainsComparison(b *testing.B, n int, s Set[int]) { + nums := nrand(n) + for _, v := range nums { + s.Add(v) + } + + b.Run("Contains", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, v := range nums { + s.Contains(v) // 1 allocation, v is moved to the heap + } + } + }) + b.Run("Contains slice", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for i := range nums { + s.Contains(nums[i : i+1]...) // no allocations, using heap-allocated slice + } + } + }) + b.Run("ContainsOne", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, v := range nums { + s.ContainsOne(v) // no allocations, using stack-allocated v + } + } + }) +} + +func BenchmarkContainsComparison1Unsafe(b *testing.B) { + benchContainsComparison(b, 1, NewThreadUnsafeSet[int]()) +} + +func BenchmarkContainsComparison1Safe(b *testing.B) { + benchContainsComparison(b, 1, NewSet[int]()) +} + +func BenchmarkContainsComparison10Unsafe(b *testing.B) { + benchContainsComparison(b, 10, NewThreadUnsafeSet[int]()) +} + +func BenchmarkContainsComparison10Safe(b *testing.B) { + benchContainsComparison(b, 10, NewSet[int]()) +} + +func BenchmarkContainsComparison100Unsafe(b *testing.B) { + benchContainsComparison(b, 100, NewThreadUnsafeSet[int]()) +} + +func BenchmarkContainsComparison100Safe(b *testing.B) { + benchContainsComparison(b, 100, NewSet[int]()) +} + func benchEqual(b *testing.B, n int, s, t Set[int]) { nums := nrand(n) for _, v := range nums { diff --git a/set.go b/set.go index 28d2ada..292089d 100644 --- a/set.go +++ b/set.go @@ -62,6 +62,13 @@ type Set[T comparable] interface { // are all in the set. Contains(val ...T) bool + // ContainsOne returns whether the given item + // is in the set. + // + // Contains may cause the argument to escape to the heap. + // See: https://github.com/deckarep/golang-set/issues/118 + ContainsOne(val T) bool + // ContainsAny returns whether at least one of the // given items are in the set. ContainsAny(val ...T) bool diff --git a/set_test.go b/set_test.go index 23e8f09..a21153d 100644 --- a/set_test.go +++ b/set_test.go @@ -318,6 +318,54 @@ func Test_ContainsMultipleUnsafeSet(t *testing.T) { } } +func Test_ContainsOneSet(t *testing.T) { + a := NewSet[int]() + + a.Add(71) + + if !a.ContainsOne(71) { + t.Error("ContainsSet should contain 71") + } + + a.Remove(71) + + if a.ContainsOne(71) { + t.Error("ContainsSet should not contain 71") + } + + a.Add(13) + a.Add(7) + a.Add(1) + + if !(a.ContainsOne(13) && a.ContainsOne(7) && a.ContainsOne(1)) { + t.Error("ContainsSet should contain 13, 7, 1") + } +} + +func Test_ContainsOneUnsafeSet(t *testing.T) { + a := NewThreadUnsafeSet[int]() + + a.Add(71) + + if !a.ContainsOne(71) { + t.Error("ContainsSet should contain 71") + } + + a.Remove(71) + + if a.ContainsOne(71) { + t.Error("ContainsSet should not contain 71") + } + + a.Add(13) + a.Add(7) + a.Add(1) + + if !(a.ContainsOne(13) && a.ContainsOne(7) && a.ContainsOne(1)) { + t.Error("ContainsSet should contain 13, 7, 1") + } +} + func Test_ContainsAnySet(t *testing.T) { a := NewSet[int]() diff --git a/threadsafe.go b/threadsafe.go index 6086f31..ad7a834 100644 --- a/threadsafe.go +++ b/threadsafe.go @@ -66,6 +66,14 @@ func (t *threadSafeSet[T]) Contains(v ...T) bool { return ret } +func (t *threadSafeSet[T]) ContainsOne(v T) bool { + t.RLock() + ret := t.uss.ContainsOne(v) + t.RUnlock() + + return ret +} + func (t *threadSafeSet[T]) ContainsAny(v ...T) bool { t.RLock() ret := t.uss.ContainsAny(v...) diff --git a/threadsafe_test.go b/threadsafe_test.go index 399fd30..071cdb5 100644 --- a/threadsafe_test.go +++ b/threadsafe_test.go @@ -172,6 +172,27 @@ func Test_ContainsConcurrent(t *testing.T) { wg.Wait() } +func Test_ContainsOneConcurrent(t *testing.T) { + runtime.GOMAXPROCS(2) + + s := NewSet[int]() + ints := rand.Perm(N) + for _, v := range ints { + s.Add(v) + } + + var wg sync.WaitGroup + for _, v := range ints { + number := v + wg.Add(1) + go func() { + s.ContainsOne(number) + wg.Done() + }() + } + wg.Wait() +} + func Test_ContainsAnyConcurrent(t *testing.T) { runtime.GOMAXPROCS(2) diff --git a/threadunsafe.go b/threadunsafe.go index 228bada..8b17b01 100644 --- a/threadunsafe.go +++ b/threadunsafe.go @@ -93,6 +93,11 @@ func (s threadUnsafeSet[T]) Contains(v ...T) bool { return true } +func (s threadUnsafeSet[T]) ContainsOne(v T) bool { + _, ok := s[v] + return ok +} + func (s threadUnsafeSet[T]) ContainsAny(v ...T) bool { for _, val := range v { if _, ok := s[val]; ok {