Skip to content

Commit

Permalink
Merge pull request #5 from Otoru/hotfix/new-tests
Browse files Browse the repository at this point in the history
✅ (hotfix/new-tests) Add tests to hooks
  • Loading branch information
Otoru authored Sep 16, 2023
2 parents f9e148b + e3ade2d commit c2147ff
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 48 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

# Test binary, built with `go test -c`
*.test
test/*

# Output of the go coverage tool, specifically when used with LiteIDE
*.out
Expand All @@ -43,7 +44,6 @@ go.work
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets

Expand All @@ -58,4 +58,5 @@ go.work
.history
.ionide


# End of https://www.toptal.com/developers/gitignore/api/go,visualstudiocode,git
68 changes: 44 additions & 24 deletions container.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package zeus
import (
"reflect"
"slices"
"sync"

"github.com/otoru/zeus/errs"
"github.com/otoru/zeus/hooks"
Expand All @@ -12,6 +13,7 @@ import (
type Container struct {
providers map[reflect.Type]reflect.Value
instances map[reflect.Type]reflect.Value
mu sync.RWMutex
hooks Hooks
}

Expand All @@ -21,9 +23,9 @@ type Container struct {
//
// c := zeus.New()
func New() *Container {
hooks := new(hooks.HooksImpl)
providers := make(map[reflect.Type]reflect.Value, 0)
instances := make(map[reflect.Type]reflect.Value, 0)
hooks := new(hooks.LifecycleHooks)
providers := make(map[reflect.Type]reflect.Value)
instances := make(map[reflect.Type]reflect.Value)

container := new(Container)
container.hooks = hooks
Expand All @@ -41,13 +43,16 @@ func (c *Container) resolve(t reflect.Type, stack []reflect.Type) (reflect.Value
return reflect.Value{}, errs.CyclicDependencyError{TypeName: t.Name()}
}

if instance, exists := c.instances[t]; exists {
c.mu.RLock()
instance, hasInstance := c.instances[t]
provider, hasProvider := c.providers[t]
c.mu.RUnlock()

if hasInstance {
return instance, nil
}

provider, ok := c.providers[t]

if !ok {
if !hasProvider {
return reflect.Value{}, errs.DependencyResolutionError{TypeName: t.Name()}
}

Expand Down Expand Up @@ -90,28 +95,36 @@ func (c *Container) resolve(t reflect.Type, stack []reflect.Type) (reflect.Value
//
// c := zeus.New()
// c.Provide(func() int { return 42 })
func (c *Container) Provide(factory interface{}) error {
factoryType := reflect.TypeOf(factory)
func (c *Container) Provide(factories ...interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()

if factoryType.Kind() != reflect.Func {
return errs.NotAFunctionError{}
}
for _, factory := range factories {
factoryType := reflect.TypeOf(factory)

if numOut := factoryType.NumOut(); numOut < 1 || numOut > 2 {
return errs.InvalidFactoryReturnError{NumReturns: numOut}
}
if factoryType.Kind() != reflect.Func {
return errs.NotAFunctionError{}
}

if factoryType.NumOut() == 2 && factoryType.Out(1).Name() != "error" {
return errs.UnexpectedReturnTypeError{TypeName: factoryType.Out(1).Name()}
}
if numOut := factoryType.NumOut(); numOut < 1 || numOut > 2 {
return errs.InvalidFactoryReturnError{NumReturns: numOut}
}

if factoryType.NumOut() == 2 {
errorType := reflect.TypeOf((*error)(nil)).Elem()
if !factoryType.Out(1).Implements(errorType) {
return errs.UnexpectedReturnTypeError{TypeName: factoryType.Out(1).Name()}
}
}

serviceType := factoryType.Out(0)
serviceType := factoryType.Out(0)

if _, exists := c.providers[serviceType]; exists {
return errs.FactoryAlreadyProvidedError{TypeName: serviceType.Name()}
}
if _, exists := c.providers[serviceType]; exists {
return errs.FactoryAlreadyProvidedError{TypeName: serviceType.Name()}
}

c.providers[serviceType] = reflect.ValueOf(factory)
c.providers[serviceType] = reflect.ValueOf(factory)
}

return nil
}
Expand Down Expand Up @@ -166,6 +179,10 @@ func (c *Container) Run(fn interface{}) error {
errorSet.Add(err)
}

if !errorSet.IsEmpty() {
return errorSet.Result()
}

results := reflect.ValueOf(fn).Call(dependencies)

if fnType.NumOut() == 1 && !results[0].IsNil() {
Expand Down Expand Up @@ -196,9 +213,12 @@ func (c *Container) Run(fn interface{}) error {
// // Handle merge error
// }
func (c *Container) Merge(other *Container) error {
c.mu.Lock()
defer c.mu.Unlock()

for t, factory := range other.providers {
if existingFactory, exists := c.providers[t]; exists {
if !reflect.DeepEqual(existingFactory, factory) {
if existingFactory.Pointer() != factory.Pointer() {
return errs.FactoryAlreadyProvidedError{TypeName: t.Name()}
}
continue
Expand Down
3 changes: 3 additions & 0 deletions container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,12 @@ func TestContainer(t *testing.T) {
started = true
return nil
})

h.OnStop(func() error {
stopped = true
return nil
})

return 42
})

Expand All @@ -272,6 +274,7 @@ func TestContainer(t *testing.T) {
h.OnStart(func() error {
return errors.New("start error")
})

return 42
})

Expand Down
39 changes: 16 additions & 23 deletions hooks/hooks.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
package hooks

import (
"sync"

"github.com/otoru/zeus/errs"
)
import "sync"

// Hooks defines an interface for lifecycle events.
// It provides methods to register functions that should be executed
Expand All @@ -16,10 +12,11 @@ type Hooks interface {
Stop() error
}

// HooksImpl is the default implementation of the Hooks interface.
type HooksImpl struct {
// LifecycleHooks is the default implementation of the Hooks interface.
type LifecycleHooks struct {
onStart []func() error
onStop []func() error
mu sync.Mutex
}

// OnStart adds a function to the list of functions to be executed at the start.
Expand All @@ -29,7 +26,9 @@ type HooksImpl struct {
// fmt.Println("Starting...")
// return nil
// })
func (h *HooksImpl) OnStart(fn func() error) {
func (h *LifecycleHooks) OnStart(fn func() error) {
h.mu.Lock()
defer h.mu.Unlock()
h.onStart = append(h.onStart, fn)
}

Expand All @@ -40,14 +39,16 @@ func (h *HooksImpl) OnStart(fn func() error) {
// fmt.Println("Stopping...")
// return nil
// })
func (h *HooksImpl) OnStop(fn func() error) {
func (h *LifecycleHooks) OnStop(fn func() error) {
h.mu.Lock()
defer h.mu.Unlock()
h.onStop = append(h.onStop, fn)
}

// Start executes all the registered OnStart hooks.
// It returns the first error encountered or nil if all hooks execute successfully.
// This method is internally used by the Container's Run function.
func (h *HooksImpl) Start() error {
func (h *LifecycleHooks) Start() error {
for _, hook := range h.onStart {
if err := hook(); err != nil {
return err
Expand All @@ -59,20 +60,12 @@ func (h *HooksImpl) Start() error {
// Stop executes all the registered OnStop hooks.
// It returns the first error encountered or nil if all hooks execute successfully.
// This method is internally used by the Container's Run function.
func (h *HooksImpl) Stop() error {
var wg sync.WaitGroup
errorSet := &errs.ErrorSet{}

func (h *LifecycleHooks) Stop() error {
for _, hook := range h.onStop {
wg.Add(1)
go func(hook func() error) {
defer wg.Done()
if err := hook(); err != nil {
errorSet.Add(err)
}
}(hook)
if err := hook(); err != nil {
return err
}
}

wg.Wait()
return errorSet.Result()
return nil
}
84 changes: 84 additions & 0 deletions hooks/hooks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package hooks

import (
"errors"
"testing"

"gotest.tools/v3/assert"
)

func TestHooksImpl(t *testing.T) {
t.Run("OnStart", func(t *testing.T) {
h := &LifecycleHooks{}

t.Run("should add function to onStart slice", func(t *testing.T) {
h.OnStart(func() error {
return nil
})
assert.Equal(t, len(h.onStart), 1)
})
})

t.Run("OnStop", func(t *testing.T) {
h := &LifecycleHooks{}

t.Run("should add function to onStop slice", func(t *testing.T) {
h.OnStop(func() error {
return nil
})
assert.Equal(t, len(h.onStop), 1)
})
})

t.Run("Start", func(t *testing.T) {
t.Run("should execute all onStart hooks without error", func(t *testing.T) {
h := &LifecycleHooks{}
h.OnStart(func() error {
return nil
})
h.OnStart(func() error {
return nil
})
err := h.Start()
assert.NilError(t, err)
})

t.Run("should return error if any onStart hook fails", func(t *testing.T) {
h := &LifecycleHooks{}
h.OnStart(func() error {
return nil
})
h.OnStart(func() error {
return errors.New("start error")
})
err := h.Start()
assert.ErrorContains(t, err, "start error")
})
})

t.Run("Stop", func(t *testing.T) {
t.Run("should execute all onStop hooks without error", func(t *testing.T) {
h := &LifecycleHooks{}
h.OnStop(func() error {
return nil
})
h.OnStop(func() error {
return nil
})
err := h.Stop()
assert.NilError(t, err)
})

t.Run("should return error if any onStop hook fails", func(t *testing.T) {
h := &LifecycleHooks{}
h.OnStop(func() error {
return nil
})
h.OnStop(func() error {
return errors.New("stop error")
})
err := h.Stop()
assert.ErrorContains(t, err, "stop error")
})
})
}

0 comments on commit c2147ff

Please sign in to comment.