Skip to content

Commit c680c16

Browse files
Add foreign key support for insert on duplicate key update (vitessio#14638)
Signed-off-by: Harshit Gangal <harshit@planetscale.com>
1 parent 548c7d8 commit c680c16

File tree

9 files changed

+1161
-60
lines changed

9 files changed

+1161
-60
lines changed

go/test/endtoend/vtgate/foreignkey/fk_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,37 @@ func TestReplaceWithFK(t *testing.T) {
12081208
utils.AssertMatches(t, conn, `select * from u_t2`, `[[INT64(1) NULL] [INT64(2) NULL]]`)
12091209
}
12101210

1211+
// TestInsertWithFKOnDup tests that insertion with on duplicate key update works as expected.
1212+
func TestInsertWithFKOnDup(t *testing.T) {
1213+
mcmp, closer := start(t)
1214+
defer closer()
1215+
1216+
utils.Exec(t, mcmp.VtConn, "use `uks`")
1217+
1218+
// insert some data.
1219+
mcmp.Exec(`insert into u_t1(id, col1) values (100, 1), (200, 2), (300, 3), (400, 4)`)
1220+
mcmp.Exec(`insert into u_t2(id, col2) values (1000, 1), (2000, 2), (3000, 3), (4000, 4)`)
1221+
1222+
// updating child to an existing value in parent.
1223+
mcmp.Exec(`insert into u_t2(id, col2) values (4000, 50) on duplicate key update col2 = 1`)
1224+
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) INT64(1)] [INT64(2000) INT64(2)] [INT64(3000) INT64(3)] [INT64(4000) INT64(1)]]`)
1225+
1226+
// updating parent, value not referred in child.
1227+
mcmp.Exec(`insert into u_t1(id, col1) values (400, 50) on duplicate key update col1 = values(col1)`)
1228+
mcmp.AssertMatches(`select * from u_t1 order by id`, `[[INT64(100) INT64(1)] [INT64(200) INT64(2)] [INT64(300) INT64(3)] [INT64(400) INT64(50)]]`)
1229+
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) INT64(1)] [INT64(2000) INT64(2)] [INT64(3000) INT64(3)] [INT64(4000) INT64(1)]]`)
1230+
1231+
// updating parent, child updated to null.
1232+
mcmp.Exec(`insert into u_t1(id, col1) values (100, 75) on duplicate key update col1 = values(col1)`)
1233+
mcmp.AssertMatches(`select * from u_t1 order by id`, `[[INT64(100) INT64(75)] [INT64(200) INT64(2)] [INT64(300) INT64(3)] [INT64(400) INT64(50)]]`)
1234+
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) NULL] [INT64(2000) INT64(2)] [INT64(3000) INT64(3)] [INT64(4000) NULL]]`)
1235+
1236+
// inserting multiple rows in parent, some child rows updated to null.
1237+
mcmp.Exec(`insert into u_t1(id, col1) values (100, 42),(600, 2),(300, 24),(200, 2) on duplicate key update col1 = values(col1)`)
1238+
mcmp.AssertMatches(`select * from u_t1 order by id`, `[[INT64(100) INT64(42)] [INT64(200) INT64(2)] [INT64(300) INT64(24)] [INT64(400) INT64(50)] [INT64(600) INT64(2)]]`)
1239+
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) NULL] [INT64(2000) INT64(2)] [INT64(3000) NULL] [INT64(4000) NULL]]`)
1240+
}
1241+
12111242
// TestDDLFk tests that table is created with fk constraint when foreign_key_checks is off.
12121243
func TestDDLFk(t *testing.T) {
12131244
mcmp, closer := start(t)

go/vt/vtgate/engine/cached_size.go

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

go/vt/vtgate/engine/upsert.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
Copyright 2023 The Vitess Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package engine
18+
19+
import (
20+
"context"
21+
"fmt"
22+
23+
"vitess.io/vitess/go/sqltypes"
24+
querypb "vitess.io/vitess/go/vt/proto/query"
25+
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
26+
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
27+
"vitess.io/vitess/go/vt/vterrors"
28+
)
29+
30+
var _ Primitive = (*Upsert)(nil)
31+
32+
// Upsert Primitive will execute the insert primitive first and
33+
// if there is `Duplicate Key` error, it executes the update primitive.
34+
type Upsert struct {
35+
Upserts []upsert
36+
37+
txNeeded
38+
}
39+
40+
type upsert struct {
41+
Insert Primitive
42+
Update Primitive
43+
}
44+
45+
// AddUpsert appends to the Upsert Primitive.
46+
func (u *Upsert) AddUpsert(ins, upd Primitive) {
47+
u.Upserts = append(u.Upserts, upsert{
48+
Insert: ins,
49+
Update: upd,
50+
})
51+
}
52+
53+
// RouteType implements Primitive interface type.
54+
func (u *Upsert) RouteType() string {
55+
return "UPSERT"
56+
}
57+
58+
// GetKeyspaceName implements Primitive interface type.
59+
func (u *Upsert) GetKeyspaceName() string {
60+
if len(u.Upserts) > 0 {
61+
return u.Upserts[0].Insert.GetKeyspaceName()
62+
}
63+
return ""
64+
}
65+
66+
// GetTableName implements Primitive interface type.
67+
func (u *Upsert) GetTableName() string {
68+
if len(u.Upserts) > 0 {
69+
return u.Upserts[0].Insert.GetTableName()
70+
}
71+
return ""
72+
}
73+
74+
// GetFields implements Primitive interface type.
75+
func (u *Upsert) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
76+
return nil, vterrors.VT13001("unexpected to receive GetFields call for insert on duplicate key update query")
77+
}
78+
79+
// TryExecute implements Primitive interface type.
80+
func (u *Upsert) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
81+
result := &sqltypes.Result{}
82+
for _, up := range u.Upserts {
83+
qr, err := execOne(ctx, vcursor, bindVars, wantfields, up)
84+
if err != nil {
85+
return nil, err
86+
}
87+
result.RowsAffected += qr.RowsAffected
88+
}
89+
return result, nil
90+
}
91+
92+
func execOne(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, up upsert) (*sqltypes.Result, error) {
93+
insQr, err := vcursor.ExecutePrimitive(ctx, up.Insert, bindVars, wantfields)
94+
if err == nil {
95+
return insQr, nil
96+
}
97+
if vterrors.Code(err) != vtrpcpb.Code_ALREADY_EXISTS {
98+
return nil, err
99+
}
100+
updQr, err := vcursor.ExecutePrimitive(ctx, up.Update, bindVars, wantfields)
101+
if err != nil {
102+
return nil, err
103+
}
104+
// To match mysql, need to report +1 on rows affected if there is any change.
105+
if updQr.RowsAffected > 0 {
106+
updQr.RowsAffected += 1
107+
}
108+
return updQr, nil
109+
}
110+
111+
// TryStreamExecute implements Primitive interface type.
112+
func (u *Upsert) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
113+
qr, err := u.TryExecute(ctx, vcursor, bindVars, wantfields)
114+
if err != nil {
115+
return err
116+
}
117+
return callback(qr)
118+
}
119+
120+
// Inputs implements Primitive interface type.
121+
func (u *Upsert) Inputs() ([]Primitive, []map[string]any) {
122+
var inputs []Primitive
123+
var inputsMap []map[string]any
124+
for i, up := range u.Upserts {
125+
inputs = append(inputs, up.Insert, up.Update)
126+
inputsMap = append(inputsMap,
127+
map[string]any{inputName: fmt.Sprintf("Insert-%d", i+1)},
128+
map[string]any{inputName: fmt.Sprintf("Update-%d", i+1)})
129+
}
130+
return inputs, inputsMap
131+
}
132+
133+
func (u *Upsert) description() PrimitiveDescription {
134+
return PrimitiveDescription{
135+
OperatorType: "Upsert",
136+
TargetTabletType: topodatapb.TabletType_PRIMARY,
137+
}
138+
}

go/vt/vtgate/planbuilder/operator_transformers.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op operators.Opera
6666
return transformFkVerify(ctx, op)
6767
case *operators.InsertSelection:
6868
return transformInsertionSelection(ctx, op)
69+
case *operators.Upsert:
70+
return transformUpsert(ctx, op)
6971
case *operators.HashJoin:
7072
return transformHashJoin(ctx, op)
7173
case *operators.Sequential:
@@ -75,6 +77,31 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op operators.Opera
7577
return nil, vterrors.VT13001(fmt.Sprintf("unknown type encountered: %T (transformToLogicalPlan)", op))
7678
}
7779

80+
func transformUpsert(ctx *plancontext.PlanningContext, op *operators.Upsert) (logicalPlan, error) {
81+
u := &upsert{}
82+
for _, source := range op.Sources {
83+
iLp, uLp, err := transformOneUpsert(ctx, source)
84+
if err != nil {
85+
return nil, err
86+
}
87+
u.insert = append(u.insert, iLp)
88+
u.update = append(u.update, uLp)
89+
}
90+
return u, nil
91+
}
92+
93+
func transformOneUpsert(ctx *plancontext.PlanningContext, source operators.UpsertSource) (iLp, uLp logicalPlan, err error) {
94+
iLp, err = transformToLogicalPlan(ctx, source.Insert)
95+
if err != nil {
96+
return
97+
}
98+
if ins, ok := iLp.(*insert); ok {
99+
ins.eInsert.PreventAutoCommit = true
100+
}
101+
uLp, err = transformToLogicalPlan(ctx, source.Update)
102+
return
103+
}
104+
78105
func transformSequential(ctx *plancontext.PlanningContext, op *operators.Sequential) (logicalPlan, error) {
79106
var lps []logicalPlan
80107
for _, source := range op.Sources {

0 commit comments

Comments
 (0)