@@ -25,6 +25,7 @@ import (
25
25
"sync"
26
26
27
27
featurev1beta1 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/feature/v1beta1"
28
+ rulesv1beta3 "buf.build/gen/go/lekkodev/cli/protocolbuffers/go/lekko/rules/v1beta3"
28
29
"github.com/go-git/go-git/v5/plumbing"
29
30
"github.com/lekkodev/cli/pkg/encoding"
30
31
"github.com/lekkodev/cli/pkg/feature"
@@ -39,6 +40,7 @@ import (
39
40
"google.golang.org/protobuf/reflect/protoregistry"
40
41
"google.golang.org/protobuf/types/descriptorpb"
41
42
"google.golang.org/protobuf/types/known/anypb"
43
+ "google.golang.org/protobuf/types/known/structpb"
42
44
)
43
45
44
46
// Provides functionality needed for accessing and making changes to Lekko configuration.
@@ -67,7 +69,110 @@ type ConfigurationStore interface {
67
69
RestoreWorkingDirectory (hash string ) error
68
70
}
69
71
70
- func (r * repository ) CompileFeature (ctx context.Context , registry * protoregistry.Types , namespace , featureName string , nv feature.NamespaceVersion ) (* feature.CompiledFeature , error ) {
72
+ func checkRuleFitsContextType (rule * rulesv1beta3.Rule , contextType protoreflect.MessageType ) error {
73
+ var err error
74
+ switch r := rule .Rule .(type ) {
75
+ case * rulesv1beta3.Rule_BoolConst :
76
+ return nil
77
+ case * rulesv1beta3.Rule_Not :
78
+ return checkRuleFitsContextType (r .Not , contextType )
79
+ case * rulesv1beta3.Rule_LogicalExpression :
80
+ for _ , lr := range r .LogicalExpression .GetRules () {
81
+ err = checkRuleFitsContextType (lr , contextType )
82
+ if err != nil {
83
+ return err
84
+ }
85
+ }
86
+ case * rulesv1beta3.Rule_Atom :
87
+ field := contextType .Descriptor ().Fields ().ByName (protoreflect .Name (r .Atom .GetContextKey ()))
88
+ if field == nil {
89
+ return errors .Errorf ("`%s` field not found in context type" , r .Atom .GetContextKey ())
90
+ }
91
+ fieldType := field .Kind ().String ()
92
+
93
+ switch r .Atom .ComparisonOperator {
94
+ case rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_PRESENT :
95
+ return nil
96
+ case
97
+ rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_EQUALS ,
98
+ rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_NOT_EQUALS :
99
+ switch r .Atom .GetComparisonValue ().Kind .(type ) {
100
+ case * structpb.Value_BoolValue :
101
+ if fieldType != "string" {
102
+ return errors .Errorf ("%s field has invalid type" , r .Atom .GetContextKey ())
103
+ }
104
+
105
+ case * structpb.Value_NumberValue :
106
+ if fieldType != "int64" && fieldType != "double" {
107
+ return errors .Errorf ("%s field has invalid type" , r .Atom .GetContextKey ())
108
+ }
109
+ case * structpb.Value_StringValue :
110
+ if fieldType != "string" {
111
+ return errors .Errorf ("%s field has invalid type" , r .Atom .GetContextKey ())
112
+ }
113
+ default :
114
+ panic ("This should never happen" )
115
+ }
116
+ case
117
+ rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN ,
118
+ rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_LESS_THAN_OR_EQUALS ,
119
+ rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN ,
120
+ rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_GREATER_THAN_OR_EQUALS :
121
+ if fieldType != "int64" && fieldType != "double" {
122
+ return errors .Errorf ("%s field has invalid type" , r .Atom .GetContextKey ())
123
+ }
124
+ return nil
125
+ case
126
+ rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH ,
127
+ rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH ,
128
+ rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_CONTAINS :
129
+ if fieldType != "string" {
130
+ return errors .Errorf ("%s field has invalid type" , r .Atom .GetContextKey ())
131
+ }
132
+ return nil
133
+ case rulesv1beta3 .ComparisonOperator_COMPARISON_OPERATOR_CONTAINED_WITHIN :
134
+ listRuleVal , ok := r .Atom .GetComparisonValue ().Kind .(* structpb.Value_ListValue )
135
+ if ! ok {
136
+ panic ("This should never happen" )
137
+ }
138
+ for _ , elem := range listRuleVal .ListValue .Values {
139
+ switch elem .Kind .(type ) {
140
+ case * structpb.Value_BoolValue :
141
+ if fieldType != "string" {
142
+ return errors .Errorf ("%s field has invalid type" , r .Atom .GetContextKey ())
143
+ }
144
+
145
+ case * structpb.Value_NumberValue :
146
+ if fieldType != "int64" && fieldType != "double" {
147
+ return errors .Errorf ("%s field has invalid type" , r .Atom .GetContextKey ())
148
+ }
149
+ case * structpb.Value_StringValue :
150
+ if fieldType != "string" {
151
+ return errors .Errorf ("%s field has invalid type" , r .Atom .GetContextKey ())
152
+ }
153
+ default :
154
+ panic ("This should never happen" )
155
+ }
156
+ }
157
+ return nil
158
+ }
159
+ return nil
160
+ case * rulesv1beta3.Rule_CallExpression :
161
+ switch f := r .CallExpression .Function .(type ) {
162
+ case * rulesv1beta3.CallExpression_Bucket_ :
163
+ contextKey := f .Bucket .ContextKey
164
+ if contextType .Descriptor ().Fields ().ByName (protoreflect .Name (contextKey )) == nil {
165
+ return errors .Errorf ("%s field not found in context type" , contextKey )
166
+ }
167
+ return nil
168
+ }
169
+ default :
170
+ panic ("This should never happen" )
171
+ }
172
+ return nil
173
+ }
174
+
175
+ func (r * repository ) CompileFeature (ctx context.Context , registry * protoregistry.Types , nsContextTypes map [string ]protoreflect.MessageType , namespace , featureName string , nv feature.NamespaceVersion ) (* feature.CompiledFeature , error ) {
71
176
if ! isValidName (namespace ) {
72
177
return nil , errors .Errorf ("invalid name '%s'" , namespace )
73
178
}
@@ -92,6 +197,14 @@ func (r *repository) CompileFeature(ctx context.Context, registry *protoregistry
92
197
if err != nil {
93
198
return nil , errors .Wrap (err , "compile" )
94
199
}
200
+ if nsContextTypes [namespace ] != nil {
201
+ for _ , override := range f .Feature .Overrides {
202
+ err = checkRuleFitsContextType (override .RuleASTV3 , nsContextTypes [namespace ])
203
+ if err != nil {
204
+ return nil , err
205
+ }
206
+ }
207
+ }
95
208
return f , nil
96
209
}
97
210
@@ -303,6 +416,38 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
303
416
if err := req .Validate (); err != nil {
304
417
return nil , errors .Wrap (err , "validate request" )
305
418
}
419
+
420
+ _ , nsMDs , err := r .ParseMetadata (ctx )
421
+ if err != nil {
422
+ return nil , errors .Wrap (err , "parse metadata" )
423
+ }
424
+ registry , err := r .registry (ctx , req .Registry )
425
+ if err != nil {
426
+ return nil , errors .Wrap (err , "registry" )
427
+ }
428
+ nsContextTypes := make (map [string ]protoreflect.MessageType )
429
+ for ns , nsMd := range nsMDs {
430
+ if nsMd .ContextProto != "" {
431
+ ct , err := registry .FindMessageByName (protoreflect .FullName (nsMd .ContextProto ))
432
+ if err != nil {
433
+ return nil , err
434
+ }
435
+ for i := 0 ; i < ct .Descriptor ().Fields ().Len (); i ++ {
436
+ f := ct .Descriptor ().Fields ().Get (i )
437
+ switch f .Kind ().String () {
438
+ case
439
+ "bool" ,
440
+ "int64" ,
441
+ "double" ,
442
+ "string" :
443
+ default :
444
+ return nil , errors .Errorf ("proto message cannot be used as a context message because type: %v of key: %s is not allowed" , f .Kind (), f .Name ())
445
+ }
446
+ }
447
+ nsContextTypes [ns ] = ct
448
+ }
449
+ }
450
+
306
451
// Step 1: collect. Find all features
307
452
vffs , numNamespaces , err := r .findVersionedFeatureFiles (ctx , req .NamespaceFilter , req .FeatureFilter , req .Verify )
308
453
if err != nil {
@@ -326,10 +471,6 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
326
471
}
327
472
r .Logf ("Found %d configs across %d namespaces\n " , len (results ), numNamespaces )
328
473
r .Logf ("Compiling...\n " )
329
- registry , err := r .registry (ctx , req .Registry )
330
- if err != nil {
331
- return nil , errors .Wrap (err , "registry" )
332
- }
333
474
concurrency := 50
334
475
if len (results ) < 50 {
335
476
concurrency = len (results )
@@ -344,7 +485,7 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
344
485
nsNameToSegments := make (map [string ]map [string ]string )
345
486
for _ , fcr := range results {
346
487
if fcr .FeatureName == "segments" {
347
- cf , err := r .CompileFeature (ctx , registry , fcr .NamespaceName , fcr .FeatureName , fcr .NamespaceVersion )
488
+ cf , err := r .CompileFeature (ctx , registry , nsContextTypes , fcr .NamespaceName , fcr .FeatureName , fcr .NamespaceVersion )
348
489
if err != nil {
349
490
fcr .CompilationError = err
350
491
}
@@ -385,7 +526,7 @@ func (r *repository) Compile(ctx context.Context, req *CompileRequest) ([]*Featu
385
526
}
386
527
fcr .FormattingDiffExists = fmtDiffExists
387
528
// compile feature
388
- cf , err := r .CompileFeature (ctx , registry , fcr .NamespaceName , fcr .FeatureName , fcr .NamespaceVersion )
529
+ cf , err := r .CompileFeature (ctx , registry , nsContextTypes , fcr .NamespaceName , fcr .FeatureName , fcr .NamespaceVersion )
389
530
fcr .CompiledFeature = cf
390
531
fcr .CompilationError = err
391
532
}
0 commit comments