diff --git a/go.mod b/go.mod index ed0e4b2..f2d6be4 100644 --- a/go.mod +++ b/go.mod @@ -77,6 +77,7 @@ require ( go.uber.org/zap v1.26.0 // indirect golang.org/x/mod v0.14.0 // indirect golang.org/x/oauth2 v0.15.0 // indirect + golang.org/x/sync v0.6.0 golang.org/x/sys v0.18.0 // indirect golang.org/x/term v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/go.sum b/go.sum index f34f1e1..92f2881 100644 --- a/go.sum +++ b/go.sum @@ -1362,6 +1362,7 @@ golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/controller/searchattribute/searchattribute.go b/internal/controller/searchattribute/searchattribute.go index 9266e00..84f6f3c 100644 --- a/internal/controller/searchattribute/searchattribute.go +++ b/internal/controller/searchattribute/searchattribute.go @@ -25,6 +25,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/pkg/errors" + "golang.org/x/sync/syncmap" "k8s.io/apimachinery/pkg/types" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -71,10 +72,11 @@ func Setup(mgr ctrl.Manager, o controller.Options) error { r := managed.NewReconciler(mgr, resource.ManagedKind(v1alpha1.SearchAttributeGroupVersionKind), managed.WithExternalConnectDisconnecter(&connector{ - kube: mgr.GetClient(), - usage: resource.NewProviderConfigUsageTracker(mgr.GetClient(), &apisv1alpha1.ProviderConfigUsage{}), - newServiceFn: temporal.NewSearchAttributeService, - logger: o.Logger.WithValues("controller", name)}), + externalClientsByCreds: syncmap.Map{}, + kube: mgr.GetClient(), + usage: resource.NewProviderConfigUsageTracker(mgr.GetClient(), &apisv1alpha1.ProviderConfigUsage{}), + newServiceFn: temporal.NewSearchAttributeService, + logger: o.Logger.WithValues("controller", name)}), managed.WithLogger(o.Logger.WithValues("controller", name)), managed.WithReferenceResolver(managed.NewAPISimpleReferenceResolver(mgr.GetClient())), managed.WithPollInterval(o.PollInterval), @@ -96,7 +98,7 @@ type connector struct { kube client.Client usage resource.Tracker logger logging.Logger - externalClientsByCreds map[string]*external + externalClientsByCreds syncmap.Map newServiceFn func(creds []byte) (temporal.SearchAttributeService, error) } @@ -137,25 +139,23 @@ func (c *connector) Connect(ctx context.Context, mg resource.Managed) (managed.E } credHash := hash(creds) - if c.externalClientsByCreds == nil { - c.externalClientsByCreds = make(map[string]*external) + + svc, err := c.newServiceFn(creds) + if err != nil { + return nil, errors.Wrap(err, errNewClient) } - ext := c.externalClientsByCreds[credHash] - if ext != nil { + ext := &external{service: svc, logger: c.logger, id: uuid.New().String()} + value, ok := c.externalClientsByCreds.LoadOrStore(credHash, ext) + if ok { + ext.service.Close() + ext = value.(*external) logger.Debug("Use existing " + ext.id) - return ext, nil } else { - svc, err := c.newServiceFn(creds) - if err != nil { - return nil, errors.Wrap(err, errNewClient) - } - - ext = &external{service: svc, logger: c.logger, id: uuid.New().String()} - c.externalClientsByCreds[credHash] = ext logger.Debug("Connected " + ext.id) } + ext.usageCounter++ return ext, nil } @@ -163,13 +163,25 @@ func (c *connector) Disconnect(ctx context.Context) error { logger := c.logger.WithValues("method", "disconnect") logger.Debug("Start Disconnect") - for credHash, ext := range c.externalClientsByCreds { - if ext != nil && ext.service != nil { + c.externalClientsByCreds.Range(func(key, value interface{}) bool { + + ext := value.(*external) + ext.usageCounter-- + if ext.usageCounter < 0 { + ext.usageCounter = 0 + } + + if ext.usageCounter == 0 && ext.service != nil { ext.service.Close() + c.externalClientsByCreds.LoadAndDelete(key) logger.Debug("Disconnected " + ext.id) + } else { + logger.Debug("Keep connection " + ext.id) } - c.externalClientsByCreds[credHash] = nil - } + + // this will continue iterating + return true + }) return nil } @@ -179,9 +191,10 @@ func (c *connector) Disconnect(ctx context.Context) error { type external struct { // A 'client' used to connect to the external resource API. In practice this // would be something like an AWS SDK client. - service temporal.SearchAttributeService - logger logging.Logger - id string + service temporal.SearchAttributeService + logger logging.Logger + id string + usageCounter int } func (c *external) Observe(ctx context.Context, mg resource.Managed) (managed.ExternalObservation, error) { diff --git a/internal/controller/temporalnamespace/temporalnamespace.go b/internal/controller/temporalnamespace/temporalnamespace.go index 27d70e9..7e4ab63 100644 --- a/internal/controller/temporalnamespace/temporalnamespace.go +++ b/internal/controller/temporalnamespace/temporalnamespace.go @@ -11,6 +11,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/pkg/errors" + "golang.org/x/sync/syncmap" "k8s.io/apimachinery/pkg/types" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -56,10 +57,11 @@ func Setup(mgr ctrl.Manager, o controller.Options) error { r := managed.NewReconciler(mgr, resource.ManagedKind(v1alpha1.TemporalNamespaceGroupVersionKind), managed.WithExternalConnectDisconnecter(&connector{ - kube: mgr.GetClient(), - usage: resource.NewProviderConfigUsageTracker(mgr.GetClient(), &apisv1alpha1.ProviderConfigUsage{}), - newServiceFn: temporal.NewNamespaceService, - logger: o.Logger.WithValues("controller", name)}), + externalClientsByCreds: syncmap.Map{}, + kube: mgr.GetClient(), + usage: resource.NewProviderConfigUsageTracker(mgr.GetClient(), &apisv1alpha1.ProviderConfigUsage{}), + newServiceFn: temporal.NewNamespaceService, + logger: o.Logger.WithValues("controller", name)}), managed.WithLogger(o.Logger.WithValues("controller", name)), managed.WithPollInterval(o.PollInterval), managed.WithRecorder(event.NewAPIRecorder(mgr.GetEventRecorderFor(name))), @@ -80,7 +82,7 @@ type connector struct { kube client.Client usage resource.Tracker logger logging.Logger - externalClientsByCreds map[string]*external + externalClientsByCreds syncmap.Map newServiceFn func(creds []byte) (temporal.NamespaceService, error) } @@ -121,25 +123,22 @@ func (c *connector) Connect(ctx context.Context, mg resource.Managed) (managed.E } credHash := hash(creds) - if c.externalClientsByCreds == nil { - c.externalClientsByCreds = make(map[string]*external) + svc, err := c.newServiceFn(creds) + if err != nil { + return nil, errors.Wrap(err, errNewClient) } - ext := c.externalClientsByCreds[credHash] - if ext != nil { + ext := &external{service: svc, logger: c.logger, id: uuid.New().String()} + value, ok := c.externalClientsByCreds.LoadOrStore(credHash, ext) + if ok { + ext.service.Close() + ext = value.(*external) logger.Debug("Use existing " + ext.id) - return ext, nil } else { - svc, err := c.newServiceFn(creds) - if err != nil { - return nil, errors.Wrap(err, errNewClient) - } - - ext = &external{service: svc, logger: c.logger, id: uuid.New().String()} - c.externalClientsByCreds[credHash] = ext logger.Debug("Connected " + ext.id) } + ext.usageCounter++ return ext, nil } @@ -147,13 +146,25 @@ func (c *connector) Disconnect(ctx context.Context) error { logger := c.logger.WithValues("method", "disconnect") logger.Debug("Start Disconnect") - for credHash, ext := range c.externalClientsByCreds { - if ext.service != nil { + c.externalClientsByCreds.Range(func(key, value interface{}) bool { + + ext := value.(*external) + ext.usageCounter-- + if ext.usageCounter < 0 { + ext.usageCounter = 0 + } + + if ext.usageCounter == 0 && ext.service != nil { ext.service.Close() + c.externalClientsByCreds.LoadAndDelete(key) logger.Debug("Disconnected " + ext.id) + } else { + logger.Debug("Keep connection " + ext.id) } - c.externalClientsByCreds[credHash] = nil - } + + // this will continue iterating + return true + }) return nil } @@ -163,9 +174,10 @@ func (c *connector) Disconnect(ctx context.Context) error { type external struct { // A 'client' used to connect to the external resource API. In practice this // would be something like an AWS SDK client. - service temporal.NamespaceService - logger logging.Logger - id string + service temporal.NamespaceService + logger logging.Logger + id string + usageCounter int } func (c *external) Observe(ctx context.Context, mg resource.Managed) (managed.ExternalObservation, error) {