diff --git a/extensions/fluxninja/heartbeats/services_test.go b/extensions/fluxninja/heartbeats/services_test.go index ef2e8cc0cf..f7475b3882 100644 --- a/extensions/fluxninja/heartbeats/services_test.go +++ b/extensions/fluxninja/heartbeats/services_test.go @@ -18,8 +18,8 @@ var _ = Describe("Services", func() { Context("Services", func() { It("reads same service from two entities", func() { - ec.Put(testEntity("1", "1.1.1.1", "some_name", []string{"baz"})) - ec.Put(testEntity("2", "1.1.1.2", "some_name", []string{"baz"})) + ec.PutForTest(testEntity("1", "1.1.1.1", "some_name", []string{"baz"})) + ec.PutForTest(testEntity("2", "1.1.1.2", "some_name", []string{"baz"})) sl := populateServicesList(ec) Expect(sl.Services).To(HaveLen(1)) Expect(sl.Services).To(ContainElement(&heartbeatv1.Service{ @@ -32,7 +32,7 @@ var _ = Describe("Services", func() { ip := "1.1.1.1" serviceNames := []string{"baz1", "baz2"} name := "entity_1234" - ec.Put(testEntity("1", ip, name, serviceNames)) + ec.PutForTest(testEntity("1", ip, name, serviceNames)) sl := populateServicesList(ec) Expect(sl.Services).To(HaveLen(2)) Expect(sl.Services).To(ContainElement(&heartbeatv1.Service{ @@ -49,7 +49,7 @@ var _ = Describe("Services", func() { ip := "1.1.1.1" serviceNames := []string{"baz"} name := "entity_1234" - ec.Put(testEntity("1", ip, name, serviceNames)) + ec.PutForTest(testEntity("1", ip, name, serviceNames)) ec.Clear() sl := populateServicesList(ec) Expect(sl.Services).To(HaveLen(0)) diff --git a/pkg/discovery/entities/entities-service.go b/pkg/discovery/entities/entities-service.go index caf1fab5f4..207b9f65c3 100644 --- a/pkg/discovery/entities/entities-service.go +++ b/pkg/discovery/entities/entities-service.go @@ -4,6 +4,8 @@ import ( "context" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" cmdv1 "github.com/fluxninja/aperture/v2/api/gen/proto/go/aperture/cmd/v1" @@ -45,12 +47,20 @@ func (c *EntitiesService) GetEntities(ctx context.Context, _ *emptypb.Empty) (*e // GetEntityByIPAddress returns an entity by IP address. func (c *EntitiesService) GetEntityByIPAddress(ctx context.Context, req *entitiesv1.GetEntityByIPAddressRequest) (*entitiesv1.Entity, error) { - return c.entityCache.GetByIP(req.GetIpAddress()) + e, err := c.entityCache.GetByIP(req.GetIpAddress()) + if err != nil { + return nil, status.Error(codes.NotFound, err.Error()) + } + return e.Clone(), nil } // GetEntityByName returns an entity by name. func (c *EntitiesService) GetEntityByName(ctx context.Context, req *entitiesv1.GetEntityByNameRequest) (*entitiesv1.Entity, error) { - return c.entityCache.GetByName(req.GetName()) + e, err := c.entityCache.GetByName(req.GetName()) + if err != nil { + return nil, status.Error(codes.NotFound, err.Error()) + } + return e.Clone(), nil } // ListDiscoveryEntities lists currently discovered entities by IP address. @@ -71,7 +81,7 @@ func (c *EntitiesService) ListDiscoveryEntity(ctx context.Context, req *cmdv1.Li return nil, err } return &cmdv1.ListDiscoveryEntityAgentResponse{ - Entity: entity, + Entity: entity.Clone(), }, nil case *cmdv1.ListDiscoveryEntityRequest_Name: entity, err := c.entityCache.GetByName(req.GetName()) @@ -79,7 +89,7 @@ func (c *EntitiesService) ListDiscoveryEntity(ctx context.Context, req *cmdv1.Li return nil, err } return &cmdv1.ListDiscoveryEntityAgentResponse{ - Entity: entity, + Entity: entity.Clone(), }, nil default: return nil, nil diff --git a/pkg/discovery/entities/entities.go b/pkg/discovery/entities/entities.go index 64575bf883..472ae012ce 100644 --- a/pkg/discovery/entities/entities.go +++ b/pkg/discovery/entities/entities.go @@ -25,10 +25,53 @@ func Module() fx.Option { ) } +// Entity is an immutable wrapper over *entitiesv1.Entity. +type Entity struct { + immutableEntity *entitiesv1.Entity +} + +// NewEntity creates a new immutable entity from the copy of given entity. +func NewEntity(entity *entitiesv1.Entity) Entity { + return Entity{immutableEntity: proto.Clone(entity).(*entitiesv1.Entity)} +} + +// NewEntityFromImmutable creates a new immutable entity, assuming given entity is immutable. +// +// This allows avoiding a copy compared to NewEntity. +func NewEntityFromImmutable(entity *entitiesv1.Entity) Entity { + return Entity{immutableEntity: entity} +} + +// Clone returns a mutable copy of the entity. +func (e Entity) Clone() *entitiesv1.Entity { + return proto.Clone(e.immutableEntity).(*entitiesv1.Entity) +} + +// UID returns the entity's UID. +func (e Entity) UID() string { return e.immutableEntity.Uid } + +// IPAddress returns the entity's IP address. +func (e Entity) IPAddress() string { return e.immutableEntity.IpAddress } + +// Name returns the entity's name. +func (e Entity) Name() string { return e.immutableEntity.Name } + +// Namespace returns the entity's namespace. +func (e Entity) Namespace() string { return e.immutableEntity.Namespace } + +// NodeName returns the entity's node name. +func (e Entity) NodeName() string { return e.immutableEntity.NodeName } + +// Services returns list of services the entity belongs to. +// +// The returned slice must not be modified. +func (e Entity) Services() []string { return e.immutableEntity.Services } + // Entities maps IP addresses and Entity names to entities. type Entities struct { sync.RWMutex - entities *entitiesv1.Entities + byIP map[string]Entity + byName map[string]Entity } // EntityTrackers allows to register a service discovery for entity cache @@ -99,89 +142,83 @@ func provideEntities(in FxIn) (*Entities, *EntityTrackers, error) { return entityCache, &EntityTrackers{trackers: in.EntityTrackers}, nil } -func (c *Entities) processUpdate(event notifiers.Event, unmarshaller config.Unmarshaller) { +func (e *Entities) processUpdate(event notifiers.Event, unmarshaller config.Unmarshaller) { log.Trace().Str("event", event.String()).Msg("Updating entity") - entity := &entitiesv1.Entity{} - if err := unmarshaller.UnmarshalKey("", entity); err != nil { + entityProto := &entitiesv1.Entity{} + if err := unmarshaller.UnmarshalKey("", entityProto); err != nil { log.Error().Err(err).Msg("Failed to unmarshal entity") return } - ip := entity.IpAddress - name := entity.Name + entity := NewEntityFromImmutable(entityProto) + ip := entity.IPAddress() + name := entity.Name() switch event.Type { case notifiers.Write: - log.Trace().Str("entity", entity.Uid).Str("ip", ip).Str("name", name).Msg("new entity") - c.Put(entity) + log.Trace().Str("entity", entity.UID()).Str("ip", ip).Str("name", name).Msg("new entity") + e.Put(entity) case notifiers.Remove: - log.Trace().Str("entity", entity.Uid).Str("ip", ip).Str("name", name).Msg("removing entity") - c.Remove(entity) + log.Trace().Str("entity", entity.UID()).Str("ip", ip).Str("name", name).Msg("removing entity") + e.Remove(entity) } } // NewEntities creates a new, empty Entities. func NewEntities() *Entities { - entities := &entitiesv1.Entities{ - EntitiesByIpAddress: &entitiesv1.Entities_Entities{ - Entities: make(map[string]*entitiesv1.Entity), - }, - EntitiesByName: &entitiesv1.Entities_Entities{ - Entities: make(map[string]*entitiesv1.Entity), - }, - } return &Entities{ - entities: entities, + byIP: make(map[string]Entity), + byName: make(map[string]Entity), } } +// PutForTest maps given IP address and name to the entity it currently represents. +func (e *Entities) PutForTest(entity *entitiesv1.Entity) { + e.Put(NewEntity(entity)) +} + // Put maps given IP address and name to the entity it currently represents. -func (c *Entities) Put(entity *entitiesv1.Entity) { - c.Lock() - defer c.Unlock() +func (e *Entities) Put(entity Entity) { + e.Lock() + defer e.Unlock() - entityIP := entity.IpAddress + entityIP := entity.IPAddress() if entityIP != "" { - c.entities.EntitiesByIpAddress.Entities[entityIP] = entity + // FIXME: would be nice to store Entity directly in the map, but that + // would require removing the reusal of proto-generated structs as + // containers. + e.byIP[entityIP] = entity } - entityName := entity.Name + entityName := entity.Name() if entityName != "" { - c.entities.EntitiesByName.Entities[entityName] = entity + e.byName[entityName] = entity } } // GetByIP retrieves entity with a given IP address. -func (c *Entities) GetByIP(entityIP string) (*entitiesv1.Entity, error) { - c.RLock() - defer c.RUnlock() - - if len(c.entities.EntitiesByIpAddress.Entities) == 0 { - return nil, errNoEntities - } - - v, ok := c.entities.EntitiesByIpAddress.Entities[entityIP] - if !ok { - return nil, errNotFound - } - - return proto.Clone(v).(*entitiesv1.Entity), nil +func (e *Entities) GetByIP(entityIP string) (Entity, error) { + return e.getFromMap(e.byIP, entityIP) } // GetByName retrieves entity with a given name. -func (c *Entities) GetByName(entityName string) (*entitiesv1.Entity, error) { - c.RLock() - defer c.RUnlock() +func (e *Entities) GetByName(entityName string) (Entity, error) { + return e.getFromMap(e.byName, entityName) +} + +func (e *Entities) getFromMap(m map[string]Entity, k string) (Entity, error) { + e.RLock() + defer e.RUnlock() - if len(c.entities.EntitiesByName.Entities) == 0 { - return nil, errNoEntities + if len(m) == 0 { + return Entity{}, errNoEntities } - v, ok := c.entities.EntitiesByName.Entities[entityName] + v, ok := m[k] if !ok { - return nil, errNotFound + return Entity{}, errNotFound } - return proto.Clone(v).(*entitiesv1.Entity), nil + return v, nil } var ( @@ -190,42 +227,52 @@ var ( ) // Clear removes all entities from the cache. -func (c *Entities) Clear() { - c.RLock() - defer c.RUnlock() - c.entities.EntitiesByIpAddress = &entitiesv1.Entities_Entities{ - Entities: make(map[string]*entitiesv1.Entity), - } - c.entities.EntitiesByName = &entitiesv1.Entities_Entities{ - Entities: make(map[string]*entitiesv1.Entity), - } +func (e *Entities) Clear() { + e.Lock() + defer e.Unlock() + e.byIP = make(map[string]Entity) + e.byName = make(map[string]Entity) } // Remove removes entity from the cache and returns `true` if any of IP address // or name mapping exists. // If no such entity was found, returns `false`. -func (c *Entities) Remove(entity *entitiesv1.Entity) bool { - c.Lock() - defer c.Unlock() +func (e *Entities) Remove(entity Entity) bool { + e.Lock() + defer e.Unlock() - entityIP := entity.IpAddress - _, okByIP := c.entities.EntitiesByIpAddress.Entities[entityIP] + entityIP := entity.IPAddress() + _, okByIP := e.byIP[entityIP] if okByIP { - delete(c.entities.EntitiesByIpAddress.Entities, entityIP) + delete(e.byIP, entityIP) } - entityName := entity.Name - _, okByName := c.entities.EntitiesByName.Entities[entityName] + entityName := entity.Name() + _, okByName := e.byName[entityName] if okByName { - delete(c.entities.EntitiesByName.Entities, entityName) + delete(e.byName, entityName) } return okByIP || okByName } // GetEntities returns *entitiesv1.EntitiyCache entities. -func (c *Entities) GetEntities() *entitiesv1.Entities { - c.RLock() - defer c.RUnlock() +func (e *Entities) GetEntities() *entitiesv1.Entities { + e.RLock() + defer e.RUnlock() + + // Not sure what caller will do with the result, let's clone + return &entitiesv1.Entities{ + EntitiesByIpAddress: cloneEntitiesMap(e.byIP), + EntitiesByName: cloneEntitiesMap(e.byName), + } +} - return proto.Clone(c.entities).(*entitiesv1.Entities) +func cloneEntitiesMap(m map[string]Entity) *entitiesv1.Entities_Entities { + clones := make(map[string]*entitiesv1.Entity, len(m)) + for k, entity := range m { + clones[k] = entity.Clone() + } + return &entitiesv1.Entities_Entities{ + Entities: clones, + } } diff --git a/pkg/discovery/entities/entities_test.go b/pkg/discovery/entities/entities_test.go index 0821aabbff..6acd960c6c 100644 --- a/pkg/discovery/entities/entities_test.go +++ b/pkg/discovery/entities/entities_test.go @@ -26,11 +26,10 @@ var _ = Describe("Cache", func() { Expect(actual).To(Equal(entity)) }) - It("returns nil when no entity found", func() { + It("returns err when no entity found", func() { ip := "1.2.3.4" - actual, err := ec.GetByIP(ip) - Expect(err).To(Not(BeNil())) - Expect(actual).To(BeNil()) + _, err := ec.GetByIP(ip) + Expect(err).To(HaveOccurred()) }) It("removes an entity properly", func() { @@ -42,9 +41,8 @@ var _ = Describe("Cache", func() { removed := ec.Remove(entity) Expect(removed).To(BeTrue()) - found, err := ec.GetByIP(ip) - Expect(err).To(Not(BeNil())) - Expect(found).To(BeNil()) + _, err := ec.GetByIP(ip) + Expect(err).To(HaveOccurred()) }) It("returns false if trying to remove a nonexistent entity", func() { @@ -76,11 +74,10 @@ var _ = Describe("Cache", func() { Expect(actual).To(Equal(entity)) }) - It("returns nil when no entity found", func() { + It("returns err when no entity found", func() { name := "some_name" - actual, err := ec.GetByName(name) - Expect(err).To(Not(BeNil())) - Expect(actual).To(BeNil()) + _, err := ec.GetByName(name) + Expect(err).To(HaveOccurred()) }) It("removes an entity properly", func() { @@ -92,9 +89,8 @@ var _ = Describe("Cache", func() { removed := ec.Remove(entity) Expect(removed).To(BeTrue()) - found, err := ec.GetByName(name) - Expect(err).To(Not(BeNil())) - Expect(found).To(BeNil()) + _, err := ec.GetByName(name) + Expect(err).To(HaveOccurred()) }) It("returns false if trying to remove a nonexistent entity", func() { @@ -120,17 +116,16 @@ var _ = Describe("Cache", func() { entity := testEntity("foo", "", "some_name", nil) ec.Put(entity) ec.Clear() - found, err := ec.GetByIP(ip) - Expect(err).To(Not(BeNil())) - Expect(found).To(BeNil()) + _, err := ec.GetByIP(ip) + Expect(err).To(HaveOccurred()) }) }) -func testEntity(uid, ipAddress, name string, services []string) *entitiesv1.Entity { - return &entitiesv1.Entity{ +func testEntity(uid, ipAddress, name string, services []string) entities.Entity { + return entities.NewEntity(&entitiesv1.Entity{ Uid: uid, IpAddress: ipAddress, Name: name, Services: services, - } + }) } diff --git a/pkg/policies/flowcontrol/service-getter/service-getter.go b/pkg/policies/flowcontrol/service-getter/service-getter.go index cc3208bb0f..746f683754 100644 --- a/pkg/policies/flowcontrol/service-getter/service-getter.go +++ b/pkg/policies/flowcontrol/service-getter/service-getter.go @@ -13,6 +13,8 @@ import ( ) // ServiceGetter can be used to query services based on client context. +// +// Caller should not modify slices returned from methods of ServiceGetter. type ServiceGetter interface { ServicesFromContext(ctx context.Context) []string ServicesFromSocketAddress(addr *corev3.SocketAddress) []string @@ -69,7 +71,7 @@ func (sg *ecServiceGetter) servicesFromContext(ctx context.Context) (svcs []stri return nil, false } - return entity.Services, true + return entity.Services(), true } // ServicesFromSocketAddress returns list of services associated with IP extracted from SocketAddress. @@ -91,7 +93,7 @@ func (sg *ecServiceGetter) ServicesFromAddress(addr string) []string { } return nil } - return entity.Services + return entity.Services() } var noEntitySampler = log.NewRatelimitingSampler() diff --git a/pkg/policies/flowcontrol/service-getter/service-getter_test.go b/pkg/policies/flowcontrol/service-getter/service-getter_test.go index a7471bdc24..90d6136c6f 100644 --- a/pkg/policies/flowcontrol/service-getter/service-getter_test.go +++ b/pkg/policies/flowcontrol/service-getter/service-getter_test.go @@ -10,12 +10,12 @@ import ( "google.golang.org/grpc/peer" entitiesv1 "github.com/fluxninja/aperture/v2/api/gen/proto/go/aperture/discovery/entities/v1" - "github.com/fluxninja/aperture/v2/pkg/discovery/entities" + discoveryentities "github.com/fluxninja/aperture/v2/pkg/discovery/entities" servicegetter "github.com/fluxninja/aperture/v2/pkg/policies/flowcontrol/service-getter" ) func TestServiceGetter(t *testing.T) { - entities := entities.NewEntities() + entities := discoveryentities.NewEntities() sg := servicegetter.FromEntities(entities) t.Run("ServicesFromContext with no peer information", func(t *testing.T) { @@ -39,13 +39,16 @@ func TestServiceGetter(t *testing.T) { t.Run("ServicesFromContext with valid IP address and entity", func(t *testing.T) { ip := "192.168.1.2" - entity := &entitiesv1.Entity{IpAddress: ip, Services: []string{"svc1", "svc2"}} + entity := discoveryentities.NewEntity(&entitiesv1.Entity{ + IpAddress: ip, + Services: []string{"svc1", "svc2"}, + }) entities.Put(entity) defer entities.Remove(entity) ctx := peerContext(ip) services := sg.ServicesFromContext(ctx) - assert.Equal(t, entity.Services, services) + assert.Equal(t, entity.Services(), services) }) t.Run("ServicesFromSocketAddress with invalid IP address", func(t *testing.T) { @@ -63,13 +66,16 @@ func TestServiceGetter(t *testing.T) { t.Run("ServicesFromSocketAddress with valid IP address and entity", func(t *testing.T) { ip := "192.168.1.4" - entity := &entitiesv1.Entity{IpAddress: ip, Services: []string{"svc3", "svc4"}} + entity := discoveryentities.NewEntity(&entitiesv1.Entity{ + IpAddress: ip, + Services: []string{"svc3", "svc4"}, + }) entities.Put(entity) defer entities.Remove(entity) addr := &corev3.SocketAddress{Address: ip} services := sg.ServicesFromSocketAddress(addr) - assert.Equal(t, entity.Services, services) + assert.Equal(t, entity.Services(), services) }) t.Run("ServicesFromAddress with invalid IP address", func(t *testing.T) { @@ -85,12 +91,15 @@ func TestServiceGetter(t *testing.T) { t.Run("ServicesFromAddress with valid IP address and entity", func(t *testing.T) { ip := "192.168.1.6" - entity := &entitiesv1.Entity{IpAddress: ip, Services: []string{"svc5", "svc6"}} + entity := discoveryentities.NewEntity(&entitiesv1.Entity{ + IpAddress: ip, + Services: []string{"svc5", "svc6"}, + }) entities.Put(entity) defer entities.Remove(entity) services := sg.ServicesFromAddress(ip) - assert.Equal(t, entity.Services, services) + assert.Equal(t, entity.Services(), services) }) } diff --git a/pkg/policies/flowcontrol/service/check/check_test.go b/pkg/policies/flowcontrol/service/check/check_test.go index cdf0eb2968..af7742e319 100644 --- a/pkg/policies/flowcontrol/service/check/check_test.go +++ b/pkg/policies/flowcontrol/service/check/check_test.go @@ -28,7 +28,7 @@ var ( var _ = BeforeEach(func() { entities := entities.NewEntities() - entities.Put(&entitiesv1.Entity{ + entities.PutForTest(&entitiesv1.Entity{ Uid: "", IpAddress: hardCodedIPAddress, Name: hardCodedEntityName, diff --git a/pkg/policies/flowcontrol/service/checkhttp/checkhttp_test.go b/pkg/policies/flowcontrol/service/checkhttp/checkhttp_test.go index f04a0316fb..3d6e644bba 100644 --- a/pkg/policies/flowcontrol/service/checkhttp/checkhttp_test.go +++ b/pkg/policies/flowcontrol/service/checkhttp/checkhttp_test.go @@ -67,7 +67,7 @@ var _ = Describe("CheckHTTP handler", func() { _, err := classifier.AddRules(context.TODO(), "test", &hardcodedRegoRules) Expect(err).NotTo(HaveOccurred()) entities := entities.NewEntities() - entities.Put(&entitiesv1.Entity{ + entities.PutForTest(&entitiesv1.Entity{ IpAddress: "1.2.3.4", Services: []string{service1Selector.Service}, }) diff --git a/pkg/policies/flowcontrol/service/envoy/authz_test.go b/pkg/policies/flowcontrol/service/envoy/authz_test.go index 880d113f54..22bedca5d0 100644 --- a/pkg/policies/flowcontrol/service/envoy/authz_test.go +++ b/pkg/policies/flowcontrol/service/envoy/authz_test.go @@ -67,7 +67,7 @@ var _ = Describe("Authorization handler", func() { _, err := classifier.AddRules(context.TODO(), "test", &hardcodedRegoRules) Expect(err).NotTo(HaveOccurred()) entities := entities.NewEntities() - entities.Put(&entitiesv1.Entity{ + entities.PutForTest(&entitiesv1.Entity{ IpAddress: "1.2.3.4", Services: []string{service1Selector.Service}, })