From 7177fd7e9fc0363efab233099aa048c0aa6c7046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Sedl=C3=A1=C4=8Dek?= Date: Thu, 22 Aug 2024 16:24:44 +0200 Subject: [PATCH] db.In, db.NotIn iterating slices (#21) --- Makefile | 5 ++++ db/cond.go | 55 +++++++++++++++++++++++++++++++++++++----- db/cond_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 116 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index f2eae69..e497a98 100644 --- a/Makefile +++ b/Makefile @@ -42,8 +42,13 @@ test: test-clean test-all: test-clean GOGC=off go test $(TEST_FLAGS) $(MOD_VENDOR) -run=$(TEST) ./... +test-all-tparse: test-clean + GOGC=off go test $(TEST_FLAGS) $(MOD_VENDOR) -run=$(TEST) ./... -json | tparse --follow + test-with-reset: db-reset test-all +test-with-reset-tparse: db-reset test-all-tparse + test-clean: GOGC=off go clean -testcache diff --git a/db/cond.go b/db/cond.go index c808c1c..b9ce159 100644 --- a/db/cond.go +++ b/db/cond.go @@ -2,6 +2,7 @@ package db import ( "fmt" + "reflect" "strings" "github.com/Masterminds/squirrel" @@ -80,7 +81,6 @@ func (n *binaryExprNode) ToSql() (string, []interface{}, error) { func compileNodes(nodes []squirrel.Sqlizer) (q string, args []interface{}, err error) { for i, node := range nodes { qn, argsn, err := node.ToSql() - if err != nil { return "", nil, fmt.Errorf("error compiling node %d: %w", i, err) } @@ -203,12 +203,12 @@ func NotILike(v interface{}) squirrel.Sqlizer { // In represents an IN operator. The value must be variadic. func In[T interface{}](v ...T) squirrel.Sqlizer { - return Func[T]("IN", v...) + return Func("IN", v...) } // NotIn represents a NOT IN operator. The value must be variadic. func NotIn[T interface{}](v ...T) squirrel.Sqlizer { - return Func[T]("NOT IN", v...) + return Func("NOT IN", v...) } // Raw represents a raw SQL expression. @@ -226,15 +226,58 @@ func Func[T interface{}](name string, params ...T) squirrel.Sqlizer { } places := make([]string, len(params)) - args := make([]interface{}, 0, len(params)) + // iterating through slices + if reflect.TypeOf(params[0]).Kind() == reflect.Slice { + elements := 0 + for _, subSlice := range params { + v := reflect.ValueOf(subSlice) + elements += v.Len() + } + + args := make([]interface{}, 0, elements) + + for i, subSlice := range params { + subSliceVal := reflect.ValueOf(subSlice) + subPlaces := make([]string, subSliceVal.Len()) + + for j := 0; j < subSliceVal.Len(); j++ { + val := subSliceVal.Index(j).Interface() + if sqlizer, ok := interface{}(val).(squirrel.Sqlizer); ok { + paramSQL, paramArgs, err := sqlizer.ToSql() + if err != nil { + return "", nil, fmt.Errorf("%s: error compiling argument %d: %w", name, i, err) + } + + subPlaces[j] = paramSQL + args = append(args, paramArgs...) + } else if reflect.TypeOf(val).Kind() == reflect.Slice { + v := reflect.ValueOf(val) + for k := 0; k < v.Len(); k++ { + subPlaces[j] = paramPlaceholder + args = append(args, v.Index(k).Interface()) + } + } else { + subPlaces[j] = paramPlaceholder + args = append(args, val) + } + } + + places[i] = "(" + strings.Join(subPlaces, ",") + ")" + } + + return name + " (" + strings.Join(places, ",") + ")", args, nil + } + + args := make([]interface{}, 0, len(params)) for i, param := range params { if sqlizer, ok := interface{}(param).(squirrel.Sqlizer); ok { - paramSql, paramArgs, err := sqlizer.ToSql() + paramSQL, paramArgs, err := sqlizer.ToSql() if err != nil { return "", nil, fmt.Errorf("%s: error compiling argument %d: %w", name, i, err) } - places[i] = paramSql + + places[i] = paramSQL args = append(args, paramArgs...) } else { places[i] = paramPlaceholder diff --git a/db/cond_test.go b/db/cond_test.go index 45d1042..5be3f2a 100644 --- a/db/cond_test.go +++ b/db/cond_test.go @@ -4,13 +4,13 @@ import ( "testing" sq "github.com/Masterminds/squirrel" - "github.com/goware/pgkit/v2/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/goware/pgkit/v2/db" ) func TestCond(t *testing.T) { - t.Run("equal to", func(t *testing.T) { cond := db.Cond{"one": 1} s, args, err := cond.ToSql() @@ -19,6 +19,14 @@ func TestCond(t *testing.T) { assert.Equal(t, "one = ?", s) }) + t.Run("equal to with multiple parameters", func(t *testing.T) { + cond := db.And{db.Cond{"one": 1}, db.Cond{"two": 2}} + s, args, err := cond.ToSql() + require.NoError(t, err) + assert.Equal(t, []interface{}{1, 2}, args) + assert.Equal(t, "(one = ? AND two = ?)", s) + }) + t.Run("equal to (inverted)", func(t *testing.T) { cond := db.Cond{1: "one"} s, args, err := cond.ToSql() @@ -64,6 +72,16 @@ func TestCond(t *testing.T) { }) t.Run("IN with slice", func(t *testing.T) { + sl1 := []int{1, 2, 3} + cond := db.Cond{"list": db.In(sl1...)} + s, args, err := cond.ToSql() + require.NoError(t, err) + + assert.Equal(t, []interface{}{1, 2, 3}, args) + assert.Equal(t, "list IN (?, ?, ?)", s) + }) + + t.Run("IN with slice variadic", func(t *testing.T) { cond := db.Cond{"list": db.In(1, 2, 3)} s, args, err := cond.ToSql() require.NoError(t, err) @@ -72,6 +90,48 @@ func TestCond(t *testing.T) { assert.Equal(t, "list IN (?, ?, ?)", s) }) + t.Run("multiple IN with slice", func(t *testing.T) { + sl1 := []int{1, 2, 3} + sl2 := []int{4, 5, 6} + cond := db.Cond{"list": db.In([]interface{}{sl1, sl2}...)} + s, args, err := cond.ToSql() + require.NoError(t, err) + + assert.Equal(t, []interface{}{1, 2, 3, 4, 5, 6}, args) + assert.Equal(t, "list IN ((?,?,?),(?,?,?))", s) + }) + + t.Run("multiple IN with slice AND where ID", func(t *testing.T) { + cond := db.And{db.Cond{"list": db.In([][]string{{"1", "2", "3"}, {"3", "4", "5"}}...)}, db.Cond{"id": 1}} + s, args, err := cond.ToSql() + require.NoError(t, err) + + assert.Equal(t, []interface{}{"1", "2", "3", "3", "4", "5", 1}, args) + assert.Equal(t, "(list IN ((?,?,?),(?,?,?)) AND id = ?)", s) + }) + + t.Run("multiple IN with struct", func(t *testing.T) { + randomStruct := []struct { + Id uint64 + Name string + }{ + {Id: 1, Name: "Lukas"}, + {Id: 2, Name: "David"}, + } + + data := [][]interface{}{} + for _, s := range randomStruct { + data = append(data, []interface{}{s.Id, s.Name}) + } + + cond := db.Cond{"list": db.In(data...)} + s, args, err := cond.ToSql() + require.NoError(t, err) + + assert.Equal(t, []interface{}{uint64(1), "Lukas", uint64(2), "David"}, args) + assert.Equal(t, "list IN ((?,?),(?,?))", s) + }) + t.Run("NOT IN", func(t *testing.T) { cond := db.Cond{"list": db.NotIn("Czech Republic", "Slovakia")} s, args, err := cond.ToSql()