diff --git a/cmd/lekko/sync.go b/cmd/lekko/sync.go index fc33130a..76bf8e2b 100644 --- a/cmd/lekko/sync.go +++ b/cmd/lekko/sync.go @@ -24,8 +24,6 @@ import ( "path/filepath" "strings" - "golang.org/x/mod/modfile" - "github.com/iancoleman/strcase" "github.com/lainio/err2" "github.com/lainio/err2/try" @@ -39,7 +37,6 @@ import ( "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/dynamicpb" "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/protobuf/types/known/durationpb" bffv1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/bff/v1beta1" featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1" @@ -69,15 +66,7 @@ func syncGoCmd() *cobra.Command { Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - b, err := os.ReadFile("go.mod") - if err != nil { - return errors.Wrap(err, "find go.mod in working directory") - } - mf, err := modfile.ParseLax("go.mod", b, nil) - if err != nil { - return err - } - + var err error if len(repoPath) == 0 { repoPath, err = repo.PrepareGithubRepo() if err != nil { @@ -85,11 +74,12 @@ func syncGoCmd() *cobra.Command { } } f = args[0] - syncer, err := sync.NewGoSyncer(mf.Module.Mod.Path, f) + syncer := sync.NewGoSyncer() + repoContents, err := syncer.Sync(f) if err != nil { - return errors.Wrap(err, "initialize code syncer") + return err } - _, err = syncer.Sync(ctx, &repoPath) + err = sync.WriteContentsToLocalRepo(ctx, repoContents, repoPath) if err != nil { return err } @@ -222,30 +212,24 @@ func isSame(ctx context.Context, existing map[string]map[string]*featurev1beta1. if err != nil { return false, err } - b, err := os.ReadFile("go.mod") - if err != nil { - return false, err - } - mf, err := modfile.ParseLax("go.mod", b, nil) - if err != nil { - return false, err - } dot := try.To1(dotlekko.ReadDotLekko("")) nlProject := try.To1(native.DetectNativeLang("")) files := try.To1(native.ListNativeConfigFiles(dot.LekkoPath, nlProject.Language)) var notEqual bool + var relPaths []string for _, f := range files { - relativePath, err := filepath.Rel(wd, f) + relPath, err := filepath.Rel(wd, f) if err != nil { return false, err } - //fmt.Printf("%s\n\n", mf.Module.Mod.Path) - g := sync.NewGoSyncerLite(mf.Module.Mod.Path, relativePath) - namespace, err := g.Sync(ctx, nil) - if err != nil { - return false, err - } - //fmt.Printf("%#v\n", namespace) + relPaths = append(relPaths, relPath) + } + g := sync.NewGoSyncer() + repoContents, err := g.Sync(relPaths...) + if err != nil { + return false, err + } + for _, namespace := range repoContents.Namespaces { existingNs, ok := existing[namespace.Name] if !ok { // New namespace not in existing @@ -485,7 +469,7 @@ func convertLangCmd() *cobra.Command { } fmt.Println(out) } else { - privateFile := goToGo(ctx, f) + privateFile := goToGo(ctx, inputFile) fmt.Println(privateFile) } return nil @@ -497,34 +481,25 @@ func convertLangCmd() *cobra.Command { return cmd } -func goToGo(ctx context.Context, f []byte) string { - registry, err := prototypes.RegisterDynamicTypes(nil) +func goToGo(ctx context.Context, filePath string) string { + syncer := sync.NewGoSyncer() + repoContents, err := syncer.Sync(filePath) if err != nil { - panic(err) + panic(errors.Wrap(err, "sync")) } - err = registry.AddFileDescriptor(durationpb.File_google_protobuf_duration_proto, false) - if err != nil { - panic(err) - } - syncer := sync.NewGoSyncerLite("", "") - namespace, err := syncer.SourceToNamespace(ctx, f) - if err != nil { - panic(err) + if len(repoContents.Namespaces) != 1 { + panic("expected 1 namespace") } + namespace := repoContents.Namespaces[0] //fmt.Printf("%+v\n", namespace) //fmt.Printf("%+v\n", registry.Types) //fmt.Print("ON TO GENERATION\n") // code gen based off that namespace object - g, err := gen.NewGoGenerator("", "/tmp", "", "", namespace.Name) // type registry? - if err != nil { - panic(err) - } - tr, err := syncer.GetTypeRegistry() - g.TypeRegistry = tr + g, err := gen.NewGoGenerator("", "/tmp", "", repoContents) if err != nil { panic(err) } - _, privateFile, err := g.GenNamespaceFiles(ctx, namespace.Features, nil) + _, privateFile, err := g.GenNamespaceFiles(ctx, namespace.Name, namespace.Features, nil) if err != nil { panic(err) } diff --git a/cmd/lekko/sync_test.go b/cmd/lekko/sync_test.go index 0dd25043..b564dad7 100644 --- a/cmd/lekko/sync_test.go +++ b/cmd/lekko/sync_test.go @@ -17,11 +17,20 @@ package main import ( //"bytes" "context" + "fmt" + "io/fs" "os" "os/exec" + "path/filepath" + "strings" //"os/exec" "testing" + + "github.com/lekkodev/cli/pkg/gen" + "github.com/lekkodev/cli/pkg/sync" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" //"google.golang.org/protobuf/types/descriptorpb" ) @@ -95,61 +104,6 @@ func Test_writeProtoFiles(t *testing.T) { } */ -func Test_goToGo(t *testing.T) { - t.Run("simple", func(t *testing.T) { - ctx := context.Background() - f, err := os.ReadFile("./testdata/simple.go") - if err != nil { - panic(err) - } - if got := goToGo(ctx, f); got != string(f) { - diff, err := DiffStyleOutput(string(f), got) - if err != nil { - panic(err) - } - t.Errorf("Difference Found: %s\n", diff) - } - }) - t.Run("withcontext", func(t *testing.T) { - ctx := context.Background() - f, err := os.ReadFile("./testdata/withcontext.go") - if err != nil { - panic(err) - } - if got := goToGo(ctx, f); got != string(f) { - t.Errorf("goToGo() = \n===\n%v+++, want \n===\n%v+++", got, string(f)) - } - }) - - t.Run("one_two", func(t *testing.T) { - ctx := context.Background() - f, err := os.ReadFile("./testdata/twostructs.go") - if err != nil { - panic(err) - } - if got := goToGo(ctx, f); got != string(f) { - t.Errorf("goToGo() = \n===\n%v+++, want \n===\n%v+++", got, string(f)) - } - }) -} - -func Test_Gertrude(t *testing.T) { - t.Run("gertrude", func(t *testing.T) { - ctx := context.Background() - f, err := os.ReadFile("./testdata/gertrude.go") - if err != nil { - panic(err) - } - if got := goToGo(ctx, f); got != string(f) { - diff, err := DiffStyleOutput(string(f), got) - if err != nil { - panic(err) - } - t.Errorf("Difference Found: %s\n", diff) - } - }) -} - func DiffStyleOutput(a, b string) (string, error) { // Create temporary files to hold the input strings fileA, err := os.CreateTemp("", "fileA") @@ -189,21 +143,72 @@ func DiffStyleOutput(a, b string) (string, error) { return string(output), nil } -func TestDefault(t *testing.T) { - t.Run("default", func(t *testing.T) { - ctx := context.Background() - f, err := os.ReadFile("./testdata/default.go") - if err != nil { - panic(err) - } - if got := goToGo(ctx, f); got != string(f) { - diff, err := DiffStyleOutput(string(f), got) - if err != nil { - panic(err) - } - t.Errorf("Difference Found: %s\n", diff) +// Test code -> repo -> code (compare) -> repo (compare) +func TestGoSyncToGenToSync(t *testing.T) { + if err := filepath.WalkDir("./testdata", func(path string, d fs.DirEntry, err error) error { + if !d.IsDir() && strings.HasSuffix(d.Name(), ".go") { + t.Run(strings.TrimSuffix(d.Name(), ".go"), func(t *testing.T) { + ctx := context.Background() + orig, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read test file %s: %v", path, err) + } + + s1 := sync.NewGoSyncer() + r1, err := s1.Sync(path) + if err != nil { + t.Fatalf("sync 1: %v", err) + } + + tmpd, err := os.MkdirTemp("", "test") + if err != nil { + t.Fatalf("tmp dir: %v", err) + } + defer os.RemoveAll(tmpd) + + g, err := gen.NewGoGenerator("test", tmpd, "", r1) + if err != nil { + t.Fatalf("initialize gen: %v", err) + } + namespace := r1.Namespaces[0] + _, private, err := g.GenNamespaceFiles(ctx, namespace.Name, namespace.Features, nil) + if err != nil { + t.Fatalf("gen: %v", err) + } + privatePath := filepath.Join(tmpd, fmt.Sprintf("%s.go", namespace.Name)) + if err := os.WriteFile(privatePath, []byte(private), 0600); err != nil { + t.Fatalf("write private %s: %v", privatePath, err) + } + + if string(orig) != private { + diff, err := DiffStyleOutput(string(orig), private) + if err != nil { + t.Fatalf("diff: %v", err) + } + t.Fatalf("mismatch in generated code: %s", diff) + } + + s2 := sync.NewGoSyncer() + r2, err := s2.Sync(privatePath) + if err != nil { + t.Fatalf("sync 2: %v", err) + } + // NOTE: Because Anys contained serialized values, their serialization needs to be + // deterministic for this check to always pass even if their deserialized values are equal. + // The problem is that the determinism is not canonical across languages. + // We should keep this in mind. + if !proto.Equal(r1, r2) { + r1json := protojson.Format(r1) + r2json := protojson.Format(r2) + diff, _ := DiffStyleOutput(r1json, r2json) + t.Fatalf("mismatch in repo contents: %s", diff) + } + }) } - }) + return nil + }); err != nil { + t.Fatalf("walk: %v", err) + } } /* diff --git a/go.mod b/go.mod index 7236fd54..ee0be1ba 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ replace github.com/bazelbuild/buildtools => github.com/lekkodev/buildtools v0.0. require ( buf.build/gen/go/lekkodev/cli/bufbuild/connect-go v1.10.0-20240528213244-5fdc18b47eea.1 - buf.build/gen/go/lekkodev/cli/protocolbuffers/go v1.34.2-20240801183157-5b3ff9a64103.2 + buf.build/gen/go/lekkodev/cli/protocolbuffers/go v1.34.2-20240923164736-6b09ba83efbf.2 github.com/AlecAivazis/survey/v2 v2.3.6 github.com/atotto/clipboard v0.1.4 github.com/bazelbuild/buildtools v0.0.0-20220907133145-b9bfff5d7f91 diff --git a/go.sum b/go.sum index 3062dd0f..62275806 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ buf.build/gen/go/lekkodev/cli/bufbuild/connect-go v1.10.0-20240528213244-5fdc18b47eea.1 h1:JqArhl+OClAdLQis1N2N6WmLv96CbOaNrQEYWL2ntlI= buf.build/gen/go/lekkodev/cli/bufbuild/connect-go v1.10.0-20240528213244-5fdc18b47eea.1/go.mod h1:gkMKhhTCMDLJVmimyqao6P3g7jB7wDe3r7u8hV2iShE= -buf.build/gen/go/lekkodev/cli/protocolbuffers/go v1.34.2-20240801183157-5b3ff9a64103.2 h1:ASguPgU4ltdoHHy+YTlpg79TDwhHpWBk3MKmxfbaE7Y= -buf.build/gen/go/lekkodev/cli/protocolbuffers/go v1.34.2-20240801183157-5b3ff9a64103.2/go.mod h1:j/ek65dWz+D5GM7p9QUiHQj5X5gtRUMfGl1+GpSfm6g= +buf.build/gen/go/lekkodev/cli/protocolbuffers/go v1.34.2-20240923164736-6b09ba83efbf.2 h1:JEiSyCH0wTycYIeIpm8xs/HcyPgHtJpdKBrUer8Jq6E= +buf.build/gen/go/lekkodev/cli/protocolbuffers/go v1.34.2-20240923164736-6b09ba83efbf.2/go.mod h1:j/ek65dWz+D5GM7p9QUiHQj5X5gtRUMfGl1+GpSfm6g= buf.build/gen/go/lekkodev/sdk/protocolbuffers/go v1.34.2-20230810202034-1c821065b9a0.2 h1:ZEir2Lbw+XH5Dlnqiv0FUc8hC7QdUMpGSCIiREMriJ0= buf.build/gen/go/lekkodev/sdk/protocolbuffers/go v1.34.2-20230810202034-1c821065b9a0.2/go.mod h1:YAvVDcY/tXuUXkpfm3LHCD6vz9SPv73CktuPGgqzJkI= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= diff --git a/pkg/gen/gen.go b/pkg/gen/gen.go index 60e2f3f8..32df5c09 100644 --- a/pkg/gen/gen.go +++ b/pkg/gen/gen.go @@ -18,8 +18,11 @@ import ( "context" "os" "path/filepath" + "sort" + featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1" "golang.org/x/mod/modfile" + "google.golang.org/protobuf/proto" "github.com/lainio/err2" "github.com/lainio/err2/try" @@ -88,6 +91,7 @@ func GenNative(ctx context.Context, project *native.Project, lekkoPath, repoPath } } +// TODO: move to golang and split repo/repoless func genFormattedGo(ctx context.Context, project *native.Project, repoPath, lekkoPath string, opts GenOptions) (err error) { defer err2.Handle(&err) var moduleRoot string @@ -100,13 +104,49 @@ func genFormattedGo(ctx context.Context, project *native.Project, repoPath, lekk moduleRoot = mf.Module.Mod.Path } outputPath := filepath.Join(opts.CodeRepoPath, lekkoPath) + generator := try.To1(NewGoGeneratorFromLocal(ctx, moduleRoot, outputPath, lekkoPath, repoPath)) for _, namespace := range opts.Namespaces { - generator := try.To1(NewGoGenerator(moduleRoot, outputPath, lekkoPath, repoPath, namespace)) if opts.InitMode { - try.To(generator.Init(ctx)) + try.To(generator.Init(ctx, namespace)) } else { - try.To(generator.Gen(ctx)) + try.To(generator.Gen(ctx, namespace)) } } return nil } + +// Reads repository contents from a local config repository. +func ReadRepoContents(ctx context.Context, repoPath string) (repoContents *featurev1beta1.RepositoryContents, err error) { + defer err2.Handle(&err) + r, err := repo.NewLocal(repoPath, nil) + if err != nil { + return nil, errors.Wrap(err, "read config repository") + } + rootMD, nsMDs := try.To2(r.ParseMetadata(ctx)) + repoContents = &featurev1beta1.RepositoryContents{} + repoContents.FileDescriptorSet = try.To1(r.GetFileDescriptorSet(ctx, rootMD.ProtoDirectory)) + for nsName := range nsMDs { + ns := &featurev1beta1.Namespace{Name: nsName} + ffs, err := r.GetFeatureFiles(ctx, nsName) + if err != nil { + return nil, errors.Wrapf(err, "read files for ns %s", nsName) + } + // Sort configs in alphabetical order + sort.SliceStable(ffs, func(i, j int) bool { + return ffs[i].CompiledProtoBinFileName < ffs[j].CompiledProtoBinFileName + }) + for _, ff := range ffs { + fc, err := r.GetFeatureContents(ctx, nsName, ff.Name) + if err != nil { + return nil, errors.Wrapf(err, "read contents for %s/%s", nsName, ff.Name) + } + f := &featurev1beta1.Feature{} + if err := proto.Unmarshal(fc.Proto, f); err != nil { + return nil, errors.Wrapf(err, "unmarshal %s/%s", nsName, ff.Name) + } + ns.Features = append(ns.Features, f) + } + repoContents.Namespaces = append(repoContents.Namespaces, ns) + } + return repoContents, nil +} diff --git a/pkg/gen/golang.go b/pkg/gen/golang.go index 1a467fbd..ebdd5ec6 100644 --- a/pkg/gen/golang.go +++ b/pkg/gen/golang.go @@ -36,6 +36,7 @@ import ( "github.com/lainio/err2/assert" "github.com/lainio/err2/try" "github.com/lekkodev/cli/pkg/dotlekko" + protoutils "github.com/lekkodev/cli/pkg/proto" "github.com/lekkodev/cli/pkg/repo" "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" @@ -48,31 +49,45 @@ import ( "github.com/lekkodev/cli/pkg/native" ) -// TODO: this can hold more state to clean up functions a bit, like storing usedVariables, etc. type goGenerator struct { moduleRoot string // e.g. github.com/lekkodev/cli outputPath string // Location for destination file, can be absolute or relative. Its suffix should be the same as lekkoPath. In most cases can be same as lekkoPath. lekkoPath string // Location relative to project root where Lekko files are stored, e.g. internal/lekko. - repoPath string - namespace string + repoContents *featurev1beta1.RepositoryContents TypeRegistry *protoregistry.Types } -func NewGoGenerator(moduleRoot, outputPath, lekkoPath, repoPath, namespace string) (*goGenerator, error) { - // Validate namespace - if !regexp.MustCompile("[a-z]+").MatchString(namespace) { - return nil, fmt.Errorf("namespace must be a lowercase alphabetic string: %s", namespace) - } - if namespace == "proto" { - return nil, errors.New("'proto' is a reserved name") +// Initializes a new generator from config repository contents. +func NewGoGenerator(moduleRoot, outputPath, lekkoPath string, repoContents *featurev1beta1.RepositoryContents) (*goGenerator, error) { + typeRegistry, err := protoutils.FileDescriptorSetToTypeRegistry(repoContents.FileDescriptorSet) + if err != nil { + return nil, errors.Wrap(err, "convert fds to type registry") } + return &goGenerator{ + moduleRoot: moduleRoot, + outputPath: outputPath, + lekkoPath: filepath.Clean(lekkoPath), + repoContents: repoContents, + TypeRegistry: typeRegistry, + }, nil +} +// Initializes a new generator, parsing config repository contents from a local repository. +func NewGoGeneratorFromLocal(ctx context.Context, moduleRoot, outputPath, lekkoPath string, repoPath string) (*goGenerator, error) { + repoContents, err := ReadRepoContents(ctx, repoPath) + if err != nil { + return nil, errors.Wrapf(err, "read contents from %s", repoContents) + } + typeRegistry, err := protoutils.FileDescriptorSetToTypeRegistry(repoContents.FileDescriptorSet) + if err != nil { + return nil, errors.Wrap(err, "convert fds to type registry") + } return &goGenerator{ - moduleRoot: moduleRoot, - outputPath: outputPath, - lekkoPath: filepath.Clean(lekkoPath), - repoPath: repoPath, - namespace: namespace, + moduleRoot: moduleRoot, + outputPath: outputPath, + lekkoPath: lekkoPath, + repoContents: repoContents, + TypeRegistry: typeRegistry, }, nil } @@ -91,14 +106,14 @@ func structpbValueToKindStringGo(v *structpb.Value) string { } // Initialize a blank Lekko config function file -func (g *goGenerator) Init(ctx context.Context) error { +func (g *goGenerator) Init(ctx context.Context, namespaceName string) error { const templateBody = `package lekko{{$.Namespace}} // This is an example description for an example config func getExample() bool { return true }` - fullOutputPath := path.Join(g.outputPath, g.namespace, fmt.Sprintf("%s.go", g.namespace)) + fullOutputPath := path.Join(g.outputPath, namespaceName, fmt.Sprintf("%s.go", namespaceName)) if _, err := os.Stat(fullOutputPath); err == nil { return fmt.Errorf("file %s already exists", fullOutputPath) } @@ -106,7 +121,7 @@ func getExample() bool { data := struct { Namespace string }{ - Namespace: g.namespace, + Namespace: namespaceName, } var contents bytes.Buffer templ := template.Must(template.New(fullOutputPath).Parse(templateBody)) @@ -117,7 +132,7 @@ func getExample() bool { if err != nil { return errors.Wrapf(err, "format %s", fullOutputPath) } - try.To(os.MkdirAll(path.Join(g.outputPath, g.namespace), 0770)) + try.To(os.MkdirAll(path.Join(g.outputPath, namespaceName), 0770)) f, err := os.Create(fullOutputPath) if err != nil { return errors.Wrapf(err, "create %s", fullOutputPath) @@ -130,7 +145,7 @@ func getExample() bool { // This is an attempt to pull out a simpler component that is more re-usable - the other one should probably be removed/changed, but that depends on // how far this change goes -func (g *goGenerator) GenNamespaceFiles(ctx context.Context, features []*featurev1beta1.Feature, staticCtxType *ProtoImport) (string, string, error) { +func (g *goGenerator) GenNamespaceFiles(ctx context.Context, namespaceName string, features []*featurev1beta1.Feature, staticCtxType *ProtoImport) (string, string, error) { // For each namespace, we want to generate under lekko/: // lekko/ // / @@ -202,7 +217,7 @@ import ( } } if ctxType == nil { - messagePath := fmt.Sprintf("%s.config.v1beta1.%sArgs", g.namespace, strcase.ToCamel(f.Key)) + messagePath := fmt.Sprintf("%s.config.v1beta1.%sArgs", namespaceName, strcase.ToCamel(f.Key)) mt, err := g.TypeRegistry.FindMessageByName(protoreflect.FullName(messagePath)) if err == nil { privateFuncStrings = append(privateFuncStrings, DescriptorToStructDeclaration(mt.Descriptor())) @@ -229,9 +244,9 @@ import ( importProtoReflect = true } - generated, err := g.genGoForFeature(ctx, nil, f, g.namespace, ctxType) + generated, err := g.genGoForFeature(ctx, nil, f, namespaceName, ctxType) if err != nil { - return "", "", errors.Wrapf(err, "generate code for %s/%s", g.namespace, f.Key) + return "", "", errors.Wrapf(err, "generate code for %s/%s", namespaceName, f.Key) } publicFuncStrings = append(publicFuncStrings, generated.public) privateFuncStrings = append(privateFuncStrings, generated.private) @@ -272,7 +287,7 @@ import ( ImportProtoReflect bool }{ protoImports, - g.namespace, + namespaceName, publicFuncStrings, privateFuncStrings, addStringsImport, @@ -281,11 +296,11 @@ import ( importProtoReflect, } - public, err := getContents(publicFileTemplateBody, fmt.Sprintf("%s_gen.go", g.namespace), data) + public, err := renderGoTemplate(publicFileTemplateBody, fmt.Sprintf("%s_gen.go", namespaceName), data) if err != nil { return "", "", err } - private, err := getContents(privateFileTemplateBody, fmt.Sprintf("%s.go", g.namespace), data) + private, err := renderGoTemplate(privateFileTemplateBody, fmt.Sprintf("%s.go", namespaceName), data) if err != nil { return "", "", err } @@ -293,7 +308,7 @@ import ( return public, private, nil } -func getContents(templateBody string, fileName string, data any) (string, error) { +func renderGoTemplate(templateBody string, fileName string, data any) (string, error) { var contents bytes.Buffer templ := template.Must(template.New(fileName).Parse(templateBody)) if err := templ.Execute(&contents, data); err != nil { @@ -308,58 +323,43 @@ func getContents(templateBody string, fileName string, data any) (string, error) return string(formatted), nil } -func (g *goGenerator) Gen(ctx context.Context) (err error) { +// Generates public and private function files for the namespace as well as the overall client file. +// Writes outputs to the output paths specified in the +// TODO: since generator takes in whole repo contents now, could generate for all/filtered namespaces +func (g *goGenerator) Gen(ctx context.Context, namespaceName string) (err error) { defer err2.Handle(&err) - r, err := repo.NewLocal(g.repoPath, nil) - if err != nil { - return errors.Wrap(err, "read config repository") - } - rootMD, nsMDs := try.To2(r.ParseMetadata(ctx)) - if g.TypeRegistry == nil { - // TODO this feels weird and there is a global set we should be able to add to but I'll worrry about it later? - g.TypeRegistry = try.To1(r.BuildDynamicTypeRegistry(ctx, rootMD.ProtoDirectory)) - } - nsMD, ok := nsMDs[g.namespace] - if !ok { - return fmt.Errorf("%s is not a namespace in the config repository", g.namespace) + // Validate namespace + if !regexp.MustCompile("[a-z]+").MatchString(namespaceName) { + return errors.Errorf("namespace must be a lowercase alphabetic string: %s", namespaceName) } - staticCtxType := UnpackProtoType(g.moduleRoot, g.lekkoPath, nsMD.ContextProto) - ffs, err := r.GetFeatureFiles(ctx, g.namespace) - if err != nil { - return err + if namespaceName == "proto" { + return errors.Errorf("'%s' is a reserved name", namespaceName) } - // Sort configs in alphabetical order - sort.SliceStable(ffs, func(i, j int) bool { - return ffs[i].CompiledProtoBinFileName < ffs[j].CompiledProtoBinFileName - }) - var features []*featurev1beta1.Feature - for _, ff := range ffs { - fff, err := os.ReadFile(path.Join(g.repoPath, g.namespace, ff.CompiledProtoBinFileName)) - if err != nil { - return err + var namespace *featurev1beta1.Namespace + for _, ns := range g.repoContents.Namespaces { + if ns.Name == namespaceName { + namespace = ns } - f := &featurev1beta1.Feature{} - if err := proto.Unmarshal(fff, f); err != nil { - return err - } - features = append(features, f) + } + if namespace == nil { + return errors.Errorf("namespace '%s' not found", namespaceName) } // Create intermediate directories for output - if err := os.MkdirAll(path.Join(g.outputPath, g.namespace), 0770); err != nil { + if err := os.MkdirAll(path.Join(g.outputPath, namespaceName), 0770); err != nil { return err } - public, private, err := g.GenNamespaceFiles(ctx, features, staticCtxType) + public, private, err := g.GenNamespaceFiles(ctx, namespaceName, namespace.Features, nil) if err != nil { return err } - if f, err := os.Create(path.Join(g.outputPath, g.namespace, fmt.Sprintf("%s_gen.go", g.namespace))); err != nil { + if f, err := os.Create(path.Join(g.outputPath, namespaceName, fmt.Sprintf("%s_gen.go", namespaceName))); err != nil { return err } else { if _, err := f.WriteString(public); err != nil { return errors.Wrap(err, fmt.Sprintf("write formatted contents to %s", f.Name())) } } - if f, err := os.Create(path.Join(g.outputPath, g.namespace, fmt.Sprintf("%s.go", g.namespace))); err != nil { + if f, err := os.Create(path.Join(g.outputPath, namespaceName, fmt.Sprintf("%s.go", namespaceName))); err != nil { return err } else { if _, err := f.WriteString(private); err != nil { diff --git a/pkg/gen/ts.go b/pkg/gen/ts.go index 161338a4..1ca1b643 100644 --- a/pkg/gen/ts.go +++ b/pkg/gen/ts.go @@ -167,6 +167,7 @@ func GenTS(ctx context.Context, repoPath, ns string, getWriter func() (io.Writer return err } } + // TODO: generate from contents, maybe split repo/repoless r, err := repo.NewLocal(repoPath, nil) if err != nil { return errors.Wrap(err, "new repo") diff --git a/pkg/sync/proto.go b/pkg/proto/proto.go similarity index 86% rename from pkg/sync/proto.go rename to pkg/proto/proto.go index 97ca0faa..49d54a96 100644 --- a/pkg/sync/proto.go +++ b/pkg/proto/proto.go @@ -12,16 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sync +package proto import ( "fmt" "strings" - featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1" - rulesv1beta2 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta2" - rulesv1beta3 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta3" - "github.com/pkg/errors" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" @@ -32,11 +28,17 @@ import ( "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/wrapperspb" + + featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1" + rulesv1beta2 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta2" + rulesv1beta3 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta3" ) -// Initializes an empty file descriptor set and starts with well-known types +// Initializes an empty file descriptor set and starts with well-known types. +// TODO: currently used for sync, but consider if this should be used in all other places (e.g. repo init, compile, etc.) func NewDefaultFileDescriptorSet() *descriptorpb.FileDescriptorSet { fds := &descriptorpb.FileDescriptorSet{} + fds.File = append(fds.File, protodesc.ToFileDescriptorProto(descriptorpb.File_google_protobuf_descriptor_proto)) fds.File = append(fds.File, protodesc.ToFileDescriptorProto(wrapperspb.File_google_protobuf_wrappers_proto)) fds.File = append(fds.File, protodesc.ToFileDescriptorProto(structpb.File_google_protobuf_struct_proto)) fds.File = append(fds.File, protodesc.ToFileDescriptorProto(durationpb.File_google_protobuf_duration_proto)) @@ -86,9 +88,19 @@ func FileRegistryToFileDescriptorSet(registry *protoregistry.Files) *descriptorp return fds } +func FileDescriptorSetToTypeRegistry(fds *descriptorpb.FileDescriptorSet) (*protoregistry.Types, error) { + fr, err := protodesc.NewFiles(fds) + if err != nil { + return nil, errors.Wrap(err, "convert to file registry") + } + tr, err := FileRegistryToTypeRegistry(fr) + return tr, errors.Wrap(err, "get type registry from file registry") +} + // TODO: Canonical proto formatter that doesn't rely on `buf format`, which we should use for all languages -func FileDescriptorToProtoString(fd protoreflect.FileDescriptor) (string, error) { +// Print the contents of a file descriptor in a format suitable for e.g. writing to a .proto file +func PrintFileDescriptor(fd protoreflect.FileDescriptor) (string, error) { var sb strings.Builder // Preamble sb.WriteString("syntax = \"proto3\";\n\n") @@ -101,7 +113,7 @@ func FileDescriptorToProtoString(fd protoreflect.FileDescriptor) (string, error) // Messages for i := range fd.Messages().Len() { md := fd.Messages().Get(i) - mds, err := MessageDescriptorToProtoString(md, 0) + mds, err := PrintMessageDescriptor(md, 0) if err != nil { return "", errors.Wrapf(err, "stringify message descriptor %s", md.FullName()) } @@ -111,7 +123,7 @@ func FileDescriptorToProtoString(fd protoreflect.FileDescriptor) (string, error) return sb.String(), nil } -func MessageDescriptorToProtoString(md protoreflect.MessageDescriptor, indentLevel int) (string, error) { +func PrintMessageDescriptor(md protoreflect.MessageDescriptor, indentLevel int) (string, error) { indent := strings.Repeat(" ", indentLevel) var sb strings.Builder sb.WriteString(fmt.Sprintf("%smessage %s {\n", indent, md.Name())) @@ -147,7 +159,7 @@ func MessageDescriptorToProtoString(md protoreflect.MessageDescriptor, indentLev if _, ok := mapFieldTypes[string(nestedMds.Get(i).FullName())]; ok { continue } - s, err := MessageDescriptorToProtoString(nestedMds.Get(i), indentLevel+1) + s, err := PrintMessageDescriptor(nestedMds.Get(i), indentLevel+1) if err != nil { return "", errors.Wrapf(err, "stringify nested message %s", nestedMds.Get(i).FullName()) } diff --git a/pkg/sync/golang.go b/pkg/sync/golang.go index ef943cff..f20c9d11 100644 --- a/pkg/sync/golang.go +++ b/pkg/sync/golang.go @@ -15,13 +15,11 @@ package sync import ( - "bytes" "context" "fmt" "go/ast" "go/parser" "go/token" - "io" "io/fs" "os" "path/filepath" @@ -35,24 +33,15 @@ import ( "github.com/pkg/errors" "golang.org/x/mod/modfile" - "path" "strconv" rulesv1beta3 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta3" "github.com/lainio/err2/assert" - "github.com/lainio/err2/try" - "github.com/lekkodev/cli/pkg/feature" - "github.com/lekkodev/cli/pkg/repo" - "github.com/lekkodev/cli/pkg/star" - "github.com/lekkodev/cli/pkg/star/static" - "github.com/lekkodev/go-sdk/pkg/eval" + protoutils "github.com/lekkodev/cli/pkg/proto" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" - "google.golang.org/protobuf/types/dynamicpb" - "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -79,28 +68,7 @@ func BisyncGo(ctx context.Context, outputPath, lekkoPath, repoPath string) ([]st return filepath.SkipDir } // Sync and gen - only target /.go files - // Semi-duplicated logic from Syncer initializer if !d.IsDir() && strings.TrimSuffix(d.Name(), ".go") == filepath.Base(filepath.Dir(p)) { - syncer, err := NewGoSyncer(mf.Module.Mod.Path, p) - if err != nil { - return errors.Wrap(err, "initialize code syncer") - } - if _, err := syncer.Sync(ctx, &repoPath); err != nil { - return errors.Wrapf(err, "sync %s", p) - } - namespace := filepath.Base(filepath.Dir(p)) - generator, err := gen.NewGoGenerator(mf.Module.Mod.Path, outputPath, lekkoPath, repoPath, namespace) - if err != nil { - return errors.Wrap(err, "initialize code generator") - } - typeRegistry, err := syncer.GetTypeRegistry() - if err != nil { - return errors.Wrap(err, "get post-sync type registry") - } - generator.TypeRegistry = typeRegistry - if err := generator.Gen(ctx); err != nil { - return errors.Wrapf(err, "generate code for %s", namespace) - } files = append(files, p) fmt.Printf("Successfully bisynced %s\n", logging.Bold(p)) } @@ -109,15 +77,26 @@ func BisyncGo(ctx context.Context, outputPath, lekkoPath, repoPath string) ([]st }); err != nil { return nil, err } + syncer := NewGoSyncer() + repoContents, err := syncer.Sync(files...) + if err != nil { + return nil, errors.Wrap(err, "sync") + } + if err := WriteContentsToLocalRepo(ctx, repoContents, repoPath); err != nil { + return nil, err + } + generator, err := gen.NewGoGenerator(mf.Module.Mod.Path, outputPath, lekkoPath, repoContents) + if err != nil { + return nil, errors.Wrap(err, "initialize code generator") + } + for _, namespace := range repoContents.Namespaces { + if err := generator.Gen(ctx, namespace.Name); err != nil { + return nil, errors.Wrapf(err, "generate code for %s", namespace.Name) + } + } return files, nil } -// TODO - make this our proto rep? -type Namespace struct { - Name string - Features []*featurev1beta1.Feature -} - func GetDependencies(descriptor *descriptorpb.DescriptorProto) []string { dependencies := make(map[string]struct{}) @@ -153,11 +132,12 @@ func GetDependencies(descriptor *descriptorpb.DescriptorProto) []string { return depList } -func (g *goSyncer) registerMessage(mdp *descriptorpb.DescriptorProto, namespace string) error { +// Registers a message descriptor for a namespace in the passed FDS +func registerMessage(fds *descriptorpb.FileDescriptorSet, mdp *descriptorpb.DescriptorProto, namespace string) error { filePath := fmt.Sprintf("%s/config/v1beta1/%s.proto", namespace, namespace) // Try to find existing file descriptor var fdp *descriptorpb.FileDescriptorProto - for _, file := range g.FDS.File { + for _, file := range fds.File { if file.GetName() == filePath { fdp = file } @@ -168,7 +148,7 @@ func (g *goSyncer) registerMessage(mdp *descriptorpb.DescriptorProto, namespace Name: proto.String(filePath), Package: proto.String(fmt.Sprintf("%s.config.v1beta1", namespace)), } - g.FDS.File = append(g.FDS.File, fdp) + fds.File = append(fds.File, fdp) } // Add message descriptor proto (and check for duplicate register) for _, message := range fdp.MessageType { @@ -196,10 +176,22 @@ func (g *goSyncer) registerMessage(mdp *descriptorpb.DescriptorProto, namespace return nil } -func (g *goSyncer) AstToNamespace(ctx context.Context, pf *ast.File, fset *token.FileSet) (*Namespace, error) { +// Translates Go code to representation of a Lekko repository's contents +type goSyncer struct { + fset *token.FileSet +} + +func NewGoSyncer() *goSyncer { + return &goSyncer{ + fset: token.NewFileSet(), + } +} + +// As a side effect, mutates the passed FileDescriptorSet to register types parsed from the AST. +func (g *goSyncer) AstToNamespace(pf *ast.File, fds *descriptorpb.FileDescriptorSet) (*featurev1beta1.Namespace, error) { // TODO: instead of panicking everywhere, collect errors (maybe using go/analysis somehow) // so we can report them properly (and not look sketchy) - namespace := Namespace{} + namespace := &featurev1beta1.Namespace{} // First pass to get general metadata and register all types ast.Inspect(pf, func(n ast.Node) bool { switch x := n.(type) { @@ -209,24 +201,9 @@ func (g *goSyncer) AstToNamespace(ctx context.Context, pf *ast.File, fset *token panic("packages for lekko must start with 'lekko'") } namespace.Name = x.Name.Name[5:] - g.Namespace = namespace.Name if len(namespace.Name) == 0 { panic("namespace name cannot be empty") } - // Analyze imports to create mapping of proto packages - // Assumes proto packages are under /proto - // and that proto package follows folder structure (e.g. default/config/v1beta1 <-> default.config.v1beta1) - protoDir := filepath.Join(g.moduleRoot, g.lekkoPath, "proto") - for _, is := range x.Imports { - if strings.Contains(is.Path.Value, protoDir) { - if is.Name == nil { - panic("protobuf imports must explicitly specify package aliases") - } - relProtoDir := try.To1(filepath.Rel(protoDir, strings.Trim(is.Path.Value, "\"'"))) - protoPackage := strings.ReplaceAll(relProtoDir, "/", ".") - g.protoPackages[is.Name.Name] = protoPackage - } - } return true case *ast.GenDecl: // TODO: try to handle doc comments using x.Doc and protoreflect.SourceLocation @@ -236,20 +213,16 @@ func (g *goSyncer) AstToNamespace(ctx context.Context, pf *ast.File, fset *token } typeSpec, ok := spec.(*ast.TypeSpec) if !ok { - // TODO: try refactoring so that we can give accurate positions for all errors easily - p := fset.Position(x.Pos()) - panic(fmt.Sprintf("error at %d:%d: only type declarations are supported", p.Line, p.Column)) + panic(g.posErr(x, "only type declarations are supported")) } structType, ok := typeSpec.Type.(*ast.StructType) if !ok { - p := fset.Position(typeSpec.Pos()) - panic(fmt.Sprintf("error at %d:%d: only struct type declarations are supported", p.Line, p.Column)) + panic(g.posErr(typeSpec, "only struct type declarations are supported")) } d := g.structToDescriptor(typeSpec.Name.Name, structType) - err := g.registerMessage(d, namespace.Name) + err := registerMessage(fds, d, namespace.Name) if err != nil { - p := fset.Position(typeSpec.Pos()) - panic(fmt.Sprintf("error at %d:%d: failed to register type for struct", p.Line, p.Column)) + panic(g.posErr(typeSpec, "failed to register type for struct")) } } return true @@ -257,11 +230,10 @@ func (g *goSyncer) AstToNamespace(ctx context.Context, pf *ast.File, fset *token return false } }) - // At this point, we should have processed all types - cache - if tr, err := g.GetTypeRegistry(); err != nil { + // At this point, we should have processed all types + tr, err := protoutils.FileDescriptorSetToTypeRegistry(fds) + if err != nil { return nil, errors.Wrap(err, "pre-process type registry") - } else { - g.typeRegistry = tr } // Second pass to handle all functions ast.Inspect(pf, func(n ast.Node) bool { @@ -292,7 +264,7 @@ func (g *goSyncer) AstToNamespace(ctx context.Context, pf *ast.File, fset *token assert.INotNil(param.Type, "must have a parameter type") typeIdent, ok := param.Type.(*ast.Ident) if !ok { - panic("parameter type must be an identifier") + panic(g.posErr(param, errors.New("parameter type must be an identifier"))) } contextKeys[param.Names[0].Name] = typeIdent.Name } @@ -300,10 +272,10 @@ func (g *goSyncer) AstToNamespace(ctx context.Context, pf *ast.File, fset *token results := x.Type.Results.List if results == nil { - panic("must have a return type") + panic(g.posErr(x, "must have a return type")) } if len(results) != 1 { - panic("must have one return type") + panic(g.posErr(x, "must have exactly one return type")) } switch t := results[0].Type.(type) { @@ -319,291 +291,55 @@ func (g *goSyncer) AstToNamespace(ctx context.Context, pf *ast.File, fset *token feature.Type = featurev1beta1.FeatureType_FEATURE_TYPE_STRING default: // TODO - check if it is one of our structs to allow non * - panic(fmt.Errorf("unsupported primitive return type %s", t.Name)) + panic(g.posErr(t, fmt.Sprintf("unsupported primitive return type %s", t.Name))) } case *ast.StarExpr: feature.Type = featurev1beta1.FeatureType_FEATURE_TYPE_PROTO default: - panic(fmt.Errorf("unsupported return type expression %+v", t)) + panic(g.posErr(t, fmt.Errorf("unsupported return type expression %+v", t))) } for _, stmt := range x.Body.List { switch n := stmt.(type) { case *ast.ReturnStmt: if feature.Tree.Default != nil { - panic("unexpected default value already processed") + panic(g.posErr(n, "unexpected default value already processed")) } // TODO also need to take care of the possibility that the default is in an else - feature.Tree.DefaultNew = g.exprToAny(n.Results[0], feature.Type) // can this be multiple things? + feature.Tree.DefaultNew = g.exprToAny(n.Results[0], feature.Type, namespace.Name, tr) // can this be multiple things? case *ast.IfStmt: - feature.Tree.Constraints = append(feature.Tree.Constraints, g.ifToConstraints(n, feature.Type, contextKeys)...) + feature.Tree.Constraints = append(feature.Tree.Constraints, g.ifToConstraints(n, feature.Type, contextKeys, namespace.Name, tr)...) default: - panic("only if and return statements allowed in function body") + panic(g.posErr(n, "only if and return statements allowed in function body")) } } return false } - panic(fmt.Sprintf("sync %s: only functions like 'getConfig' are supported", x.Name.Name)) + panic(g.posErr(x.Name, "only function names like 'getConfig' are supported")) } return true }) // TODO static context - return &namespace, nil -} - -func (g *goSyncer) SourceToNamespace(ctx context.Context, src []byte) (*Namespace, error) { - if bytes.Contains(src, []byte("<<<<<<<")) { - return nil, fmt.Errorf("%s has unresolved merge conflicts", g.filePath) - } - fset := token.NewFileSet() - fset.AddFile(g.filePath, fset.Base(), len(src)) - pf, err := parser.ParseFile(fset, g.filePath, src, parser.ParseComments) - if err != nil { - return nil, err - } - - return g.AstToNamespace(ctx, pf, fset) -} - -// Translates Go code to Protobuf/Starlark and writes changes to local config repository -type goSyncer struct { - moduleRoot string // e.g. github.com/lekkodev/cli - lekkoPath string - filePath string // Path to Go source file to sync - - FDS *descriptorpb.FileDescriptorSet - typeRegistry *protoregistry.Types - protoPackages map[string]string // Map of local package names to protobuf packages (e.g. configv1beta1 -> default.config.v1beta1) - Namespace string -} - -func NewGoSyncer(moduleRoot, filePath string) (*goSyncer, error) { - // Validate filePath ends with /.go - namespace := filepath.Dir(filePath) - if filepath.Base(filepath.Dir(filePath)) != strings.TrimSuffix(filepath.Base(filePath), ".go") { - return nil, fmt.Errorf("files to be synced by Lekko must have same name as parent directory (e.g. internal/lekko/default/default.go): %s", filePath) - } - // Validate namespace regex - if !regexp.MustCompile("[a-z]+").MatchString(namespace) { - return nil, fmt.Errorf("files to be synced by Lekko must have lowercase alphabetic names: %s", filePath) - } - - return &goSyncer{ - moduleRoot: moduleRoot, - // Assumes target file is at // - lekkoPath: filepath.Clean(filepath.Dir(filepath.Dir(filePath))), - filePath: filepath.Clean(filePath), - protoPackages: make(map[string]string), - FDS: NewDefaultFileDescriptorSet(), - Namespace: namespace, - }, nil -} - -func NewGoSyncerLite(moduleRoot string, filePath string) *goSyncer { - return &goSyncer{ - moduleRoot: moduleRoot, - lekkoPath: filepath.Clean(filepath.Dir(filepath.Dir(filePath))), - filePath: filepath.Clean(filePath), - protoPackages: make(map[string]string), - FDS: NewDefaultFileDescriptorSet(), - } + return namespace, nil } -// Convert source code to a namespace representation. -// If `repoPath` is passed, also propagates changes to the local config repository at that path. -func (g *goSyncer) Sync(ctx context.Context, repoPath *string) (*Namespace, error) { - src, err := os.ReadFile(g.filePath) - if err != nil { - return nil, errors.Wrap(err, fmt.Sprintf("open %s", g.filePath)) - } - namespace, err := g.SourceToNamespace(ctx, src) - if err != nil { - return nil, err - } - - if repoPath != nil { - r, err := repo.NewLocal(*repoPath, nil) - if err != nil { - return nil, err - } - // Discard logs, mainly for silencing compilation later - // TODO: Maybe a verbose flag - r.ConfigureLogger(&repo.LoggingConfiguration{ - Writer: io.Discard, - }) - rootMD, _, err := r.ParseMetadata(ctx) - if err != nil { - return nil, err - } - nsExists := false - // Need to keep track of which configs were synced - // Any configs that were already present but not synced should be removed - toRemove := make(map[string]struct{}) // Set of config names in existing namespace - for _, nsFromMeta := range rootMD.Namespaces { - if namespace.Name == nsFromMeta { - nsExists = true - ffs, err := r.GetFeatureFiles(ctx, namespace.Name) - if err != nil { - return nil, errors.Wrap(err, "read existing configs") - } - for _, ff := range ffs { - toRemove[ff.Name] = struct{}{} - } - break - } - } - if !nsExists { - if err := r.AddNamespace(ctx, namespace.Name); err != nil { - return nil, errors.Wrap(err, "add namespace") - } - } - - typeRegistry, err := g.GetTypeRegistry() +// Translate a collection of Go files to a representation of repository contents. +// Files -> repo instead of file -> namespace because FDS is shared repo-wide. +// Takes file paths instead of contents for more helpful error reporting. +func (g *goSyncer) Sync(filePaths ...string) (*featurev1beta1.RepositoryContents, error) { + ret := &featurev1beta1.RepositoryContents{FileDescriptorSet: protoutils.NewDefaultFileDescriptorSet()} + for _, filePath := range filePaths { + astf, err := parser.ParseFile(g.fset, filePath, nil, parser.ParseComments|parser.AllErrors|parser.SkipObjectResolution) if err != nil { - return nil, errors.Wrap(err, "get type registry") - } - - for _, configProto := range namespace.Features { - // create a new starlark file from a template (based on the config type) - var starBytes []byte - starImports := make([]*featurev1beta1.ImportStatement, 0) - if configProto.Type == featurev1beta1.FeatureType_FEATURE_TYPE_PROTO { - typeURL := configProto.GetTree().GetDefaultNew().GetTypeUrl() - messageType, found := strings.CutPrefix(typeURL, "type.googleapis.com/") - if !found { - return nil, fmt.Errorf("can't parse type url: %s", typeURL) - } - starInputs, err := r.BuildProtoStarInputsWithTypes(ctx, messageType, feature.LatestNamespaceVersion(), typeRegistry) - if err != nil { - return nil, err - } - starBytes, err = star.RenderExistingProtoTemplate(*starInputs, feature.LatestNamespaceVersion()) - if err != nil { - return nil, err - } - for importPackage, importAlias := range starInputs.Packages { - starImports = append(starImports, &featurev1beta1.ImportStatement{ - Lhs: &featurev1beta1.IdentExpr{ - Token: importAlias, - }, - Operator: "=", - Rhs: &featurev1beta1.ImportExpr{ - Dot: &featurev1beta1.DotExpr{ - X: "proto", - Name: "package", - }, - Args: []string{importPackage}, - }, - }) - } - } else { - starBytes, err = star.GetTemplate(eval.ConfigTypeFromProto(configProto.Type), feature.LatestNamespaceVersion(), nil) - if err != nil { - return nil, err - } - } - if configProto.Tree.Default == nil { - configProto.Tree.Default = &anypb.Any{ - TypeUrl: configProto.Tree.DefaultNew.GetTypeUrl(), - Value: configProto.Tree.DefaultNew.GetValue(), - } - } - // mutate starlark with the actual config - walker := static.NewWalker("", starBytes, typeRegistry, feature.NamespaceVersionV1Beta7) - newBytes, err := walker.Mutate(&featurev1beta1.StaticFeature{ - Key: configProto.Key, - Type: configProto.GetType(), - Feature: &featurev1beta1.FeatureStruct{ - Description: configProto.GetDescription(), - }, - FeatureOld: configProto, - Imports: starImports, - }) - if err != nil { - return nil, errors.Wrap(err, "walker mutate") - } - configFile := feature.NewFeatureFile(namespace.Name, configProto.Key) - // write starlark to disk - if err := r.WriteFile(path.Join(namespace.Name, configFile.StarlarkFileName), newBytes, 0600); err != nil { - return nil, errors.Wrap(err, "write after mutation") - } - delete(toRemove, configProto.Key) + return nil, errors.Wrapf(err, "parse %s", filePath) } - // Remove leftovers - for configName := range toRemove { - if err := r.RemoveFeature(ctx, namespace.Name, configName); err != nil { - return nil, errors.Wrapf(err, "remove %s", configName) - } - } - // Write types to files & rebuild in-repo type registry - if err := g.writeTypesToRepo(ctx, r); err != nil { - return nil, errors.Wrap(err, "write type files") - } - repoReg, err := r.ReBuildDynamicTypeRegistry(ctx, rootMD.ProtoDirectory, false) + ns, err := g.AstToNamespace(astf, ret.FileDescriptorSet) if err != nil { - return nil, errors.Wrap(err, "final rebuild type registry") + return nil, errors.Wrapf(err, "translate %s", filePath) } - repoReg.RangeMessages(func(mt protoreflect.MessageType) bool { - _ = g.typeRegistry.RegisterMessage(mt) - return true - }) - - // Final compile to verify healthy sync - if _, err := r.Compile(ctx, &repo.CompileRequest{ - Registry: g.typeRegistry, - IgnoreBackwardsCompatibility: true, - }); err != nil { - return nil, errors.Wrap(err, "final compile") - } - } - - return namespace, nil -} - -// Gets the type registry of the syncer, converted from the internal fds. -func (g *goSyncer) GetTypeRegistry() (*protoregistry.Types, error) { - fr, err := protodesc.NewFiles(g.FDS) - if err != nil { - return nil, errors.Wrap(err, "convert to file registry") - } - tr, err := FileRegistryToTypeRegistry(fr) - if err != nil { - return nil, errors.Wrap(err, "get type registry from file registry") + ret.Namespaces = append(ret.Namespaces, ns) } - return tr, nil -} -func (g *goSyncer) writeTypesToRepo(ctx context.Context, r repo.ConfigurationRepository) error { - rootMD, _, err := r.ParseMetadata(ctx) - if err != nil { - return errors.Wrap(err, "parse repository metadata") - } - fr, err := protodesc.NewFiles(g.FDS) - if err != nil { - return errors.Wrap(err, "convert to file registry") - } - var writeErr error - fr.RangeFiles(func(fd protoreflect.FileDescriptor) bool { - // Ignore well-known types since they shouldn't be written as files - if strings.HasPrefix(string(fd.FullName()), "google.protobuf") { - return true - } - // Ignore our types since they shouldn't be written as files - if strings.HasPrefix(string(fd.FullName()), "lekko.") { - return true - } - contents, err := FileDescriptorToProtoString(fd) - if err != nil { - writeErr = errors.Wrapf(err, "stringify file descriptor %s", fd.FullName()) - return false - } - path := filepath.Join(rootMD.ProtoDirectory, fd.Path()) - if err := r.WriteFile(path, []byte(contents), 0600); err != nil { - writeErr = errors.Wrapf(err, "write to %s", path) - return false - } - return true - }) - return writeErr + return ret, nil } // TODO - is this only used for context keys, or other things? @@ -620,7 +356,7 @@ func (g *goSyncer) exprToValue(expr ast.Expr) string { } // TODO -- We know the return type.. -func (g *goSyncer) exprToAny(expr ast.Expr, want featurev1beta1.FeatureType) *featurev1beta1.Any { +func (g *goSyncer) exprToAny(expr ast.Expr, want featurev1beta1.FeatureType, namespace string, typeRegistry *protoregistry.Types) *featurev1beta1.Any { switch node := expr.(type) { case *ast.UnaryExpr: switch node.Op { @@ -628,12 +364,12 @@ func (g *goSyncer) exprToAny(expr ast.Expr, want featurev1beta1.FeatureType) *fe switch x := node.X.(type) { case *ast.CompositeLit: // TODO - this is the one place we set the values for return types - message, overrides := g.compositeLitToProto(x) + message, overrides := g.compositeLitToProto(x, namespace, typeRegistry) protoMsg, ok := message.(protoreflect.ProtoMessage) if !ok { panic("This should never happen") } - value, err := proto.Marshal(protoMsg) + value, err := proto.MarshalOptions{Deterministic: true}.Marshal(protoMsg) if err != nil { panic(err) } @@ -654,11 +390,11 @@ func (g *goSyncer) exprToAny(expr ast.Expr, want featurev1beta1.FeatureType) *fe // TODO -- do the other stuff... configKey := strcase.ToKebab(fun.Name[3:]) protoMsg := &featurev1beta1.ConfigCall{ - Namespace: g.Namespace, + Namespace: namespace, Key: configKey, // TODO do we know the location of this? } - value, err := proto.Marshal(protoMsg) + value, err := proto.MarshalOptions{Deterministic: true}.Marshal(protoMsg) if err != nil { panic(err) } @@ -671,10 +407,10 @@ func (g *goSyncer) exprToAny(expr ast.Expr, want featurev1beta1.FeatureType) *fe } default: // TODO - value := g.primitiveToProtoValue(expr) + value := g.primitiveToProtoValue(expr, namespace) switch typedValue := value.(type) { case string: - value, err := proto.MarshalOptions{}.Marshal(&wrapperspb.StringValue{Value: typedValue}) + value, err := proto.MarshalOptions{Deterministic: true}.Marshal(&wrapperspb.StringValue{Value: typedValue}) if err != nil { panic(err) } @@ -686,7 +422,7 @@ func (g *goSyncer) exprToAny(expr ast.Expr, want featurev1beta1.FeatureType) *fe // A value parsed as an integer might actually be for a float config switch want { case featurev1beta1.FeatureType_FEATURE_TYPE_INT: - value, err := proto.MarshalOptions{}.Marshal(&wrapperspb.Int64Value{Value: typedValue}) + value, err := proto.MarshalOptions{Deterministic: true}.Marshal(&wrapperspb.Int64Value{Value: typedValue}) if err != nil { panic(err) } @@ -696,7 +432,7 @@ func (g *goSyncer) exprToAny(expr ast.Expr, want featurev1beta1.FeatureType) *fe } case featurev1beta1.FeatureType_FEATURE_TYPE_FLOAT: // TODO: handle precision boundaries properly - value, err := proto.MarshalOptions{}.Marshal(&wrapperspb.DoubleValue{Value: float64(typedValue)}) + value, err := proto.MarshalOptions{Deterministic: true}.Marshal(&wrapperspb.DoubleValue{Value: float64(typedValue)}) if err != nil { panic(err) } @@ -708,7 +444,7 @@ func (g *goSyncer) exprToAny(expr ast.Expr, want featurev1beta1.FeatureType) *fe panic(fmt.Errorf("unexpected primitive %v for target return type %v", typedValue, want)) } case float64: - value, err := proto.MarshalOptions{}.Marshal(&wrapperspb.DoubleValue{Value: typedValue}) + value, err := proto.MarshalOptions{Deterministic: true}.Marshal(&wrapperspb.DoubleValue{Value: typedValue}) if err != nil { panic(err) } @@ -717,7 +453,7 @@ func (g *goSyncer) exprToAny(expr ast.Expr, want featurev1beta1.FeatureType) *fe Value: value, } case bool: - value, err := proto.MarshalOptions{}.Marshal(&wrapperspb.BoolValue{Value: typedValue}) + value, err := proto.MarshalOptions{Deterministic: true}.Marshal(&wrapperspb.BoolValue{Value: typedValue}) if err != nil { panic(err) } @@ -731,54 +467,21 @@ func (g *goSyncer) exprToAny(expr ast.Expr, want featurev1beta1.FeatureType) *fe } } -// e.g. configv1beta1.Message -> [configv1beta1, Message] -func exprToNameParts(expr ast.Expr) []string { - switch node := expr.(type) { - case *ast.Ident: - return []string{node.Name} - case *ast.SelectorExpr: - return append(exprToNameParts(node.X), exprToNameParts(node.Sel)...) - default: - panic(fmt.Errorf("invalid expression for name %+v", node)) - } -} - -func (g *goSyncer) compositeLitToMessageType(x *ast.CompositeLit) protoreflect.MessageType { - var protoPackage string +// TODO: Handling for duration and nested types in general are really complex, buggy and not well tested. +// We should probably start with spec'ing out the type constructs we're willing to support in all native languages. +func (g *goSyncer) compositeLitToMessageType(x *ast.CompositeLit, namespace string, typeRegistry *protoregistry.Types) protoreflect.MessageType { var fullName protoreflect.FullName innerExpr, ok := x.Type.(*ast.SelectorExpr) if ok { innerIdent, ok := innerExpr.X.(*ast.Ident) if ok && innerIdent.Name == "durationpb" { - mt, err := g.typeRegistry.FindMessageByName(protoreflect.FullName("google.protobuf").Append(protoreflect.Name(innerExpr.Sel.Name))) + mt, err := typeRegistry.FindMessageByName(protoreflect.FullName("google.protobuf").Append(protoreflect.Name(innerExpr.Sel.Name))) if err == nil { return mt } panic(err) } - parts := exprToNameParts(x.Type) - assert.Equal(len(parts), 2, fmt.Sprintf("expected message name to be 2 parts: %v", parts)) - protoPackage, ok = g.protoPackages[parts[0]] - assert.Equal(ok, true, fmt.Sprintf("unknown package %v", parts[0])) - fullName = protoreflect.FullName(protoPackage).Append(protoreflect.Name(parts[1])) - mt, err := g.typeRegistry.FindMessageByName(fullName) - if errors.Is(err, protoregistry.NotFound) { - // Check if nested type (e.g. Outer_Inner) (only works 2 levels for now) - if strings.Contains(parts[1], "_") { - names := strings.Split(parts[1], "_") - assert.Equal(len(names), 2, fmt.Sprintf("only singly nested messages are supported: %v", parts[1])) - if outerDescriptor, err := g.typeRegistry.FindMessageByName(protoreflect.FullName(protoPackage).Append(protoreflect.Name(names[0]))); err == nil { - if innerDescriptor := outerDescriptor.Descriptor().Messages().ByName(protoreflect.Name(names[1])); innerDescriptor != nil { - return dynamicpb.NewMessageType(innerDescriptor) - } - } - } - panic(fmt.Errorf("missing proto type in registry %s", fullName)) - } else if err != nil { - panic(errors.Wrap(err, "error while finding message type registry")) - } else { - return mt - } + panic(errors.New("unsupported selector expression for composite literal type")) } else { // it should be an ident for a bare raw struct ident, ok := x.Type.(*ast.Ident) @@ -786,18 +489,17 @@ func (g *goSyncer) compositeLitToMessageType(x *ast.CompositeLit) protoreflect.M panic("Unknown syntax") } // TODO - fix this - this is gross af - namespace := g.Namespace fullName = protoreflect.FullName(fmt.Sprintf("%s.config.v1beta1", namespace)).Append(protoreflect.Name(ident.Name)) - mt, err := g.typeRegistry.FindMessageByName(fullName) + mt, err := typeRegistry.FindMessageByName(fullName) if err != nil { - panic(errors.Wrap(err, "error while finding message type registry")) + panic(errors.Wrapf(err, "find %s in type registry", fullName)) } else { return mt } } } -func (g *goSyncer) primitiveToProtoValue(expr ast.Expr) any { +func (g *goSyncer) primitiveToProtoValue(expr ast.Expr, namespace string) any { switch x := expr.(type) { case *ast.BasicLit: switch x.Kind { @@ -842,7 +544,7 @@ func (g *goSyncer) primitiveToProtoValue(expr ast.Expr) any { // TODO -- do the other stuff... configKey := strcase.ToKebab(fun.Name[3:]) return &featurev1beta1.ConfigCall{ - Namespace: g.Namespace, + Namespace: namespace, Key: configKey, // TODO do we know the location of this? } @@ -854,9 +556,9 @@ func (g *goSyncer) primitiveToProtoValue(expr ast.Expr) any { } } -func (g *goSyncer) compositeLitToProto(x *ast.CompositeLit) (protoreflect.Message, []*featurev1beta1.ValueOveride) { +func (g *goSyncer) compositeLitToProto(x *ast.CompositeLit, namespace string, typeRegistry *protoregistry.Types) (protoreflect.Message, []*featurev1beta1.ValueOveride) { var overrides []*featurev1beta1.ValueOveride - mt := g.compositeLitToMessageType(x) + mt := g.compositeLitToMessageType(x, namespace, typeRegistry) msg := mt.New() for _, v := range x.Elts { kv, ok := v.(*ast.KeyValueExpr) @@ -874,7 +576,7 @@ func (g *goSyncer) compositeLitToProto(x *ast.CompositeLit) (protoreflect.Messag case token.AND: switch ix := node.X.(type) { case *ast.CompositeLit: - innerMessage, calls := g.compositeLitToProto(ix) + innerMessage, calls := g.compositeLitToProto(ix, namespace, typeRegistry) if len(calls) > 0 { fmt.Printf("%+v\n", calls) } @@ -893,7 +595,7 @@ func (g *goSyncer) compositeLitToProto(x *ast.CompositeLit) (protoreflect.Messag case *ast.Ident: // Primitive type array for _, elt := range node.Elts { - eltVal := g.primitiveToProtoValue(elt) + eltVal := g.primitiveToProtoValue(elt, namespace) lVal.Append(protoreflect.ValueOf(eltVal)) } case *ast.StarExpr: @@ -926,7 +628,7 @@ func (g *goSyncer) compositeLitToProto(x *ast.CompositeLit) (protoreflect.Messag default: panic(fmt.Errorf("unsupported slice element type %+v", elt)) } - innerMessage, innerOverrides := g.compositeLitToProto(cl) + innerMessage, innerOverrides := g.compositeLitToProto(cl, namespace, typeRegistry) overrides = append(overrides, innerOverrides...) lVal.Append(protoreflect.ValueOf(innerMessage)) } @@ -955,7 +657,7 @@ func (g *goSyncer) compositeLitToProto(x *ast.CompositeLit) (protoreflect.Messag assert.Equal(ok, true, "expected basic literal for map key") key := protoreflect.ValueOfString(strings.Trim(basicLit.Value, "\"")).MapKey() // For now, assume all map values are primitives - value := g.primitiveToProtoValue(pair.Value) + value := g.primitiveToProtoValue(pair.Value, namespace) msg.Mutable(field).Map().Set(key, protoreflect.ValueOf(value)) } default: @@ -963,7 +665,7 @@ func (g *goSyncer) compositeLitToProto(x *ast.CompositeLit) (protoreflect.Messag } default: // Value is not a composite literal - try handling as a primitive - value := g.primitiveToProtoValue(node) + value := g.primitiveToProtoValue(node, namespace) if field.Kind() == protoreflect.EnumKind { // Special handling for enums intValue, ok := value.(int64) @@ -990,14 +692,14 @@ func (g *goSyncer) compositeLitToProto(x *ast.CompositeLit) (protoreflect.Messag return msg, overrides } -func (g *goSyncer) exprToComparisonValue(expr ast.Expr) *structpb.Value { +func (g *goSyncer) exprToComparisonValue(expr ast.Expr, namespace string) *structpb.Value { switch node := expr.(type) { case *ast.CompositeLit: _, ok := node.Type.(*ast.ArrayType) assert.Equal(ok, true, "only slices are allowed for composite literals in comparisons") var list []*structpb.Value for _, elt := range node.Elts { - list = append(list, g.exprToComparisonValue(elt)) + list = append(list, g.exprToComparisonValue(elt, namespace)) } return &structpb.Value{ Kind: &structpb.Value_ListValue{ @@ -1008,7 +710,7 @@ func (g *goSyncer) exprToComparisonValue(expr ast.Expr) *structpb.Value { } default: // If not composite lit, must(/should) be primitive - value := g.primitiveToProtoValue(expr) + value := g.primitiveToProtoValue(expr, namespace) ret := &structpb.Value{} switch typedValue := value.(type) { case string: @@ -1034,18 +736,18 @@ func (g *goSyncer) exprToComparisonValue(expr ast.Expr) *structpb.Value { } } -func (g *goSyncer) binaryExprToRule(expr *ast.BinaryExpr, contextKeys map[string]string) *rulesv1beta3.Rule { +func (g *goSyncer) binaryExprToRule(expr *ast.BinaryExpr, contextKeys map[string]string, namespace string) *rulesv1beta3.Rule { switch expr.Op { case token.LAND: var rules []*rulesv1beta3.Rule - left := g.exprToRule(expr.X, contextKeys) + left := g.exprToRule(expr.X, contextKeys, namespace) l, ok := left.Rule.(*rulesv1beta3.Rule_LogicalExpression) if ok && l.LogicalExpression.LogicalOperator == rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_AND { rules = append(rules, l.LogicalExpression.Rules...) } else { rules = append(rules, left) } - right := g.exprToRule(expr.Y, contextKeys) + right := g.exprToRule(expr.Y, contextKeys, namespace) r, ok := right.Rule.(*rulesv1beta3.Rule_LogicalExpression) if ok && r.LogicalExpression.LogicalOperator == rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_AND { rules = append(rules, r.LogicalExpression.Rules...) @@ -1055,14 +757,14 @@ func (g *goSyncer) binaryExprToRule(expr *ast.BinaryExpr, contextKeys map[string return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_LogicalExpression{LogicalExpression: &rulesv1beta3.LogicalExpression{LogicalOperator: rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_AND, Rules: rules}}} case token.LOR: var rules []*rulesv1beta3.Rule - left := g.exprToRule(expr.X, contextKeys) + left := g.exprToRule(expr.X, contextKeys, namespace) l, ok := left.Rule.(*rulesv1beta3.Rule_LogicalExpression) if ok && l.LogicalExpression.LogicalOperator == rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_OR { rules = append(rules, l.LogicalExpression.Rules...) } else { rules = append(rules, left) } - right := g.exprToRule(expr.Y, contextKeys) + right := g.exprToRule(expr.Y, contextKeys, namespace) r, ok := right.Rule.(*rulesv1beta3.Rule_LogicalExpression) if ok && r.LogicalExpression.LogicalOperator == rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_OR { rules = append(rules, r.LogicalExpression.Rules...) @@ -1071,23 +773,23 @@ func (g *goSyncer) binaryExprToRule(expr *ast.BinaryExpr, contextKeys map[string } return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_LogicalExpression{LogicalExpression: &rulesv1beta3.LogicalExpression{LogicalOperator: rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_OR, Rules: rules}}} case token.EQL: - return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y)}}} + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y, namespace)}}} case token.LSS: - return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y)}}} + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y, namespace)}}} case token.GTR: - return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y)}}} + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y, namespace)}}} case token.NEQ: - return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_NOT_EQUALS, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y)}}} + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_NOT_EQUALS, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y, namespace)}}} case token.LEQ: - return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN_OR_EQUALS, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y)}}} + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN_OR_EQUALS, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y, namespace)}}} case token.GEQ: - return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN_OR_EQUALS, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y)}}} + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN_OR_EQUALS, ContextKey: g.exprToValue(expr.X), ComparisonValue: g.exprToComparisonValue(expr.Y, namespace)}}} default: panic(fmt.Errorf("unexpected token in binary expression %v", expr.Op)) } } -func (g *goSyncer) callExprToRule(expr *ast.CallExpr) *rulesv1beta3.Rule { +func (g *goSyncer) callExprToRule(expr *ast.CallExpr, namespace string) *rulesv1beta3.Rule { // TODO check Fun selectorExpr, ok := expr.Fun.(*ast.SelectorExpr) assert.Equal(ok, true) @@ -1097,18 +799,18 @@ func (g *goSyncer) callExprToRule(expr *ast.CallExpr) *rulesv1beta3.Rule { case "slices": switch selectorExpr.Sel.Name { case "Contains": - return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN, ContextKey: g.exprToValue(expr.Args[1]), ComparisonValue: g.exprToComparisonValue(expr.Args[0])}}} + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN, ContextKey: g.exprToValue(expr.Args[1]), ComparisonValue: g.exprToComparisonValue(expr.Args[0], namespace)}}} default: panic(fmt.Errorf("unsupported slices operator %s", selectorExpr.Sel.Name)) } case "strings": switch selectorExpr.Sel.Name { case "Contains": - return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINS, ContextKey: g.exprToValue(expr.Args[0]), ComparisonValue: g.exprToComparisonValue(expr.Args[1])}}} + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINS, ContextKey: g.exprToValue(expr.Args[0]), ComparisonValue: g.exprToComparisonValue(expr.Args[1], namespace)}}} case "HasPrefix": - return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH, ContextKey: g.exprToValue(expr.Args[0]), ComparisonValue: g.exprToComparisonValue(expr.Args[1])}}} + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH, ContextKey: g.exprToValue(expr.Args[0]), ComparisonValue: g.exprToComparisonValue(expr.Args[1], namespace)}}} case "HasSuffix": - return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH, ContextKey: g.exprToValue(expr.Args[0]), ComparisonValue: g.exprToComparisonValue(expr.Args[1])}}} + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Atom{Atom: &rulesv1beta3.Atom{ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH, ContextKey: g.exprToValue(expr.Args[0]), ComparisonValue: g.exprToComparisonValue(expr.Args[1], namespace)}}} default: panic(fmt.Errorf("unsupported strings operator %s", selectorExpr.Sel.Name)) } @@ -1117,10 +819,10 @@ func (g *goSyncer) callExprToRule(expr *ast.CallExpr) *rulesv1beta3.Rule { } } -func (g *goSyncer) unaryExprToRule(expr *ast.UnaryExpr, contextKeys map[string]string) *rulesv1beta3.Rule { +func (g *goSyncer) unaryExprToRule(expr *ast.UnaryExpr, contextKeys map[string]string, namespace string) *rulesv1beta3.Rule { switch expr.Op { case token.NOT: - rule := g.exprToRule(expr.X, contextKeys) + rule := g.exprToRule(expr.X, contextKeys, namespace) if atom := rule.GetAtom(); atom != nil { boolValue, isBool := atom.ComparisonValue.GetKind().(*structpb.Value_BoolValue) if isBool && atom.ComparisonOperator == rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS { @@ -1149,18 +851,18 @@ func (g *goSyncer) identToRule(ident *ast.Ident, contextKeys map[string]string) panic(fmt.Errorf("not a boolean expression: %+v", ident)) } -func (g *goSyncer) exprToRule(expr ast.Expr, contextKeys map[string]string) *rulesv1beta3.Rule { +func (g *goSyncer) exprToRule(expr ast.Expr, contextKeys map[string]string, namespace string) *rulesv1beta3.Rule { switch node := expr.(type) { case *ast.Ident: return g.identToRule(node, contextKeys) case *ast.BinaryExpr: - return g.binaryExprToRule(node, contextKeys) + return g.binaryExprToRule(node, contextKeys, namespace) case *ast.CallExpr: - return g.callExprToRule(node) + return g.callExprToRule(node, namespace) case *ast.ParenExpr: - return g.exprToRule(node.X, contextKeys) + return g.exprToRule(node.X, contextKeys, namespace) case *ast.UnaryExpr: - return g.unaryExprToRule(node, contextKeys) + return g.unaryExprToRule(node, contextKeys, namespace) case *ast.SelectorExpr: // TODO - make sure this is args return g.identToRule(node.Sel, contextKeys) default: @@ -1168,17 +870,17 @@ func (g *goSyncer) exprToRule(expr ast.Expr, contextKeys map[string]string) *rul } } -func (g *goSyncer) ifToConstraints(ifStmt *ast.IfStmt, want featurev1beta1.FeatureType, contextKeys map[string]string) []*featurev1beta1.Constraint { +func (g *goSyncer) ifToConstraints(ifStmt *ast.IfStmt, want featurev1beta1.FeatureType, contextKeys map[string]string, namespace string, typeRegistry *protoregistry.Types) []*featurev1beta1.Constraint { constraint := &featurev1beta1.Constraint{} - constraint.RuleAstNew = g.exprToRule(ifStmt.Cond, contextKeys) + constraint.RuleAstNew = g.exprToRule(ifStmt.Cond, contextKeys, namespace) assert.Equal(len(ifStmt.Body.List), 1, "if statements can only contain one return statement") returnStmt, ok := ifStmt.Body.List[0].(*ast.ReturnStmt) // TODO assert.Equal(ok, true, "if statements can only contain return statements") - constraint.ValueNew = g.exprToAny(returnStmt.Results[0], want) // TODO - if ifStmt.Else != nil { // TODO bare else? + constraint.ValueNew = g.exprToAny(returnStmt.Results[0], want, namespace, typeRegistry) // TODO + if ifStmt.Else != nil { // TODO bare else? elseIfStmt, ok := ifStmt.Else.(*ast.IfStmt) assert.Equal(ok, true, "bare else statements are not supported, must be else if") - return append([]*featurev1beta1.Constraint{constraint}, g.ifToConstraints(elseIfStmt, want, contextKeys)...) + return append([]*featurev1beta1.Constraint{constraint}, g.ifToConstraints(elseIfStmt, want, contextKeys, namespace, typeRegistry)...) } return []*featurev1beta1.Constraint{constraint} } @@ -1326,6 +1028,21 @@ func (g *goSyncer) structToDescriptor(structName string, structType *ast.StructT return descriptor } +// Wrap an error related to an AST node with positional information +func (g *goSyncer) posErr(node ast.Node, err any) error { + var inner error + switch e := err.(type) { + case string: + inner = errors.New(e) + case error: + inner = e + default: + panic("invalid inner error type") + } + p := g.fset.Position(node.Pos()) + return errors.Wrapf(inner, "error at %s:%d:%d", p.Filename, p.Line, p.Column) +} + func StructToMap(structType *ast.StructType) map[string]string { ret := make(map[string]string) for _, field := range structType.Fields.List { diff --git a/pkg/sync/repo.go b/pkg/sync/repo.go new file mode 100644 index 00000000..25b65736 --- /dev/null +++ b/pkg/sync/repo.go @@ -0,0 +1,204 @@ +// Copyright 2022 Lekko Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sync + +import ( + "context" + "io" + "path/filepath" + "strings" + + featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1" + "github.com/lekkodev/cli/pkg/feature" + "github.com/lekkodev/cli/pkg/proto" + "github.com/lekkodev/cli/pkg/repo" + "github.com/lekkodev/cli/pkg/star" + "github.com/lekkodev/cli/pkg/star/static" + "github.com/lekkodev/go-sdk/pkg/eval" + "github.com/pkg/errors" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/known/anypb" +) + +func WriteContentsToLocalRepo(ctx context.Context, contents *featurev1beta1.RepositoryContents, repoPath string) error { + // NOTE: For now, this function still needs a proper Lekko repository as a prereq, + // because it's uncertain if we'll ever need functionality to create a local repository + // from scratch from native lang code for whatever reason + r, err := repo.NewLocal(repoPath, nil) + if err != nil { + return errors.Wrap(err, "prepare repo") + } + // Discard logs, mainly for silencing compilation later + // TODO: Maybe a verbose flag + r.ConfigureLogger(&repo.LoggingConfiguration{ + Writer: io.Discard, + }) + rootMD, _, err := r.ParseMetadata(ctx) + if err != nil { + return err + } + typeRegistry, err := proto.FileDescriptorSetToTypeRegistry(contents.FileDescriptorSet) + if err != nil { + return errors.Wrap(err, "type registry from contents fds") + } + for _, ns := range contents.Namespaces { + nsExists := false + // Any configs that were already present but not in incoming contents should be removed + toRemove := make(map[string]struct{}) // Set of config names in existing namespace + for _, nsFromMeta := range rootMD.Namespaces { + if ns.Name == nsFromMeta { + nsExists = true + ffs, err := r.GetFeatureFiles(ctx, ns.Name) + if err != nil { + return errors.Wrapf(err, "read existing lekkos in namespace %s", ns.Name) + } + for _, ff := range ffs { + toRemove[ff.Name] = struct{}{} + } + break + } + } + if !nsExists { + if err := r.AddNamespace(ctx, ns.Name); err != nil { + return errors.Wrapf(err, "add namespace %s", ns.Name) + } + } + + for _, f := range ns.Features { + // create a new starlark file from a template (based on the config type) + var starBytes []byte + starImports := make([]*featurev1beta1.ImportStatement, 0) + if f.Type == featurev1beta1.FeatureType_FEATURE_TYPE_PROTO { + typeURL := f.GetTree().GetDefaultNew().GetTypeUrl() + messageType, found := strings.CutPrefix(typeURL, "type.googleapis.com/") + if !found { + return errors.Errorf("can't parse type url: %s", typeURL) + } + starInputs, err := r.BuildProtoStarInputsWithTypes(ctx, messageType, feature.LatestNamespaceVersion(), typeRegistry) + if err != nil { + return err + } + starBytes, err = star.RenderExistingProtoTemplate(*starInputs, feature.LatestNamespaceVersion()) + if err != nil { + return err + } + for importPackage, importAlias := range starInputs.Packages { + starImports = append(starImports, &featurev1beta1.ImportStatement{ + Lhs: &featurev1beta1.IdentExpr{ + Token: importAlias, + }, + Operator: "=", + Rhs: &featurev1beta1.ImportExpr{ + Dot: &featurev1beta1.DotExpr{ + X: "proto", + Name: "package", + }, + Args: []string{importPackage}, + }, + }) + } + } else { + starBytes, err = star.GetTemplate(eval.ConfigTypeFromProto(f.Type), feature.LatestNamespaceVersion(), nil) + if err != nil { + return err + } + } + if f.Tree.Default == nil { + f.Tree.Default = &anypb.Any{ + TypeUrl: f.Tree.DefaultNew.GetTypeUrl(), + Value: f.Tree.DefaultNew.GetValue(), + } + } + // mutate starlark with the actual config + walker := static.NewWalker("", starBytes, typeRegistry, feature.NamespaceVersionV1Beta7) + newBytes, err := walker.Mutate(&featurev1beta1.StaticFeature{ + Key: f.Key, + Type: f.GetType(), + Feature: &featurev1beta1.FeatureStruct{ + Description: f.GetDescription(), + }, + FeatureOld: f, + Imports: starImports, + }) + if err != nil { + return errors.Wrap(err, "walker mutate") + } + configFile := feature.NewFeatureFile(ns.Name, f.Key) + // write starlark to disk + if err := r.WriteFile(filepath.Join(ns.Name, configFile.StarlarkFileName), newBytes, 0600); err != nil { + return errors.Wrap(err, "write after mutation") + } + delete(toRemove, f.Key) + } + // Remove leftovers + for configName := range toRemove { + if err := r.RemoveFeature(ctx, ns.Name, configName); err != nil { + return errors.Wrapf(err, "remove %s/%s", ns.Name, configName) + } + } + } + // Write types to files & rebuild in-repo type registry + if err := WriteTypesToRepo(ctx, contents.FileDescriptorSet, r); err != nil { + return errors.Wrap(err, "write type files") + } + if _, err := r.ReBuildDynamicTypeRegistry(ctx, rootMD.ProtoDirectory, false); err != nil { + return errors.Wrap(err, "final rebuild type registry") + } + + // Final compile to verify overall health + if _, err := r.Compile(ctx, &repo.CompileRequest{ + IgnoreBackwardsCompatibility: true, + }); err != nil { + return errors.Wrap(err, "final compile") + } + + return nil +} + +func WriteTypesToRepo(ctx context.Context, fds *descriptorpb.FileDescriptorSet, r repo.ConfigurationRepository) error { + rootMD, _, err := r.ParseMetadata(ctx) + if err != nil { + return errors.Wrap(err, "parse repository metadata") + } + fr, err := protodesc.NewFiles(fds) + if err != nil { + return errors.Wrap(err, "convert to file registry") + } + var writeErr error + fr.RangeFiles(func(fd protoreflect.FileDescriptor) bool { + // Ignore well-known types since they shouldn't be written as files + if strings.HasPrefix(string(fd.FullName()), "google.protobuf") { + return true + } + // Ignore our types since they shouldn't be written as files + if strings.HasPrefix(string(fd.FullName()), "lekko.") { + return true + } + contents, err := proto.PrintFileDescriptor(fd) + if err != nil { + writeErr = errors.Wrapf(err, "stringify file descriptor %s", fd.FullName()) + return false + } + path := filepath.Join(rootMD.ProtoDirectory, fd.Path()) + if err := r.WriteFile(path, []byte(contents), 0600); err != nil { + writeErr = errors.Wrapf(err, "write to %s", path) + return false + } + return true + }) + return writeErr +}