diff --git a/go/mysql/conn.go b/go/mysql/conn.go index d61549c92ef..9164e658111 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -1624,12 +1624,18 @@ func (c *Conn) TLSEnabled() bool { return c.Capabilities&CapabilityClientSSL > 0 } -// IsUnixSocket returns true if this connection is over a Unix socket. +// IsUnixSocket returns true if the server connection is over a Unix socket. func (c *Conn) IsUnixSocket() bool { _, ok := c.listener.listener.(*net.UnixListener) return ok } +// IsClientUnixSocket returns true if the client connection is over a Unix socket with the server. +func (c *Conn) IsClientUnixSocket() bool { + _, ok := c.conn.(*net.UnixConn) + return ok +} + // GetRawConn returns the raw net.Conn for nefarious purposes. func (c *Conn) GetRawConn() net.Conn { return c.conn diff --git a/go/vt/vttablet/endtoend/connecttcp/main_test.go b/go/vt/vttablet/endtoend/connecttcp/main_test.go new file mode 100644 index 00000000000..9d52b1287a1 --- /dev/null +++ b/go/vt/vttablet/endtoend/connecttcp/main_test.go @@ -0,0 +1,119 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package connecttcp + +import ( + "context" + "flag" + "fmt" + "os" + "testing" + + "vitess.io/vitess/go/mysql" + vttestpb "vitess.io/vitess/go/vt/proto/vttest" + "vitess.io/vitess/go/vt/vttablet/endtoend/framework" + "vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv" + "vitess.io/vitess/go/vt/vttest" +) + +var ( + connParams mysql.ConnParams + connAppDebugParams mysql.ConnParams +) + +func TestMain(m *testing.M) { + flag.Parse() // Do not remove this comment, import into google3 depends on it + tabletenv.Init() + + exitCode := func() int { + // Launch MySQL. + // We need a Keyspace in the topology, so the DbName is set. + // We need a Shard too, so the database 'vttest' is created. + cfg := vttest.Config{ + Topology: &vttestpb.VTTestTopology{ + Keyspaces: []*vttestpb.Keyspace{ + { + Name: "vttest", + Shards: []*vttestpb.Shard{ + { + Name: "0", + DbNameOverride: "vttest", + }, + }, + }, + }, + }, + OnlyMySQL: true, + Charset: "utf8mb4_general_ci", + } + if err := cfg.InitSchemas("vttest", testSchema, nil); err != nil { + fmt.Fprintf(os.Stderr, "InitSchemas failed: %v\n", err) + return 1 + } + defer os.RemoveAll(cfg.SchemaDir) + cluster := vttest.LocalCluster{ + Config: cfg, + } + if err := cluster.Setup(); err != nil { + fmt.Fprintf(os.Stderr, "could not launch mysql: %v\n", err) + return 1 + } + defer cluster.TearDown() + + if err := allowConnectOnTCP(cluster); err != nil { + fmt.Fprintf(os.Stderr, "failed to allow tcp priviliges: %v", err) + return 1 + } + + connParams = cluster.MySQLTCPConnParams() + connAppDebugParams = cluster.MySQLAppDebugConnParams() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + config := tabletenv.NewDefaultConfig() + config.TwoPCEnable = true + config.TwoPCAbandonAge = 1 + + if err := framework.StartCustomServer(ctx, connParams, connAppDebugParams, cluster.DbName(), config); err != nil { + fmt.Fprintf(os.Stderr, "%v", err) + return 1 + } + defer framework.StopServer() + + return m.Run() + }() + os.Exit(exitCode) +} + +func allowConnectOnTCP(cluster vttest.LocalCluster) error { + connParams = cluster.MySQLConnParams() + conn, err := mysql.Connect(context.Background(), &connParams) + if err != nil { + return err + } + if _, err = conn.ExecuteFetch("UPDATE mysql.user SET Host = '%' WHERE User = 'vt_dba';", 0, false); err != nil { + return err + } + if _, err = conn.ExecuteFetch("FLUSH PRIVILEGES;", 0, false); err != nil { + return err + } + conn.Close() + return nil +} + +var testSchema = `create table vitess_test(intval int primary key);` diff --git a/go/vt/vttablet/endtoend/connecttcp/prepare_test.go b/go/vt/vttablet/endtoend/connecttcp/prepare_test.go new file mode 100644 index 00000000000..524b4a3dcc8 --- /dev/null +++ b/go/vt/vttablet/endtoend/connecttcp/prepare_test.go @@ -0,0 +1,41 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package connecttcp + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/vttablet/endtoend/framework" +) + +// TestPrepareOnTCP tests that a prepare statement is not allowed on a network connection. +func TestPrepareOnTCP(t *testing.T) { + client := framework.NewClient() + + query := "insert into vitess_test (intval) values(4)" + + err := client.Begin(false) + require.NoError(t, err) + + _, err = client.Execute(query, nil) + require.NoError(t, err) + + err = client.Prepare("aa") + require.ErrorContains(t, err, "VT10002: atomic distributed transaction not allowed: cannot prepare the transaction on a network connection") +} diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn.go b/go/vt/vttablet/tabletserver/connpool/dbconn.go index af8c5fbc78e..61816b16d08 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn.go @@ -598,3 +598,7 @@ func (dbc *Conn) applySameSetting(ctx context.Context) error { _, err := dbc.execOnce(ctx, dbc.setting.ApplyQuery(), 1, false, false) return err } + +func (dbc *Conn) IsUnixSocket() bool { + return dbc.conn.IsClientUnixSocket() +} diff --git a/go/vt/vttablet/tabletserver/dt_executor.go b/go/vt/vttablet/tabletserver/dt_executor.go index 9ddca3247a3..1fd1df12d56 100644 --- a/go/vt/vttablet/tabletserver/dt_executor.go +++ b/go/vt/vttablet/tabletserver/dt_executor.go @@ -67,6 +67,13 @@ func (dte *DTExecutor) Prepare(transactionID int64, dtid string) error { return nil } + // We can only prepare on a Unix socket connection. + // Unix socket are reliable and we can be sure that the connection is not lost with the server after prepare. + if !conn.IsUnixSocket() { + dte.te.txPool.RollbackAndRelease(dte.ctx, conn) + return vterrors.VT10002("cannot prepare the transaction on a network connection") + } + // If the connection is tainted, we cannot prepare it. As there could be temporary tables involved. if conn.IsTainted() { dte.te.txPool.RollbackAndRelease(dte.ctx, conn) diff --git a/go/vt/vttablet/tabletserver/stateful_connection.go b/go/vt/vttablet/tabletserver/stateful_connection.go index 067f2194655..9b34cfce737 100644 --- a/go/vt/vttablet/tabletserver/stateful_connection.go +++ b/go/vt/vttablet/tabletserver/stateful_connection.go @@ -264,7 +264,7 @@ func (sc *StatefulConnection) IsTainted() bool { // LogTransaction logs transaction related stats func (sc *StatefulConnection) LogTransaction(reason tx.ReleaseReason) { if sc.txProps == nil { - return //Nothing to log as no transaction exists on this connection. + return // Nothing to log as no transaction exists on this connection. } sc.txProps.Conclusion = reason.Name() sc.txProps.EndTime = time.Now() @@ -288,7 +288,7 @@ func (sc *StatefulConnection) SetTimeout(timeout time.Duration) { // logReservedConn logs reserved connection related stats. func (sc *StatefulConnection) logReservedConn() { if sc.reservedProps == nil { - return //Nothing to log as this connection is not reserved. + return // Nothing to log as this connection is not reserved. } duration := time.Since(sc.reservedProps.StartTime) username := sc.getUsername() @@ -315,3 +315,8 @@ func (sc *StatefulConnection) ApplySetting(ctx context.Context, setting *smartco func (sc *StatefulConnection) resetExpiryTime() { sc.expiryTime = time.Now().Add(sc.timeout) } + +// IsUnixSocket returns true if the connection is using a unix socket +func (sc *StatefulConnection) IsUnixSocket() bool { + return sc.dbConn.Conn.IsUnixSocket() +} diff --git a/go/vt/vttest/local_cluster.go b/go/vt/vttest/local_cluster.go index 406269ef749..576a78bb761 100644 --- a/go/vt/vttest/local_cluster.go +++ b/go/vt/vttest/local_cluster.go @@ -292,6 +292,15 @@ func (db *LocalCluster) MySQLConnParams() mysql.ConnParams { return connParams } +func (db *LocalCluster) MySQLTCPConnParams() mysql.ConnParams { + connParams := db.mysql.Params(db.DbName()) + _, port := db.mysql.Address() + connParams.UnixSocket = "" + connParams.Host = "127.0.0.1" + connParams.Port = port + return connParams +} + // MySQLAppDebugConnParams returns a mysql.ConnParams struct that can be used // to connect directly to the mysqld service in the self-contained cluster, // using the appdebug user. It's valid only if you used MySQLOnly option.