Skip to content

Commit

Permalink
refactor(pipeline): refactor the condition evaluation (#952)
Browse files Browse the repository at this point in the history
Because

- The original implementation for pipeline condition evaluation is not
robust. It converts the condition string into Go code and evaluates it
using the AST. However, in our syntax, we typically use dot notation in
recipes, while Go maps require bracket notation, especially when
component or variable names are not valid Go identifiers.

This commit

- Refactors the condition evaluation.
  • Loading branch information
donch1989 authored Jan 14, 2025
1 parent fc47584 commit 4d932da
Show file tree
Hide file tree
Showing 13 changed files with 1,270 additions and 307 deletions.
11 changes: 11 additions & 0 deletions pkg/data/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,14 @@ func (a Array) String() string {
}
return fmt.Sprintf("[%s]", strings.Join(segments, ", "))
}

func (a Array) ToJSONValue() (v any, err error) {
jsonArr := make([]any, len(a))
for i, v := range a {
jsonArr[i], err = v.ToJSONValue()
if err != nil {
return nil, err
}
}
return jsonArr, nil
}
4 changes: 4 additions & 0 deletions pkg/data/boolean.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ func (b *booleanData) Equal(other format.Value) bool {
func (b *booleanData) Hash() string {
return fmt.Sprintf("%t", b.Raw)
}

func (b *booleanData) ToJSONValue() (v any, err error) {
return b.Raw, nil
}
5 changes: 5 additions & 0 deletions pkg/data/bytearray.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,8 @@ func (b *byteArrayData) Equal(other format.Value) bool {
}
return false
}

func (b *byteArrayData) ToJSONValue() (v any, err error) {
base64str := base64.StdEncoding.EncodeToString(b.Raw)
return base64str, nil
}
8 changes: 8 additions & 0 deletions pkg/data/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,11 @@ func (f *fileData) Equal(other format.Value) bool {
}
return false
}

func (f *fileData) ToJSONValue() (v any, err error) {
base64str, err := f.Base64()
if err != nil {
return nil, err
}
return base64str.String(), nil
}
1 change: 1 addition & 0 deletions pkg/data/format/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
type Value interface {
IsValue()
ToStructValue() (v *structpb.Value, err error)
ToJSONValue() (v any, err error)
Get(p *path.Path) (v Value, err error)
Equal(other Value) bool
String() string
Expand Down
12 changes: 12 additions & 0 deletions pkg/data/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,15 @@ func (m Map) String() string {
}
return fmt.Sprintf("{%s}", strings.Join(segments, ", "))
}

func (m Map) ToJSONValue() (v any, err error) {

jsonMap := make(map[string]any)
for k, v := range m {
jsonMap[k], err = v.ToJSONValue()
if err != nil {
return nil, err
}
}
return jsonMap, nil
}
4 changes: 4 additions & 0 deletions pkg/data/null.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ func (n *nullData) Equal(other format.Value) bool {
func (n *nullData) String() string {
return "null"
}

func (n *nullData) ToJSONValue() (v any, err error) {
return nil, nil
}
4 changes: 4 additions & 0 deletions pkg/data/number.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,7 @@ func (n *numberData) Equal(other format.Value) bool {
}
return false
}

func (n *numberData) ToJSONValue() (v any, err error) {
return n.Float64(), nil
}
4 changes: 4 additions & 0 deletions pkg/data/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ func (s *stringData) Equal(other format.Value) bool {
}
return false
}

func (s *stringData) ToJSONValue() (v any, err error) {
return s.Raw, nil
}
284 changes: 0 additions & 284 deletions pkg/recipe/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,9 @@ import (
"context"
"encoding/json"
"fmt"
"reflect"
"slices"
"strconv"
"strings"

"go/ast"
"go/token"

"google.golang.org/protobuf/types/known/structpb"

"github.com/instill-ai/pipeline-backend/pkg/constant"
Expand Down Expand Up @@ -238,285 +233,6 @@ func Render(ctx context.Context, template format.Value, batchIdx int, wfm memory
}
}

func EvalCondition(expr ast.Expr, value map[string]any) (any, error) {
switch e := (expr).(type) {
case *ast.UnaryExpr:
xRes, err := EvalCondition(e.X, value)
if err != nil {
return nil, err
}

switch e.Op {
case token.NOT: // !
switch xVal := xRes.(type) {
case bool:
return !xVal, nil
}
case token.SUB: // -
switch xVal := xRes.(type) {
case int64:
return -xVal, nil
case float64:
return -xVal, nil
}
}
case *ast.BinaryExpr:

xRes, err := EvalCondition(e.X, value)
if err != nil {
return nil, err
}
yRes, err := EvalCondition(e.Y, value)
if err != nil {
return nil, err
}

switch e.Op {
case token.LAND: // &&

xBool := false
yBool := false
switch xVal := xRes.(type) {
case int64, float64:
xBool = (xVal != 0)
case string:
xBool = (xVal != "")
case bool:
xBool = xVal
}
switch yVal := yRes.(type) {
case int64, float64:
yBool = (yVal != 0)
case string:
yBool = (yVal != "")
case bool:
yBool = yVal
}
return xBool && yBool, nil
case token.LOR: // ||

xBool := false
yBool := false
switch xVal := xRes.(type) {
case int64, float64:
xBool = (xVal != 0)
case string:
xBool = (xVal != "")
case bool:
xBool = xVal
}
switch yVal := yRes.(type) {
case int64, float64:
yBool = (yVal != 0)
case string:
yBool = (yVal != "")
case bool:
yBool = yVal
}
return xBool || yBool, nil

case token.EQL: // ==
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal == yVal, nil
case float64:
return float64(xVal) == yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal == float64(yVal), nil
case float64:
return xVal == yVal, nil
}
}
return reflect.DeepEqual(xRes, yRes), nil
case token.NEQ: // !=
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal != yVal, nil
case float64:
return float64(xVal) != yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal != float64(yVal), nil
case float64:
return xVal != yVal, nil
}
}
return !reflect.DeepEqual(xRes, yRes), nil

case token.LSS: // <
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal < yVal, nil
case float64:
return float64(xVal) < yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal < float64(yVal), nil
case float64:
return xVal < yVal, nil
}
}
case token.GTR: // >
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal > yVal, nil
case float64:
return float64(xVal) > yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal > float64(yVal), nil
case float64:
return xVal > yVal, nil
}
}

case token.LEQ: // <=
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal <= yVal, nil
case float64:
return float64(xVal) <= yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal <= float64(yVal), nil
case float64:
return xVal <= yVal, nil
}
}
case token.GEQ: // >=
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal >= yVal, nil
case float64:
return float64(xVal) >= yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal >= float64(yVal), nil
case float64:
return xVal >= yVal, nil
}
}
}

case *ast.ParenExpr:
return EvalCondition(e.X, value)
case *ast.SelectorExpr:
v, err := EvalCondition(e.X, value)
if err != nil {
return nil, err
}
// Convert InputsMemory and ComponentItemMemory into map[string]any.
// Ignore error handling here since all of them are JSON data.
b, _ := json.Marshal(v)
m := map[string]any{}
_ = json.Unmarshal(b, &m)
return m[e.Sel.String()], nil
case *ast.BasicLit:
if e.Kind == token.INT {
return strconv.ParseInt(e.Value, 10, 64)
}
if e.Kind == token.FLOAT {
return strconv.ParseFloat(e.Value, 64)
}
if e.Kind == token.STRING {
return e.Value[1 : len(e.Value)-1], nil
}
return e.Value, nil
case *ast.Ident:
if e.Name == "true" {
return true, nil
}
if e.Name == "false" {
return false, nil
}

return value[e.Name], nil

case *ast.IndexExpr:
v, err := EvalCondition(e.X, value)
if err != nil {
return nil, err
}
switch idxVal := e.Index.(type) {
case *ast.BasicLit:
// handle arr[index]
if idxVal.Kind == token.INT {
index, err := strconv.Atoi(idxVal.Value)
if err != nil {
return nil, err
}
return v.([]any)[index], nil
}
// handle obj[key]
if idxVal.Kind == token.STRING {
// key: remove ""
key := idxVal.Value[1 : len(idxVal.Value)-1]
return v.(map[string]any)[key], nil
}
}

}
return false, fmt.Errorf("condition error")
}

func SanitizeCondition(cond string) (string, map[string]string, map[string]string) {
varMapping := map[string]string{}
revVarMapping := map[string]string{}
varNameIdx := 0
for {
leftIdx := strings.Index(cond, "${")
if leftIdx == -1 {
break
}
rightIdx := strings.Index(cond, "}")

left := cond[:leftIdx]
v := cond[leftIdx+2 : rightIdx]
right := cond[rightIdx+1:]

srcName := strings.Split(strings.TrimSpace(v), ".")[0]
if varName, ok := revVarMapping[srcName]; ok {
varMapping[varName] = srcName
revVarMapping[srcName] = varName
cond = left + strings.ReplaceAll(v, srcName, varName) + right
} else {
varName := fmt.Sprintf("var%d", varNameIdx)
varMapping[varName] = srcName
revVarMapping[srcName] = varName
varNameIdx++
cond = left + strings.ReplaceAll(v, srcName, varName) + right
}

}

return cond, varMapping, revVarMapping
}

func GenerateDAG(componentMap datamodel.ComponentMap) (*dag, error) {

componentIDMap := make(map[string]bool)
Expand Down
Loading

0 comments on commit 4d932da

Please sign in to comment.