Skip to content

Commit

Permalink
Use sync map and counter to overcome disconnect limitation
Browse files Browse the repository at this point in the history
  • Loading branch information
denniskniep committed May 25, 2024
1 parent 3d8cc85 commit 9133f2b
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 48 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ require (
go.uber.org/zap v1.26.0 // indirect
golang.org/x/mod v0.14.0 // indirect
golang.org/x/oauth2 v0.15.0 // indirect
golang.org/x/sync v0.6.0
golang.org/x/sys v0.18.0 // indirect
golang.org/x/term v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,7 @@ golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
61 changes: 37 additions & 24 deletions internal/controller/searchattribute/searchattribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/pkg/errors"
"golang.org/x/sync/syncmap"
"k8s.io/apimachinery/pkg/types"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -71,10 +72,11 @@ func Setup(mgr ctrl.Manager, o controller.Options) error {
r := managed.NewReconciler(mgr,
resource.ManagedKind(v1alpha1.SearchAttributeGroupVersionKind),
managed.WithExternalConnectDisconnecter(&connector{
kube: mgr.GetClient(),
usage: resource.NewProviderConfigUsageTracker(mgr.GetClient(), &apisv1alpha1.ProviderConfigUsage{}),
newServiceFn: temporal.NewSearchAttributeService,
logger: o.Logger.WithValues("controller", name)}),
externalClientsByCreds: syncmap.Map{},
kube: mgr.GetClient(),
usage: resource.NewProviderConfigUsageTracker(mgr.GetClient(), &apisv1alpha1.ProviderConfigUsage{}),
newServiceFn: temporal.NewSearchAttributeService,
logger: o.Logger.WithValues("controller", name)}),
managed.WithLogger(o.Logger.WithValues("controller", name)),
managed.WithReferenceResolver(managed.NewAPISimpleReferenceResolver(mgr.GetClient())),
managed.WithPollInterval(o.PollInterval),
Expand All @@ -96,7 +98,7 @@ type connector struct {
kube client.Client
usage resource.Tracker
logger logging.Logger
externalClientsByCreds map[string]*external
externalClientsByCreds syncmap.Map
newServiceFn func(creds []byte) (temporal.SearchAttributeService, error)
}

Expand Down Expand Up @@ -137,39 +139,49 @@ func (c *connector) Connect(ctx context.Context, mg resource.Managed) (managed.E
}

credHash := hash(creds)
if c.externalClientsByCreds == nil {
c.externalClientsByCreds = make(map[string]*external)

svc, err := c.newServiceFn(creds)
if err != nil {
return nil, errors.Wrap(err, errNewClient)
}

ext := c.externalClientsByCreds[credHash]
if ext != nil {
ext := &external{service: svc, logger: c.logger, id: uuid.New().String()}
value, ok := c.externalClientsByCreds.LoadOrStore(credHash, ext)
if ok {
ext.service.Close()
ext = value.(*external)
logger.Debug("Use existing " + ext.id)
return ext, nil
} else {
svc, err := c.newServiceFn(creds)
if err != nil {
return nil, errors.Wrap(err, errNewClient)
}

ext = &external{service: svc, logger: c.logger, id: uuid.New().String()}
c.externalClientsByCreds[credHash] = ext
logger.Debug("Connected " + ext.id)
}

ext.usageCounter++
return ext, nil
}

func (c *connector) Disconnect(ctx context.Context) error {
logger := c.logger.WithValues("method", "disconnect")
logger.Debug("Start Disconnect")

for credHash, ext := range c.externalClientsByCreds {
if ext != nil && ext.service != nil {
c.externalClientsByCreds.Range(func(key, value interface{}) bool {

ext := value.(*external)
ext.usageCounter--
if ext.usageCounter < 0 {
ext.usageCounter = 0
}

if ext.usageCounter == 0 && ext.service != nil {
ext.service.Close()
c.externalClientsByCreds.LoadAndDelete(key)
logger.Debug("Disconnected " + ext.id)
} else {
logger.Debug("Keep connection " + ext.id)
}
c.externalClientsByCreds[credHash] = nil
}

// this will continue iterating
return true
})

return nil
}
Expand All @@ -179,9 +191,10 @@ func (c *connector) Disconnect(ctx context.Context) error {
type external struct {
// A 'client' used to connect to the external resource API. In practice this
// would be something like an AWS SDK client.
service temporal.SearchAttributeService
logger logging.Logger
id string
service temporal.SearchAttributeService
logger logging.Logger
id string
usageCounter int
}

func (c *external) Observe(ctx context.Context, mg resource.Managed) (managed.ExternalObservation, error) {
Expand Down
60 changes: 36 additions & 24 deletions internal/controller/temporalnamespace/temporalnamespace.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/pkg/errors"
"golang.org/x/sync/syncmap"
"k8s.io/apimachinery/pkg/types"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -56,10 +57,11 @@ func Setup(mgr ctrl.Manager, o controller.Options) error {
r := managed.NewReconciler(mgr,
resource.ManagedKind(v1alpha1.TemporalNamespaceGroupVersionKind),
managed.WithExternalConnectDisconnecter(&connector{
kube: mgr.GetClient(),
usage: resource.NewProviderConfigUsageTracker(mgr.GetClient(), &apisv1alpha1.ProviderConfigUsage{}),
newServiceFn: temporal.NewNamespaceService,
logger: o.Logger.WithValues("controller", name)}),
externalClientsByCreds: syncmap.Map{},
kube: mgr.GetClient(),
usage: resource.NewProviderConfigUsageTracker(mgr.GetClient(), &apisv1alpha1.ProviderConfigUsage{}),
newServiceFn: temporal.NewNamespaceService,
logger: o.Logger.WithValues("controller", name)}),
managed.WithLogger(o.Logger.WithValues("controller", name)),
managed.WithPollInterval(o.PollInterval),
managed.WithRecorder(event.NewAPIRecorder(mgr.GetEventRecorderFor(name))),
Expand All @@ -80,7 +82,7 @@ type connector struct {
kube client.Client
usage resource.Tracker
logger logging.Logger
externalClientsByCreds map[string]*external
externalClientsByCreds syncmap.Map
newServiceFn func(creds []byte) (temporal.NamespaceService, error)
}

Expand Down Expand Up @@ -121,39 +123,48 @@ func (c *connector) Connect(ctx context.Context, mg resource.Managed) (managed.E
}

credHash := hash(creds)
if c.externalClientsByCreds == nil {
c.externalClientsByCreds = make(map[string]*external)
svc, err := c.newServiceFn(creds)
if err != nil {
return nil, errors.Wrap(err, errNewClient)
}

ext := c.externalClientsByCreds[credHash]
if ext != nil {
ext := &external{service: svc, logger: c.logger, id: uuid.New().String()}
value, ok := c.externalClientsByCreds.LoadOrStore(credHash, ext)
if ok {
ext.service.Close()
ext = value.(*external)
logger.Debug("Use existing " + ext.id)
return ext, nil
} else {
svc, err := c.newServiceFn(creds)
if err != nil {
return nil, errors.Wrap(err, errNewClient)
}

ext = &external{service: svc, logger: c.logger, id: uuid.New().String()}
c.externalClientsByCreds[credHash] = ext
logger.Debug("Connected " + ext.id)
}

ext.usageCounter++
return ext, nil
}

func (c *connector) Disconnect(ctx context.Context) error {
logger := c.logger.WithValues("method", "disconnect")
logger.Debug("Start Disconnect")

for credHash, ext := range c.externalClientsByCreds {
if ext.service != nil {
c.externalClientsByCreds.Range(func(key, value interface{}) bool {

ext := value.(*external)
ext.usageCounter--
if ext.usageCounter < 0 {
ext.usageCounter = 0
}

if ext.usageCounter == 0 && ext.service != nil {
ext.service.Close()
c.externalClientsByCreds.LoadAndDelete(key)
logger.Debug("Disconnected " + ext.id)
} else {
logger.Debug("Keep connection " + ext.id)
}
c.externalClientsByCreds[credHash] = nil
}

// this will continue iterating
return true
})

return nil
}
Expand All @@ -163,9 +174,10 @@ func (c *connector) Disconnect(ctx context.Context) error {
type external struct {
// A 'client' used to connect to the external resource API. In practice this
// would be something like an AWS SDK client.
service temporal.NamespaceService
logger logging.Logger
id string
service temporal.NamespaceService
logger logging.Logger
id string
usageCounter int
}

func (c *external) Observe(ctx context.Context, mg resource.Managed) (managed.ExternalObservation, error) {
Expand Down

0 comments on commit 9133f2b

Please sign in to comment.