Skip to content

Commit f3c9251

Browse files
ands, ors, slices
1 parent cefa6a4 commit f3c9251

File tree

1 file changed

+70
-30
lines changed

1 file changed

+70
-30
lines changed

cmd/lekko/gen.go

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ package main
1616

1717
import (
1818
"bytes"
19-
"encoding/json"
2019
"fmt"
2120
"os"
2221
"os/exec"
@@ -36,8 +35,13 @@ import (
3635
"golang.org/x/mod/modfile"
3736
"google.golang.org/protobuf/encoding/protojson"
3837
"google.golang.org/protobuf/proto"
38+
"google.golang.org/protobuf/reflect/protoregistry"
39+
"google.golang.org/protobuf/types/known/anypb"
40+
"google.golang.org/protobuf/types/known/structpb"
3941
)
4042

43+
var typeRegistry *protoregistry.Types
44+
4145
func genGoCmd() *cobra.Command {
4246
var ns string
4347
var wd string
@@ -61,8 +65,9 @@ func genGoCmd() *cobra.Command {
6165
if err != nil {
6266
return errors.Wrap(err, "new repo")
6367
}
64-
_, nsMDs := try.To2(r.ParseMetadata(cmd.Context()))
65-
68+
rootMD, nsMDs := try.To2(r.ParseMetadata(cmd.Context()))
69+
// TODO this feels weird and there is a global set we should be able to add to but I'll worrry about it later?
70+
typeRegistry = try.To1(r.BuildDynamicTypeRegistry(cmd.Context(), rootMD.ProtoDirectory))
6671
staticCtxType := unpackProtoType(moduleRoot, nsMDs[ns].ContextProto)
6772
ffs, err := r.GetFeatureFiles(cmd.Context(), ns)
6873
if err != nil {
@@ -113,6 +118,7 @@ import (
113118
{{range $.ProtoImports}}
114119
{{ . }}{{end}}
115120
"context"
121+
"golang.org/x/exp/slices"
116122
client "github.com/lekkodev/go-sdk/client"
117123
)
118124
@@ -162,7 +168,8 @@ var StaticConfig = map[string]map[string][]byte{
162168
{{range $.ProtoAsByteStrings}}{{ . }}{{end}} },
163169
}
164170
{{range $.CodeStrings}}
165-
{{ . }}{{end}}`
171+
{{ . }}
172+
{{end}}`
166173

167174
// buf generate --template '{"version":"v1","plugins":[{"plugin":"go","out":"gen/go"}]}'
168175
//
@@ -226,10 +233,8 @@ func (c *LekkoClient) {{$.FuncName}}(ctx context.Context) ({{$.RetType}}, error)
226233
}
227234
228235
// {{$.Description}}
229-
{{if $.NaturalLanguage}}func (c *SafeLekkoClient) {{$.FuncName}}(ctx *{{$.StaticType}}) {{$.RetType}} {
236+
func (c *SafeLekkoClient) {{$.FuncName}}(ctx *{{$.StaticType}}) {{$.RetType}} {
230237
{{range $.NaturalLanguage}}{{ . }}
231-
{{end}}{{else}}func (c *SafeLekkoClient) {{$.FuncName}}(ctx context.Context) {{$.RetType}} {
232-
return c.{{$.GetFunction}}(ctx, "{{$.Namespace}}", "{{$.Key}}")
233238
{{end}}}`
234239

235240
const protoTemplateBody = `// {{$.Description}}
@@ -272,30 +277,30 @@ func (c *SafeLekkoClient) {{$.FuncName}}(ctx context.Context, result interface{}
272277
StaticContextType string
273278
}
274279
var staticContextInfo *StaticContextInfo
280+
if staticCtxType != nil {
281+
staticContextInfo = &StaticContextInfo{
282+
Natty: translateFeature(f),
283+
StaticContextType: fmt.Sprintf("%s.%s", staticCtxType.PackageAlias, staticCtxType.Type),
284+
}
285+
}
275286

276287
switch f.Type {
277-
case 1:
288+
case featurev1beta1.FeatureType_FEATURE_TYPE_BOOL:
278289
retType = "bool"
279290
getFunction = "GetBool"
280-
case 2:
291+
case featurev1beta1.FeatureType_FEATURE_TYPE_INT:
281292
retType = "int64"
282293
getFunction = "GetInt"
283-
case 3:
294+
case featurev1beta1.FeatureType_FEATURE_TYPE_FLOAT:
284295
retType = "float64"
285296
getFunction = "GetFloat"
286-
case 4:
297+
case featurev1beta1.FeatureType_FEATURE_TYPE_STRING:
287298
retType = "string"
288299
getFunction = "GetString"
289-
if staticCtxType != nil {
290-
staticContextInfo = &StaticContextInfo{
291-
Natty: translateFeature(f),
292-
StaticContextType: fmt.Sprintf("%s.%s", staticCtxType.PackageAlias, staticCtxType.Type),
293-
}
294-
}
295-
case 5:
300+
case featurev1beta1.FeatureType_FEATURE_TYPE_JSON:
296301
getFunction = "GetJSON"
297302
templateBody = jsonTemplateBody
298-
case 6:
303+
case featurev1beta1.FeatureType_FEATURE_TYPE_PROTO:
299304
getFunction = "GetProto"
300305
templateBody = protoTemplateBody
301306
// we don't need the import path so sending in empty string
@@ -343,8 +348,8 @@ type protoImport struct {
343348
}
344349

345350
// This function handles both the google.protobuf.Any.TypeURL variable
346-
// which has the format of `types.googleapis.com/fully.qualified.Proto`
347-
// and purely `fully.qualified.Proto`
351+
// which has the format of `types.googleapis.com/fully.qualified.v1beta1.Proto`
352+
// and purely `fully.qualified.v1beta1.Proto`
348353
//
349354
// return nil if typeURL is empty. Panics on any problems like the rest of the file.
350355
func unpackProtoType(moduleRoot string, typeURL string) *protoImport {
@@ -389,13 +394,12 @@ func translateFeature(f *featurev1beta1.Feature) []string {
389394
buffer = append(buffer, fmt.Sprintf("\t%s %s {", ifToken, rule))
390395

391396
// TODO this doesn't work for proto, but let's try
392-
393-
buffer = append(buffer, fmt.Sprintf("\t\treturn %s", try.To1(protojson.Marshal(try.To1(constraint.Value.UnmarshalNew())))))
397+
buffer = append(buffer, fmt.Sprintf("\t\treturn %s", translateRetValue(constraint.Value)))
394398
}
395399
if len(f.Tree.Constraints) > 0 {
396400
buffer = append(buffer, "\t}")
397401
}
398-
buffer = append(buffer, fmt.Sprintf("\treturn %s", try.To1(protojson.Marshal(try.To1(f.Tree.Default.UnmarshalNew())))))
402+
buffer = append(buffer, fmt.Sprintf("\treturn %s", translateRetValue(f.GetTree().GetDefault())))
399403
return buffer
400404
}
401405

@@ -407,16 +411,52 @@ func translateRule(rule *rulesv1beta3.Rule) string {
407411
case *rulesv1beta3.Rule_Atom:
408412
switch v.Atom.GetComparisonOperator() {
409413
case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS:
410-
b, err := json.Marshal(v.Atom.ComparisonValue)
411-
if err != nil {
412-
panic(err)
413-
}
414-
return fmt.Sprintf("ctx.%s == %s", strcase.UpperCamelCase(v.Atom.ContextKey), string(b))
414+
return fmt.Sprintf("ctx.%s == %s", strcase.UpperCamelCase(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue))))
415415
case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN:
416+
sliceType := "string"
417+
switch v.Atom.ComparisonValue.GetListValue().GetValues()[0].GetKind().(type) {
418+
case *structpb.Value_NumberValue:
419+
// technically doubles may not work for ints....
420+
sliceType = "float64"
421+
case *structpb.Value_BoolValue:
422+
sliceType = "bool"
423+
case *structpb.Value_StringValue:
424+
// technically doubles may not work for ints....
425+
sliceType = "string"
426+
}
427+
var elements []string
428+
for _, comparisonVal := range v.Atom.ComparisonValue.GetListValue().GetValues() {
429+
elements = append(elements, string(try.To1(protojson.Marshal(comparisonVal))))
430+
}
431+
return fmt.Sprintf("slices.Contains([]%s{%s}, ctx.%s)", sliceType, strings.Join(elements, ", "), strcase.UpperCamelCase(v.Atom.ContextKey))
416432
// TODO, probably logical to have this here but we need slice syntax, use slices as of golang 1.21
417433
}
418434
case *rulesv1beta3.Rule_LogicalExpression:
419-
// TODO do some ands and ors
435+
operator := " && "
436+
switch v.LogicalExpression.GetLogicalOperator() {
437+
case rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_OR:
438+
operator = " || "
439+
}
440+
var result []string
441+
for _, rule := range v.LogicalExpression.Rules {
442+
// worry about inner parens later
443+
result = append(result, translateRule(rule))
444+
}
445+
return strings.Join(result, operator)
420446
}
447+
421448
return ""
422449
}
450+
451+
func translateRetValue(val *anypb.Any) string {
452+
// protos
453+
msg, err := anypb.UnmarshalNew(val, proto.UnmarshalOptions{Resolver: typeRegistry})
454+
if err != nil {
455+
panic(err)
456+
}
457+
res, err := protojson.MarshalOptions{Resolver: typeRegistry}.Marshal(msg)
458+
if err != nil {
459+
panic(err)
460+
}
461+
return string(res)
462+
}

0 commit comments

Comments
 (0)