diff --git a/pkg/gen/golang.go b/pkg/gen/golang.go index c6701a25..3d381e87 100644 --- a/pkg/gen/golang.go +++ b/pkg/gen/golang.go @@ -584,15 +584,14 @@ func (g *goGenerator) tryStoreUsedVariable(usedVariables map[string]string, k st usedVariables[k] = t return } - // TODO: test with err2 handlers to handle more gracefully assert.Equal(t, existT) } +// Recursively translate a rule, which is an n-ary tree. See lekko.rules.v1beta3.Rule. func (g *goGenerator) translateRule(rule *rulesv1beta3.Rule, staticContext bool, usedVariables map[string]string, usedStrings, usedSlices *bool) string { if rule == nil { return "" } - // TODO: Do we actually want to case context keys in terms of cross language? switch v := rule.GetRule().(type) { case *rulesv1beta3.Rule_Atom: switch v.Atom.GetComparisonOperator() { @@ -686,7 +685,24 @@ func (g *goGenerator) translateRule(rule *rulesv1beta3.Rule, staticContext bool, return fmt.Sprintf("slices.Contains([]%s{%s}, %s)", sliceType, strings.Join(elements, ", "), strcase.ToLowerCamel(v.Atom.ContextKey)) } // TODO, probably logical to have this here but we need slice syntax, use slices as of golang 1.21 + default: + panic(fmt.Errorf("unsupported operator %+v", v.Atom.ComparisonOperator)) } + case *rulesv1beta3.Rule_Not: + ruleStrFmt := "!%s" + // For some cases, we want to wrap the generated Go expression string in parens + switch rule := v.Not.Rule.(type) { + case *rulesv1beta3.Rule_LogicalExpression: + ruleStrFmt = "!(%s)" + case *rulesv1beta3.Rule_Atom: + if rule.Atom.ComparisonOperator != rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN && + rule.Atom.ComparisonOperator != rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINS && + rule.Atom.ComparisonOperator != rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH && + rule.Atom.ComparisonOperator != rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH { + ruleStrFmt = "!(%s)" + } + } + return fmt.Sprintf(ruleStrFmt, g.translateRule(v.Not, staticContext, usedVariables, usedStrings, usedSlices)) case *rulesv1beta3.Rule_LogicalExpression: operator := " && " switch v.LogicalExpression.GetLogicalOperator() { @@ -695,14 +711,21 @@ func (g *goGenerator) translateRule(rule *rulesv1beta3.Rule, staticContext bool, } var result []string for _, rule := range v.LogicalExpression.Rules { - // worry about inner parens later - result = append(result, g.translateRule(rule, staticContext, usedVariables, usedStrings, usedSlices)) + ruleStrFmt := "%s" + // If child is a nested logical expression, wrap in parens + if l, nested := rule.Rule.(*rulesv1beta3.Rule_LogicalExpression); nested { + // Exception: if current level is || and child is &&, we don't need parens + // This technically depends on dev preference, we should pick one version and stick with it for canonicity + if !(v.LogicalExpression.LogicalOperator == rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_OR && l.LogicalExpression.LogicalOperator == rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_AND) { + ruleStrFmt = "(%s)" + } + } + result = append(result, fmt.Sprintf(ruleStrFmt, g.translateRule(rule, staticContext, usedVariables, usedStrings, usedSlices))) } return strings.Join(result, operator) + default: + panic(fmt.Errorf("unsupported type of rule %+v", v)) } - - fmt.Printf("Need to learn how to: %+v\n", rule.GetRule()) - return "" } func (g *goGenerator) translateProtoFieldValue(parent protoreflect.Message, f protoreflect.FieldDescriptor, val protoreflect.Value) string { diff --git a/pkg/sync/golang.go b/pkg/sync/golang.go index 490ba41c..4f14046e 100644 --- a/pkg/sync/golang.go +++ b/pkg/sync/golang.go @@ -781,12 +781,26 @@ func (g *goSyncer) callExprToRule(expr *ast.CallExpr) *rulesv1beta3.Rule { } } +func (g *goSyncer) unaryExprToRule(expr *ast.UnaryExpr) *rulesv1beta3.Rule { + switch expr.Op { + case token.NOT: + rule := g.exprToRule(expr.X) + return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Not{Not: rule}} + default: + panic(fmt.Errorf("unsupported unary expression %+v", expr)) + } +} + func (g *goSyncer) exprToRule(expr ast.Expr) *rulesv1beta3.Rule { switch node := expr.(type) { case *ast.BinaryExpr: return g.binaryExprToRule(node) case *ast.CallExpr: return g.callExprToRule(node) + case *ast.ParenExpr: + return g.exprToRule(node.X) + case *ast.UnaryExpr: + return g.unaryExprToRule(node) default: panic(fmt.Errorf("unsupported expression type for rule: %T", node)) }