Skip to content

Commit 1965d0a

Browse files
committed
support _limit in update
1 parent b74c7dc commit 1965d0a

File tree

4 files changed

+89
-26
lines changed

4 files changed

+89
-26
lines changed

builder/builder.go

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,18 @@ var (
2121
errHavingUnsupportedOperator = errors.New(`[builder] "_having" contains unsupported operator`)
2222
errLockModeValueType = errors.New(`[builder] the value of "_lockMode" must be of string type`)
2323
errNotAllowedLockMode = errors.New(`[builder] the value of "_lockMode" is not allowed`)
24+
errUpdateLimitType = errors.New(`[builder] the value of "_limit" in update query must be one of int,uint,int64,uint64`)
2425

2526
errWhereInterfaceSliceType = `[builder] the value of "xxx %s" must be of []interface{} type`
2627
errEmptySliceCondition = `[builder] the value of "%s" must contain at least one element`
28+
29+
defaultIgnoreKeys = map[string]struct{}{
30+
"_orderby": struct{}{},
31+
"_groupby": struct{}{},
32+
"_having": struct{}{},
33+
"_limit": struct{}{},
34+
"_lockMode": struct{}{},
35+
}
2736
)
2837

2938
type whereMapSet struct {
@@ -59,37 +68,31 @@ func BuildSelect(table string, where map[string]interface{}, selectField []strin
5968
var groupBy string
6069
var having map[string]interface{}
6170
var lockMode string
62-
copiedWhere := copyWhere(where)
63-
if val, ok := copiedWhere["_orderby"]; ok {
71+
if val, ok := where["_orderby"]; ok {
6472
s, ok := val.(string)
6573
if !ok {
6674
err = errOrderByValueType
6775
return
6876
}
6977
orderBy = strings.TrimSpace(s)
70-
delete(copiedWhere, "_orderby")
7178
}
72-
if val, ok := copiedWhere["_groupby"]; ok {
79+
if val, ok := where["_groupby"]; ok {
7380
s, ok := val.(string)
7481
if !ok {
7582
err = errGroupByValueType
7683
return
7784
}
7885
groupBy = strings.TrimSpace(s)
79-
delete(copiedWhere, "_groupby")
8086
if "" != groupBy {
81-
if h, ok := copiedWhere["_having"]; ok {
87+
if h, ok := where["_having"]; ok {
8288
having, err = resolveHaving(h)
8389
if nil != err {
8490
return
8591
}
8692
}
8793
}
8894
}
89-
if _, ok := copiedWhere["_having"]; ok {
90-
delete(copiedWhere, "_having")
91-
}
92-
if val, ok := copiedWhere["_limit"]; ok {
95+
if val, ok := where["_limit"]; ok {
9396
arr, ok := val.([]uint)
9497
if !ok {
9598
err = errLimitValueType
@@ -108,9 +111,8 @@ func BuildSelect(table string, where map[string]interface{}, selectField []strin
108111
begin: begin,
109112
step: step,
110113
}
111-
delete(copiedWhere, "_limit")
112114
}
113-
if val, ok := copiedWhere["_lockMode"]; ok {
115+
if val, ok := where["_lockMode"]; ok {
114116
s, ok := val.(string)
115117
if !ok {
116118
err = errLockModeValueType
@@ -121,14 +123,13 @@ func BuildSelect(table string, where map[string]interface{}, selectField []strin
121123
err = errNotAllowedLockMode
122124
return
123125
}
124-
delete(copiedWhere, "_lockMode")
125126
}
126-
conditions, err := getWhereConditions(copiedWhere)
127+
conditions, err := getWhereConditions(where, defaultIgnoreKeys)
127128
if nil != err {
128129
return
129130
}
130131
if having != nil {
131-
havingCondition, err1 := getWhereConditions(having)
132+
havingCondition, err1 := getWhereConditions(having, defaultIgnoreKeys)
132133
if nil != err1 {
133134
err = err1
134135
return
@@ -169,16 +170,31 @@ func resolveHaving(having interface{}) (map[string]interface{}, error) {
169170

170171
// BuildUpdate work as its name says
171172
func BuildUpdate(table string, where map[string]interface{}, update map[string]interface{}) (string, []interface{}, error) {
172-
conditions, err := getWhereConditions(where)
173+
var limit uint
174+
if v, ok := where["_limit"]; ok {
175+
switch val := v.(type) {
176+
case int:
177+
limit = uint(val)
178+
case uint:
179+
limit = val
180+
case int64:
181+
limit = uint(val)
182+
case uint64:
183+
limit = uint(val)
184+
default:
185+
return "", nil, errUpdateLimitType
186+
}
187+
}
188+
conditions, err := getWhereConditions(where, defaultIgnoreKeys)
173189
if nil != err {
174190
return "", nil, err
175191
}
176-
return buildUpdate(table, update, conditions...)
192+
return buildUpdate(table, update, limit, conditions...)
177193
}
178194

179195
// BuildDelete work as its name says
180196
func BuildDelete(table string, where map[string]interface{}) (string, []interface{}, error) {
181-
conditions, err := getWhereConditions(where)
197+
conditions, err := getWhereConditions(where, defaultIgnoreKeys)
182198
if nil != err {
183199
return "", nil, err
184200
}
@@ -209,7 +225,7 @@ func isStringInSlice(str string, arr []string) bool {
209225
return false
210226
}
211227

212-
func getWhereConditions(where map[string]interface{}) ([]Comparable, error) {
228+
func getWhereConditions(where map[string]interface{}, ignoreKeys map[string]struct{}) ([]Comparable, error) {
213229
if len(where) == 0 {
214230
return nil, nil
215231
}
@@ -218,6 +234,9 @@ func getWhereConditions(where map[string]interface{}) ([]Comparable, error) {
218234
var field, operator string
219235
var err error
220236
for key, val := range where {
237+
if _, ok := ignoreKeys[key]; ok {
238+
continue
239+
}
221240
if key == "_or" {
222241
var (
223242
orWheres []map[string]interface{}
@@ -231,7 +250,7 @@ func getWhereConditions(where map[string]interface{}) ([]Comparable, error) {
231250
if orWhere == nil {
232251
continue
233252
}
234-
orNestWhere, err := getWhereConditions(orWhere)
253+
orNestWhere, err := getWhereConditions(orWhere, ignoreKeys)
235254
if nil != err {
236255
return nil, err
237256
}

builder/builder_test.go

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,46 @@ func Test_BuildUpdate(t *testing.T) {
447447
err: nil,
448448
},
449449
},
450+
{
451+
in: inStruct{
452+
table: "tb",
453+
where: map[string]interface{}{
454+
"foo": "bar",
455+
"age >=": 23,
456+
"sex in": []interface{}{"male", "female"},
457+
"_limit": 10,
458+
},
459+
setData: map[string]interface{}{
460+
"score": 50,
461+
"district": "010",
462+
},
463+
},
464+
out: outStruct{
465+
cond: "UPDATE tb SET district=?,score=? WHERE (foo=? AND sex IN (?,?) AND age>=?) LIMIT ?",
466+
vals: []interface{}{"010", 50, "bar", "male", "female", 23, 10},
467+
err: nil,
468+
},
469+
},
470+
{
471+
in: inStruct{
472+
table: "tb",
473+
where: map[string]interface{}{
474+
"foo": "bar",
475+
"age >=": 23,
476+
"sex in": []interface{}{"male", "female"},
477+
"_limit": 5.5,
478+
},
479+
setData: map[string]interface{}{
480+
"score": 50,
481+
"district": "010",
482+
},
483+
},
484+
out: outStruct{
485+
cond: "",
486+
vals: nil,
487+
err: errUpdateLimitType,
488+
},
489+
},
450490
}
451491
ass := assert.New(t)
452492
for _, tc := range data {
@@ -1245,13 +1285,13 @@ func TestNotLike_1(t *testing.T) {
12451285
func TestFixBug_insert_quote_field(t *testing.T) {
12461286
cond, vals, err := BuildInsert("tb", []map[string]interface{}{
12471287
{
1248-
"id": 1,
1288+
"id": 1,
12491289
"`order`": 2,
1250-
"`id`": 3, // I know this is forbidden, but just for test
1290+
"`id`": 3, // I know this is forbidden, but just for test
12511291
},
12521292
})
12531293
ass := assert.New(t)
12541294
ass.NoError(err)
12551295
ass.Equal("INSERT INTO tb (`id`,`order`,id) VALUES (?,?,?)", cond)
1256-
ass.Equal([]interface{}{3,2,1}, vals)
1257-
}
1296+
ass.Equal([]interface{}{3, 2, 1}, vals)
1297+
}

builder/dao.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ func buildInsert(table string, setMap []map[string]interface{}, insertType inser
397397
return fmt.Sprintf(format, insertType, quoteField(table), strings.Join(fields, ","), strings.Join(sets, ",")), vals, nil
398398
}
399399

400-
func buildUpdate(table string, update map[string]interface{}, conditions ...Comparable) (string, []interface{}, error) {
400+
func buildUpdate(table string, update map[string]interface{}, limit uint, conditions ...Comparable) (string, []interface{}, error) {
401401
format := "UPDATE %s SET %s"
402402
keys, vals := resolveKV(update)
403403
var sets string
@@ -411,6 +411,10 @@ func buildUpdate(table string, update map[string]interface{}, conditions ...Comp
411411
cond = fmt.Sprintf("%s WHERE %s", cond, whereString)
412412
vals = append(vals, whereVals...)
413413
}
414+
if limit > 0 {
415+
cond += " LIMIT ?"
416+
vals = append(vals, int(limit))
417+
}
414418
return cond, vals, nil
415419
}
416420

builder/dao_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ func TestBuildUpdate(t *testing.T) {
285285
}
286286
ass := assert.New(t)
287287
for _, tc := range data {
288-
cond, vals, err := buildUpdate(tc.table, tc.data, tc.conditions...)
288+
cond, vals, err := buildUpdate(tc.table, tc.data, 0, tc.conditions...)
289289
ass.Equal(tc.outErr, err)
290290
ass.Equal(tc.outStr, cond)
291291
ass.Equal(tc.outVals, vals)

0 commit comments

Comments
 (0)