diff --git a/threadsafe.go b/threadsafe.go
index ad7a834..93f20c8 100644
--- a/threadsafe.go
+++ b/threadsafe.go
@@ -29,7 +29,7 @@ import "sync"
 
 type threadSafeSet[T comparable] struct {
 	sync.RWMutex
-	uss threadUnsafeSet[T]
+	uss *threadUnsafeSet[T]
 }
 
 func newThreadSafeSet[T comparable]() *threadSafeSet[T] {
@@ -123,7 +123,7 @@ func (t *threadSafeSet[T]) Union(other Set[T]) Set[T] {
 	t.RLock()
 	o.RLock()
 
-	unsafeUnion := t.uss.Union(o.uss).(threadUnsafeSet[T])
+	unsafeUnion := t.uss.Union(o.uss).(*threadUnsafeSet[T])
 	ret := &threadSafeSet[T]{uss: unsafeUnion}
 	t.RUnlock()
 	o.RUnlock()
@@ -136,7 +136,7 @@ func (t *threadSafeSet[T]) Intersect(other Set[T]) Set[T] {
 	t.RLock()
 	o.RLock()
 
-	unsafeIntersection := t.uss.Intersect(o.uss).(threadUnsafeSet[T])
+	unsafeIntersection := t.uss.Intersect(o.uss).(*threadUnsafeSet[T])
 	ret := &threadSafeSet[T]{uss: unsafeIntersection}
 	t.RUnlock()
 	o.RUnlock()
@@ -149,7 +149,7 @@ func (t *threadSafeSet[T]) Difference(other Set[T]) Set[T] {
 	t.RLock()
 	o.RLock()
 
-	unsafeDifference := t.uss.Difference(o.uss).(threadUnsafeSet[T])
+	unsafeDifference := t.uss.Difference(o.uss).(*threadUnsafeSet[T])
 	ret := &threadSafeSet[T]{uss: unsafeDifference}
 	t.RUnlock()
 	o.RUnlock()
@@ -162,7 +162,7 @@ func (t *threadSafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
 	t.RLock()
 	o.RLock()
 
-	unsafeDifference := t.uss.SymmetricDifference(o.uss).(threadUnsafeSet[T])
+	unsafeDifference := t.uss.SymmetricDifference(o.uss).(*threadUnsafeSet[T])
 	ret := &threadSafeSet[T]{uss: unsafeDifference}
 	t.RUnlock()
 	o.RUnlock()
@@ -177,7 +177,7 @@ func (t *threadSafeSet[T]) Clear() {
 
 func (t *threadSafeSet[T]) Remove(v T) {
 	t.Lock()
-	delete(t.uss, v)
+	delete(*t.uss, v)
 	t.Unlock()
 }
 
@@ -190,12 +190,12 @@ func (t *threadSafeSet[T]) RemoveAll(i ...T) {
 func (t *threadSafeSet[T]) Cardinality() int {
 	t.RLock()
 	defer t.RUnlock()
-	return len(t.uss)
+	return len(*t.uss)
 }
 
 func (t *threadSafeSet[T]) Each(cb func(T) bool) {
 	t.RLock()
-	for elem := range t.uss {
+	for elem := range *t.uss {
 		if cb(elem) {
 			break
 		}
@@ -208,7 +208,7 @@ func (t *threadSafeSet[T]) Iter() <-chan T {
 	go func() {
 		t.RLock()
 
-		for elem := range t.uss {
+		for elem := range *t.uss {
 			ch <- elem
 		}
 		close(ch)
@@ -224,7 +224,7 @@ func (t *threadSafeSet[T]) Iterator() *Iterator[T] {
 	go func() {
 		t.RLock()
 	L:
-		for elem := range t.uss {
+		for elem := range *t.uss {
 			select {
 			case <-stopCh:
 				break L
@@ -253,7 +253,7 @@ func (t *threadSafeSet[T]) Equal(other Set[T]) bool {
 func (t *threadSafeSet[T]) Clone() Set[T] {
 	t.RLock()
 
-	unsafeClone := t.uss.Clone().(threadUnsafeSet[T])
+	unsafeClone := t.uss.Clone().(*threadUnsafeSet[T])
 	ret := &threadSafeSet[T]{uss: unsafeClone}
 	t.RUnlock()
 	return ret
@@ -275,7 +275,7 @@ func (t *threadSafeSet[T]) Pop() (T, bool) {
 func (t *threadSafeSet[T]) ToSlice() []T {
 	keys := make([]T, 0, t.Cardinality())
 	t.RLock()
-	for elem := range t.uss {
+	for elem := range *t.uss {
 		keys = append(keys, elem)
 	}
 	t.RUnlock()
diff --git a/threadsafe_test.go b/threadsafe_test.go
index 071cdb5..ca998c9 100644
--- a/threadsafe_test.go
+++ b/threadsafe_test.go
@@ -584,17 +584,7 @@ func Test_UnmarshalJSON(t *testing.T) {
 		t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
 	}
 }
-func TestThreadUnsafeSet_UnmarshalJSON(t *testing.T) {
-	expected := NewThreadUnsafeSet[int64](1, 2, 3)
-	actual := NewThreadUnsafeSet[int64]()
-	err := actual.UnmarshalJSON([]byte(`[1, 2, 3]`))
-	if err != nil {
-		t.Errorf("Error should be nil: %v", err)
-	}
-	if !expected.Equal(actual) {
-		t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
-	}
-}
+
 func Test_MarshalJSON(t *testing.T) {
 	expected := NewSet(
 		[]string{
diff --git a/threadunsafe.go b/threadunsafe.go
index 8b17b01..7e3243b 100644
--- a/threadunsafe.go
+++ b/threadunsafe.go
@@ -34,14 +34,16 @@ import (
 type threadUnsafeSet[T comparable] map[T]struct{}
 
 // Assert concrete type:threadUnsafeSet adheres to Set interface.
-var _ Set[string] = (threadUnsafeSet[string])(nil)
+var _ Set[string] = (*threadUnsafeSet[string])(nil)
 
-func newThreadUnsafeSet[T comparable]() threadUnsafeSet[T] {
-	return make(threadUnsafeSet[T])
+func newThreadUnsafeSet[T comparable]() *threadUnsafeSet[T] {
+	t := make(threadUnsafeSet[T])
+	return &t
 }
 
-func newThreadUnsafeSetWithSize[T comparable](cardinality int) threadUnsafeSet[T] {
-	return make(threadUnsafeSet[T], cardinality)
+func newThreadUnsafeSetWithSize[T comparable](cardinality int) *threadUnsafeSet[T] {
+	t := make(threadUnsafeSet[T], cardinality)
+	return &t
 }
 
 func (s threadUnsafeSet[T]) Add(v T) bool {
@@ -50,57 +52,57 @@ func (s threadUnsafeSet[T]) Add(v T) bool {
 	return prevLen != len(s)
 }
 
-func (s threadUnsafeSet[T]) Append(v ...T) int {
-	prevLen := len(s)
+func (s *threadUnsafeSet[T]) Append(v ...T) int {
+	prevLen := len(*s)
 	for _, val := range v {
-		(s)[val] = struct{}{}
+		(*s)[val] = struct{}{}
 	}
-	return len(s) - prevLen
+	return len(*s) - prevLen
 }
 
 // private version of Add which doesn't return a value
-func (s threadUnsafeSet[T]) add(v T) {
-	s[v] = struct{}{}
+func (s *threadUnsafeSet[T]) add(v T) {
+	(*s)[v] = struct{}{}
 }
 
-func (s threadUnsafeSet[T]) Cardinality() int {
-	return len(s)
+func (s *threadUnsafeSet[T]) Cardinality() int {
+	return len(*s)
 }
 
-func (s threadUnsafeSet[T]) Clear() {
+func (s *threadUnsafeSet[T]) Clear() {
 	// Constructions like this are optimised by compiler, and replaced by
 	// mapclear() function, defined in
 	// https://github.com/golang/go/blob/29bbca5c2c1ad41b2a9747890d183b6dd3a4ace4/src/runtime/map.go#L993)
-	for key := range s {
-		delete(s, key)
+	for key := range *s {
+		delete(*s, key)
 	}
 }
 
-func (s threadUnsafeSet[T]) Clone() Set[T] {
+func (s *threadUnsafeSet[T]) Clone() Set[T] {
 	clonedSet := newThreadUnsafeSetWithSize[T](s.Cardinality())
-	for elem := range s {
+	for elem := range *s {
 		clonedSet.add(elem)
 	}
 	return clonedSet
 }
 
-func (s threadUnsafeSet[T]) Contains(v ...T) bool {
+func (s *threadUnsafeSet[T]) Contains(v ...T) bool {
 	for _, val := range v {
-		if _, ok := s[val]; !ok {
+		if _, ok := (*s)[val]; !ok {
 			return false
 		}
 	}
 	return true
 }
 
-func (s threadUnsafeSet[T]) ContainsOne(v T) bool {
-	_, ok := s[v]
+func (s *threadUnsafeSet[T]) ContainsOne(v T) bool {
+	_, ok := (*s)[v]
 	return ok
 }
 
-func (s threadUnsafeSet[T]) ContainsAny(v ...T) bool {
+func (s *threadUnsafeSet[T]) ContainsAny(v ...T) bool {
 	for _, val := range v {
-		if _, ok := s[val]; ok {
+		if _, ok := (*s)[val]; ok {
 			return true
 		}
 	}
@@ -108,16 +110,16 @@ func (s threadUnsafeSet[T]) ContainsAny(v ...T) bool {
 }
 
 // private version of Contains for a single element v
-func (s threadUnsafeSet[T]) contains(v T) (ok bool) {
-	_, ok = s[v]
+func (s *threadUnsafeSet[T]) contains(v T) (ok bool) {
+	_, ok = (*s)[v]
 	return ok
 }
 
-func (s threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
-	o := other.(threadUnsafeSet[T])
+func (s *threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
+	o := other.(*threadUnsafeSet[T])
 
 	diff := newThreadUnsafeSet[T]()
-	for elem := range s {
+	for elem := range *s {
 		if !o.contains(elem) {
 			diff.add(elem)
 		}
@@ -125,21 +127,21 @@ func (s threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
 	return diff
 }
 
-func (s threadUnsafeSet[T]) Each(cb func(T) bool) {
-	for elem := range s {
+func (s *threadUnsafeSet[T]) Each(cb func(T) bool) {
+	for elem := range *s {
 		if cb(elem) {
 			break
 		}
 	}
 }
 
-func (s threadUnsafeSet[T]) Equal(other Set[T]) bool {
-	o := other.(threadUnsafeSet[T])
+func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool {
+	o := other.(*threadUnsafeSet[T])
 
 	if s.Cardinality() != other.Cardinality() {
 		return false
 	}
-	for elem := range s {
+	for elem := range *s {
 		if !o.contains(elem) {
 			return false
 		}
@@ -147,19 +149,19 @@ func (s threadUnsafeSet[T]) Equal(other Set[T]) bool {
 	return true
 }
 
-func (s threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
-	o := other.(threadUnsafeSet[T])
+func (s *threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
+	o := other.(*threadUnsafeSet[T])
 
 	intersection := newThreadUnsafeSet[T]()
 	// loop over smaller set
 	if s.Cardinality() < other.Cardinality() {
-		for elem := range s {
+		for elem := range *s {
 			if o.contains(elem) {
 				intersection.add(elem)
 			}
 		}
 	} else {
-		for elem := range o {
+		for elem := range *o {
 			if s.contains(elem) {
 				intersection.add(elem)
 			}
@@ -168,24 +170,24 @@ func (s threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
 	return intersection
 }
 
-func (s threadUnsafeSet[T]) IsEmpty() bool {
+func (s *threadUnsafeSet[T]) IsEmpty() bool {
 	return s.Cardinality() == 0
 }
 
-func (s threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool {
+func (s *threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool {
 	return s.Cardinality() < other.Cardinality() && s.IsSubset(other)
 }
 
-func (s threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool {
+func (s *threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool {
 	return s.Cardinality() > other.Cardinality() && s.IsSuperset(other)
 }
 
-func (s threadUnsafeSet[T]) IsSubset(other Set[T]) bool {
-	o := other.(threadUnsafeSet[T])
+func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool {
+	o := other.(*threadUnsafeSet[T])
 	if s.Cardinality() > other.Cardinality() {
 		return false
 	}
-	for elem := range s {
+	for elem := range *s {
 		if !o.contains(elem) {
 			return false
 		}
@@ -193,14 +195,14 @@ func (s threadUnsafeSet[T]) IsSubset(other Set[T]) bool {
 	return true
 }
 
-func (s threadUnsafeSet[T]) IsSuperset(other Set[T]) bool {
+func (s *threadUnsafeSet[T]) IsSuperset(other Set[T]) bool {
 	return other.IsSubset(s)
 }
 
-func (s threadUnsafeSet[T]) Iter() <-chan T {
+func (s *threadUnsafeSet[T]) Iter() <-chan T {
 	ch := make(chan T)
 	go func() {
-		for elem := range s {
+		for elem := range *s {
 			ch <- elem
 		}
 		close(ch)
@@ -209,12 +211,12 @@ func (s threadUnsafeSet[T]) Iter() <-chan T {
 	return ch
 }
 
-func (s threadUnsafeSet[T]) Iterator() *Iterator[T] {
+func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] {
 	iterator, ch, stopCh := newIterator[T]()
 
 	go func() {
 	L:
-		for elem := range s {
+		for elem := range *s {
 			select {
 			case <-stopCh:
 				break L
@@ -229,9 +231,9 @@ func (s threadUnsafeSet[T]) Iterator() *Iterator[T] {
 
 // Pop returns a popped item in case set is not empty, or nil-value of T
 // if set is already empty
-func (s threadUnsafeSet[T]) Pop() (v T, ok bool) {
-	for item := range s {
-		delete(s, item)
+func (s *threadUnsafeSet[T]) Pop() (v T, ok bool) {
+	for item := range *s {
+		delete(*s, item)
 		return item, true
 	}
 	return v, false
@@ -256,16 +258,16 @@ func (s threadUnsafeSet[T]) String() string {
 	return fmt.Sprintf("Set{%s}", strings.Join(items, ", "))
 }
 
-func (s threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
-	o := other.(threadUnsafeSet[T])
+func (s *threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
+	o := other.(*threadUnsafeSet[T])
 
 	sd := newThreadUnsafeSet[T]()
-	for elem := range s {
+	for elem := range *s {
 		if !o.contains(elem) {
 			sd.add(elem)
 		}
 	}
-	for elem := range o {
+	for elem := range *o {
 		if !s.contains(elem) {
 			sd.add(elem)
 		}
@@ -283,7 +285,7 @@ func (s threadUnsafeSet[T]) ToSlice() []T {
 }
 
 func (s threadUnsafeSet[T]) Union(other Set[T]) Set[T] {
-	o := other.(threadUnsafeSet[T])
+	o := other.(*threadUnsafeSet[T])
 
 	n := s.Cardinality()
 	if o.Cardinality() > n {
@@ -294,10 +296,10 @@ func (s threadUnsafeSet[T]) Union(other Set[T]) Set[T] {
 	for elem := range s {
 		unionedSet.add(elem)
 	}
-	for elem := range o {
+	for elem := range *o {
 		unionedSet.add(elem)
 	}
-	return unionedSet
+	return &unionedSet
 }
 
 // MarshalJSON creates a JSON array from the set, it marshals all elements
@@ -318,7 +320,7 @@ func (s threadUnsafeSet[T]) MarshalJSON() ([]byte, error) {
 
 // UnmarshalJSON recreates a set from a JSON array, it only decodes
 // primitive types. Numbers are decoded as json.Number.
-func (s threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
+func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
 	var i []T
 	err := json.Unmarshal(b, &i)
 	if err != nil {
diff --git a/threadunsafe_test.go b/threadunsafe_test.go
new file mode 100644
index 0000000..c670305
--- /dev/null
+++ b/threadunsafe_test.go
@@ -0,0 +1,164 @@
+/*
+Open Source Initiative OSI - The MIT License (MIT):Licensing
+
+The MIT License (MIT)
+Copyright (c) 2013 - 2022 Ralph Caraveo (deckarep@gmail.com)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+of the Software, and to permit persons to whom the Software is furnished to do
+so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+*/
+
+package mapset
+
+import (
+	"encoding/json"
+	"testing"
+)
+
+func TestThreadUnsafeSet_MarshalJSON(t *testing.T) {
+	expected := NewThreadUnsafeSet[int64](1, 2, 3)
+	actual := newThreadUnsafeSet[int64]()
+
+	// test Marshal from Set method
+	b, err := expected.MarshalJSON()
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+
+	err = json.Unmarshal(b, actual)
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+
+	if !expected.Equal(actual) {
+		t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
+	}
+
+	// test Marshal from json package
+	b, err = json.Marshal(expected)
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+
+	err = json.Unmarshal(b, actual)
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+
+	if !expected.Equal(actual) {
+		t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
+	}
+}
+
+func TestThreadUnsafeSet_UnmarshalJSON(t *testing.T) {
+	expected := NewThreadUnsafeSet[int64](1, 2, 3)
+	actual := NewThreadUnsafeSet[int64]()
+
+	// test Unmarshal from Set method
+	err := actual.UnmarshalJSON([]byte(`[1, 2, 3]`))
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+	if !expected.Equal(actual) {
+		t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
+	}
+
+	// test Unmarshal from json package
+	actual = NewThreadUnsafeSet[int64]()
+	err = json.Unmarshal([]byte(`[1, 2, 3]`), actual)
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+	if !expected.Equal(actual) {
+		t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
+	}
+}
+
+func TestThreadUnsafeSet_MarshalJSON_Struct(t *testing.T) {
+	expected := &testStruct{"test", NewThreadUnsafeSet("a")}
+
+	b, err := json.Marshal(&testStruct{"test", NewThreadUnsafeSet("a")})
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+
+	actual := &testStruct{}
+	err = json.Unmarshal(b, actual)
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+
+	if !expected.Set.Equal(actual.Set) {
+		t.Errorf("Expected no difference, got: %v", expected.Set.Difference(actual.Set))
+	}
+}
+func TestThreadUnsafeSet_UnmarshalJSON_Struct(t *testing.T) {
+	expected := &testStruct{"test", NewThreadUnsafeSet("a", "b", "c")}
+	actual := &testStruct{}
+
+	err := json.Unmarshal([]byte(`{"other":"test", "set":["a", "b", "c"]}`), actual)
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+	if !expected.Set.Equal(actual.Set) {
+		t.Errorf("Expected no difference, got: %v", expected.Set.Difference(actual.Set))
+	}
+
+	expectedComplex := NewThreadUnsafeSet(struct{ Val string }{Val: "a"}, struct{ Val string }{Val: "b"})
+	actualComplex := NewThreadUnsafeSet[struct{ Val string }]()
+
+	err = actualComplex.UnmarshalJSON([]byte(`[{"Val": "a"}, {"Val": "b"}]`))
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+	if !expectedComplex.Equal(actualComplex) {
+		t.Errorf("Expected no difference, got: %v", expectedComplex.Difference(actualComplex))
+	}
+
+	actualComplex = NewThreadUnsafeSet[struct{ Val string }]()
+	err = json.Unmarshal([]byte(`[{"Val": "a"}, {"Val": "b"}]`), actualComplex)
+	if err != nil {
+		t.Errorf("Error should be nil: %v", err)
+	}
+	if !expectedComplex.Equal(actualComplex) {
+		t.Errorf("Expected no difference, got: %v", expectedComplex.Difference(actualComplex))
+	}
+}
+
+// this serves as an example of how to correctly unmarshal a struct with a Set property
+type testStruct struct {
+	Other string
+	Set   Set[string]
+}
+
+func (t *testStruct) UnmarshalJSON(b []byte) error {
+	raw := struct {
+		Other string
+		Set   []string
+	}{}
+
+	err := json.Unmarshal(b, &raw)
+	if err != nil {
+		return err
+	}
+
+	t.Other = raw.Other
+	t.Set = NewThreadUnsafeSet(raw.Set...)
+
+	return nil
+}