diff --git a/pkg/metadata/metadata.go b/pkg/metadata/metadata.go index 62b3259c..5a7bfaf2 100644 --- a/pkg/metadata/metadata.go +++ b/pkg/metadata/metadata.go @@ -50,8 +50,9 @@ type RootConfigRepoMetadata struct { type NamespaceConfigRepoMetadata struct { // This version refers to the version of the configuration in the repo itself. // TODO we should move this to a separate version number. - Version string `json:"version,omitempty" yaml:"version,omitempty"` - Name string `json:"name,omitempty" yaml:"name,omitempty"` + Version string `json:"version,omitempty" yaml:"version,omitempty"` + Name string `json:"name,omitempty" yaml:"name,omitempty"` + ContextProto string `json:"contextProto,omitempty" yaml:"contextProto,omitempty"` } const DefaultRootConfigRepoMetadataFileName = "lekko.root.yaml" diff --git a/pkg/repo/feature.go b/pkg/repo/feature.go index c774d6f9..1745de1a 100644 --- a/pkg/repo/feature.go +++ b/pkg/repo/feature.go @@ -25,6 +25,7 @@ import ( "sync" featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1" + rulesv1beta3 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta3" "github.com/go-git/go-git/v5/plumbing" "github.com/lekkodev/cli/pkg/encoding" "github.com/lekkodev/cli/pkg/feature" @@ -39,6 +40,7 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/structpb" ) // Provides functionality needed for accessing and making changes to Lekko configuration. @@ -67,7 +69,110 @@ type ConfigurationStore interface { RestoreWorkingDirectory(hash string) error } -func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry.Types, namespace, featureName string, nv feature.NamespaceVersion) (*feature.CompiledFeature, error) { +func checkRuleFitsContextType(rule *rulesv1beta3.Rule, contextType protoreflect.MessageType) error { + var err error + switch r := rule.Rule.(type) { + case *rulesv1beta3.Rule_BoolConst: + return nil + case *rulesv1beta3.Rule_Not: + return checkRuleFitsContextType(r.Not, contextType) + case *rulesv1beta3.Rule_LogicalExpression: + for _, lr := range r.LogicalExpression.GetRules() { + err = checkRuleFitsContextType(lr, contextType) + if err != nil { + return err + } + } + case *rulesv1beta3.Rule_Atom: + field := contextType.Descriptor().Fields().ByName(protoreflect.Name(r.Atom.GetContextKey())) + if field == nil { + return errors.Errorf("`%s` field not found in context type", r.Atom.GetContextKey()) + } + fieldType := field.Kind().String() + + switch r.Atom.ComparisonOperator { + case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_PRESENT: + return nil + case + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_NOT_EQUALS: + switch r.Atom.GetComparisonValue().Kind.(type) { + case *structpb.Value_BoolValue: + if fieldType != "string" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + + case *structpb.Value_NumberValue: + if fieldType != "int64" && fieldType != "double" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + case *structpb.Value_StringValue: + if fieldType != "string" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + default: + panic("This should never happen") + } + case + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN_OR_EQUALS, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN_OR_EQUALS: + if fieldType != "int64" && fieldType != "double" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + return nil + case + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINS: + if fieldType != "string" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + return nil + case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN: + listRuleVal, ok := r.Atom.GetComparisonValue().Kind.(*structpb.Value_ListValue) + if !ok { + panic("This should never happen") + } + for _, elem := range listRuleVal.ListValue.Values { + switch elem.Kind.(type) { + case *structpb.Value_BoolValue: + if fieldType != "string" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + + case *structpb.Value_NumberValue: + if fieldType != "int64" && fieldType != "double" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + case *structpb.Value_StringValue: + if fieldType != "string" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + default: + panic("This should never happen") + } + } + return nil + } + return nil + case *rulesv1beta3.Rule_CallExpression: + switch f := r.CallExpression.Function.(type) { + case *rulesv1beta3.CallExpression_Bucket_: + contextKey := f.Bucket.ContextKey + if contextType.Descriptor().Fields().ByName(protoreflect.Name(contextKey)) == nil { + return errors.Errorf("%s field not found in context type", contextKey) + } + return nil + } + default: + panic("This should never happen") + } + return nil +} + +func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry.Types, nsContextTypes map[string]protoreflect.MessageType, namespace, featureName string, nv feature.NamespaceVersion) (*feature.CompiledFeature, error) { if !isValidName(namespace) { return nil, errors.Errorf("invalid name '%s'", namespace) } @@ -92,6 +197,14 @@ func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry if err != nil { return nil, errors.Wrap(err, "compile") } + if nsContextTypes[namespace] != nil { + for _, override := range f.Feature.Overrides { + err = checkRuleFitsContextType(override.RuleASTV3, nsContextTypes[namespace]) + if err != nil { + return nil, err + } + } + } return f, nil } @@ -303,6 +416,40 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu if err := req.Validate(); err != nil { return nil, errors.Wrap(err, "validate request") } + + _, nsMDs, err := r.ParseMetadata(ctx) + if err != nil { + return nil, errors.Wrap(err, "parse metadata") + } + registry, err := r.registry(ctx, req.Registry) + if err != nil { + return nil, errors.Wrap(err, "registry") + } + nsContextTypes := make(map[string]protoreflect.MessageType) + for ns, nsMd := range nsMDs { + if nsMd.ContextProto != "" { + r.Logf("%s: %s\n", ns, nsMd.ContextProto) + ct, err := registry.FindMessageByName(protoreflect.FullName(nsMd.ContextProto)) + if err != nil { + return nil, err + } + for i := 0; i < ct.Descriptor().Fields().Len(); i++ { + f := ct.Descriptor().Fields().Get(i) + switch f.Kind().String() { + case + "bool", + "int64", + "double", + "string": + default: + return nil, errors.New("Invalid context type thingy make this better") + } + } + nsContextTypes[ns] = ct + } + } + r.Logf("%#v\n", nsContextTypes) + // Step 1: collect. Find all features vffs, numNamespaces, err := r.findVersionedFeatureFiles(ctx, req.NamespaceFilter, req.FeatureFilter, req.Verify) if err != nil { @@ -326,10 +473,6 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu } r.Logf("Found %d configs across %d namespaces\n", len(results), numNamespaces) r.Logf("Compiling...\n") - registry, err := r.registry(ctx, req.Registry) - if err != nil { - return nil, errors.Wrap(err, "registry") - } concurrency := 50 if len(results) < 50 { concurrency = len(results) @@ -344,7 +487,7 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu nsNameToSegments := make(map[string]map[string]string) for _, fcr := range results { if fcr.FeatureName == "segments" { - cf, err := r.CompileFeature(ctx, registry, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion) + cf, err := r.CompileFeature(ctx, registry, nsContextTypes, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion) if err != nil { fcr.CompilationError = err } @@ -385,7 +528,7 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu } fcr.FormattingDiffExists = fmtDiffExists // compile feature - cf, err := r.CompileFeature(ctx, registry, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion) + cf, err := r.CompileFeature(ctx, registry, nsContextTypes, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion) fcr.CompiledFeature = cf fcr.CompilationError = err }