Skip to content

Commit

Permalink
Add the ability to declare static context keys per namespace
Browse files Browse the repository at this point in the history
Bask in eternal beauty of the protobuf API
  • Loading branch information
lekko-jonathan committed Jan 9, 2024
1 parent a72c40b commit a44e9ed
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 9 deletions.
5 changes: 3 additions & 2 deletions pkg/metadata/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ type RootConfigRepoMetadata struct {
type NamespaceConfigRepoMetadata struct {
// This version refers to the version of the configuration in the repo itself.
// TODO we should move this to a separate version number.
Version string `json:"version,omitempty" yaml:"version,omitempty"`
Name string `json:"name,omitempty" yaml:"name,omitempty"`
Version string `json:"version,omitempty" yaml:"version,omitempty"`
Name string `json:"name,omitempty" yaml:"name,omitempty"`
ContextProto string `json:"contextProto,omitempty" yaml:"contextProto,omitempty"`
}

const DefaultRootConfigRepoMetadataFileName = "lekko.root.yaml"
Expand Down
157 changes: 150 additions & 7 deletions pkg/repo/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"sync"

featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1"
rulesv1beta3 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta3"
"github.com/go-git/go-git/v5/plumbing"
"github.com/lekkodev/cli/pkg/encoding"
"github.com/lekkodev/cli/pkg/feature"
Expand All @@ -39,6 +40,7 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
)

// Provides functionality needed for accessing and making changes to Lekko configuration.
Expand Down Expand Up @@ -67,7 +69,110 @@ type ConfigurationStore interface {
RestoreWorkingDirectory(hash string) error
}

func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry.Types, namespace, featureName string, nv feature.NamespaceVersion) (*feature.CompiledFeature, error) {
func checkRuleFitsContextType(rule *rulesv1beta3.Rule, contextType protoreflect.MessageType) error {
var err error
switch r := rule.Rule.(type) {
case *rulesv1beta3.Rule_BoolConst:
return nil
case *rulesv1beta3.Rule_Not:
return checkRuleFitsContextType(r.Not, contextType)
case *rulesv1beta3.Rule_LogicalExpression:
for _, lr := range r.LogicalExpression.GetRules() {
err = checkRuleFitsContextType(lr, contextType)
if err != nil {
return err
}
}
case *rulesv1beta3.Rule_Atom:
field := contextType.Descriptor().Fields().ByName(protoreflect.Name(r.Atom.GetContextKey()))
if field == nil {
return errors.Errorf("`%s` field not found in context type", r.Atom.GetContextKey())
}
fieldType := field.Kind().String()

switch r.Atom.ComparisonOperator {
case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_PRESENT:
return nil
case
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_NOT_EQUALS:
switch r.Atom.GetComparisonValue().Kind.(type) {
case *structpb.Value_BoolValue:
if fieldType != "string" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}

case *structpb.Value_NumberValue:
if fieldType != "int64" && fieldType != "double" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
case *structpb.Value_StringValue:
if fieldType != "string" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
default:
panic("This should never happen")
}
case
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN_OR_EQUALS,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN_OR_EQUALS:
if fieldType != "int64" && fieldType != "double" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
return nil
case
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH,
rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINS:
if fieldType != "string" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
return nil
case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN:
listRuleVal, ok := r.Atom.GetComparisonValue().Kind.(*structpb.Value_ListValue)
if !ok {
panic("This should never happen")
}
for _, elem := range listRuleVal.ListValue.Values {
switch elem.Kind.(type) {
case *structpb.Value_BoolValue:
if fieldType != "string" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}

case *structpb.Value_NumberValue:
if fieldType != "int64" && fieldType != "double" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
case *structpb.Value_StringValue:
if fieldType != "string" {
return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey())
}
default:
panic("This should never happen")
}
}
return nil
}
return nil
case *rulesv1beta3.Rule_CallExpression:
switch f := r.CallExpression.Function.(type) {
case *rulesv1beta3.CallExpression_Bucket_:
contextKey := f.Bucket.ContextKey
if contextType.Descriptor().Fields().ByName(protoreflect.Name(contextKey)) == nil {
return errors.Errorf("%s field not found in context type", contextKey)
}
return nil
}
default:
panic("This should never happen")
}
return nil
}

func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry.Types, nsContextTypes map[string]protoreflect.MessageType, namespace, featureName string, nv feature.NamespaceVersion) (*feature.CompiledFeature, error) {
if !isValidName(namespace) {
return nil, errors.Errorf("invalid name '%s'", namespace)
}
Expand All @@ -92,6 +197,14 @@ func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry
if err != nil {
return nil, errors.Wrap(err, "compile")
}
if nsContextTypes[namespace] != nil {
for _, override := range f.Feature.Overrides {
err = checkRuleFitsContextType(override.RuleASTV3, nsContextTypes[namespace])
if err != nil {
return nil, err
}
}
}
return f, nil
}

Expand Down Expand Up @@ -303,6 +416,40 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
if err := req.Validate(); err != nil {
return nil, errors.Wrap(err, "validate request")
}

_, nsMDs, err := r.ParseMetadata(ctx)
if err != nil {
return nil, errors.Wrap(err, "parse metadata")
}
registry, err := r.registry(ctx, req.Registry)
if err != nil {
return nil, errors.Wrap(err, "registry")
}
nsContextTypes := make(map[string]protoreflect.MessageType)
for ns, nsMd := range nsMDs {
if nsMd.ContextProto != "" {
r.Logf("%s: %s\n", ns, nsMd.ContextProto)
ct, err := registry.FindMessageByName(protoreflect.FullName(nsMd.ContextProto))
if err != nil {
return nil, err
}
for i := 0; i < ct.Descriptor().Fields().Len(); i++ {
f := ct.Descriptor().Fields().Get(i)
switch f.Kind().String() {
case
"bool",
"int64",
"double",
"string":
default:
return nil, errors.New("Invalid context type thingy make this better")
}
}
nsContextTypes[ns] = ct
}
}
r.Logf("%#v\n", nsContextTypes)

// Step 1: collect. Find all features
vffs, numNamespaces, err := r.findVersionedFeatureFiles(ctx, req.NamespaceFilter, req.FeatureFilter, req.Verify)
if err != nil {
Expand All @@ -326,10 +473,6 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
}
r.Logf("Found %d configs across %d namespaces\n", len(results), numNamespaces)
r.Logf("Compiling...\n")
registry, err := r.registry(ctx, req.Registry)
if err != nil {
return nil, errors.Wrap(err, "registry")
}
concurrency := 50
if len(results) < 50 {
concurrency = len(results)
Expand All @@ -344,7 +487,7 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
nsNameToSegments := make(map[string]map[string]string)
for _, fcr := range results {
if fcr.FeatureName == "segments" {
cf, err := r.CompileFeature(ctx, registry, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion)
cf, err := r.CompileFeature(ctx, registry, nsContextTypes, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion)
if err != nil {
fcr.CompilationError = err
}
Expand Down Expand Up @@ -385,7 +528,7 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
}
fcr.FormattingDiffExists = fmtDiffExists
// compile feature
cf, err := r.CompileFeature(ctx, registry, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion)
cf, err := r.CompileFeature(ctx, registry, nsContextTypes, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion)
fcr.CompiledFeature = cf
fcr.CompilationError = err
}
Expand Down

0 comments on commit a44e9ed

Please sign in to comment.