From 332b8fcbefabe946776938fa433d0d28bbe2fa2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Krasnoborski?= Date: Wed, 21 Jun 2023 10:03:58 +0200 Subject: [PATCH 1/3] Stop copying entity in GetByIP We're never mutate entities directly in the map, so that's ok. Also, added a type wrapper so it's less likely someone tries mutating the returned entity. --- pkg/discovery/entities/entities-service.go | 18 ++++- pkg/discovery/entities/entities.go | 79 ++++++++++++++++--- pkg/discovery/entities/entities_test.go | 57 ++++++------- .../service-getter/service-getter.go | 6 +- 4 files changed, 110 insertions(+), 50 deletions(-) diff --git a/pkg/discovery/entities/entities-service.go b/pkg/discovery/entities/entities-service.go index caf1fab5f4..e6d5247cc4 100644 --- a/pkg/discovery/entities/entities-service.go +++ b/pkg/discovery/entities/entities-service.go @@ -4,6 +4,8 @@ import ( "context" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" cmdv1 "github.com/fluxninja/aperture/v2/api/gen/proto/go/aperture/cmd/v1" @@ -45,12 +47,20 @@ func (c *EntitiesService) GetEntities(ctx context.Context, _ *emptypb.Empty) (*e // GetEntityByIPAddress returns an entity by IP address. func (c *EntitiesService) GetEntityByIPAddress(ctx context.Context, req *entitiesv1.GetEntityByIPAddressRequest) (*entitiesv1.Entity, error) { - return c.entityCache.GetByIP(req.GetIpAddress()) + e, err := c.entityCache.GetByIP(req.GetIpAddress()) + if err != nil { + return nil, status.Error(codes.NotFound, err.Error()) + } + return e.Copy(), nil } // GetEntityByName returns an entity by name. func (c *EntitiesService) GetEntityByName(ctx context.Context, req *entitiesv1.GetEntityByNameRequest) (*entitiesv1.Entity, error) { - return c.entityCache.GetByName(req.GetName()) + e, err := c.entityCache.GetByName(req.GetName()) + if err != nil { + return nil, status.Error(codes.NotFound, err.Error()) + } + return e.Copy(), nil } // ListDiscoveryEntities lists currently discovered entities by IP address. @@ -71,7 +81,7 @@ func (c *EntitiesService) ListDiscoveryEntity(ctx context.Context, req *cmdv1.Li return nil, err } return &cmdv1.ListDiscoveryEntityAgentResponse{ - Entity: entity, + Entity: entity.Copy(), }, nil case *cmdv1.ListDiscoveryEntityRequest_Name: entity, err := c.entityCache.GetByName(req.GetName()) @@ -79,7 +89,7 @@ func (c *EntitiesService) ListDiscoveryEntity(ctx context.Context, req *cmdv1.Li return nil, err } return &cmdv1.ListDiscoveryEntityAgentResponse{ - Entity: entity, + Entity: entity.Copy(), }, nil default: return nil, nil diff --git a/pkg/discovery/entities/entities.go b/pkg/discovery/entities/entities.go index 64575bf883..a485e8c8bf 100644 --- a/pkg/discovery/entities/entities.go +++ b/pkg/discovery/entities/entities.go @@ -25,6 +25,51 @@ func Module() fx.Option { ) } +// Entity is an immutable wrapper over *entitiesv1.Entity. +type Entity struct { + immutableEntity *entitiesv1.Entity +} + +// NewEntity creates a new immutable entity from the copy of given entity. +func NewEntity(entity *entitiesv1.Entity) Entity { + return Entity{immutableEntity: proto.Clone(entity).(*entitiesv1.Entity)} +} + +// NewEntity creates a new immutable entity, assuming given entity is immutable. +func NewEntityFromImmutable(entity *entitiesv1.Entity) Entity { + return Entity{immutableEntity: entity} +} + +// Copy returns a mutable copy of the entity. +func (e Entity) Copy() *entitiesv1.Entity { + return proto.Clone(e.immutableEntity).(*entitiesv1.Entity) +} + +// Borrow returns the inner *entitiesv1.Entity. +// +// The returned struct must not be mutated. +func (e Entity) Borrow() *entitiesv1.Entity { return e.immutableEntity } + +// UID. +func (e Entity) UID() string { return e.immutableEntity.Uid } + +// IP address. +func (e Entity) IPAddress() string { return e.immutableEntity.IpAddress } + +// Name. +func (e Entity) Name() string { return e.immutableEntity.Name } + +// Namespace +func (e Entity) Namespace() string { return e.immutableEntity.Namespace } + +// NodeName. +func (e Entity) NodeName() string { return e.immutableEntity.NodeName } + +// Services returns list of services the entity belongs to. +// +// The returned slice must not be modified. +func (e Entity) Services() []string { return e.immutableEntity.Services } + // Entities maps IP addresses and Entity names to entities. type Entities struct { sync.RWMutex @@ -136,52 +181,59 @@ func NewEntities() *Entities { // Put maps given IP address and name to the entity it currently represents. func (c *Entities) Put(entity *entitiesv1.Entity) { + c.PutFast(NewEntity(entity)) +} + +// PutFast maps given IP address and name to the entity it currently represents. +func (c *Entities) PutFast(entity Entity) { c.Lock() defer c.Unlock() - entityIP := entity.IpAddress + entityIP := entity.IPAddress() if entityIP != "" { - c.entities.EntitiesByIpAddress.Entities[entityIP] = entity + // FIXME: would be nice to store Entity directly in the map, but that + // would require removing the reusal of proto-generated structs as + // containers. + c.entities.EntitiesByIpAddress.Entities[entityIP] = entity.immutableEntity } - entityName := entity.Name + entityName := entity.Name() if entityName != "" { - c.entities.EntitiesByName.Entities[entityName] = entity + c.entities.EntitiesByName.Entities[entityName] = entity.immutableEntity } } // GetByIP retrieves entity with a given IP address. -func (c *Entities) GetByIP(entityIP string) (*entitiesv1.Entity, error) { +func (c *Entities) GetByIP(entityIP string) (Entity, error) { c.RLock() defer c.RUnlock() if len(c.entities.EntitiesByIpAddress.Entities) == 0 { - return nil, errNoEntities + return Entity{}, errNoEntities } v, ok := c.entities.EntitiesByIpAddress.Entities[entityIP] if !ok { - return nil, errNotFound + return Entity{}, errNotFound } - return proto.Clone(v).(*entitiesv1.Entity), nil + return NewEntityFromImmutable(v), nil } // GetByName retrieves entity with a given name. -func (c *Entities) GetByName(entityName string) (*entitiesv1.Entity, error) { +func (c *Entities) GetByName(entityName string) (Entity, error) { c.RLock() defer c.RUnlock() if len(c.entities.EntitiesByName.Entities) == 0 { - return nil, errNoEntities + return Entity{}, errNoEntities } v, ok := c.entities.EntitiesByName.Entities[entityName] if !ok { - return nil, errNotFound + return Entity{}, errNotFound } - - return proto.Clone(v).(*entitiesv1.Entity), nil + return NewEntityFromImmutable(v), nil } var ( @@ -227,5 +279,6 @@ func (c *Entities) GetEntities() *entitiesv1.Entities { c.RLock() defer c.RUnlock() + // FIXME: This Clone could be avoided, as we store immutable entities. return proto.Clone(c.entities).(*entitiesv1.Entities) } diff --git a/pkg/discovery/entities/entities_test.go b/pkg/discovery/entities/entities_test.go index 0821aabbff..df8de771a8 100644 --- a/pkg/discovery/entities/entities_test.go +++ b/pkg/discovery/entities/entities_test.go @@ -20,31 +20,29 @@ var _ = Describe("Cache", func() { ip := "1.2.3.4" name := "entity_1234" entity := testEntity("foo", ip, name, nil) - ec.Put(entity) + ec.PutFast(entity) actual, err := ec.GetByIP(ip) Expect(err).NotTo(HaveOccurred()) Expect(actual).To(Equal(entity)) }) - It("returns nil when no entity found", func() { + It("returns err when no entity found", func() { ip := "1.2.3.4" - actual, err := ec.GetByIP(ip) - Expect(err).To(Not(BeNil())) - Expect(actual).To(BeNil()) + _, err := ec.GetByIP(ip) + Expect(err).To(HaveOccurred()) }) It("removes an entity properly", func() { ip := "1.2.3.4" name := "entity_1234" entity := testEntity("foo", ip, name, nil) - ec.Put(entity) + ec.PutFast(entity) - removed := ec.Remove(entity) + removed := ec.Remove(entity.Borrow()) Expect(removed).To(BeTrue()) - found, err := ec.GetByIP(ip) - Expect(err).To(Not(BeNil())) - Expect(found).To(BeNil()) + _, err := ec.GetByIP(ip) + Expect(err).To(HaveOccurred()) }) It("returns false if trying to remove a nonexistent entity", func() { @@ -53,10 +51,10 @@ var _ = Describe("Cache", func() { name := "entity_1234" otherName := "other_entity_4321" entity := testEntity("foo", ip, name, nil) - ec.Put(entity) + ec.PutFast(entity) otherEntity := testEntity("foo2", otherIP, otherName, nil) - removed := ec.Remove(otherEntity) + removed := ec.Remove(otherEntity.Borrow()) Expect(removed).To(BeFalse()) found, err := ec.GetByIP(ip) @@ -70,31 +68,29 @@ var _ = Describe("Cache", func() { uid := "foo" name := "some_name" entity := testEntity(uid, "", name, nil) - ec.Put(entity) + ec.PutFast(entity) actual, err := ec.GetByName(name) Expect(err).NotTo(HaveOccurred()) Expect(actual).To(Equal(entity)) }) - It("returns nil when no entity found", func() { + It("returns err when no entity found", func() { name := "some_name" - actual, err := ec.GetByName(name) - Expect(err).To(Not(BeNil())) - Expect(actual).To(BeNil()) + _, err := ec.GetByName(name) + Expect(err).To(HaveOccurred()) }) It("removes an entity properly", func() { uid := "bar" name := "some_name" entity := testEntity(uid, "", name, nil) - ec.Put(entity) + ec.PutFast(entity) - removed := ec.Remove(entity) + removed := ec.Remove(entity.Borrow()) Expect(removed).To(BeTrue()) - found, err := ec.GetByName(name) - Expect(err).To(Not(BeNil())) - Expect(found).To(BeNil()) + _, err := ec.GetByName(name) + Expect(err).To(HaveOccurred()) }) It("returns false if trying to remove a nonexistent entity", func() { @@ -103,10 +99,10 @@ var _ = Describe("Cache", func() { otherUid := "baz" otherName := "another_name" entity := testEntity(uid, "1.1.1.1", name, nil) - ec.Put(entity) + ec.PutFast(entity) otherEntity := testEntity(otherUid, "1.1.1.2", otherName, nil) - removed := ec.Remove(otherEntity) + removed := ec.Remove(otherEntity.Borrow()) Expect(removed).To(BeFalse()) found, err := ec.GetByName(name) @@ -118,19 +114,18 @@ var _ = Describe("Cache", func() { It("clears all entities from the map", func() { ip := "1.2.3.4" entity := testEntity("foo", "", "some_name", nil) - ec.Put(entity) + ec.PutFast(entity) ec.Clear() - found, err := ec.GetByIP(ip) - Expect(err).To(Not(BeNil())) - Expect(found).To(BeNil()) + _, err := ec.GetByIP(ip) + Expect(err).To(HaveOccurred()) }) }) -func testEntity(uid, ipAddress, name string, services []string) *entitiesv1.Entity { - return &entitiesv1.Entity{ +func testEntity(uid, ipAddress, name string, services []string) entities.Entity { + return entities.NewEntity(&entitiesv1.Entity{ Uid: uid, IpAddress: ipAddress, Name: name, Services: services, - } + }) } diff --git a/pkg/policies/flowcontrol/service-getter/service-getter.go b/pkg/policies/flowcontrol/service-getter/service-getter.go index cc3208bb0f..746f683754 100644 --- a/pkg/policies/flowcontrol/service-getter/service-getter.go +++ b/pkg/policies/flowcontrol/service-getter/service-getter.go @@ -13,6 +13,8 @@ import ( ) // ServiceGetter can be used to query services based on client context. +// +// Caller should not modify slices returned from methods of ServiceGetter. type ServiceGetter interface { ServicesFromContext(ctx context.Context) []string ServicesFromSocketAddress(addr *corev3.SocketAddress) []string @@ -69,7 +71,7 @@ func (sg *ecServiceGetter) servicesFromContext(ctx context.Context) (svcs []stri return nil, false } - return entity.Services, true + return entity.Services(), true } // ServicesFromSocketAddress returns list of services associated with IP extracted from SocketAddress. @@ -91,7 +93,7 @@ func (sg *ecServiceGetter) ServicesFromAddress(addr string) []string { } return nil } - return entity.Services + return entity.Services() } var noEntitySampler = log.NewRatelimitingSampler() From 2ff7dfa89761398ba1c7aad6d25c20e8b7db4e62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Krasnoborski?= Date: Wed, 21 Jun 2023 11:54:28 +0200 Subject: [PATCH 2/3] stop using proto-generated structs in entities datastructure --- pkg/discovery/entities/entities-service.go | 8 +- pkg/discovery/entities/entities.go | 147 ++++++++++----------- 2 files changed, 76 insertions(+), 79 deletions(-) diff --git a/pkg/discovery/entities/entities-service.go b/pkg/discovery/entities/entities-service.go index e6d5247cc4..207b9f65c3 100644 --- a/pkg/discovery/entities/entities-service.go +++ b/pkg/discovery/entities/entities-service.go @@ -51,7 +51,7 @@ func (c *EntitiesService) GetEntityByIPAddress(ctx context.Context, req *entitie if err != nil { return nil, status.Error(codes.NotFound, err.Error()) } - return e.Copy(), nil + return e.Clone(), nil } // GetEntityByName returns an entity by name. @@ -60,7 +60,7 @@ func (c *EntitiesService) GetEntityByName(ctx context.Context, req *entitiesv1.G if err != nil { return nil, status.Error(codes.NotFound, err.Error()) } - return e.Copy(), nil + return e.Clone(), nil } // ListDiscoveryEntities lists currently discovered entities by IP address. @@ -81,7 +81,7 @@ func (c *EntitiesService) ListDiscoveryEntity(ctx context.Context, req *cmdv1.Li return nil, err } return &cmdv1.ListDiscoveryEntityAgentResponse{ - Entity: entity.Copy(), + Entity: entity.Clone(), }, nil case *cmdv1.ListDiscoveryEntityRequest_Name: entity, err := c.entityCache.GetByName(req.GetName()) @@ -89,7 +89,7 @@ func (c *EntitiesService) ListDiscoveryEntity(ctx context.Context, req *cmdv1.Li return nil, err } return &cmdv1.ListDiscoveryEntityAgentResponse{ - Entity: entity.Copy(), + Entity: entity.Clone(), }, nil default: return nil, nil diff --git a/pkg/discovery/entities/entities.go b/pkg/discovery/entities/entities.go index a485e8c8bf..e07379c685 100644 --- a/pkg/discovery/entities/entities.go +++ b/pkg/discovery/entities/entities.go @@ -35,13 +35,13 @@ func NewEntity(entity *entitiesv1.Entity) Entity { return Entity{immutableEntity: proto.Clone(entity).(*entitiesv1.Entity)} } -// NewEntity creates a new immutable entity, assuming given entity is immutable. +// NewEntityFromImmutable creates a new immutable entity, assuming given entity is immutable. func NewEntityFromImmutable(entity *entitiesv1.Entity) Entity { return Entity{immutableEntity: entity} } -// Copy returns a mutable copy of the entity. -func (e Entity) Copy() *entitiesv1.Entity { +// Clone returns a mutable copy of the entity. +func (e Entity) Clone() *entitiesv1.Entity { return proto.Clone(e.immutableEntity).(*entitiesv1.Entity) } @@ -50,19 +50,19 @@ func (e Entity) Copy() *entitiesv1.Entity { // The returned struct must not be mutated. func (e Entity) Borrow() *entitiesv1.Entity { return e.immutableEntity } -// UID. +// UID returns the entity's UID. func (e Entity) UID() string { return e.immutableEntity.Uid } -// IP address. +// IPAddress returns the entity's IP address. func (e Entity) IPAddress() string { return e.immutableEntity.IpAddress } -// Name. +// Name returns the entity's name. func (e Entity) Name() string { return e.immutableEntity.Name } -// Namespace +// Namespace returns the entity's namespace. func (e Entity) Namespace() string { return e.immutableEntity.Namespace } -// NodeName. +// NodeName returns the entity's node name. func (e Entity) NodeName() string { return e.immutableEntity.NodeName } // Services returns list of services the entity belongs to. @@ -73,7 +73,8 @@ func (e Entity) Services() []string { return e.immutableEntity.Services } // Entities maps IP addresses and Entity names to entities. type Entities struct { sync.RWMutex - entities *entitiesv1.Entities + byIP map[string]Entity + byName map[string]Entity } // EntityTrackers allows to register a service discovery for entity cache @@ -144,96 +145,83 @@ func provideEntities(in FxIn) (*Entities, *EntityTrackers, error) { return entityCache, &EntityTrackers{trackers: in.EntityTrackers}, nil } -func (c *Entities) processUpdate(event notifiers.Event, unmarshaller config.Unmarshaller) { +func (e *Entities) processUpdate(event notifiers.Event, unmarshaller config.Unmarshaller) { log.Trace().Str("event", event.String()).Msg("Updating entity") - entity := &entitiesv1.Entity{} - if err := unmarshaller.UnmarshalKey("", entity); err != nil { + entityProto := &entitiesv1.Entity{} + if err := unmarshaller.UnmarshalKey("", entityProto); err != nil { log.Error().Err(err).Msg("Failed to unmarshal entity") return } - ip := entity.IpAddress - name := entity.Name + entity := NewEntityFromImmutable(entityProto) + ip := entity.IPAddress() + name := entity.Name() switch event.Type { case notifiers.Write: - log.Trace().Str("entity", entity.Uid).Str("ip", ip).Str("name", name).Msg("new entity") - c.Put(entity) + log.Trace().Str("entity", entity.UID()).Str("ip", ip).Str("name", name).Msg("new entity") + e.PutFast(entity) case notifiers.Remove: - log.Trace().Str("entity", entity.Uid).Str("ip", ip).Str("name", name).Msg("removing entity") - c.Remove(entity) + log.Trace().Str("entity", entity.UID()).Str("ip", ip).Str("name", name).Msg("removing entity") + e.Remove(entity.Borrow()) } } // NewEntities creates a new, empty Entities. func NewEntities() *Entities { - entities := &entitiesv1.Entities{ - EntitiesByIpAddress: &entitiesv1.Entities_Entities{ - Entities: make(map[string]*entitiesv1.Entity), - }, - EntitiesByName: &entitiesv1.Entities_Entities{ - Entities: make(map[string]*entitiesv1.Entity), - }, - } return &Entities{ - entities: entities, + byIP: make(map[string]Entity), + byName: make(map[string]Entity), } } // Put maps given IP address and name to the entity it currently represents. -func (c *Entities) Put(entity *entitiesv1.Entity) { - c.PutFast(NewEntity(entity)) +func (e *Entities) Put(entity *entitiesv1.Entity) { + e.PutFast(NewEntity(entity)) } // PutFast maps given IP address and name to the entity it currently represents. -func (c *Entities) PutFast(entity Entity) { - c.Lock() - defer c.Unlock() +func (e *Entities) PutFast(entity Entity) { + e.Lock() + defer e.Unlock() entityIP := entity.IPAddress() if entityIP != "" { // FIXME: would be nice to store Entity directly in the map, but that // would require removing the reusal of proto-generated structs as // containers. - c.entities.EntitiesByIpAddress.Entities[entityIP] = entity.immutableEntity + e.byIP[entityIP] = entity } entityName := entity.Name() if entityName != "" { - c.entities.EntitiesByName.Entities[entityName] = entity.immutableEntity + e.byName[entityName] = entity } } // GetByIP retrieves entity with a given IP address. -func (c *Entities) GetByIP(entityIP string) (Entity, error) { - c.RLock() - defer c.RUnlock() - - if len(c.entities.EntitiesByIpAddress.Entities) == 0 { - return Entity{}, errNoEntities - } - - v, ok := c.entities.EntitiesByIpAddress.Entities[entityIP] - if !ok { - return Entity{}, errNotFound - } - - return NewEntityFromImmutable(v), nil +func (e *Entities) GetByIP(entityIP string) (Entity, error) { + return e.getFromMap(e.byIP, entityIP) } // GetByName retrieves entity with a given name. -func (c *Entities) GetByName(entityName string) (Entity, error) { - c.RLock() - defer c.RUnlock() +func (e *Entities) GetByName(entityName string) (Entity, error) { + return e.getFromMap(e.byName, entityName) +} + +func (e *Entities) getFromMap(m map[string]Entity, k string) (Entity, error) { + e.RLock() + defer e.RUnlock() - if len(c.entities.EntitiesByName.Entities) == 0 { + if len(m) == 0 { return Entity{}, errNoEntities } - v, ok := c.entities.EntitiesByName.Entities[entityName] + v, ok := m[k] if !ok { return Entity{}, errNotFound } - return NewEntityFromImmutable(v), nil + + return v, nil } var ( @@ -242,43 +230,52 @@ var ( ) // Clear removes all entities from the cache. -func (c *Entities) Clear() { - c.RLock() - defer c.RUnlock() - c.entities.EntitiesByIpAddress = &entitiesv1.Entities_Entities{ - Entities: make(map[string]*entitiesv1.Entity), - } - c.entities.EntitiesByName = &entitiesv1.Entities_Entities{ - Entities: make(map[string]*entitiesv1.Entity), - } +func (e *Entities) Clear() { + e.Lock() + defer e.Unlock() + e.byIP = make(map[string]Entity) + e.byName = make(map[string]Entity) } // Remove removes entity from the cache and returns `true` if any of IP address // or name mapping exists. // If no such entity was found, returns `false`. -func (c *Entities) Remove(entity *entitiesv1.Entity) bool { - c.Lock() - defer c.Unlock() +func (e *Entities) Remove(entity *entitiesv1.Entity) bool { + e.Lock() + defer e.Unlock() entityIP := entity.IpAddress - _, okByIP := c.entities.EntitiesByIpAddress.Entities[entityIP] + _, okByIP := e.byIP[entityIP] if okByIP { - delete(c.entities.EntitiesByIpAddress.Entities, entityIP) + delete(e.byIP, entityIP) } entityName := entity.Name - _, okByName := c.entities.EntitiesByName.Entities[entityName] + _, okByName := e.byName[entityName] if okByName { - delete(c.entities.EntitiesByName.Entities, entityName) + delete(e.byName, entityName) } return okByIP || okByName } // GetEntities returns *entitiesv1.EntitiyCache entities. -func (c *Entities) GetEntities() *entitiesv1.Entities { - c.RLock() - defer c.RUnlock() +func (e *Entities) GetEntities() *entitiesv1.Entities { + e.RLock() + defer e.RUnlock() + + // Not sure what caller will do with the result, let's clone + return &entitiesv1.Entities{ + EntitiesByIpAddress: cloneEntitiesMap(e.byIP), + EntitiesByName: cloneEntitiesMap(e.byName), + } +} - // FIXME: This Clone could be avoided, as we store immutable entities. - return proto.Clone(c.entities).(*entitiesv1.Entities) +func cloneEntitiesMap(m map[string]Entity) *entitiesv1.Entities_Entities { + clones := make(map[string]*entitiesv1.Entity, len(m)) + for k, entity := range m { + clones[k] = entity.Clone() + } + return &entitiesv1.Entities_Entities{ + Entities: clones, + } } From 8ee3b5ffa1c2941f53aec017974501a5b0be34a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Krasnoborski?= Date: Wed, 21 Jun 2023 12:30:19 +0200 Subject: [PATCH 3/3] Use entities.Entity as argument to entities.Remove --- .../fluxninja/heartbeats/services_test.go | 8 +++--- pkg/discovery/entities/entities.go | 27 +++++++++---------- pkg/discovery/entities/entities_test.go | 22 +++++++-------- .../service-getter/service-getter_test.go | 25 +++++++++++------ .../flowcontrol/service/check/check_test.go | 2 +- .../service/checkhttp/checkhttp_test.go | 2 +- .../flowcontrol/service/envoy/authz_test.go | 2 +- 7 files changed, 47 insertions(+), 41 deletions(-) diff --git a/extensions/fluxninja/heartbeats/services_test.go b/extensions/fluxninja/heartbeats/services_test.go index ef2e8cc0cf..f7475b3882 100644 --- a/extensions/fluxninja/heartbeats/services_test.go +++ b/extensions/fluxninja/heartbeats/services_test.go @@ -18,8 +18,8 @@ var _ = Describe("Services", func() { Context("Services", func() { It("reads same service from two entities", func() { - ec.Put(testEntity("1", "1.1.1.1", "some_name", []string{"baz"})) - ec.Put(testEntity("2", "1.1.1.2", "some_name", []string{"baz"})) + ec.PutForTest(testEntity("1", "1.1.1.1", "some_name", []string{"baz"})) + ec.PutForTest(testEntity("2", "1.1.1.2", "some_name", []string{"baz"})) sl := populateServicesList(ec) Expect(sl.Services).To(HaveLen(1)) Expect(sl.Services).To(ContainElement(&heartbeatv1.Service{ @@ -32,7 +32,7 @@ var _ = Describe("Services", func() { ip := "1.1.1.1" serviceNames := []string{"baz1", "baz2"} name := "entity_1234" - ec.Put(testEntity("1", ip, name, serviceNames)) + ec.PutForTest(testEntity("1", ip, name, serviceNames)) sl := populateServicesList(ec) Expect(sl.Services).To(HaveLen(2)) Expect(sl.Services).To(ContainElement(&heartbeatv1.Service{ @@ -49,7 +49,7 @@ var _ = Describe("Services", func() { ip := "1.1.1.1" serviceNames := []string{"baz"} name := "entity_1234" - ec.Put(testEntity("1", ip, name, serviceNames)) + ec.PutForTest(testEntity("1", ip, name, serviceNames)) ec.Clear() sl := populateServicesList(ec) Expect(sl.Services).To(HaveLen(0)) diff --git a/pkg/discovery/entities/entities.go b/pkg/discovery/entities/entities.go index e07379c685..472ae012ce 100644 --- a/pkg/discovery/entities/entities.go +++ b/pkg/discovery/entities/entities.go @@ -36,6 +36,8 @@ func NewEntity(entity *entitiesv1.Entity) Entity { } // NewEntityFromImmutable creates a new immutable entity, assuming given entity is immutable. +// +// This allows avoiding a copy compared to NewEntity. func NewEntityFromImmutable(entity *entitiesv1.Entity) Entity { return Entity{immutableEntity: entity} } @@ -45,11 +47,6 @@ func (e Entity) Clone() *entitiesv1.Entity { return proto.Clone(e.immutableEntity).(*entitiesv1.Entity) } -// Borrow returns the inner *entitiesv1.Entity. -// -// The returned struct must not be mutated. -func (e Entity) Borrow() *entitiesv1.Entity { return e.immutableEntity } - // UID returns the entity's UID. func (e Entity) UID() string { return e.immutableEntity.Uid } @@ -159,10 +156,10 @@ func (e *Entities) processUpdate(event notifiers.Event, unmarshaller config.Unma switch event.Type { case notifiers.Write: log.Trace().Str("entity", entity.UID()).Str("ip", ip).Str("name", name).Msg("new entity") - e.PutFast(entity) + e.Put(entity) case notifiers.Remove: log.Trace().Str("entity", entity.UID()).Str("ip", ip).Str("name", name).Msg("removing entity") - e.Remove(entity.Borrow()) + e.Remove(entity) } } @@ -174,13 +171,13 @@ func NewEntities() *Entities { } } -// Put maps given IP address and name to the entity it currently represents. -func (e *Entities) Put(entity *entitiesv1.Entity) { - e.PutFast(NewEntity(entity)) +// PutForTest maps given IP address and name to the entity it currently represents. +func (e *Entities) PutForTest(entity *entitiesv1.Entity) { + e.Put(NewEntity(entity)) } -// PutFast maps given IP address and name to the entity it currently represents. -func (e *Entities) PutFast(entity Entity) { +// Put maps given IP address and name to the entity it currently represents. +func (e *Entities) Put(entity Entity) { e.Lock() defer e.Unlock() @@ -240,17 +237,17 @@ func (e *Entities) Clear() { // Remove removes entity from the cache and returns `true` if any of IP address // or name mapping exists. // If no such entity was found, returns `false`. -func (e *Entities) Remove(entity *entitiesv1.Entity) bool { +func (e *Entities) Remove(entity Entity) bool { e.Lock() defer e.Unlock() - entityIP := entity.IpAddress + entityIP := entity.IPAddress() _, okByIP := e.byIP[entityIP] if okByIP { delete(e.byIP, entityIP) } - entityName := entity.Name + entityName := entity.Name() _, okByName := e.byName[entityName] if okByName { delete(e.byName, entityName) diff --git a/pkg/discovery/entities/entities_test.go b/pkg/discovery/entities/entities_test.go index df8de771a8..6acd960c6c 100644 --- a/pkg/discovery/entities/entities_test.go +++ b/pkg/discovery/entities/entities_test.go @@ -20,7 +20,7 @@ var _ = Describe("Cache", func() { ip := "1.2.3.4" name := "entity_1234" entity := testEntity("foo", ip, name, nil) - ec.PutFast(entity) + ec.Put(entity) actual, err := ec.GetByIP(ip) Expect(err).NotTo(HaveOccurred()) Expect(actual).To(Equal(entity)) @@ -36,9 +36,9 @@ var _ = Describe("Cache", func() { ip := "1.2.3.4" name := "entity_1234" entity := testEntity("foo", ip, name, nil) - ec.PutFast(entity) + ec.Put(entity) - removed := ec.Remove(entity.Borrow()) + removed := ec.Remove(entity) Expect(removed).To(BeTrue()) _, err := ec.GetByIP(ip) @@ -51,10 +51,10 @@ var _ = Describe("Cache", func() { name := "entity_1234" otherName := "other_entity_4321" entity := testEntity("foo", ip, name, nil) - ec.PutFast(entity) + ec.Put(entity) otherEntity := testEntity("foo2", otherIP, otherName, nil) - removed := ec.Remove(otherEntity.Borrow()) + removed := ec.Remove(otherEntity) Expect(removed).To(BeFalse()) found, err := ec.GetByIP(ip) @@ -68,7 +68,7 @@ var _ = Describe("Cache", func() { uid := "foo" name := "some_name" entity := testEntity(uid, "", name, nil) - ec.PutFast(entity) + ec.Put(entity) actual, err := ec.GetByName(name) Expect(err).NotTo(HaveOccurred()) Expect(actual).To(Equal(entity)) @@ -84,9 +84,9 @@ var _ = Describe("Cache", func() { uid := "bar" name := "some_name" entity := testEntity(uid, "", name, nil) - ec.PutFast(entity) + ec.Put(entity) - removed := ec.Remove(entity.Borrow()) + removed := ec.Remove(entity) Expect(removed).To(BeTrue()) _, err := ec.GetByName(name) @@ -99,10 +99,10 @@ var _ = Describe("Cache", func() { otherUid := "baz" otherName := "another_name" entity := testEntity(uid, "1.1.1.1", name, nil) - ec.PutFast(entity) + ec.Put(entity) otherEntity := testEntity(otherUid, "1.1.1.2", otherName, nil) - removed := ec.Remove(otherEntity.Borrow()) + removed := ec.Remove(otherEntity) Expect(removed).To(BeFalse()) found, err := ec.GetByName(name) @@ -114,7 +114,7 @@ var _ = Describe("Cache", func() { It("clears all entities from the map", func() { ip := "1.2.3.4" entity := testEntity("foo", "", "some_name", nil) - ec.PutFast(entity) + ec.Put(entity) ec.Clear() _, err := ec.GetByIP(ip) Expect(err).To(HaveOccurred()) diff --git a/pkg/policies/flowcontrol/service-getter/service-getter_test.go b/pkg/policies/flowcontrol/service-getter/service-getter_test.go index a7471bdc24..90d6136c6f 100644 --- a/pkg/policies/flowcontrol/service-getter/service-getter_test.go +++ b/pkg/policies/flowcontrol/service-getter/service-getter_test.go @@ -10,12 +10,12 @@ import ( "google.golang.org/grpc/peer" entitiesv1 "github.com/fluxninja/aperture/v2/api/gen/proto/go/aperture/discovery/entities/v1" - "github.com/fluxninja/aperture/v2/pkg/discovery/entities" + discoveryentities "github.com/fluxninja/aperture/v2/pkg/discovery/entities" servicegetter "github.com/fluxninja/aperture/v2/pkg/policies/flowcontrol/service-getter" ) func TestServiceGetter(t *testing.T) { - entities := entities.NewEntities() + entities := discoveryentities.NewEntities() sg := servicegetter.FromEntities(entities) t.Run("ServicesFromContext with no peer information", func(t *testing.T) { @@ -39,13 +39,16 @@ func TestServiceGetter(t *testing.T) { t.Run("ServicesFromContext with valid IP address and entity", func(t *testing.T) { ip := "192.168.1.2" - entity := &entitiesv1.Entity{IpAddress: ip, Services: []string{"svc1", "svc2"}} + entity := discoveryentities.NewEntity(&entitiesv1.Entity{ + IpAddress: ip, + Services: []string{"svc1", "svc2"}, + }) entities.Put(entity) defer entities.Remove(entity) ctx := peerContext(ip) services := sg.ServicesFromContext(ctx) - assert.Equal(t, entity.Services, services) + assert.Equal(t, entity.Services(), services) }) t.Run("ServicesFromSocketAddress with invalid IP address", func(t *testing.T) { @@ -63,13 +66,16 @@ func TestServiceGetter(t *testing.T) { t.Run("ServicesFromSocketAddress with valid IP address and entity", func(t *testing.T) { ip := "192.168.1.4" - entity := &entitiesv1.Entity{IpAddress: ip, Services: []string{"svc3", "svc4"}} + entity := discoveryentities.NewEntity(&entitiesv1.Entity{ + IpAddress: ip, + Services: []string{"svc3", "svc4"}, + }) entities.Put(entity) defer entities.Remove(entity) addr := &corev3.SocketAddress{Address: ip} services := sg.ServicesFromSocketAddress(addr) - assert.Equal(t, entity.Services, services) + assert.Equal(t, entity.Services(), services) }) t.Run("ServicesFromAddress with invalid IP address", func(t *testing.T) { @@ -85,12 +91,15 @@ func TestServiceGetter(t *testing.T) { t.Run("ServicesFromAddress with valid IP address and entity", func(t *testing.T) { ip := "192.168.1.6" - entity := &entitiesv1.Entity{IpAddress: ip, Services: []string{"svc5", "svc6"}} + entity := discoveryentities.NewEntity(&entitiesv1.Entity{ + IpAddress: ip, + Services: []string{"svc5", "svc6"}, + }) entities.Put(entity) defer entities.Remove(entity) services := sg.ServicesFromAddress(ip) - assert.Equal(t, entity.Services, services) + assert.Equal(t, entity.Services(), services) }) } diff --git a/pkg/policies/flowcontrol/service/check/check_test.go b/pkg/policies/flowcontrol/service/check/check_test.go index cdf0eb2968..af7742e319 100644 --- a/pkg/policies/flowcontrol/service/check/check_test.go +++ b/pkg/policies/flowcontrol/service/check/check_test.go @@ -28,7 +28,7 @@ var ( var _ = BeforeEach(func() { entities := entities.NewEntities() - entities.Put(&entitiesv1.Entity{ + entities.PutForTest(&entitiesv1.Entity{ Uid: "", IpAddress: hardCodedIPAddress, Name: hardCodedEntityName, diff --git a/pkg/policies/flowcontrol/service/checkhttp/checkhttp_test.go b/pkg/policies/flowcontrol/service/checkhttp/checkhttp_test.go index f04a0316fb..3d6e644bba 100644 --- a/pkg/policies/flowcontrol/service/checkhttp/checkhttp_test.go +++ b/pkg/policies/flowcontrol/service/checkhttp/checkhttp_test.go @@ -67,7 +67,7 @@ var _ = Describe("CheckHTTP handler", func() { _, err := classifier.AddRules(context.TODO(), "test", &hardcodedRegoRules) Expect(err).NotTo(HaveOccurred()) entities := entities.NewEntities() - entities.Put(&entitiesv1.Entity{ + entities.PutForTest(&entitiesv1.Entity{ IpAddress: "1.2.3.4", Services: []string{service1Selector.Service}, }) diff --git a/pkg/policies/flowcontrol/service/envoy/authz_test.go b/pkg/policies/flowcontrol/service/envoy/authz_test.go index 880d113f54..22bedca5d0 100644 --- a/pkg/policies/flowcontrol/service/envoy/authz_test.go +++ b/pkg/policies/flowcontrol/service/envoy/authz_test.go @@ -67,7 +67,7 @@ var _ = Describe("Authorization handler", func() { _, err := classifier.AddRules(context.TODO(), "test", &hardcodedRegoRules) Expect(err).NotTo(HaveOccurred()) entities := entities.NewEntities() - entities.Put(&entitiesv1.Entity{ + entities.PutForTest(&entitiesv1.Entity{ IpAddress: "1.2.3.4", Services: []string{service1Selector.Service}, })