Skip to content

refactor(pipeline): refactor the condition evaluation #952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pkg/data/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,14 @@ func (a Array) String() string {
}
return fmt.Sprintf("[%s]", strings.Join(segments, ", "))
}

func (a Array) ToJSONValue() (v any, err error) {
jsonArr := make([]any, len(a))
for i, v := range a {
jsonArr[i], err = v.ToJSONValue()
if err != nil {
return nil, err
}
}
return jsonArr, nil
}
4 changes: 4 additions & 0 deletions pkg/data/boolean.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ func (b *booleanData) Equal(other format.Value) bool {
func (b *booleanData) Hash() string {
return fmt.Sprintf("%t", b.Raw)
}

func (b *booleanData) ToJSONValue() (v any, err error) {
return b.Raw, nil
}
5 changes: 5 additions & 0 deletions pkg/data/bytearray.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,8 @@ func (b *byteArrayData) Equal(other format.Value) bool {
}
return false
}

func (b *byteArrayData) ToJSONValue() (v any, err error) {
base64str := base64.StdEncoding.EncodeToString(b.Raw)
return base64str, nil
}
8 changes: 8 additions & 0 deletions pkg/data/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,11 @@ func (f *fileData) Equal(other format.Value) bool {
}
return false
}

func (f *fileData) ToJSONValue() (v any, err error) {
base64str, err := f.Base64()
if err != nil {
return nil, err
}
return base64str.String(), nil
}
1 change: 1 addition & 0 deletions pkg/data/format/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
type Value interface {
IsValue()
ToStructValue() (v *structpb.Value, err error)
ToJSONValue() (v any, err error)
Get(p *path.Path) (v Value, err error)
Equal(other Value) bool
String() string
Expand Down
12 changes: 12 additions & 0 deletions pkg/data/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,15 @@ func (m Map) String() string {
}
return fmt.Sprintf("{%s}", strings.Join(segments, ", "))
}

func (m Map) ToJSONValue() (v any, err error) {

jsonMap := make(map[string]any)
for k, v := range m {
jsonMap[k], err = v.ToJSONValue()
if err != nil {
return nil, err
}
}
return jsonMap, nil
}
4 changes: 4 additions & 0 deletions pkg/data/null.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ func (n *nullData) Equal(other format.Value) bool {
func (n *nullData) String() string {
return "null"
}

func (n *nullData) ToJSONValue() (v any, err error) {
return nil, nil
}
4 changes: 4 additions & 0 deletions pkg/data/number.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,7 @@ func (n *numberData) Equal(other format.Value) bool {
}
return false
}

func (n *numberData) ToJSONValue() (v any, err error) {
return n.Float64(), nil
}
4 changes: 4 additions & 0 deletions pkg/data/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ func (s *stringData) Equal(other format.Value) bool {
}
return false
}

func (s *stringData) ToJSONValue() (v any, err error) {
return s.Raw, nil
}
284 changes: 0 additions & 284 deletions pkg/recipe/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,9 @@ import (
"context"
"encoding/json"
"fmt"
"reflect"
"slices"
"strconv"
"strings"

"go/ast"
"go/token"

"google.golang.org/protobuf/types/known/structpb"

"github.com/instill-ai/pipeline-backend/pkg/constant"
Expand Down Expand Up @@ -238,285 +233,6 @@ func Render(ctx context.Context, template format.Value, batchIdx int, wfm memory
}
}

func EvalCondition(expr ast.Expr, value map[string]any) (any, error) {
switch e := (expr).(type) {
case *ast.UnaryExpr:
xRes, err := EvalCondition(e.X, value)
if err != nil {
return nil, err
}

switch e.Op {
case token.NOT: // !
switch xVal := xRes.(type) {
case bool:
return !xVal, nil
}
case token.SUB: // -
switch xVal := xRes.(type) {
case int64:
return -xVal, nil
case float64:
return -xVal, nil
}
}
case *ast.BinaryExpr:

xRes, err := EvalCondition(e.X, value)
if err != nil {
return nil, err
}
yRes, err := EvalCondition(e.Y, value)
if err != nil {
return nil, err
}

switch e.Op {
case token.LAND: // &&

xBool := false
yBool := false
switch xVal := xRes.(type) {
case int64, float64:
xBool = (xVal != 0)
case string:
xBool = (xVal != "")
case bool:
xBool = xVal
}
switch yVal := yRes.(type) {
case int64, float64:
yBool = (yVal != 0)
case string:
yBool = (yVal != "")
case bool:
yBool = yVal
}
return xBool && yBool, nil
case token.LOR: // ||

xBool := false
yBool := false
switch xVal := xRes.(type) {
case int64, float64:
xBool = (xVal != 0)
case string:
xBool = (xVal != "")
case bool:
xBool = xVal
}
switch yVal := yRes.(type) {
case int64, float64:
yBool = (yVal != 0)
case string:
yBool = (yVal != "")
case bool:
yBool = yVal
}
return xBool || yBool, nil

case token.EQL: // ==
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal == yVal, nil
case float64:
return float64(xVal) == yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal == float64(yVal), nil
case float64:
return xVal == yVal, nil
}
}
return reflect.DeepEqual(xRes, yRes), nil
case token.NEQ: // !=
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal != yVal, nil
case float64:
return float64(xVal) != yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal != float64(yVal), nil
case float64:
return xVal != yVal, nil
}
}
return !reflect.DeepEqual(xRes, yRes), nil

case token.LSS: // <
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal < yVal, nil
case float64:
return float64(xVal) < yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal < float64(yVal), nil
case float64:
return xVal < yVal, nil
}
}
case token.GTR: // >
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal > yVal, nil
case float64:
return float64(xVal) > yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal > float64(yVal), nil
case float64:
return xVal > yVal, nil
}
}

case token.LEQ: // <=
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal <= yVal, nil
case float64:
return float64(xVal) <= yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal <= float64(yVal), nil
case float64:
return xVal <= yVal, nil
}
}
case token.GEQ: // >=
switch xVal := xRes.(type) {
case int64:
switch yVal := yRes.(type) {
case int64:
return xVal >= yVal, nil
case float64:
return float64(xVal) >= yVal, nil
}
case float64:
switch yVal := yRes.(type) {
case int64:
return xVal >= float64(yVal), nil
case float64:
return xVal >= yVal, nil
}
}
}

case *ast.ParenExpr:
return EvalCondition(e.X, value)
case *ast.SelectorExpr:
v, err := EvalCondition(e.X, value)
if err != nil {
return nil, err
}
// Convert InputsMemory and ComponentItemMemory into map[string]any.
// Ignore error handling here since all of them are JSON data.
b, _ := json.Marshal(v)
m := map[string]any{}
_ = json.Unmarshal(b, &m)
return m[e.Sel.String()], nil
case *ast.BasicLit:
if e.Kind == token.INT {
return strconv.ParseInt(e.Value, 10, 64)
}
if e.Kind == token.FLOAT {
return strconv.ParseFloat(e.Value, 64)
}
if e.Kind == token.STRING {
return e.Value[1 : len(e.Value)-1], nil
}
return e.Value, nil
case *ast.Ident:
if e.Name == "true" {
return true, nil
}
if e.Name == "false" {
return false, nil
}

return value[e.Name], nil

case *ast.IndexExpr:
v, err := EvalCondition(e.X, value)
if err != nil {
return nil, err
}
switch idxVal := e.Index.(type) {
case *ast.BasicLit:
// handle arr[index]
if idxVal.Kind == token.INT {
index, err := strconv.Atoi(idxVal.Value)
if err != nil {
return nil, err
}
return v.([]any)[index], nil
}
// handle obj[key]
if idxVal.Kind == token.STRING {
// key: remove ""
key := idxVal.Value[1 : len(idxVal.Value)-1]
return v.(map[string]any)[key], nil
}
}

}
return false, fmt.Errorf("condition error")
}

func SanitizeCondition(cond string) (string, map[string]string, map[string]string) {
varMapping := map[string]string{}
revVarMapping := map[string]string{}
varNameIdx := 0
for {
leftIdx := strings.Index(cond, "${")
if leftIdx == -1 {
break
}
rightIdx := strings.Index(cond, "}")

left := cond[:leftIdx]
v := cond[leftIdx+2 : rightIdx]
right := cond[rightIdx+1:]

srcName := strings.Split(strings.TrimSpace(v), ".")[0]
if varName, ok := revVarMapping[srcName]; ok {
varMapping[varName] = srcName
revVarMapping[srcName] = varName
cond = left + strings.ReplaceAll(v, srcName, varName) + right
} else {
varName := fmt.Sprintf("var%d", varNameIdx)
varMapping[varName] = srcName
revVarMapping[srcName] = varName
varNameIdx++
cond = left + strings.ReplaceAll(v, srcName, varName) + right
}

}

return cond, varMapping, revVarMapping
}

func GenerateDAG(componentMap datamodel.ComponentMap) (*dag, error) {

componentIDMap := make(map[string]bool)
Expand Down
Loading
Loading