diff --git a/.gitignore b/.gitignore index 5dd3e3b..716a605 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ # Test binary, built with `go test -c` *.test +test/* # Output of the go coverage tool, specifically when used with LiteIDE *.out @@ -43,7 +44,6 @@ go.work .vscode/* !.vscode/settings.json !.vscode/tasks.json -!.vscode/launch.json !.vscode/extensions.json !.vscode/*.code-snippets @@ -58,4 +58,5 @@ go.work .history .ionide + # End of https://www.toptal.com/developers/gitignore/api/go,visualstudiocode,git \ No newline at end of file diff --git a/container.go b/container.go index d5efb46..fdc5966 100644 --- a/container.go +++ b/container.go @@ -3,6 +3,7 @@ package zeus import ( "reflect" "slices" + "sync" "github.com/otoru/zeus/errs" "github.com/otoru/zeus/hooks" @@ -12,6 +13,7 @@ import ( type Container struct { providers map[reflect.Type]reflect.Value instances map[reflect.Type]reflect.Value + mu sync.RWMutex hooks Hooks } @@ -21,9 +23,9 @@ type Container struct { // // c := zeus.New() func New() *Container { - hooks := new(hooks.HooksImpl) - providers := make(map[reflect.Type]reflect.Value, 0) - instances := make(map[reflect.Type]reflect.Value, 0) + hooks := new(hooks.LifecycleHooks) + providers := make(map[reflect.Type]reflect.Value) + instances := make(map[reflect.Type]reflect.Value) container := new(Container) container.hooks = hooks @@ -41,13 +43,16 @@ func (c *Container) resolve(t reflect.Type, stack []reflect.Type) (reflect.Value return reflect.Value{}, errs.CyclicDependencyError{TypeName: t.Name()} } - if instance, exists := c.instances[t]; exists { + c.mu.RLock() + instance, hasInstance := c.instances[t] + provider, hasProvider := c.providers[t] + c.mu.RUnlock() + + if hasInstance { return instance, nil } - provider, ok := c.providers[t] - - if !ok { + if !hasProvider { return reflect.Value{}, errs.DependencyResolutionError{TypeName: t.Name()} } @@ -90,28 +95,36 @@ func (c *Container) resolve(t reflect.Type, stack []reflect.Type) (reflect.Value // // c := zeus.New() // c.Provide(func() int { return 42 }) -func (c *Container) Provide(factory interface{}) error { - factoryType := reflect.TypeOf(factory) +func (c *Container) Provide(factories ...interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() - if factoryType.Kind() != reflect.Func { - return errs.NotAFunctionError{} - } + for _, factory := range factories { + factoryType := reflect.TypeOf(factory) - if numOut := factoryType.NumOut(); numOut < 1 || numOut > 2 { - return errs.InvalidFactoryReturnError{NumReturns: numOut} - } + if factoryType.Kind() != reflect.Func { + return errs.NotAFunctionError{} + } - if factoryType.NumOut() == 2 && factoryType.Out(1).Name() != "error" { - return errs.UnexpectedReturnTypeError{TypeName: factoryType.Out(1).Name()} - } + if numOut := factoryType.NumOut(); numOut < 1 || numOut > 2 { + return errs.InvalidFactoryReturnError{NumReturns: numOut} + } + + if factoryType.NumOut() == 2 { + errorType := reflect.TypeOf((*error)(nil)).Elem() + if !factoryType.Out(1).Implements(errorType) { + return errs.UnexpectedReturnTypeError{TypeName: factoryType.Out(1).Name()} + } + } - serviceType := factoryType.Out(0) + serviceType := factoryType.Out(0) - if _, exists := c.providers[serviceType]; exists { - return errs.FactoryAlreadyProvidedError{TypeName: serviceType.Name()} - } + if _, exists := c.providers[serviceType]; exists { + return errs.FactoryAlreadyProvidedError{TypeName: serviceType.Name()} + } - c.providers[serviceType] = reflect.ValueOf(factory) + c.providers[serviceType] = reflect.ValueOf(factory) + } return nil } @@ -166,6 +179,10 @@ func (c *Container) Run(fn interface{}) error { errorSet.Add(err) } + if !errorSet.IsEmpty() { + return errorSet.Result() + } + results := reflect.ValueOf(fn).Call(dependencies) if fnType.NumOut() == 1 && !results[0].IsNil() { @@ -196,9 +213,12 @@ func (c *Container) Run(fn interface{}) error { // // Handle merge error // } func (c *Container) Merge(other *Container) error { + c.mu.Lock() + defer c.mu.Unlock() + for t, factory := range other.providers { if existingFactory, exists := c.providers[t]; exists { - if !reflect.DeepEqual(existingFactory, factory) { + if existingFactory.Pointer() != factory.Pointer() { return errs.FactoryAlreadyProvidedError{TypeName: t.Name()} } continue diff --git a/container_test.go b/container_test.go index 234858d..0251c09 100644 --- a/container_test.go +++ b/container_test.go @@ -251,10 +251,12 @@ func TestContainer(t *testing.T) { started = true return nil }) + h.OnStop(func() error { stopped = true return nil }) + return 42 }) @@ -272,6 +274,7 @@ func TestContainer(t *testing.T) { h.OnStart(func() error { return errors.New("start error") }) + return 42 }) diff --git a/hooks/hooks.go b/hooks/hooks.go index d9adef1..c16f3e7 100644 --- a/hooks/hooks.go +++ b/hooks/hooks.go @@ -1,10 +1,6 @@ package hooks -import ( - "sync" - - "github.com/otoru/zeus/errs" -) +import "sync" // Hooks defines an interface for lifecycle events. // It provides methods to register functions that should be executed @@ -16,10 +12,11 @@ type Hooks interface { Stop() error } -// HooksImpl is the default implementation of the Hooks interface. -type HooksImpl struct { +// LifecycleHooks is the default implementation of the Hooks interface. +type LifecycleHooks struct { onStart []func() error onStop []func() error + mu sync.Mutex } // OnStart adds a function to the list of functions to be executed at the start. @@ -29,7 +26,9 @@ type HooksImpl struct { // fmt.Println("Starting...") // return nil // }) -func (h *HooksImpl) OnStart(fn func() error) { +func (h *LifecycleHooks) OnStart(fn func() error) { + h.mu.Lock() + defer h.mu.Unlock() h.onStart = append(h.onStart, fn) } @@ -40,14 +39,16 @@ func (h *HooksImpl) OnStart(fn func() error) { // fmt.Println("Stopping...") // return nil // }) -func (h *HooksImpl) OnStop(fn func() error) { +func (h *LifecycleHooks) OnStop(fn func() error) { + h.mu.Lock() + defer h.mu.Unlock() h.onStop = append(h.onStop, fn) } // Start executes all the registered OnStart hooks. // It returns the first error encountered or nil if all hooks execute successfully. // This method is internally used by the Container's Run function. -func (h *HooksImpl) Start() error { +func (h *LifecycleHooks) Start() error { for _, hook := range h.onStart { if err := hook(); err != nil { return err @@ -59,20 +60,12 @@ func (h *HooksImpl) Start() error { // Stop executes all the registered OnStop hooks. // It returns the first error encountered or nil if all hooks execute successfully. // This method is internally used by the Container's Run function. -func (h *HooksImpl) Stop() error { - var wg sync.WaitGroup - errorSet := &errs.ErrorSet{} - +func (h *LifecycleHooks) Stop() error { for _, hook := range h.onStop { - wg.Add(1) - go func(hook func() error) { - defer wg.Done() - if err := hook(); err != nil { - errorSet.Add(err) - } - }(hook) + if err := hook(); err != nil { + return err + } } - wg.Wait() - return errorSet.Result() + return nil } diff --git a/hooks/hooks_test.go b/hooks/hooks_test.go new file mode 100644 index 0000000..590d6fc --- /dev/null +++ b/hooks/hooks_test.go @@ -0,0 +1,84 @@ +package hooks + +import ( + "errors" + "testing" + + "gotest.tools/v3/assert" +) + +func TestHooksImpl(t *testing.T) { + t.Run("OnStart", func(t *testing.T) { + h := &LifecycleHooks{} + + t.Run("should add function to onStart slice", func(t *testing.T) { + h.OnStart(func() error { + return nil + }) + assert.Equal(t, len(h.onStart), 1) + }) + }) + + t.Run("OnStop", func(t *testing.T) { + h := &LifecycleHooks{} + + t.Run("should add function to onStop slice", func(t *testing.T) { + h.OnStop(func() error { + return nil + }) + assert.Equal(t, len(h.onStop), 1) + }) + }) + + t.Run("Start", func(t *testing.T) { + t.Run("should execute all onStart hooks without error", func(t *testing.T) { + h := &LifecycleHooks{} + h.OnStart(func() error { + return nil + }) + h.OnStart(func() error { + return nil + }) + err := h.Start() + assert.NilError(t, err) + }) + + t.Run("should return error if any onStart hook fails", func(t *testing.T) { + h := &LifecycleHooks{} + h.OnStart(func() error { + return nil + }) + h.OnStart(func() error { + return errors.New("start error") + }) + err := h.Start() + assert.ErrorContains(t, err, "start error") + }) + }) + + t.Run("Stop", func(t *testing.T) { + t.Run("should execute all onStop hooks without error", func(t *testing.T) { + h := &LifecycleHooks{} + h.OnStop(func() error { + return nil + }) + h.OnStop(func() error { + return nil + }) + err := h.Stop() + assert.NilError(t, err) + }) + + t.Run("should return error if any onStop hook fails", func(t *testing.T) { + h := &LifecycleHooks{} + h.OnStop(func() error { + return nil + }) + h.OnStop(func() error { + return errors.New("stop error") + }) + err := h.Stop() + assert.ErrorContains(t, err, "stop error") + }) + }) +}