Skip to content

Commit deea764

Browse files
bgdnxtgfk
and
gfk
authoredDec 6, 2022
feat: mssql and pg merge query (uptrace#723)
* chore: support MergeQuery for MSSQL and PostgreSQL Co-authored-by: gfk <gfk@bb.io>
1 parent 996fead commit deea764

19 files changed

+424
-3
lines changed
 

‎db.go

+12
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ func (db *DB) NewValues(model interface{}) *ValuesQuery {
8282
return NewValuesQuery(db, model)
8383
}
8484

85+
func (db *DB) NewMerge() *MergeQuery {
86+
return NewMergeQuery(db)
87+
}
88+
8589
func (db *DB) NewSelect() *SelectQuery {
8690
return NewSelectQuery(db)
8791
}
@@ -330,6 +334,10 @@ func (c Conn) NewValues(model interface{}) *ValuesQuery {
330334
return NewValuesQuery(c.db, model).Conn(c)
331335
}
332336

337+
func (c Conn) NewMerge() *MergeQuery {
338+
return NewMergeQuery(c.db).Conn(c)
339+
}
340+
333341
func (c Conn) NewSelect() *SelectQuery {
334342
return NewSelectQuery(c.db).Conn(c)
335343
}
@@ -640,6 +648,10 @@ func (tx Tx) NewValues(model interface{}) *ValuesQuery {
640648
return NewValuesQuery(tx.db, model).Conn(tx)
641649
}
642650

651+
func (tx Tx) NewMerge() *MergeQuery {
652+
return NewMergeQuery(tx.db).Conn(tx)
653+
}
654+
643655
func (tx Tx) NewSelect() *SelectQuery {
644656
return NewSelectQuery(tx.db).Conn(tx)
645657
}

‎internal/dbtest/mssql_test.go

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package dbtest_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestMssqlMerge(t *testing.T) {
10+
db := mssql2019(t)
11+
defer db.Close()
12+
13+
type Model struct {
14+
ID int64 `bun:",pk,autoincrement"`
15+
16+
Name string
17+
Value string
18+
}
19+
20+
err := db.ResetModel(ctx, (*Model)(nil))
21+
require.NoError(t, err)
22+
23+
_, err = db.NewInsert().Model(&Model{Name: "A", Value: "hello"}).Exec(ctx)
24+
require.NoError(t, err)
25+
26+
newModels := []*Model{
27+
{
28+
Name: "A",
29+
Value: "world",
30+
},
31+
{
32+
Name: "B",
33+
Value: "test",
34+
},
35+
}
36+
37+
changes := []string{}
38+
_, err = db.NewMerge().
39+
Model(&Model{}).
40+
With("_data", db.NewValues(&newModels)).
41+
Using("_data").
42+
On("?TableAlias.name = _data.name").
43+
When("MATCHED THEN UPDATE SET ?TableAlias.value = _data.value").
44+
When("NOT MATCHED THEN INSERT (name, value) VALUES (_data.name, _data.value)").
45+
Returning("$action").
46+
Exec(ctx, &changes)
47+
require.NoError(t, err)
48+
49+
require.Len(t, changes, 2)
50+
require.Equal(t, "UPDATE", changes[0])
51+
require.Equal(t, "INSERT", changes[1])
52+
53+
}

‎internal/dbtest/query_test.go

+46
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,52 @@ func TestQuery(t *testing.T) {
923923
}
924924
return db.NewSelect().Model(new(Model)).Relation("SoftDelete")
925925
},
926+
func(db *bun.DB) schema.QueryAppender {
927+
type Model struct {
928+
ID int64 `bun:",pk,autoincrement"`
929+
Name string
930+
Value string
931+
}
932+
933+
newModels := []*Model{
934+
{Name: "A", Value: "world"},
935+
{Name: "B", Value: "test"},
936+
}
937+
938+
return db.NewMerge().
939+
Model(new(Model)).
940+
With("_data", db.NewValues(&newModels)).
941+
Using("_data").
942+
On("?TableAlias.name = _data.name").
943+
WhenUpdate("MATCHED", func(q *bun.UpdateQuery) *bun.UpdateQuery {
944+
return q.Set("value = _data.value")
945+
}).
946+
WhenInsert("NOT MATCHED", func(q *bun.InsertQuery) *bun.InsertQuery {
947+
return q.Value("name", "_data.name").Value("value", "_data.value")
948+
}).
949+
Returning("$action")
950+
},
951+
func(db *bun.DB) schema.QueryAppender {
952+
type Model struct {
953+
ID int64 `bun:",pk,autoincrement"`
954+
Name string
955+
Value string
956+
}
957+
958+
newModels := []*Model{
959+
{Name: "A", Value: "world"},
960+
{Name: "B", Value: "test"},
961+
}
962+
963+
return db.NewMerge().
964+
Model(new(Model)).
965+
With("_data", db.NewValues(&newModels)).
966+
Using("_data").
967+
On("?TableAlias.name = _data.name").
968+
WhenDelete("MATCHED").
969+
When("NOT MATCHED THEN INSERT (name, value) VALUES (_data.name, _data.value)").
970+
Returning("$action")
971+
},
926972
}
927973

928974
timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bun: merge not supported for current dialect
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bun: merge not supported for current dialect
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
WITH "_data" AS (SELECT * FROM (VALUES (NULL, 'A', 'world'), (NULL, 'B', 'test')) AS t ("id", "name", "value")) MERGE "models" AS "model" USING _data ON "model".name = _data.name WHEN MATCHED THEN UPDATE SET value = _data.value WHEN NOT MATCHED THEN INSERT ("name", "value") VALUES (_data.name, _data.value) OUTPUT $action;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
WITH "_data" AS (SELECT * FROM (VALUES (NULL, 'A', 'world'), (NULL, 'B', 'test')) AS t ("id", "name", "value")) MERGE "models" AS "model" USING _data ON "model".name = _data.name WHEN MATCHED THEN DELETE WHEN NOT MATCHED THEN INSERT (name, value) VALUES (_data.name, _data.value) OUTPUT $action;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bun: merge not supported for current dialect
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bun: merge not supported for current dialect
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bun: merge not supported for current dialect
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bun: merge not supported for current dialect
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
WITH "_data" ("id", "name", "value") AS (VALUES (NULL::BIGINT, 'A'::VARCHAR, 'world'::VARCHAR), (NULL::BIGINT, 'B'::VARCHAR, 'test'::VARCHAR)) MERGE INTO "models" AS "model" USING _data ON "model".name = _data.name WHEN MATCHED THEN UPDATE SET value = _data.value WHEN NOT MATCHED THEN INSERT ("id", "name", "value") VALUES (DEFAULT, _data.name, _data.value);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
WITH "_data" ("id", "name", "value") AS (VALUES (NULL::BIGINT, 'A'::VARCHAR, 'world'::VARCHAR), (NULL::BIGINT, 'B'::VARCHAR, 'test'::VARCHAR)) MERGE INTO "models" AS "model" USING _data ON "model".name = _data.name WHEN MATCHED THEN DELETE WHEN NOT MATCHED THEN INSERT (name, value) VALUES (_data.name, _data.value);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
WITH "_data" ("id", "name", "value") AS (VALUES (NULL::BIGINT, 'A'::VARCHAR, 'world'::VARCHAR), (NULL::BIGINT, 'B'::VARCHAR, 'test'::VARCHAR)) MERGE INTO "models" AS "model" USING _data ON "model".name = _data.name WHEN MATCHED THEN UPDATE SET value = _data.value WHEN NOT MATCHED THEN INSERT ("id", "name", "value") VALUES (DEFAULT, _data.name, _data.value);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
WITH "_data" ("id", "name", "value") AS (VALUES (NULL::BIGINT, 'A'::VARCHAR, 'world'::VARCHAR), (NULL::BIGINT, 'B'::VARCHAR, 'test'::VARCHAR)) MERGE INTO "models" AS "model" USING _data ON "model".name = _data.name WHEN MATCHED THEN DELETE WHEN NOT MATCHED THEN INSERT (name, value) VALUES (_data.name, _data.value);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bun: merge not supported for current dialect
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bun: merge not supported for current dialect

‎query_insert.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
192192
return nil, err
193193
}
194194

195-
b, err = q.appendColumnsValues(fmter, b)
195+
b, err = q.appendColumnsValues(fmter, b, false)
196196
if err != nil {
197197
return nil, err
198198
}
@@ -214,7 +214,7 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
214214
}
215215

216216
func (q *InsertQuery) appendColumnsValues(
217-
fmter schema.Formatter, b []byte,
217+
fmter schema.Formatter, b []byte, skipOutput bool,
218218
) (_ []byte, err error) {
219219
if q.hasMultiTables() {
220220
if q.columns != nil {
@@ -275,7 +275,7 @@ func (q *InsertQuery) appendColumnsValues(
275275
b = q.appendFields(fmter, b, fields)
276276
b = append(b, ")"...)
277277

278-
if q.hasFeature(feature.Output) && q.hasReturning() {
278+
if q.hasFeature(feature.Output) && q.hasReturning() && !skipOutput {
279279
b = append(b, " OUTPUT "...)
280280
b, err = q.appendOutput(fmter, b)
281281
if err != nil {

0 commit comments

Comments
 (0)