diff --git a/pkg/gen/golang.go b/pkg/gen/golang.go index 3d381e87..9250773a 100644 --- a/pkg/gen/golang.go +++ b/pkg/gen/golang.go @@ -594,73 +594,52 @@ func (g *goGenerator) translateRule(rule *rulesv1beta3.Rule, staticContext bool, } switch v := rule.GetRule().(type) { case *rulesv1beta3.Rule_Atom: + var contextKeyName string + if staticContext { + contextKeyName = fmt.Sprintf("ctx.%s", strcase.ToCamel(v.Atom.ContextKey)) + } else { + contextKeyName = strcase.ToLowerCamel(v.Atom.ContextKey) + } + switch v.Atom.GetComparisonOperator() { case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS: - g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, structpbValueToKindStringGo(v.Atom.ComparisonValue)) - if staticContext { - return fmt.Sprintf("ctx.%s == %s", strcase.ToCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } else { - return fmt.Sprintf("%s == %s", strcase.ToLowerCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) + if b, ok := v.Atom.ComparisonValue.GetKind().(*structpb.Value_BoolValue); ok { + g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, "bool") + if b.BoolValue { + return contextKeyName + } else { + return fmt.Sprintf("!%s", contextKeyName) + } } + g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, structpbValueToKindStringGo(v.Atom.ComparisonValue)) + return fmt.Sprintf("%s == %s", contextKeyName, string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_NOT_EQUALS: g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, structpbValueToKindStringGo(v.Atom.ComparisonValue)) - if staticContext { - return fmt.Sprintf("ctx.%s != %s", strcase.ToCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } else { - return fmt.Sprintf("%s != %s", strcase.ToLowerCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } + return fmt.Sprintf("%s != %s", contextKeyName, string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN: g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, structpbValueToKindStringGo(v.Atom.ComparisonValue)) - if staticContext { - return fmt.Sprintf("ctx.%s < %s", strcase.ToCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } else { - return fmt.Sprintf("%s < %s", strcase.ToLowerCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } + return fmt.Sprintf("%s < %s", contextKeyName, string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN_OR_EQUALS: g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, structpbValueToKindStringGo(v.Atom.ComparisonValue)) - if staticContext { - return fmt.Sprintf("ctx.%s <= %s", strcase.ToCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } else { - return fmt.Sprintf("%s <= %s", strcase.ToLowerCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } + return fmt.Sprintf("%s <= %s", contextKeyName, string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN: g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, structpbValueToKindStringGo(v.Atom.ComparisonValue)) - if staticContext { - return fmt.Sprintf("ctx.%s > %s", strcase.ToCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } else { - return fmt.Sprintf("%s > %s", strcase.ToLowerCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } + return fmt.Sprintf("%s > %s", contextKeyName, string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN_OR_EQUALS: g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, structpbValueToKindStringGo(v.Atom.ComparisonValue)) - if staticContext { - return fmt.Sprintf("ctx.%s >= %s", strcase.ToCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } else { - return fmt.Sprintf("%s >= %s", strcase.ToLowerCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } + return fmt.Sprintf("%s >= %s", contextKeyName, string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINS: g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, structpbValueToKindStringGo(v.Atom.ComparisonValue)) *usedStrings = true - if staticContext { - return fmt.Sprintf("strings.Contains(ctx.%s, %s)", strcase.ToCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } else { - return fmt.Sprintf("strings.Contains(%s, %s)", strcase.ToLowerCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } + return fmt.Sprintf("strings.Contains(%s, %s)", contextKeyName, string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH: g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, structpbValueToKindStringGo(v.Atom.ComparisonValue)) *usedStrings = true - if staticContext { - return fmt.Sprintf("strings.HasPrefix(ctx.%s, %s)", strcase.ToCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } else { - return fmt.Sprintf("strings.HasPrefix(%s, %s)", strcase.ToLowerCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } + return fmt.Sprintf("strings.HasPrefix(%s, %s)", contextKeyName, string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH: g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, structpbValueToKindStringGo(v.Atom.ComparisonValue)) *usedStrings = true - if staticContext { - return fmt.Sprintf("strings.HasSuffix(ctx.%s, %s)", strcase.ToCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } else { - return fmt.Sprintf("strings.HasSuffix(%s, %s)", strcase.ToLowerCamel(v.Atom.ContextKey), string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) - } + return fmt.Sprintf("strings.HasSuffix(%s, %s)", contextKeyName, string(try.To1(protojson.Marshal(v.Atom.ComparisonValue)))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN: sliceType := "string" switch v.Atom.ComparisonValue.GetListValue().GetValues()[0].GetKind().(type) { @@ -679,11 +658,7 @@ func (g *goGenerator) translateRule(rule *rulesv1beta3.Rule, staticContext bool, } g.tryStoreUsedVariable(usedVariables, v.Atom.ContextKey, sliceType) *usedSlices = true - if staticContext { - return fmt.Sprintf("slices.Contains([]%s{%s}, ctx.%s)", sliceType, strings.Join(elements, ", "), strcase.ToCamel(v.Atom.ContextKey)) - } else { - return fmt.Sprintf("slices.Contains([]%s{%s}, %s)", sliceType, strings.Join(elements, ", "), strcase.ToLowerCamel(v.Atom.ContextKey)) - } + return fmt.Sprintf("slices.Contains([]%s{%s}, %s)", sliceType, strings.Join(elements, ", "), contextKeyName) // TODO, probably logical to have this here but we need slice syntax, use slices as of golang 1.21 default: panic(fmt.Errorf("unsupported operator %+v", v.Atom.ComparisonOperator)) diff --git a/pkg/gen/ts.go b/pkg/gen/ts.go index 59160dcc..790dffe1 100644 --- a/pkg/gen/ts.go +++ b/pkg/gen/ts.go @@ -402,42 +402,51 @@ func translateRuleTS(rule *rulesv1beta3.Rule, usedVariables map[string]string) s } switch v := rule.GetRule().(type) { case *rulesv1beta3.Rule_Atom: + contextKeyName := strcase.ToLowerCamel(v.Atom.ContextKey) usedVariables[v.Atom.ContextKey] = "string" // TODO - ugly as hell switch v.Atom.GetComparisonOperator() { case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS: + if b, ok := v.Atom.ComparisonValue.GetKind().(*structpb.Value_BoolValue); ok { + usedVariables[v.Atom.ContextKey] = "boolean" + if b.BoolValue { + return fmt.Sprintf("(%s)", contextKeyName) + } else { + return fmt.Sprintf("(!%s)", contextKeyName) + } + } usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("( %s === %s )", strcase.ToLowerCamel(v.Atom.ContextKey), try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return fmt.Sprintf("( %s === %s )", contextKeyName, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_NOT_EQUALS: usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("( %s !== %s )", strcase.ToLowerCamel(v.Atom.ContextKey), try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return fmt.Sprintf("( %s !== %s )", contextKeyName, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN: usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue.GetListValue().GetValues()[0]) var elements []string for _, comparisonVal := range v.Atom.ComparisonValue.GetListValue().GetValues() { elements = append(elements, string(try.To1(marshalOptions.Marshal(comparisonVal)))) } - return fmt.Sprintf("([%s].includes(%s))", strings.Join(elements, ", "), strcase.ToLowerCamel(v.Atom.ContextKey)) + return fmt.Sprintf("([%s].includes(%s))", strings.Join(elements, ", "), contextKeyName) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN: usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("(%s < %s)", strcase.ToLowerCamel(v.Atom.ContextKey), try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return fmt.Sprintf("(%s < %s)", contextKeyName, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN_OR_EQUALS: usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("(%s <= %s)", strcase.ToLowerCamel(v.Atom.ContextKey), try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return fmt.Sprintf("(%s <= %s)", contextKeyName, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN: usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("(%s > %s)", strcase.ToLowerCamel(v.Atom.ContextKey), try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return fmt.Sprintf("(%s > %s)", contextKeyName, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN_OR_EQUALS: usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("(%s >= %s)", strcase.ToLowerCamel(v.Atom.ContextKey), try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return fmt.Sprintf("(%s >= %s)", contextKeyName, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINS: usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("(%s.includes(%s))", strcase.ToLowerCamel(v.Atom.ContextKey), try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return fmt.Sprintf("(%s.includes(%s))", contextKeyName, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH: usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("(%s.startsWith(%s))", strcase.ToLowerCamel(v.Atom.ContextKey), try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return fmt.Sprintf("(%s.startsWith(%s))", contextKeyName, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH: usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("(%s.endsWith(%s))", strcase.ToLowerCamel(v.Atom.ContextKey), try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return fmt.Sprintf("(%s.endsWith(%s))", contextKeyName, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) } case *rulesv1beta3.Rule_LogicalExpression: operator := " && " @@ -451,6 +460,8 @@ func translateRuleTS(rule *rulesv1beta3.Rule, usedVariables map[string]string) s result = append(result, translateRuleTS(rule, usedVariables)) } return "(" + strings.Join(result, operator) + ")" + case *rulesv1beta3.Rule_Not: + return "!(" + translateRuleTS(v.Not, usedVariables) + ")" } fmt.Printf("Need to learn how to: %+v\n", rule.GetRule()) diff --git a/pkg/sync/golang.go b/pkg/sync/golang.go index da25546a..402fd7f2 100644 --- a/pkg/sync/golang.go +++ b/pkg/sync/golang.go @@ -167,6 +167,18 @@ func (g *goSyncer) FileLocationToNamespace(ctx context.Context) (*Namespace, err } privateName := x.Name.Name configName := strcase.ToKebab(privateName[3:]) + + contextKeys := make(map[string]string) + for _, param := range x.Type.Params.List { + assert.SNotEmpty(param.Names, "must have a parameter name") + assert.INotNil(param.Type, "must have a parameter type") + typeIdent, ok := param.Type.(*ast.Ident) + if !ok { + panic("parameter type must be an identifier") + } + contextKeys[param.Names[0].Name] = typeIdent.Name + } + results := x.Type.Results.List if results == nil { panic("must have a return type") @@ -204,7 +216,7 @@ func (g *goSyncer) FileLocationToNamespace(ctx context.Context) (*Namespace, err // TODO also need to take care of the possibility that the default is in an else feature.Tree.Default = g.exprToAny(n.Results[0], feature.Type) // can this be multiple things? case *ast.IfStmt: - feature.Tree.Constraints = append(feature.Tree.Constraints, g.ifToConstraints(n, feature.Type)...) + feature.Tree.Constraints = append(feature.Tree.Constraints, g.ifToConstraints(n, feature.Type, contextKeys)...) default: panic("only if and return statements allowed in function body") } @@ -702,18 +714,18 @@ func (g *goSyncer) exprToComparisonValue(expr ast.Expr) *structpb.Value { } } -func (g *goSyncer) binaryExprToRule(expr *ast.BinaryExpr) *rulesv1beta3.Rule { +func (g *goSyncer) binaryExprToRule(expr *ast.BinaryExpr, contextKeys map[string]string) *rulesv1beta3.Rule { switch expr.Op { case token.LAND: var rules []*rulesv1beta3.Rule - left := g.exprToRule(expr.X) + left := g.exprToRule(expr.X, contextKeys) l, ok := left.Rule.(*rulesv1beta3.Rule_LogicalExpression) if ok && l.LogicalExpression.LogicalOperator == rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_AND { rules = append(rules, l.LogicalExpression.Rules...) } else { rules = append(rules, left) } - right := g.exprToRule(expr.Y) + right := g.exprToRule(expr.Y, contextKeys) r, ok := right.Rule.(*rulesv1beta3.Rule_LogicalExpression) if ok && r.LogicalExpression.LogicalOperator == rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_AND { rules = append(rules, r.LogicalExpression.Rules...) @@ -723,14 +735,14 @@ func (g *goSyncer) binaryExprToRule(expr *ast.BinaryExpr) *rulesv1beta3.Rule { return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_LogicalExpression{LogicalExpression: &rulesv1beta3.LogicalExpression{LogicalOperator: rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_AND, Rules: rules}}} case token.LOR: var rules []*rulesv1beta3.Rule - left := g.exprToRule(expr.X) + left := g.exprToRule(expr.X, contextKeys) l, ok := left.Rule.(*rulesv1beta3.Rule_LogicalExpression) if ok && l.LogicalExpression.LogicalOperator == rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_OR { rules = append(rules, l.LogicalExpression.Rules...) } else { rules = append(rules, left) } - right := g.exprToRule(expr.Y) + right := g.exprToRule(expr.Y, contextKeys) r, ok := right.Rule.(*rulesv1beta3.Rule_LogicalExpression) if ok && r.LogicalExpression.LogicalOperator == rulesv1beta3.LogicalOperator_LOGICAL_OPERATOR_OR { rules = append(rules, r.LogicalExpression.Rules...) @@ -785,34 +797,58 @@ func (g *goSyncer) callExprToRule(expr *ast.CallExpr) *rulesv1beta3.Rule { } } -func (g *goSyncer) unaryExprToRule(expr *ast.UnaryExpr) *rulesv1beta3.Rule { +func (g *goSyncer) unaryExprToRule(expr *ast.UnaryExpr, contextKeys map[string]string) *rulesv1beta3.Rule { switch expr.Op { case token.NOT: - rule := g.exprToRule(expr.X) + rule := g.exprToRule(expr.X, contextKeys) + if atom := rule.GetAtom(); atom != nil { + boolValue, isBool := atom.ComparisonValue.GetKind().(*structpb.Value_BoolValue) + if isBool && atom.ComparisonOperator == rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS { + atom.ComparisonValue = structpb.NewBoolValue(!boolValue.BoolValue) + } + return rule + } return &rulesv1beta3.Rule{Rule: &rulesv1beta3.Rule_Not{Not: rule}} default: panic(fmt.Errorf("unsupported unary expression %+v", expr)) } } -func (g *goSyncer) exprToRule(expr ast.Expr) *rulesv1beta3.Rule { +func (g *goSyncer) identToRule(ident *ast.Ident, contextKeys map[string]string) *rulesv1beta3.Rule { + if contextKeyType, ok := contextKeys[ident.Name]; ok && contextKeyType == "bool" { + return &rulesv1beta3.Rule{ + Rule: &rulesv1beta3.Rule_Atom{ + Atom: &rulesv1beta3.Atom{ + ComparisonOperator: rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_EQUALS, + ContextKey: strcase.ToSnake(ident.Name), + ComparisonValue: structpb.NewBoolValue(true), + }, + }, + } + } + panic(fmt.Errorf("not a boolean expression: %+v", ident)) +} + +func (g *goSyncer) exprToRule(expr ast.Expr, contextKeys map[string]string) *rulesv1beta3.Rule { switch node := expr.(type) { + case *ast.Ident: + return g.identToRule(node, contextKeys) case *ast.BinaryExpr: - return g.binaryExprToRule(node) + return g.binaryExprToRule(node, contextKeys) case *ast.CallExpr: return g.callExprToRule(node) case *ast.ParenExpr: - return g.exprToRule(node.X) + return g.exprToRule(node.X, contextKeys) case *ast.UnaryExpr: - return g.unaryExprToRule(node) + return g.unaryExprToRule(node, contextKeys) default: panic(fmt.Errorf("unsupported expression type for rule: %T", node)) } } -func (g *goSyncer) ifToConstraints(ifStmt *ast.IfStmt, want featurev1beta1.FeatureType) []*featurev1beta1.Constraint { +func (g *goSyncer) ifToConstraints(ifStmt *ast.IfStmt, want featurev1beta1.FeatureType, contextKeys map[string]string) []*featurev1beta1.Constraint { constraint := &featurev1beta1.Constraint{} - constraint.RuleAstNew = g.exprToRule(ifStmt.Cond) + constraint.RuleAstNew = g.exprToRule(ifStmt.Cond, contextKeys) assert.Equal(len(ifStmt.Body.List), 1, "if statements can only contain one return statement") returnStmt, ok := ifStmt.Body.List[0].(*ast.ReturnStmt) // TODO assert.Equal(ok, true, "if statements can only contain return statements") @@ -820,7 +856,7 @@ func (g *goSyncer) ifToConstraints(ifStmt *ast.IfStmt, want featurev1beta1.Featu if ifStmt.Else != nil { // TODO bare else? elseIfStmt, ok := ifStmt.Else.(*ast.IfStmt) assert.Equal(ok, true, "bare else statements are not supported, must be else if") - return append([]*featurev1beta1.Constraint{constraint}, g.ifToConstraints(elseIfStmt, want)...) + return append([]*featurev1beta1.Constraint{constraint}, g.ifToConstraints(elseIfStmt, want, contextKeys)...) } return []*featurev1beta1.Constraint{constraint} }