Skip to content

Commit

Permalink
Improve how field values from nested message literals are rendered as…
Browse files Browse the repository at this point in the history
… strings for flag set defaults
  • Loading branch information
kralicky committed Sep 26, 2023
1 parent 51d0e35 commit b144e1e
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 44 deletions.
31 changes: 0 additions & 31 deletions internal/codegen/cli/extensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@ package cli

import (
"fmt"
"slices"

"github.com/samber/lo"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoimpl"
"google.golang.org/protobuf/types/dynamicpb"
)

func getExtension[T proto.Message](desc protoreflect.Descriptor, ext *protoimpl.ExtensionInfo) (out T, ok bool) {
Expand Down Expand Up @@ -65,30 +61,3 @@ func applyOptions(desc protoreflect.Descriptor, out proto.Message) {
proto.Merge(out, opts)
}
}

func (f *FlagSetOptions) ForEachDefault(fieldMessage *protogen.Message, fn func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool) {
if f.Default == nil {
return
}
dm := dynamicpb.NewMessage(fieldMessage.Desc)
if err := f.Default.UnmarshalTo(dm.Interface()); err != nil {
panic(err)
}
orderedRange(dm, fn)
}

func orderedRange(dm protoreflect.Message, fn func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool) {
ordered := []lo.Tuple2[protoreflect.FieldDescriptor, protoreflect.Value]{}
dm.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
ordered = append(ordered, lo.T2(fd, v))
return true
})
slices.SortFunc(ordered, func(a, b lo.Tuple2[protoreflect.FieldDescriptor, protoreflect.Value]) int {
return int(a.A.Number() - b.A.Number())
})
for _, t := range ordered {
if !fn(t.A, t.B) {
return
}
}
}
88 changes: 76 additions & 12 deletions internal/codegen/cli/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ import (
"google.golang.org/genproto/googleapis/api/annotations"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protopath"
"google.golang.org/protobuf/reflect/protorange"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
)

func NewGenerator() *Generator {
Expand Down Expand Up @@ -758,19 +763,78 @@ func (cg *Generator) generateFlagSet(g *protogen.GeneratedFile, message *protoge
g.P("fs.AddFlagSet(in.", field.GoName, `.FlagSet(append(prefix,"`, kebabName, `")...))`)
}

flagSetOpts.ForEachDefault(field.Message, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
fdOpts := FlagOptions{}
applyOptions(fd, &fdOpts)
if fdOpts.Skip {
return true
if flagSetOpts.Default != nil {
dm := dynamicpb.NewMessage(field.Message.Desc)
if err := flagSetOpts.Default.UnmarshalTo(dm.Interface()); err != nil {
panic(err)
}
if flagSetOpts.NoPrefix {
g.P(_flagutil.Ident("SetDefValue"), `(fs, `, _strings.Ident("Join"), `(append(prefix, "`, formatKebab(fd.Name()), `"), "."), `, fmt.Sprintf("%q", v.String()), `)`)
} else {
g.P(_flagutil.Ident("SetDefValue"), `(fs, `, _strings.Ident("Join"), `(append(prefix, "`, kebabName, `", "`, formatKebab(fd.Name()), `"), "."), `, fmt.Sprintf("%q", v.String()), `)`)
}
return true
})

protorange.Options{
Stable: true,
}.Range(dm, func(vs protopath.Values) (retVal error) {
v := vs.Index(-1)
if v.Step.Kind() != protopath.FieldAccessStep {
return nil
}
fd := v.Step.FieldDescriptor()
fdOpts := FlagOptions{}
applyOptions(fd, &fdOpts)
if fdOpts.Skip {
return protorange.Break
}

var valueStr string
if fd.Kind() == protoreflect.MessageKind && !fd.IsMap() {
switch fd.Message().FullName() {
case "google.protobuf.Timestamp":
dm := v.Value.Message().Interface().(*dynamicpb.Message)
wire, _ := proto.Marshal(dm)
ts := &timestamppb.Timestamp{}
proto.Unmarshal(wire, ts)
valueStr = fmt.Sprintf("%q", ts.AsTime().Format(time.RFC3339))

retVal = protorange.Break
case "google.protobuf.Duration":
dm := v.Value.Message().Interface().(*dynamicpb.Message)
wire, _ := proto.Marshal(dm)
dur := &durationpb.Duration{}
proto.Unmarshal(wire, dur)
valueStr = fmt.Sprintf("%q", dur.AsDuration().String())

retVal = protorange.Break
default:
// recurse into nested messages
return nil
}
}

if valueStr == "" {
if fd.IsList() {
strs := []string{}
list := v.Value.List()
for i := 0; i < list.Len(); i++ {
strs = append(strs, fmt.Sprintf("%q", list.Get(i).String()))
}
valueStr = fmt.Sprintf("`[%s]`", strings.Join(strs, ","))
} else {
valueStr = fmt.Sprintf("%q", v.Value.String())
}
}

parts := []string{}
for _, part := range vs.Path[1:] {
parts = append(parts, formatKebab(part.FieldDescriptor().Name()))
}

if flagSetOpts.NoPrefix {
g.P(_flagutil.Ident("SetDefValue"), `(fs, `, _strings.Ident("Join"), `(append(prefix, "`, strings.Join(parts, "."), `"), "."), `, valueStr, `)`)
} else {
g.P(_flagutil.Ident("SetDefValue"), `(fs, `, _strings.Ident("Join"), `(append(prefix, "`, kebabName, `", "`, strings.Join(parts, "."), `"), "."), `, valueStr, `)`)
}

return
}, nil)
}
}
continue
}
Expand Down
2 changes: 1 addition & 1 deletion plugins/metrics/apis/cortexops/cortexops_cli.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b144e1e

Please sign in to comment.