diff --git a/sets/sets.go b/sets/sets.go index decf8e7..82863c1 100644 --- a/sets/sets.go +++ b/sets/sets.go @@ -72,9 +72,26 @@ func (s *Set[T]) Remove(item T) bool { // Difference returns a set containing the elements of s that are not in x. func (s *Set[T]) Difference(x *Set[T]) *Set[T] { result := New[T](max(0, s.Size()-x.Size())) - for i := range s.items { - if !x.Contains(i) { - result.items[i] = struct{}{} + for item := range s.items { + if !x.Contains(item) { + result.items[item] = struct{}{} + } + } + return result +} + +// Intersection returns a set containing the elements that are both in s and x. +func (s *Set[T]) Intersection(x *Set[T]) *Set[T] { + result := New[T](0) + // loop over the smaller set (thanks to https://github.com/deckarep/golang-set) + smaller := s + bigger := x + if smaller.Size() > bigger.Size() { + smaller, bigger = bigger, smaller + } + for item := range smaller.items { + if bigger.Contains(item) { + result.items[item] = struct{}{} } } return result diff --git a/sets/sets_test.go b/sets/sets_test.go index d9194fa..3a68e63 100644 --- a/sets/sets_test.go +++ b/sets/sets_test.go @@ -142,6 +142,65 @@ func TestDifference(t *testing.T) { } } +func TestIntersection(t *testing.T) { + type testCase struct { + name string + s *sets.Set[int] + x *sets.Set[int] + wantList []int + } + + test := func(t *testing.T, tc testCase) { + result := tc.s.Intersection(tc.x) + sorted := result.OrderedList() + + assert.DeepEqual(t, sorted, tc.wantList) + } + + testCases := []testCase{ + { + name: "both empty", + s: sets.From[int](), + x: sets.From[int](), + wantList: []int{}, + }, + { + name: "empty x returns empty", + s: sets.From(1, 2, 3), + x: sets.From[int](), + wantList: []int{}, + }, + { + name: "nothing in common returns empty", + s: sets.From(1, 2, 3), + x: sets.From(4, 5), + wantList: []int{}, + }, + { + name: "one in common", + s: sets.From(1, 2, 3), + x: sets.From(4, 2), + wantList: []int{2}, + }, + { + name: "s subset of x returns s", + s: sets.From(1, 2, 3), + x: sets.From(1, 2, 3, 12), + wantList: []int{1, 2, 3}, + }, + { + name: "x subset of s returns x", + s: sets.From(1, 2, 3, 12), + x: sets.From(1, 2, 3), + wantList: []int{1, 2, 3}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { test(t, tc) }) + } +} + func TestRemoveFound(t *testing.T) { type testCase struct { name string