diff --git a/cmd/lekko/gen.go b/cmd/lekko/gen.go index c734fa84..ed54445c 100644 --- a/cmd/lekko/gen.go +++ b/cmd/lekko/gen.go @@ -16,6 +16,7 @@ package main import ( "bytes" + "encoding/json" "fmt" "os" "os/exec" @@ -25,11 +26,15 @@ import ( "text/template" featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1" + rulesv1beta3 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta3" + "github.com/lainio/err2/try" "github.com/lekkodev/cli/pkg/repo" "github.com/lekkodev/cli/pkg/secrets" "github.com/pkg/errors" "github.com/spf13/cobra" + strcase "github.com/stoewer/go-strcase" "golang.org/x/mod/modfile" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) @@ -56,6 +61,9 @@ func genGoCmd() *cobra.Command { if err != nil { return errors.Wrap(err, "new repo") } + _, nsMDs := try.To2(r.ParseMetadata(cmd.Context())) + + staticCtxType := unpackProtoType(moduleRoot, nsMDs[ns].ContextProto) ffs, err := r.GetFeatureFiles(cmd.Context(), ns) if err != nil { return err @@ -66,6 +74,9 @@ func genGoCmd() *cobra.Command { var protoAsByteStrings []string var codeStrings []string protoImportSet := make(map[string]*protoImport) + if staticCtxType != nil { + protoImportSet[staticCtxType.ImportPath] = staticCtxType + } for _, ff := range ffs { fff, err := os.ReadFile(wd + "/" + ns + "/" + ff.CompiledProtoBinFileName) if err != nil { @@ -75,13 +86,13 @@ func genGoCmd() *cobra.Command { if err := proto.Unmarshal(fff, f); err != nil { return err } - codeString, err := genGoForFeature(f, ns) + codeString, err := genGoForFeature(f, ns, staticCtxType) if err != nil { return err } if f.Type == featurev1beta1.FeatureType_FEATURE_TYPE_PROTO { protoImport := unpackProtoType(moduleRoot, f.Tree.Default.TypeUrl) - protoImportSet[protoImport.ImportPath] = &protoImport + protoImportSet[protoImport.ImportPath] = protoImport } protoAsBytes := fmt.Sprintf("\t\t\"%s\": []byte{", f.Key) for idx, b := range fff { @@ -166,7 +177,6 @@ var StaticConfig = map[string]map[string][]byte{ "--include-imports", wd) // #nosec G204 pCmd.Dir = "." - fmt.Println("executing in wd: " + wd + " command: " + pCmd.String()) if out, err := pCmd.CombinedOutput(); err != nil { fmt.Println("this is the error probably") fmt.Println(string(out)) @@ -210,17 +220,18 @@ var genCmd = &cobra.Command{ Short: "generate library code from configs", } -func genGoForFeature(f *featurev1beta1.Feature, ns string) (string, error) { +func genGoForFeature(f *featurev1beta1.Feature, ns string, staticCtxType *protoImport) (string, error) { const defaultTemplateBody = `// {{$.Description}} func (c *LekkoClient) {{$.FuncName}}(ctx context.Context) ({{$.RetType}}, error) { return c.{{$.GetFunction}}(ctx, "{{$.Namespace}}", "{{$.Key}}") } // {{$.Description}} -func (c *SafeLekkoClient) {{$.FuncName}}(ctx context.Context) {{$.RetType}} { +{{if $.NaturalLanguage}}func (c *SafeLekkoClient) {{$.FuncName}}(ctx *{{$.StaticType}}) {{$.RetType}} { +{{range $.NaturalLanguage}}{{ . }} +{{end}}{{else}}func (c *SafeLekkoClient) {{$.FuncName}}(ctx context.Context) {{$.RetType}} { return c.{{$.GetFunction}}(ctx, "{{$.Namespace}}", "{{$.Key}}") -} -` +{{end}}}` const protoTemplateBody = `// {{$.Description}} func (c *LekkoClient) {{$.FuncName}}(ctx context.Context) (*{{$.RetType}}, error) { @@ -255,6 +266,8 @@ func (c *SafeLekkoClient) {{$.FuncName}}(ctx context.Context, result interface{} var retType string var getFunction string templateBody := defaultTemplateBody + var natty []string + switch f.Type { case 1: retType = "bool" @@ -268,6 +281,7 @@ func (c *SafeLekkoClient) {{$.FuncName}}(ctx context.Context, result interface{} case 4: retType = "string" getFunction = "GetString" + natty = translateFeature(f) case 5: getFunction = "GetJSON" templateBody = jsonTemplateBody @@ -281,12 +295,14 @@ func (c *SafeLekkoClient) {{$.FuncName}}(ctx context.Context, result interface{} } data := struct { - Description string - FuncName string - GetFunction string - RetType string - Namespace string - Key string + Description string + FuncName string + GetFunction string + RetType string + Namespace string + Key string + NaturalLanguage []string + StaticType string }{ f.Description, funcName, @@ -294,6 +310,8 @@ func (c *SafeLekkoClient) {{$.FuncName}}(ctx context.Context, result interface{} retType, ns, f.Key, + natty, + fmt.Sprintf("%s.%s", staticCtxType.PackageAlias, staticCtxType.Type), } templ, err := template.New("go func").Parse(templateBody) if err != nil { @@ -310,26 +328,82 @@ type protoImport struct { Type string } -func unpackProtoType(moduleRoot string, typeURL string) protoImport { +// This function handles both the google.protobuf.Any.TypeURL variable +// which has the format of `types.googleapis.com/fully.qualified.Proto` +// and purely `fully.qualified.Proto` +// +// return nil if typeURL is empty. Panics on any problems like the rest of the file. +func unpackProtoType(moduleRoot string, typeURL string) *protoImport { + if typeURL == "" { + return nil + } anyURLSplit := strings.Split(typeURL, "/") - if anyURLSplit[0] != "type.googleapis.com" { - panic("invalid any type url: " + typeURL) + fqType := anyURLSplit[0] + if len(anyURLSplit) > 1 { + if anyURLSplit[0] != "type.googleapis.com" { + panic("invalid any type url: " + typeURL) + } + fqType = anyURLSplit[1] } + // turn default.config.v1beta1.DBConfig into: // moduleRoot/internal/lekko/proto/default/config/v1beta1 - typeParts := strings.Split(anyURLSplit[1], ".") + typeParts := strings.Split(fqType, ".") importPath := strings.Join(append([]string{moduleRoot + "/internal/lekko/proto"}, typeParts[:len(typeParts)-1]...), "/") prefix := fmt.Sprintf(`%s%s`, typeParts[len(typeParts)-3], typeParts[len(typeParts)-2]) // TODO do google.protobuf.X - switch anyURLSplit[1] { + switch fqType { case "google.protobuf.Duration": importPath = "google.golang.org/protobuf/types/known/durationpb" prefix = "durationpb" default: } - return protoImport{PackageAlias: prefix, ImportPath: importPath, Type: typeParts[len(typeParts)-1]} + return &protoImport{PackageAlias: prefix, ImportPath: importPath, Type: typeParts[len(typeParts)-1]} +} + +func translateFeature(f *featurev1beta1.Feature) []string { + var buffer []string + for i, constraint := range f.Tree.Constraints { + ifToken := "} else if" + if i == 0 { + ifToken = "if" + } + rule := translateRule(constraint.GetRuleAstNew()) + buffer = append(buffer, fmt.Sprintf("\t%s %s {", ifToken, rule)) + + // TODO this doesn't work for proto, but let's try + + buffer = append(buffer, fmt.Sprintf("\t\treturn %s", try.To1(protojson.Marshal(try.To1(constraint.Value.UnmarshalNew()))))) + } + if len(f.Tree.Constraints) > 0 { + buffer = append(buffer, "\t}") + } + buffer = append(buffer, fmt.Sprintf("\treturn %s", try.To1(protojson.Marshal(try.To1(f.Tree.Default.UnmarshalNew()))))) + return buffer +} + +func translateRule(rule *rulesv1beta3.Rule) string { + if rule == nil { + return "" + } + switch v := rule.GetRule().(type) { + case *rulesv1beta3.Rule_Atom: + switch v.Atom.GetComparisonOperator() { + case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS: + b, err := json.Marshal(v.Atom.ComparisonValue) + if err != nil { + panic(err) + } + return fmt.Sprintf("ctx.%s == %s", strcase.UpperCamelCase(v.Atom.ContextKey), string(b)) + case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN: + // TODO, probably logical to have this here but we need slice syntax, use slices as of golang 1.21 + } + case *rulesv1beta3.Rule_LogicalExpression: + // TODO do some ands and ors + } + return "" } diff --git a/go.mod b/go.mod index 0fa39d1b..75c28565 100644 --- a/go.mod +++ b/go.mod @@ -12,12 +12,14 @@ require ( github.com/go-git/go-billy/v5 v5.4.1 github.com/go-git/go-git/v5 v5.8.0 github.com/google/go-github/v52 v52.0.0 + github.com/lainio/err2 v0.9.51 github.com/lekkodev/go-sdk v0.2.6-0.20230830172236-f072eb8bf64e github.com/lekkodev/rules v1.5.3-0.20230724195144-d0ed93c3e218 github.com/migueleliasweb/go-github-mock v0.0.16 github.com/mitchellh/go-homedir v1.1.0 github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.5.0 + github.com/stoewer/go-strcase v1.2.0 github.com/stretchr/testify v1.8.0 github.com/whilp/git-urls v1.0.0 golang.org/x/mod v0.8.0 diff --git a/go.sum b/go.sum index 21d5be02..623968ac 100644 --- a/go.sum +++ b/go.sum @@ -156,6 +156,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= +github.com/lainio/err2 v0.9.51 h1:R9lzMUxvP0T2U1C5DnpQG0FMMYooqGjUypP6aVmw+eY= +github.com/lainio/err2 v0.9.51/go.mod h1:glTVV2qNFbBy6WzZFDP2G5BqMiZI58cudp588cEgCuM= github.com/lekkodev/go-sdk v0.2.6-0.20230830172236-f072eb8bf64e h1:UO23VqLwbW0NC8yP5dY7S9jBX/liGodXNSRnGBq02Js= github.com/lekkodev/go-sdk v0.2.6-0.20230830172236-f072eb8bf64e/go.mod h1:zJ3izZC3/2MvgKrM1O4kV3PcaBHhCmmFziH1ls1APFI= github.com/lekkodev/rules v1.5.3-0.20230724195144-d0ed93c3e218 h1:ULjfHubYgiEHrGdwcfNpAg+DNQCWMaU/zBNPUQNDNBE= @@ -208,6 +210,7 @@ github.com/spf13/cobra v1.5.0 h1:X+jTBEBqF0bHN+9cSMgmfuvv2VHJ9ezmFNf9Y/XstYU= github.com/spf13/cobra v1.5.0/go.mod h1:dWXEIy2H428czQCjInthrTRUg7yKbok+2Qi/yBIJoUM= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= diff --git a/pkg/metadata/metadata.go b/pkg/metadata/metadata.go index 62b3259c..5a7bfaf2 100644 --- a/pkg/metadata/metadata.go +++ b/pkg/metadata/metadata.go @@ -50,8 +50,9 @@ type RootConfigRepoMetadata struct { type NamespaceConfigRepoMetadata struct { // This version refers to the version of the configuration in the repo itself. // TODO we should move this to a separate version number. - Version string `json:"version,omitempty" yaml:"version,omitempty"` - Name string `json:"name,omitempty" yaml:"name,omitempty"` + Version string `json:"version,omitempty" yaml:"version,omitempty"` + Name string `json:"name,omitempty" yaml:"name,omitempty"` + ContextProto string `json:"contextProto,omitempty" yaml:"contextProto,omitempty"` } const DefaultRootConfigRepoMetadataFileName = "lekko.root.yaml" diff --git a/pkg/repo/feature.go b/pkg/repo/feature.go index c774d6f9..3b71b37f 100644 --- a/pkg/repo/feature.go +++ b/pkg/repo/feature.go @@ -25,6 +25,7 @@ import ( "sync" featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1" + rulesv1beta3 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta3" "github.com/go-git/go-git/v5/plumbing" "github.com/lekkodev/cli/pkg/encoding" "github.com/lekkodev/cli/pkg/feature" @@ -39,6 +40,7 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/structpb" ) // Provides functionality needed for accessing and making changes to Lekko configuration. @@ -67,7 +69,110 @@ type ConfigurationStore interface { RestoreWorkingDirectory(hash string) error } -func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry.Types, namespace, featureName string, nv feature.NamespaceVersion) (*feature.CompiledFeature, error) { +func checkRuleFitsContextType(rule *rulesv1beta3.Rule, contextType protoreflect.MessageType) error { + var err error + switch r := rule.Rule.(type) { + case *rulesv1beta3.Rule_BoolConst: + return nil + case *rulesv1beta3.Rule_Not: + return checkRuleFitsContextType(r.Not, contextType) + case *rulesv1beta3.Rule_LogicalExpression: + for _, lr := range r.LogicalExpression.GetRules() { + err = checkRuleFitsContextType(lr, contextType) + if err != nil { + return err + } + } + case *rulesv1beta3.Rule_Atom: + field := contextType.Descriptor().Fields().ByName(protoreflect.Name(r.Atom.GetContextKey())) + if field == nil { + return errors.Errorf("`%s` field not found in context type", r.Atom.GetContextKey()) + } + fieldType := field.Kind().String() + + switch r.Atom.ComparisonOperator { + case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_PRESENT: + return nil + case + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_NOT_EQUALS: + switch r.Atom.GetComparisonValue().Kind.(type) { + case *structpb.Value_BoolValue: + if fieldType != "string" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + + case *structpb.Value_NumberValue: + if fieldType != "int64" && fieldType != "double" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + case *structpb.Value_StringValue: + if fieldType != "string" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + default: + panic("This should never happen") + } + case + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN_OR_EQUALS, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN_OR_EQUALS: + if fieldType != "int64" && fieldType != "double" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + return nil + case + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH, + rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINS: + if fieldType != "string" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + return nil + case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN: + listRuleVal, ok := r.Atom.GetComparisonValue().Kind.(*structpb.Value_ListValue) + if !ok { + panic("This should never happen") + } + for _, elem := range listRuleVal.ListValue.Values { + switch elem.Kind.(type) { + case *structpb.Value_BoolValue: + if fieldType != "string" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + + case *structpb.Value_NumberValue: + if fieldType != "int64" && fieldType != "double" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + case *structpb.Value_StringValue: + if fieldType != "string" { + return errors.Errorf("%s field has invalid type", r.Atom.GetContextKey()) + } + default: + panic("This should never happen") + } + } + return nil + } + return nil + case *rulesv1beta3.Rule_CallExpression: + switch f := r.CallExpression.Function.(type) { + case *rulesv1beta3.CallExpression_Bucket_: + contextKey := f.Bucket.ContextKey + if contextType.Descriptor().Fields().ByName(protoreflect.Name(contextKey)) == nil { + return errors.Errorf("%s field not found in context type", contextKey) + } + return nil + } + default: + panic("This should never happen") + } + return nil +} + +func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry.Types, nsContextTypes map[string]protoreflect.MessageType, namespace, featureName string, nv feature.NamespaceVersion) (*feature.CompiledFeature, error) { if !isValidName(namespace) { return nil, errors.Errorf("invalid name '%s'", namespace) } @@ -92,6 +197,14 @@ func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry if err != nil { return nil, errors.Wrap(err, "compile") } + if nsContextTypes[namespace] != nil { + for _, override := range f.Feature.Overrides { + err = checkRuleFitsContextType(override.RuleASTV3, nsContextTypes[namespace]) + if err != nil { + return nil, err + } + } + } return f, nil } @@ -303,6 +416,39 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu if err := req.Validate(); err != nil { return nil, errors.Wrap(err, "validate request") } + + _, nsMDs, err := r.ParseMetadata(ctx) + if err != nil { + return nil, errors.Wrap(err, "parse metadata") + } + registry, err := r.registry(ctx, req.Registry) + if err != nil { + return nil, errors.Wrap(err, "registry") + } + nsContextTypes := make(map[string]protoreflect.MessageType) + for ns, nsMd := range nsMDs { + if nsMd.ContextProto != "" { + r.Logf("%s: %s\n", ns, nsMd.ContextProto) + ct, err := registry.FindMessageByName(protoreflect.FullName(nsMd.ContextProto)) + if err != nil { + return nil, err + } + for i := 0; i < ct.Descriptor().Fields().Len(); i++ { + f := ct.Descriptor().Fields().Get(i) + switch f.Kind().String() { + case + "bool", + "int64", + "double", + "string": + default: + return nil, errors.New("Invalid context type thingy make this better") + } + } + nsContextTypes[ns] = ct + } + } + // Step 1: collect. Find all features vffs, numNamespaces, err := r.findVersionedFeatureFiles(ctx, req.NamespaceFilter, req.FeatureFilter, req.Verify) if err != nil { @@ -326,10 +472,6 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu } r.Logf("Found %d configs across %d namespaces\n", len(results), numNamespaces) r.Logf("Compiling...\n") - registry, err := r.registry(ctx, req.Registry) - if err != nil { - return nil, errors.Wrap(err, "registry") - } concurrency := 50 if len(results) < 50 { concurrency = len(results) @@ -344,7 +486,7 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu nsNameToSegments := make(map[string]map[string]string) for _, fcr := range results { if fcr.FeatureName == "segments" { - cf, err := r.CompileFeature(ctx, registry, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion) + cf, err := r.CompileFeature(ctx, registry, nsContextTypes, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion) if err != nil { fcr.CompilationError = err } @@ -385,7 +527,7 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu } fcr.FormattingDiffExists = fmtDiffExists // compile feature - cf, err := r.CompileFeature(ctx, registry, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion) + cf, err := r.CompileFeature(ctx, registry, nsContextTypes, fcr.NamespaceName, fcr.FeatureName, fcr.NamespaceVersion) fcr.CompiledFeature = cf fcr.CompilationError = err }