Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to declare static context keys per namespace #279

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pkg/metadata/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ type RootConfigRepoMetadata struct {
type NamespaceConfigRepoMetadata struct {
// This version refers to the version of the configuration in the repo itself.
// TODO we should move this to a separate version number.
Version string `json:"version,omitempty" yaml:"version,omitempty"`
Name string `json:"name,omitempty" yaml:"name,omitempty"`
Version string `json:"version,omitempty" yaml:"version,omitempty"`
Name string `json:"name,omitempty" yaml:"name,omitempty"`
ContextProto string `json:"contextProto,omitempty" yaml:"contextProto,omitempty"`
}

const DefaultRootConfigRepoMetadataFileName = "lekko.root.yaml"
Expand Down
155 changes: 148 additions & 7 deletions pkg/repo/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"sync"

featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1"
rulesv1beta3 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta3"
"github.com/go-git/go-git/v5/plumbing"
"github.com/lekkodev/cli/pkg/encoding"
"github.com/lekkodev/cli/pkg/feature"
Expand All @@ -39,6 +40,7 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
)

// Provides functionality needed for accessing and making changes to Lekko configuration.
Expand Down Expand Up @@ -67,7 +69,110 @@ type ConfigurationStore interface {
RestoreWorkingDirectory(hash string) error
}

func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry.Types, namespace, featureName string, nv feature.NamespaceVersion) (*feature.CompiledFeature, error) {
func checkRuleFitsContextType(rule *rulesv1beta3.Rule, contextType protoreflect.MessageType) error {
var err error
switch r := rule.Rule.(type) {
case *rulesv1beta3.Rule_BoolConst:
return nil
case *rulesv1beta3.Rule_Not:
return checkRuleFitsContextType(r.Not, contextType)
case *rulesv1beta3.Rule_LogicalExpression:
for _, lr := range r.LogicalExpression.GetRules() {
err = checkRuleFitsContextType(lr, contextType)
if err != nil {
return err
}
}
case *rulesv1beta3.Rule_Atom:
field := contextType.Descriptor().Fields().ByName(protoreflect.Name(r.Atom.GetContextKey()))
if field == nil {
return errors.Errorf("`%s` field not found in context type", r.Atom.GetContextKey())
}
fieldType := field.Kind().String()

switch r.Atom.ComparisonOperator {
case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_PRESENT:
return nil
case
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_NOT_EQUALS:
switch r.Atom.GetComparisonValue().Kind.(type) {
case *structpb.Value_BoolValue:
if fieldType != "string" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}

case *structpb.Value_NumberValue:
if fieldType != "int64" && fieldType != "double" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
case *structpb.Value_StringValue:
if fieldType != "string" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
default:
panic("This should never happen")
}
case
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN_OR_EQUALS,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN_OR_EQUALS:
if fieldType != "int64" && fieldType != "double" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
return nil
case
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINS:
if fieldType != "string" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
return nil
case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN:
listRuleVal, ok := r.Atom.GetComparisonValue().Kind.(*structpb.Value_ListValue)
if !ok {
panic("This should never happen")
}
for _, elem := range listRuleVal.ListValue.Values {
switch elem.Kind.(type) {
case *structpb.Value_BoolValue:
if fieldType != "string" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}

case *structpb.Value_NumberValue:
if fieldType != "int64" && fieldType != "double" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
case *structpb.Value_StringValue:
if fieldType != "string" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
default:
panic("This should never happen")
}
}
return nil
}
return nil
case *rulesv1beta3.Rule_CallExpression:
switch f := r.CallExpression.Function.(type) {
case *rulesv1beta3.CallExpression_Bucket_:
contextKey := f.Bucket.ContextKey
if contextType.Descriptor().Fields().ByName(protoreflect.Name(contextKey)) == nil {
return errors.Errorf("%s field not found in context type", contextKey)
}
return nil
}
default:
panic("This should never happen")
}
return nil
}

func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry.Types, nsContextTypes map[string]protoreflect.MessageType, namespace, featureName string, nv feature.NamespaceVersion) (*feature.CompiledFeature, error) {
if !isValidName(namespace) {
return nil, errors.Errorf("invalid name '%s'", namespace)
}
Expand All @@ -92,6 +197,14 @@ func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry
if err != nil {
return nil, errors.Wrap(err, "compile")
}
if nsContextTypes[namespace] != nil {
for _, override := range f.Feature.Overrides {
err = checkRuleFitsContextType(override.RuleASTV3, nsContextTypes[namespace])
if err != nil {
return nil, err
}
}
}
return f, nil
}

Expand Down Expand Up @@ -303,6 +416,38 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
if err := req.Validate(); err != nil {
return nil, errors.Wrap(err, "validate request")
}

_, nsMDs, err := r.ParseMetadata(ctx)
if err != nil {
return nil, errors.Wrap(err, "parse metadata")
}
registry, err := r.registry(ctx, req.Registry)
if err != nil {
return nil, errors.Wrap(err, "registry")
}
nsContextTypes := make(map[string]protoreflect.MessageType)
for ns, nsMd := range nsMDs {
if nsMd.ContextProto != "" {
ct, err := registry.FindMessageByName(protoreflect.FullName(nsMd.ContextProto))
if err != nil {
return nil, err
}
for i := 0; i < ct.Descriptor().Fields().Len(); i++ {
f := ct.Descriptor().Fields().Get(i)
switch f.Kind().String() {
case
"bool",
"int64",
"double",
"string":
default:
return nil, errors.Errorf("proto message cannot be used as a context message because type: %v of key: %s is not allowed", f.Kind(), f.Name())
}
}
nsContextTypes[ns] = ct
}
}

// Step 1: collect. Find all features
vffs, numNamespaces, err := r.findVersionedFeatureFiles(ctx, req.NamespaceFilter, req.FeatureFilter, req.Verify)
if err != nil {
Expand All @@ -326,10 +471,6 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
}
r.Logf("Found %d configs across %d namespaces\n", len(results), numNamespaces)
r.Logf("Compiling...\n")
registry, err := r.registry(ctx, req.Registry)
if err != nil {
return nil, errors.Wrap(err, "registry")
}
concurrency := 50
if len(results) < 50 {
concurrency = len(results)
Expand All @@ -344,7 +485,7 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
nsNameToSegments := make(map[string]map[string]string)
for _, fcr := range results {
if fcr.FeatureName == "segments" {
cf, err := r.CompileFeature(ctx, registry, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion)
cf, err := r.CompileFeature(ctx, registry, nsContextTypes, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion)
if err != nil {
fcr.CompilationError = err
}
Expand Down Expand Up @@ -385,7 +526,7 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
}
fcr.FormattingDiffExists = fmtDiffExists
// compile feature
cf, err := r.CompileFeature(ctx, registry, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion)
cf, err := r.CompileFeature(ctx, registry, nsContextTypes, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion)
fcr.CompiledFeature = cf
fcr.CompilationError = err
}
Expand Down