From e802cffff2653e3615f63f2e8622bf9dabceca5d Mon Sep 17 00:00:00 2001 From: Manan Gupta <35839558+GuptaManan100@users.noreply.github.com> Date: Wed, 1 Nov 2023 22:17:35 +0530 Subject: [PATCH] Add cycle detection for foreign keys (#14339) Signed-off-by: Manan Gupta --- go/test/endtoend/utils/utils.go | 27 ++++ go/test/endtoend/vtgate/foreignkey/fk_test.go | 23 +++ go/test/vschemawrapper/vschema_wrapper.go | 4 + go/vt/graph/graph.go | 119 ++++++++++++++ go/vt/graph/graph_test.go | 153 ++++++++++++++++++ go/vt/schemadiff/semantics.go | 4 + go/vt/vterrors/code.go | 3 +- .../vtgate/planbuilder/plancontext/vschema.go | 3 + .../testdata/vindex_func_cases.json | 18 +-- go/vt/vtgate/semantics/FakeSI.go | 12 ++ go/vt/vtgate/semantics/analyzer.go | 5 + go/vt/vtgate/semantics/analyzer_test.go | 22 +++ go/vt/vtgate/semantics/info_schema.go | 4 + go/vt/vtgate/semantics/semantic_state.go | 1 + go/vt/vtgate/vcursor_impl.go | 8 + go/vt/vtgate/vschema_manager.go | 46 ++++++ go/vt/vtgate/vschema_manager_test.go | 104 +++++++++++- 17 files changed, 544 insertions(+), 12 deletions(-) create mode 100644 go/vt/graph/graph.go create mode 100644 go/vt/graph/graph_test.go diff --git a/go/test/endtoend/utils/utils.go b/go/test/endtoend/utils/utils.go index aae791c3064..fa270ba30a0 100644 --- a/go/test/endtoend/utils/utils.go +++ b/go/test/endtoend/utils/utils.go @@ -253,6 +253,33 @@ func WaitForAuthoritative(t *testing.T, ks, tbl string, readVSchema func() (*int } } +// WaitForKsError waits for the ks error field to be populated and returns it. +func WaitForKsError(t *testing.T, vtgateProcess cluster.VtgateProcess, ks string) string { + timeout := time.After(60 * time.Second) + for { + select { + case <-timeout: + t.Fatalf("schema tracking did not find error in '%s'", ks) + return "" + default: + time.Sleep(1 * time.Second) + res, err := vtgateProcess.ReadVSchema() + require.NoError(t, err, res) + kss := convertToMap(*res)["keyspaces"] + ksMap := convertToMap(convertToMap(kss)[ks]) + ksErr, fieldPresent := ksMap["error"] + if !fieldPresent { + break + } + errString, isErr := ksErr.(string) + if !isErr { + break + } + return errString + } + } +} + // WaitForColumn waits for a table's column to be present func WaitForColumn(t *testing.T, vtgateProcess cluster.VtgateProcess, ks, tbl, col string) error { timeout := time.After(60 * time.Second) diff --git a/go/test/endtoend/vtgate/foreignkey/fk_test.go b/go/test/endtoend/vtgate/foreignkey/fk_test.go index c3be526e584..46bbc2ed433 100644 --- a/go/test/endtoend/vtgate/foreignkey/fk_test.go +++ b/go/test/endtoend/vtgate/foreignkey/fk_test.go @@ -774,3 +774,26 @@ func TestFkScenarios(t *testing.T) { }) } } + +func TestCyclicFks(t *testing.T) { + mcmp, closer := start(t) + defer closer() + _ = utils.Exec(t, mcmp.VtConn, "use `uks`") + + // Create a cyclic foreign key constraint. + utils.Exec(t, mcmp.VtConn, "alter table fk_t10 add constraint test_cyclic_fks foreign key (col) references fk_t12 (col) on delete cascade on update cascade") + defer func() { + utils.Exec(t, mcmp.VtConn, "alter table fk_t10 drop foreign key test_cyclic_fks") + }() + + // Wait for schema-tracking to be complete. + ksErr := utils.WaitForKsError(t, clusterInstance.VtgateProcess, unshardedKs) + // Make sure Vschema has the error for cyclic foreign keys. + assert.Contains(t, ksErr, "VT09019: uks has cyclic foreign keys") + + // Ensure that the Vitess database is originally empty + ensureDatabaseState(t, mcmp.VtConn, true) + + _, err := utils.ExecAllowError(t, mcmp.VtConn, "insert into fk_t10(id, col) values (1, 1)") + require.ErrorContains(t, err, "VT09019: uks has cyclic foreign keys") +} diff --git a/go/test/vschemawrapper/vschema_wrapper.go b/go/test/vschemawrapper/vschema_wrapper.go index 1656fafa41b..78cf0f8d41c 100644 --- a/go/test/vschemawrapper/vschema_wrapper.go +++ b/go/test/vschemawrapper/vschema_wrapper.go @@ -136,6 +136,10 @@ func (vw *VSchemaWrapper) ForeignKeyMode(keyspace string) (vschemapb.Keyspace_Fo return defaultFkMode, nil } +func (vw *VSchemaWrapper) KeyspaceError(keyspace string) error { + return nil +} + func (vw *VSchemaWrapper) AllKeyspace() ([]*vindexes.Keyspace, error) { if vw.Keyspace == nil { return nil, vterrors.VT13001("keyspace not available") diff --git a/go/vt/graph/graph.go b/go/vt/graph/graph.go new file mode 100644 index 00000000000..54668027008 --- /dev/null +++ b/go/vt/graph/graph.go @@ -0,0 +1,119 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package graph + +import ( + "fmt" + "slices" + "strings" +) + +// Graph is a generic graph implementation. +type Graph[C comparable] struct { + edges map[C][]C +} + +// NewGraph creates a new graph for the given comparable type. +func NewGraph[C comparable]() *Graph[C] { + return &Graph[C]{ + edges: map[C][]C{}, + } +} + +// AddVertex adds a vertex to the given Graph. +func (gr *Graph[C]) AddVertex(vertex C) { + _, alreadyExists := gr.edges[vertex] + if alreadyExists { + return + } + gr.edges[vertex] = []C{} +} + +// AddEdge adds an edge to the given Graph. +// It also makes sure that the vertices are added to the graph if not already present. +func (gr *Graph[C]) AddEdge(start C, end C) { + gr.AddVertex(start) + gr.AddVertex(end) + gr.edges[start] = append(gr.edges[start], end) +} + +// PrintGraph is used to print the graph. This is only used for testing. +func (gr *Graph[C]) PrintGraph() string { + adjacencyLists := []string{} + for vertex, edges := range gr.edges { + adjacencyList := fmt.Sprintf("%v -", vertex) + for _, end := range edges { + adjacencyList += fmt.Sprintf(" %v", end) + } + adjacencyLists = append(adjacencyLists, adjacencyList) + } + slices.Sort(adjacencyLists) + return strings.Join(adjacencyLists, "\n") +} + +// Empty checks whether the graph is empty or not. +func (gr *Graph[C]) Empty() bool { + return len(gr.edges) == 0 +} + +// HasCycles checks whether the given graph has a cycle or not. +// We are using a well-known DFS based colouring algorithm to check for cycles. +// Look at https://cp-algorithms.com/graph/finding-cycle.html for more details on the algorithm. +func (gr *Graph[C]) HasCycles() bool { + // If the graph is empty, then we don't need to check anything. + if gr.Empty() { + return false + } + // Initialize the coloring map. + // 0 represents white. + // 1 represents grey. + // 2 represents black. + color := map[C]int{} + for vertex := range gr.edges { + // If any vertex is still white, we initiate a new DFS. + if color[vertex] == 0 { + if gr.hasCyclesDfs(color, vertex) { + return true + } + } + } + return false +} + +// hasCyclesDfs is a utility function for checking for cycles in a graph. +// It runs a dfs from the given vertex marking each vertex as grey. During the dfs, +// if we encounter a grey vertex, we know we have a cycle. We mark the visited vertices black +// on finishing the dfs. +func (gr *Graph[C]) hasCyclesDfs(color map[C]int, vertex C) bool { + // Mark the vertex grey. + color[vertex] = 1 + // Go over all the edges. + for _, end := range gr.edges[vertex] { + // If we encounter a white vertex, we continue the dfs. + if color[end] == 0 { + if gr.hasCyclesDfs(color, end) { + return true + } + } else if color[end] == 1 { + // We encountered a grey vertex, we have a cycle. + return true + } + } + // Mark the vertex black before finishing + color[vertex] = 2 + return false +} diff --git a/go/vt/graph/graph_test.go b/go/vt/graph/graph_test.go new file mode 100644 index 00000000000..bc334c7d225 --- /dev/null +++ b/go/vt/graph/graph_test.go @@ -0,0 +1,153 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package graph + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestIntegerGraph tests that a graph with Integers can be created and all graph functions work as intended. +func TestIntegerGraph(t *testing.T) { + testcases := []struct { + name string + edges [][2]int + wantedGraph string + wantEmpty bool + wantHasCycles bool + }{ + { + name: "empty graph", + edges: nil, + wantedGraph: "", + wantEmpty: true, + wantHasCycles: false, + }, { + name: "non-cyclic graph", + edges: [][2]int{ + {1, 2}, + {2, 3}, + {1, 4}, + {2, 5}, + {4, 5}, + }, + wantedGraph: `1 - 2 4 +2 - 3 5 +3 - +4 - 5 +5 -`, + wantEmpty: false, + wantHasCycles: false, + }, { + name: "cyclic graph", + edges: [][2]int{ + {1, 2}, + {2, 3}, + {1, 4}, + {2, 5}, + {4, 5}, + {5, 6}, + {6, 1}, + }, + wantedGraph: `1 - 2 4 +2 - 3 5 +3 - +4 - 5 +5 - 6 +6 - 1`, + wantEmpty: false, + wantHasCycles: true, + }, + } + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + graph := NewGraph[int]() + for _, edge := range tt.edges { + graph.AddEdge(edge[0], edge[1]) + } + require.Equal(t, tt.wantedGraph, graph.PrintGraph()) + require.Equal(t, tt.wantEmpty, graph.Empty()) + require.Equal(t, tt.wantHasCycles, graph.HasCycles()) + }) + } +} + +// TestStringGraph tests that a graph with strings can be created and all graph functions work as intended. +func TestStringGraph(t *testing.T) { + testcases := []struct { + name string + edges [][2]string + wantedGraph string + wantEmpty bool + wantHasCycles bool + }{ + { + name: "empty graph", + edges: nil, + wantedGraph: "", + wantEmpty: true, + wantHasCycles: false, + }, { + name: "non-cyclic graph", + edges: [][2]string{ + {"A", "B"}, + {"B", "C"}, + {"A", "D"}, + {"B", "E"}, + {"D", "E"}, + }, + wantedGraph: `A - B D +B - C E +C - +D - E +E -`, + wantEmpty: false, + wantHasCycles: false, + }, { + name: "cyclic graph", + edges: [][2]string{ + {"A", "B"}, + {"B", "C"}, + {"A", "D"}, + {"B", "E"}, + {"D", "E"}, + {"E", "F"}, + {"F", "A"}, + }, + wantedGraph: `A - B D +B - C E +C - +D - E +E - F +F - A`, + wantEmpty: false, + wantHasCycles: true, + }, + } + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + graph := NewGraph[string]() + for _, edge := range tt.edges { + graph.AddEdge(edge[0], edge[1]) + } + require.Equal(t, tt.wantedGraph, graph.PrintGraph()) + require.Equal(t, tt.wantEmpty, graph.Empty()) + require.Equal(t, tt.wantHasCycles, graph.HasCycles()) + }) + } +} diff --git a/go/vt/schemadiff/semantics.go b/go/vt/schemadiff/semantics.go index 1a9acfecad9..da9c6b1e2a9 100644 --- a/go/vt/schemadiff/semantics.go +++ b/go/vt/schemadiff/semantics.go @@ -60,6 +60,10 @@ func (si *declarativeSchemaInformation) ForeignKeyMode(keyspace string) (vschema return vschemapb.Keyspace_unmanaged, nil } +func (si *declarativeSchemaInformation) KeyspaceError(keyspace string) error { + return nil +} + // addTable adds a fake table with an empty column list func (si *declarativeSchemaInformation) addTable(tableName string) { tbl := &vindexes.Table{ diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index 9bfb7747c09..6bc317db4ed 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -79,7 +79,8 @@ var ( VT09015 = errorWithoutState("VT09015", vtrpcpb.Code_FAILED_PRECONDITION, "schema tracking required", "This query cannot be planned without more information on the SQL schema. Please turn on schema tracking or add authoritative columns information to your VSchema.") VT09016 = errorWithState("VT09016", vtrpcpb.Code_FAILED_PRECONDITION, RowIsReferenced2, "Cannot delete or update a parent row: a foreign key constraint fails", "SET DEFAULT is not supported by InnoDB") VT09017 = errorWithoutState("VT09017", vtrpcpb.Code_FAILED_PRECONDITION, "%s", "Invalid syntax for the statement type.") - VT09018 = errorWithoutState("VT09017", vtrpcpb.Code_FAILED_PRECONDITION, "%s", "Invalid syntax for the vindex function statement.") + VT09018 = errorWithoutState("VT09018", vtrpcpb.Code_FAILED_PRECONDITION, "%s", "Invalid syntax for the vindex function statement.") + VT09019 = errorWithoutState("VT09019", vtrpcpb.Code_FAILED_PRECONDITION, "%s has cyclic foreign keys", "Vitess doesn't support cyclic foreign keys.") VT10001 = errorWithoutState("VT10001", vtrpcpb.Code_ABORTED, "foreign key constraints are not allowed", "Foreign key constraints are not allowed, see https://vitess.io/blog/2021-06-15-online-ddl-why-no-fk/.") diff --git a/go/vt/vtgate/planbuilder/plancontext/vschema.go b/go/vt/vtgate/planbuilder/plancontext/vschema.go index fc5ee6d9207..9dde6dee31c 100644 --- a/go/vt/vtgate/planbuilder/plancontext/vschema.go +++ b/go/vt/vtgate/planbuilder/plancontext/vschema.go @@ -57,6 +57,9 @@ type VSchema interface { // ForeignKeyMode returns the foreign_key flag value ForeignKeyMode(keyspace string) (vschemapb.Keyspace_ForeignKeyMode, error) + // KeyspaceError returns any error in the keyspace vschema. + KeyspaceError(keyspace string) error + // GetVSchema returns the latest cached vindexes.VSchema GetVSchema() *vindexes.VSchema diff --git a/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json b/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json index 039786362a1..4c6256d93cc 100644 --- a/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json @@ -426,46 +426,46 @@ { "comment": "select keyspace_id from user_index where id = 1 and id = 2", "query": "select keyspace_id from user_index where id = 1 and id = 2", - "plan": "VT09017: WHERE clause for vindex function must be of the form id = or id in(,...) (multiple filters)" + "plan": "VT09018: WHERE clause for vindex function must be of the form id = or id in(,...) (multiple filters)" }, { "comment": "select keyspace_id from user_index where func(id)", "query": "select keyspace_id from user_index where func(id)", - "plan": "VT09017: WHERE clause for vindex function must be of the form id = or id in(,...) (not a comparison)" + "plan": "VT09018: WHERE clause for vindex function must be of the form id = or id in(,...) (not a comparison)" }, { "comment": "select keyspace_id from user_index where id > 1", "query": "select keyspace_id from user_index where id > 1", - "plan": "VT09017: WHERE clause for vindex function must be of the form id = or id in(,...) (not equality)" + "plan": "VT09018: WHERE clause for vindex function must be of the form id = or id in(,...) (not equality)" }, { "comment": "select keyspace_id from user_index where 1 = id", "query": "select keyspace_id from user_index where 1 = id", - "plan": "VT09017: WHERE clause for vindex function must be of the form id = or id in(,...) (lhs is not a column)" + "plan": "VT09018: WHERE clause for vindex function must be of the form id = or id in(,...) (lhs is not a column)" }, { "comment": "select keyspace_id from user_index where keyspace_id = 1", "query": "select keyspace_id from user_index where keyspace_id = 1", - "plan": "VT09017: WHERE clause for vindex function must be of the form id = or id in(,...) (lhs is not id)" + "plan": "VT09018: WHERE clause for vindex function must be of the form id = or id in(,...) (lhs is not id)" }, { "comment": "select keyspace_id from user_index where id = id+1", "query": "select keyspace_id from user_index where id = id+1", - "plan": "VT09017: WHERE clause for vindex function must be of the form id = or id in(,...) (rhs is not a value)" + "plan": "VT09018: WHERE clause for vindex function must be of the form id = or id in(,...) (rhs is not a value)" }, { "comment": "vindex func without where condition", "query": "select keyspace_id from user_index", - "plan": "VT09017: WHERE clause for vindex function must be of the form id = or id in(,...) (where clause missing)" + "plan": "VT09018: WHERE clause for vindex function must be of the form id = or id in(,...) (where clause missing)" }, { "comment": "vindex func in subquery without where", "query": "select id from user where exists(select keyspace_id from user_index)", - "plan": "VT09017: WHERE clause for vindex function must be of the form id = or id in(,...) (where clause missing)" + "plan": "VT09018: WHERE clause for vindex function must be of the form id = or id in(,...) (where clause missing)" }, { "comment": "select func(keyspace_id) from user_index where id = :id", "query": "select func(keyspace_id) from user_index where id = :id", - "plan": "VT09017: cannot add 'func(keyspace_id)' expression to a table/vindex" + "plan": "VT09018: cannot add 'func(keyspace_id)' expression to a table/vindex" } ] diff --git a/go/vt/vtgate/semantics/FakeSI.go b/go/vt/vtgate/semantics/FakeSI.go index eaa8291e342..b7043b42980 100644 --- a/go/vt/vtgate/semantics/FakeSI.go +++ b/go/vt/vtgate/semantics/FakeSI.go @@ -34,6 +34,7 @@ type FakeSI struct { Tables map[string]*vindexes.Table VindexTables map[string]vindexes.Vindex KsForeignKeyMode map[string]vschemapb.Keyspace_ForeignKeyMode + KsError map[string]error } // FindTableOrVindex implements the SchemaInformation interface @@ -59,3 +60,14 @@ func (s *FakeSI) ForeignKeyMode(keyspace string) (vschemapb.Keyspace_ForeignKeyM } return vschemapb.Keyspace_unmanaged, nil } + +func (s *FakeSI) KeyspaceError(keyspace string) error { + if s.KsError != nil { + fkErr, isPresent := s.KsError[keyspace] + if !isPresent { + return fmt.Errorf("%v keyspace not found", keyspace) + } + return fkErr + } + return nil +} diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index ba3344d70c0..e524b1a33cf 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -469,6 +469,11 @@ func (a *analyzer) getAllManagedForeignKeys() (map[TableSet][]vindexes.ChildFKIn if fkMode != vschemapb.Keyspace_managed { continue } + // Cyclic foreign key constraints error is stored in the keyspace. + ksErr := a.tables.si.KeyspaceError(vi.Keyspace.Name) + if ksErr != nil { + return nil, nil, ksErr + } // Add all the child and parent foreign keys to our map. ts := SingleTableSet(idx) diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index f1227bfe7b0..c8251dd36c3 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -17,6 +17,7 @@ limitations under the License. package semantics import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -1716,6 +1717,27 @@ func TestGetAllManagedForeignKeys(t *testing.T) { }, expectedErr: "undefined_ks keyspace not found", }, + { + name: "Cyclic fk constraints error", + analyzer: &analyzer{ + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t0"], tbl["t1"], + &DerivedTable{}, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, + }, + KsError: map[string]error{ + "ks": fmt.Errorf("VT09019: ks has cyclic foreign keys"), + }, + }, + }, + }, + expectedErr: "VT09019: ks has cyclic foreign keys", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/go/vt/vtgate/semantics/info_schema.go b/go/vt/vtgate/semantics/info_schema.go index f8c8f711901..838f6276472 100644 --- a/go/vt/vtgate/semantics/info_schema.go +++ b/go/vt/vtgate/semantics/info_schema.go @@ -1717,3 +1717,7 @@ func (i *infoSchemaWithColumns) ConnCollation() collations.ID { func (i *infoSchemaWithColumns) ForeignKeyMode(keyspace string) (vschemapb.Keyspace_ForeignKeyMode, error) { return i.inner.ForeignKeyMode(keyspace) } + +func (i *infoSchemaWithColumns) KeyspaceError(keyspace string) error { + return i.inner.KeyspaceError(keyspace) +} diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index fdbf2f0e04d..94b1302b357 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -150,6 +150,7 @@ type ( ConnCollation() collations.ID // ForeignKeyMode returns the foreign_key flag value ForeignKeyMode(keyspace string) (vschemapb.Keyspace_ForeignKeyMode, error) + KeyspaceError(keyspace string) error } shortCut = int diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index d81fc3e9c9c..0e89d6fbc95 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -1047,6 +1047,14 @@ func (vc *vcursorImpl) ForeignKeyMode(keyspace string) (vschemapb.Keyspace_Forei return ks.ForeignKeyMode, nil } +func (vc *vcursorImpl) KeyspaceError(keyspace string) error { + ks := vc.vschema.Keyspaces[keyspace] + if ks == nil { + return vterrors.VT14004(keyspace) + } + return ks.Error +} + // ParseDestinationTarget parses destination target string and sets default keyspace if possible. func parseDestinationTarget(targetString string, vschema *vindexes.VSchema) (string, topodatapb.TabletType, key.Destination, error) { destKeyspace, destTabletType, dest, err := topoprotopb.ParseDestination(targetString, defaultTabletType) diff --git a/go/vt/vtgate/vschema_manager.go b/go/vt/vtgate/vschema_manager.go index 3b99be052b0..7f2b7267dc0 100644 --- a/go/vt/vtgate/vschema_manager.go +++ b/go/vt/vtgate/vschema_manager.go @@ -20,11 +20,13 @@ import ( "context" "sync" + "vitess.io/vitess/go/vt/graph" "vitess.io/vitess/go/vt/log" topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/topo" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/vindexes" vschemapb "vitess.io/vitess/go/vt/proto/vschema" @@ -188,6 +190,9 @@ func (vm *VSchemaManager) buildAndEnhanceVSchema(v *vschemapb.SrvVSchema) *vinde vschema := vindexes.BuildVSchema(v) if vm.schema != nil { vm.updateFromSchema(vschema) + // We mark the keyspaces that have foreign key management in Vitess and have cyclic foreign keys + // to have an error. This makes all queries against them to fail. + markErrorIfCyclesInFk(vschema) } return vschema } @@ -231,6 +236,47 @@ func (vm *VSchemaManager) updateFromSchema(vschema *vindexes.VSchema) { } } +type tableCol struct { + tableName sqlparser.TableName + colNames sqlparser.Columns +} + +var tableColHash = func(tc tableCol) string { + res := sqlparser.String(tc.tableName) + for _, colName := range tc.colNames { + res += "|" + sqlparser.String(colName) + } + return res +} + +func markErrorIfCyclesInFk(vschema *vindexes.VSchema) { + for ksName, ks := range vschema.Keyspaces { + // Only check cyclic foreign keys for keyspaces that have + // foreign keys managed in Vitess. + if ks.ForeignKeyMode != vschemapb.Keyspace_managed { + continue + } + g := graph.NewGraph[string]() + for _, table := range ks.Tables { + for _, cfk := range table.ChildForeignKeys { + childTable := cfk.Table + parentVertex := tableCol{ + tableName: table.GetTableName(), + colNames: cfk.ParentColumns, + } + childVertex := tableCol{ + tableName: childTable.GetTableName(), + colNames: cfk.ChildColumns, + } + g.AddEdge(tableColHash(parentVertex), tableColHash(childVertex)) + } + } + if g.HasCycles() { + ks.Error = vterrors.VT09019(ksName) + } + } +} + func setColumns(ks *vindexes.KeyspaceSchema, tblName string, columns []vindexes.Column) *vindexes.Table { vTbl := ks.Tables[tblName] if vTbl == nil { diff --git a/go/vt/vtgate/vschema_manager_test.go b/go/vt/vtgate/vschema_manager_test.go index 6e7a9a9a2d1..9c51266c26a 100644 --- a/go/vt/vtgate/vschema_manager_test.go +++ b/go/vt/vtgate/vschema_manager_test.go @@ -3,11 +3,12 @@ package vtgate import ( "testing" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/utils" querypb "vitess.io/vitess/go/vt/proto/query" - "vitess.io/vitess/go/vt/sqlparser" - vschemapb "vitess.io/vitess/go/vt/proto/vschema" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -421,3 +422,102 @@ func (f *fakeSchema) Views(string) map[string]sqlparser.SelectStatement { } var _ SchemaInfo = (*fakeSchema)(nil) + +func TestMarkErrorIfCyclesInFk(t *testing.T) { + ksName := "ks" + keyspace := &vindexes.Keyspace{ + Name: ksName, + } + tests := []struct { + name string + getVschema func() *vindexes.VSchema + errWanted string + }{ + { + name: "Has a cycle", + getVschema: func() *vindexes.VSchema { + vschema := &vindexes.VSchema{ + Keyspaces: map[string]*vindexes.KeyspaceSchema{ + ksName: { + ForeignKeyMode: vschemapb.Keyspace_managed, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: keyspace, + }, + "t2": { + Name: sqlparser.NewIdentifierCS("t2"), + Keyspace: keyspace, + }, + "t3": { + Name: sqlparser.NewIdentifierCS("t3"), + Keyspace: keyspace, + }, + }, + }, + }, + } + _ = vschema.AddForeignKey("ks", "t2", createFkDefinition([]string{"col"}, "t1", []string{"col"}, sqlparser.Cascade, sqlparser.Cascade)) + _ = vschema.AddForeignKey("ks", "t3", createFkDefinition([]string{"col"}, "t2", []string{"col"}, sqlparser.Cascade, sqlparser.Cascade)) + _ = vschema.AddForeignKey("ks", "t1", createFkDefinition([]string{"col"}, "t3", []string{"col"}, sqlparser.Cascade, sqlparser.Cascade)) + return vschema + }, + errWanted: "VT09019: ks has cyclic foreign keys", + }, + { + name: "No cycle", + getVschema: func() *vindexes.VSchema { + vschema := &vindexes.VSchema{ + Keyspaces: map[string]*vindexes.KeyspaceSchema{ + ksName: { + ForeignKeyMode: vschemapb.Keyspace_managed, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: keyspace, + }, + "t2": { + Name: sqlparser.NewIdentifierCS("t2"), + Keyspace: keyspace, + }, + "t3": { + Name: sqlparser.NewIdentifierCS("t3"), + Keyspace: keyspace, + }, + }, + }, + }, + } + _ = vschema.AddForeignKey("ks", "t2", createFkDefinition([]string{"col"}, "t1", []string{"col"}, sqlparser.Cascade, sqlparser.Cascade)) + _ = vschema.AddForeignKey("ks", "t3", createFkDefinition([]string{"col"}, "t2", []string{"col"}, sqlparser.Cascade, sqlparser.Cascade)) + return vschema + }, + errWanted: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vschema := tt.getVschema() + markErrorIfCyclesInFk(vschema) + if tt.errWanted != "" { + require.EqualError(t, vschema.Keyspaces[ksName].Error, tt.errWanted) + return + } + require.NoError(t, vschema.Keyspaces[ksName].Error) + }) + } +} + +// createFkDefinition is a helper function to create a Foreign key definition struct from the columns used in it provided as list of strings. +func createFkDefinition(childCols []string, parentTableName string, parentCols []string, onUpdate, onDelete sqlparser.ReferenceAction) *sqlparser.ForeignKeyDefinition { + pKs, pTbl, _ := sqlparser.ParseTable(parentTableName) + return &sqlparser.ForeignKeyDefinition{ + Source: sqlparser.MakeColumns(childCols...), + ReferenceDefinition: &sqlparser.ReferenceDefinition{ + ReferencedTable: sqlparser.NewTableNameWithQualifier(pTbl, pKs), + ReferencedColumns: sqlparser.MakeColumns(parentCols...), + OnUpdate: onUpdate, + OnDelete: onDelete, + }, + } +}