From 9257cbd46e745df8459fb815d780890e7d7d69e0 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Mon, 17 Jun 2024 11:20:29 +0200 Subject: [PATCH 01/29] ZDM-71: Introduce protocol negotiation --- integration-tests/connect_test.go | 35 +++++++++++++ integration-tests/setup/testcluster.go | 1 + integration-tests/utils/testutils.go | 9 +++- proxy/pkg/config/config.go | 1 + proxy/pkg/zdmproxy/controlconn.go | 70 ++++++++++++++++++++------ proxy/pkg/zdmproxy/cqlconn.go | 8 +-- 6 files changed, 105 insertions(+), 19 deletions(-) diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index b7df5f71..77fb11bf 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -8,6 +8,7 @@ import ( "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/datastax/zdm-proxy/integration-tests/client" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/config" @@ -45,6 +46,40 @@ func TestGoCqlConnect(t *testing.T) { require.Equal(t, "fake", iter.Columns()[0].Name) } +func TestProtocolVersionNegotiation(t *testing.T) { + testCassandraVersion := env.CassandraVersion + env.CassandraVersion = "2.1" // downgrade C* version for protocol negotiation test + defer func() { + env.CassandraVersion = testCassandraVersion + }() + c := setup.NewTestConfig("", "") + c.ProtocolVersion = 4 // configure unsupported protocol version + testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c) + require.Nil(t, err) + defer testSetup.Cleanup() + + // Connect to proxy as a "client" + proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3) + + if err != nil { + t.Log("Unable to connect to proxy session.") + t.Fatal(err) + } + defer proxy.Close() + + iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter() + result, err := iter.SliceMap() + + if err != nil { + t.Fatal("query failed:", err) + } + + require.Equal(t, 0, len(result)) + + // simulacron generates fake response metadata when queries aren't primed + require.Equal(t, "fake", iter.Columns()[0].Name) +} + func TestMaxClientsThreshold(t *testing.T) { maxClients := 10 goCqlConnectionsPerHost := 1 diff --git a/integration-tests/setup/testcluster.go b/integration-tests/setup/testcluster.go index 1eb60144..dac16ac0 100644 --- a/integration-tests/setup/testcluster.go +++ b/integration-tests/setup/testcluster.go @@ -452,6 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config { conf.ReadMode = config.ReadModePrimaryOnly conf.SystemQueriesMode = config.SystemQueriesModeOrigin conf.AsyncHandshakeTimeoutMs = 4000 + conf.ProtocolVersion = 3 conf.ProxyRequestTimeoutMs = 10000 diff --git a/integration-tests/utils/testutils.go b/integration-tests/utils/testutils.go index 2c050ecd..e0ca5edd 100644 --- a/integration-tests/utils/testutils.go +++ b/integration-tests/utils/testutils.go @@ -116,9 +116,9 @@ func CheckMetricsEndpointResult(httpAddr string, success bool) error { return nil } -// ConnectToCluster is used to connect to source and destination clusters -func ConnectToCluster(hostname string, username string, password string, port int) (*gocql.Session, error) { +func ConnectToClusterUsingVersion(hostname string, username string, password string, port int, protoVersion int) (*gocql.Session, error) { cluster := NewCluster(hostname, username, password, port) + cluster.ProtoVersion = protoVersion session, err := cluster.CreateSession() log.Debugf("Connection established with Cluster: %s:%d", cluster.Hosts[0], cluster.Port) if err != nil { @@ -127,6 +127,11 @@ func ConnectToCluster(hostname string, username string, password string, port in return session, nil } +// ConnectToCluster is used to connect to source and destination clusters +func ConnectToCluster(hostname string, username string, password string, port int) (*gocql.Session, error) { + return ConnectToClusterUsingVersion(hostname, username, password, port, 4) +} + // NewCluster initializes a ClusterConfig object with common settings func NewCluster(hostname string, username string, password string, port int) *gocql.ClusterConfig { cluster := gocql.NewCluster(hostname) diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index d5cc5c67..6e6c4027 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -21,6 +21,7 @@ type Config struct { ReplaceCqlFunctions bool `default:"false" split_words:"true"` AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"` LogLevel string `default:"INFO" split_words:"true"` + ProtocolVersion uint `default:"3" split_words:"true"` // Proxy Topology (also known as system.peers "virtualization") bucket diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index e32bc967..e99f683f 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -54,11 +54,11 @@ type ControlConn struct { protocolEventSubscribers map[ProtocolEventObserver]interface{} authEnabled *atomic.Value metricsHandler *metrics.MetricHandler + protocolVersion primitive.ProtocolVersion } const ProxyVirtualRack = "rack0" const ProxyVirtualPartitioner = "org.apache.cassandra.dht.Murmur3Partitioner" -const ccProtocolVersion = primitive.ProtocolVersion3 const ccWriteTimeout = 5 * time.Second const ccReadTimeout = 10 * time.Second @@ -320,15 +320,9 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( currentIndex := (firstEndpointIndex + i) % len(endpoints) endpoint = endpoints[currentIndex] - tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false) - if err != nil { - log.Warnf("Failed to open control connection to %v using endpoint %v: %v", - cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err) - continue - } - newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf) - err = newConn.InitializeContext(ccProtocolVersion, ctx) + newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ProtocolVersion, ctx) + if err == nil { newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) { switch f.Body.Message.(type) { @@ -355,9 +349,11 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( log.Warnf("Error while initializing a new cql connection for the control connection of %v: %v", cc.connConfig.GetClusterType(), err) } - err2 := newConn.Close() - if err2 != nil { - log.Errorf("Failed to close cql connection: %v", err2) + if newConn != nil { + err2 := newConn.Close() + if err2 != nil { + log.Errorf("Failed to close cql connection: %v", err2) + } } continue @@ -372,6 +368,52 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( return conn, endpoint } +func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer uint, ctx context.Context) (CqlConnection, error) { + protoVer := primitive.ProtocolVersion(initialProtoVer) + for { + tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false) + if err != nil { + log.Warnf("Failed to open control connection to %v using endpoint %v: %v", + cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err) + return nil, err + } + newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf) + err = newConn.InitializeContext(protoVer, ctx) + if err != nil && strings.Contains(err.Error(), "Invalid or unsupported protocol version") { + // unsupported protocol version + // protocol renegotiation requires opening a new TCP connection + err2 := newConn.Close() + if err2 != nil { + log.Errorf("Failed to close cql connection: %v", err2) + } + protoVer = downgradeProtocol(protoVer) + log.Infof("Downgrading protocol version: %v", protoVer) + if protoVer == 0 { + // we cannot downgrade anymore + return nil, err + } + continue // retry lower protocol version + } else { + cc.protocolVersion = protoVer + return newConn, err // we may have successfully established connection or faced other error + } + } +} + +func downgradeProtocol(version primitive.ProtocolVersion) primitive.ProtocolVersion { + switch version { + case primitive.ProtocolVersionDse2: + return primitive.ProtocolVersionDse1 + case primitive.ProtocolVersionDse1: + return primitive.ProtocolVersion4 + case primitive.ProtocolVersion4: + return primitive.ProtocolVersion3 + case primitive.ProtocolVersion3: + return primitive.ProtocolVersion2 + } + return 0 +} + func (cc *ControlConn) Close() { cc.cqlConnLock.Lock() conn := cc.cqlConn @@ -387,7 +429,7 @@ func (cc *ControlConn) Close() { } func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]*Host, error) { - localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), ccProtocolVersion, ctx) + localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx) if err != nil { return nil, fmt.Errorf("could not fetch information from system.local table: %w", err) } @@ -410,7 +452,7 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] } } - peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), ccProtocolVersion, ctx) + peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx) if err != nil { return nil, fmt.Errorf("could not fetch information from system.peers table: %w", err) } diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 041894fe..c8a6e43d 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -59,6 +59,7 @@ type cqlConn struct { eventHandlerLock *sync.Mutex authEnabled bool frameProcessor FrameProcessor + protocolVersion primitive.ProtocolVersion } var ( @@ -237,6 +238,7 @@ func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx conte return fmt.Errorf("failed to perform handshake: %w", err) } + c.protocolVersion = version c.initialized = true c.authEnabled = authEnabled return nil @@ -375,7 +377,7 @@ func (c *cqlConn) Query( }, } - queryFrame := frame.NewFrame(ccProtocolVersion, -1, queryMsg) + queryFrame := frame.NewFrame(c.protocolVersion, -1, queryMsg) var rowSet *ParsedRowSet for { localResponse, err := c.SendAndReceive(queryFrame, ctx) @@ -429,7 +431,7 @@ func (c *cqlConn) Query( } func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Message, error) { - queryFrame := frame.NewFrame(ccProtocolVersion, -1, msg) + queryFrame := frame.NewFrame(c.protocolVersion, -1, msg) localResponse, err := c.SendAndReceive(queryFrame, ctx) if err != nil { return nil, err @@ -440,7 +442,7 @@ func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Mes func (c *cqlConn) SendHeartbeat(ctx context.Context) error { optionsMsg := &message.Options{} - heartBeatFrame := frame.NewFrame(ccProtocolVersion, -1, optionsMsg) + heartBeatFrame := frame.NewFrame(c.protocolVersion, -1, optionsMsg) response, err := c.SendAndReceive(heartBeatFrame, ctx) if err != nil { From 72e55189e6725817910e42ef8345cfdaa0858678 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 18 Jun 2024 12:06:42 +0200 Subject: [PATCH 02/29] ZDM-71: Introduce protocol negotiation --- integration-tests/connect_test.go | 5 ++--- integration-tests/setup/testcluster.go | 2 +- proxy/pkg/config/config.go | 12 ++++++------ proxy/pkg/zdmproxy/controlconn.go | 8 +++----- proxy/pkg/zdmproxy/cqlconn.go | 19 ++++++++++++------- 5 files changed, 24 insertions(+), 22 deletions(-) diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index 77fb11bf..b3dbff43 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -53,7 +53,7 @@ func TestProtocolVersionNegotiation(t *testing.T) { env.CassandraVersion = testCassandraVersion }() c := setup.NewTestConfig("", "") - c.ProtocolVersion = 4 // configure unsupported protocol version + c.ControlConnMaxProtocolVersion = 4 // configure unsupported protocol version testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c) require.Nil(t, err) defer testSetup.Cleanup() @@ -62,8 +62,7 @@ func TestProtocolVersionNegotiation(t *testing.T) { proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3) if err != nil { - t.Log("Unable to connect to proxy session.") - t.Fatal(err) + t.Fatal("Unable to connect to proxy session.") } defer proxy.Close() diff --git a/integration-tests/setup/testcluster.go b/integration-tests/setup/testcluster.go index dac16ac0..929850c4 100644 --- a/integration-tests/setup/testcluster.go +++ b/integration-tests/setup/testcluster.go @@ -452,7 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config { conf.ReadMode = config.ReadModePrimaryOnly conf.SystemQueriesMode = config.SystemQueriesModeOrigin conf.AsyncHandshakeTimeoutMs = 4000 - conf.ProtocolVersion = 3 + conf.ControlConnMaxProtocolVersion = 3 conf.ProxyRequestTimeoutMs = 10000 diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index 6e6c4027..c48b5e6e 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -16,12 +16,12 @@ type Config struct { // Global bucket - PrimaryCluster string `default:"ORIGIN" split_words:"true"` - ReadMode string `default:"PRIMARY_ONLY" split_words:"true"` - ReplaceCqlFunctions bool `default:"false" split_words:"true"` - AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"` - LogLevel string `default:"INFO" split_words:"true"` - ProtocolVersion uint `default:"3" split_words:"true"` + PrimaryCluster string `default:"ORIGIN" split_words:"true"` + ReadMode string `default:"PRIMARY_ONLY" split_words:"true"` + ReplaceCqlFunctions bool `default:"false" split_words:"true"` + AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"` + LogLevel string `default:"INFO" split_words:"true"` + ControlConnMaxProtocolVersion uint `default:"3" split_words:"true"` // Proxy Topology (also known as system.peers "virtualization") bucket diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index e99f683f..cef8bb3c 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -54,7 +54,6 @@ type ControlConn struct { protocolEventSubscribers map[ProtocolEventObserver]interface{} authEnabled *atomic.Value metricsHandler *metrics.MetricHandler - protocolVersion primitive.ProtocolVersion } const ProxyVirtualRack = "rack0" @@ -321,7 +320,7 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( currentIndex := (firstEndpointIndex + i) % len(endpoints) endpoint = endpoints[currentIndex] - newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ProtocolVersion, ctx) + newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ControlConnMaxProtocolVersion, ctx) if err == nil { newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) { @@ -394,7 +393,6 @@ func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoV } continue // retry lower protocol version } else { - cc.protocolVersion = protoVer return newConn, err // we may have successfully established connection or faced other error } } @@ -429,7 +427,7 @@ func (cc *ControlConn) Close() { } func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]*Host, error) { - localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx) + localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), ctx) if err != nil { return nil, fmt.Errorf("could not fetch information from system.local table: %w", err) } @@ -452,7 +450,7 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] } } - peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx) + peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), ctx) if err != nil { return nil, fmt.Errorf("could not fetch information from system.peers table: %w", err) } diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index c8a6e43d..d7bb7a67 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -14,6 +14,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" ) @@ -32,7 +33,7 @@ type CqlConnection interface { SendAndReceive(request *frame.Frame, ctx context.Context) (*frame.Frame, error) Close() error Execute(msg message.Message, ctx context.Context) (message.Message, error) - Query(cql string, genericTypeCodec *GenericTypeCodec, version primitive.ProtocolVersion, ctx context.Context) (*ParsedRowSet, error) + Query(cql string, genericTypeCodec *GenericTypeCodec, ctx context.Context) (*ParsedRowSet, error) SendHeartbeat(ctx context.Context) error SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) SubscribeToProtocolEvents(ctx context.Context, eventTypes []primitive.EventType) error @@ -59,7 +60,7 @@ type cqlConn struct { eventHandlerLock *sync.Mutex authEnabled bool frameProcessor FrameProcessor - protocolVersion primitive.ProtocolVersion + protocolVersion *atomic.Value } var ( @@ -238,7 +239,8 @@ func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx conte return fmt.Errorf("failed to perform handshake: %w", err) } - c.protocolVersion = version + c.protocolVersion = &atomic.Value{} + c.protocolVersion.Store(version) c.initialized = true c.authEnabled = authEnabled return nil @@ -369,7 +371,7 @@ func (c *cqlConn) PerformHandshake(version primitive.ProtocolVersion, ctx contex } func (c *cqlConn) Query( - cql string, genericTypeCodec *GenericTypeCodec, version primitive.ProtocolVersion, ctx context.Context) (*ParsedRowSet, error) { + cql string, genericTypeCodec *GenericTypeCodec, ctx context.Context) (*ParsedRowSet, error) { queryMsg := &message.Query{ Query: cql, Options: &message.QueryOptions{ @@ -377,7 +379,8 @@ func (c *cqlConn) Query( }, } - queryFrame := frame.NewFrame(c.protocolVersion, -1, queryMsg) + version := c.protocolVersion.Load().(primitive.ProtocolVersion) + queryFrame := frame.NewFrame(version, -1, queryMsg) var rowSet *ParsedRowSet for { localResponse, err := c.SendAndReceive(queryFrame, ctx) @@ -431,7 +434,8 @@ func (c *cqlConn) Query( } func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Message, error) { - queryFrame := frame.NewFrame(c.protocolVersion, -1, msg) + version := c.protocolVersion.Load().(primitive.ProtocolVersion) + queryFrame := frame.NewFrame(version, -1, msg) localResponse, err := c.SendAndReceive(queryFrame, ctx) if err != nil { return nil, err @@ -442,7 +446,8 @@ func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Mes func (c *cqlConn) SendHeartbeat(ctx context.Context) error { optionsMsg := &message.Options{} - heartBeatFrame := frame.NewFrame(c.protocolVersion, -1, optionsMsg) + version := c.protocolVersion.Load().(primitive.ProtocolVersion) + heartBeatFrame := frame.NewFrame(version, -1, optionsMsg) response, err := c.SendAndReceive(heartBeatFrame, ctx) if err != nil { From 9de161bf7c211097103f210f63773aa0f245d02a Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 18 Jun 2024 13:32:53 +0200 Subject: [PATCH 03/29] Allow to run with specific Simulacron cluster version --- integration-tests/asyncreads_test.go | 2 +- integration-tests/connect_test.go | 3 ++- integration-tests/setup/testcluster.go | 10 +++++----- integration-tests/simulacron/cluster.go | 4 ++-- integration-tests/simulacron/http.go | 12 ++++++++++-- 5 files changed, 20 insertions(+), 11 deletions(-) diff --git a/integration-tests/asyncreads_test.go b/integration-tests/asyncreads_test.go index 3510d0b2..d487f2aa 100644 --- a/integration-tests/asyncreads_test.go +++ b/integration-tests/asyncreads_test.go @@ -287,7 +287,7 @@ func TestAsyncReadsRequestTypes(t *testing.T) { } testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig( - t, false, false, 1, nil) + t, false, false, 1, nil, nil) require.Nil(t, err) defer testSetup.Cleanup() diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index b3dbff43..19ae101f 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -10,6 +10,7 @@ import ( "github.com/datastax/zdm-proxy/integration-tests/client" "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/rs/zerolog" @@ -54,7 +55,7 @@ func TestProtocolVersionNegotiation(t *testing.T) { }() c := setup.NewTestConfig("", "") c.ControlConnMaxProtocolVersion = 4 // configure unsupported protocol version - testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c) + testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, true, false, 1, c, &simulacron.ClusterVersion{"2.1", "2.1"}) require.Nil(t, err) defer testSetup.Cleanup() diff --git a/integration-tests/setup/testcluster.go b/integration-tests/setup/testcluster.go index 929850c4..7791cab2 100644 --- a/integration-tests/setup/testcluster.go +++ b/integration-tests/setup/testcluster.go @@ -127,22 +127,22 @@ func NewSimulacronTestSetupWithSession(t *testing.T, createProxy bool, createSes } func NewSimulacronTestSetupWithSessionAndConfig(t *testing.T, createProxy bool, createSession bool, config *config.Config) (*SimulacronTestSetup, error) { - return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, 1, config) + return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, 1, config, nil) } func NewSimulacronTestSetupWithSessionAndNodes(t *testing.T, createProxy bool, createSession bool, nodes int) (*SimulacronTestSetup, error) { - return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, nodes, nil) + return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, nodes, nil, nil) } -func NewSimulacronTestSetupWithSessionAndNodesAndConfig(t *testing.T, createProxy bool, createSession bool, nodes int, config *config.Config) (*SimulacronTestSetup, error) { +func NewSimulacronTestSetupWithSessionAndNodesAndConfig(t *testing.T, createProxy bool, createSession bool, nodes int, config *config.Config, version *simulacron.ClusterVersion) (*SimulacronTestSetup, error) { if !env.RunMockTests { t.Skip("Skipping Simulacron tests, RUN_MOCKTESTS is set false") } - origin, err := simulacron.GetNewCluster(createSession, nodes) + origin, err := simulacron.GetNewCluster(createSession, nodes, version) if err != nil { log.Panic("simulacron origin startup failed: ", err) } - target, err := simulacron.GetNewCluster(createSession, nodes) + target, err := simulacron.GetNewCluster(createSession, nodes, version) if err != nil { log.Panic("simulacron target startup failed: ", err) } diff --git a/integration-tests/simulacron/cluster.go b/integration-tests/simulacron/cluster.go index 0c98387c..6423c833 100644 --- a/integration-tests/simulacron/cluster.go +++ b/integration-tests/simulacron/cluster.go @@ -83,14 +83,14 @@ func (baseSimulacron *baseSimulacron) GetId() string { return baseSimulacron.id } -func GetNewCluster(startSession bool, numberOfNodes int) (*Cluster, error) { +func GetNewCluster(startSession bool, numberOfNodes int, version *ClusterVersion) (*Cluster, error) { process, err := GetOrCreateGlobalSimulacronProcess() if err != nil { return nil, err } - cluster, createErr := process.Create(startSession, numberOfNodes) + cluster, createErr := process.Create(startSession, numberOfNodes, version) if createErr != nil { return nil, createErr diff --git a/integration-tests/simulacron/http.go b/integration-tests/simulacron/http.go index cfb18241..ff18967b 100644 --- a/integration-tests/simulacron/http.go +++ b/integration-tests/simulacron/http.go @@ -18,6 +18,11 @@ type ClusterData struct { Datacenters []*DatacenterData `json:"data_centers"` } +type ClusterVersion struct { + Cassandra string + Dse string +} + type DatacenterData struct { Id int `json:"id"` Nodes []*NodeData `json:"nodes"` @@ -31,11 +36,14 @@ type NodeData struct { const createUrl = "/cluster?data_centers=%s&cassandra_version=%s&dse_version=%s&name=%s&activity_log=%s&num_tokens=%d" -func (process *Process) Create(startSession bool, numberOfNodes int) (*Cluster, error) { +func (process *Process) Create(startSession bool, numberOfNodes int, version *ClusterVersion) (*Cluster, error) { + if version == nil { + version = &ClusterVersion{env.CassandraVersion, env.DseVersion} + } name := "test_" + uuid.New().String() resp, err := process.execHttp( "POST", - fmt.Sprintf(createUrl, strconv.FormatInt(int64(numberOfNodes), 10), env.CassandraVersion, env.DseVersion, name, "true", 1), + fmt.Sprintf(createUrl, strconv.FormatInt(int64(numberOfNodes), 10), version.Cassandra, version.Dse, name, "true", 1), nil) if err != nil { From f60bd28a4bda0af65981395e4534b6c259101d72 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 18 Jun 2024 13:35:12 +0200 Subject: [PATCH 04/29] Allow to run with specific Simulacron cluster version --- integration-tests/connect_test.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index 19ae101f..c634cbfa 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -8,7 +8,6 @@ import ( "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/datastax/zdm-proxy/integration-tests/client" - "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/integration-tests/utils" @@ -48,11 +47,6 @@ func TestGoCqlConnect(t *testing.T) { } func TestProtocolVersionNegotiation(t *testing.T) { - testCassandraVersion := env.CassandraVersion - env.CassandraVersion = "2.1" // downgrade C* version for protocol negotiation test - defer func() { - env.CassandraVersion = testCassandraVersion - }() c := setup.NewTestConfig("", "") c.ControlConnMaxProtocolVersion = 4 // configure unsupported protocol version testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, true, false, 1, c, &simulacron.ClusterVersion{"2.1", "2.1"}) From a244c8bb02ae042446fa1e54a6c48e96aa03d138 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 18 Jun 2024 16:31:17 +0200 Subject: [PATCH 05/29] Better classification of ProtocolError --- proxy/pkg/zdmproxy/controlconn.go | 4 +++- proxy/pkg/zdmproxy/cqlconn.go | 2 ++ proxy/pkg/zdmproxy/response.go | 23 ++++++++++++++++++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index cef8bb3c..7fea4d5c 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -2,6 +2,7 @@ package zdmproxy import ( "context" + "errors" "fmt" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" @@ -378,7 +379,8 @@ func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoV } newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf) err = newConn.InitializeContext(protoVer, ctx) - if err != nil && strings.Contains(err.Error(), "Invalid or unsupported protocol version") { + var respErr *ResponseError + if err != nil && errors.As(err, &respErr) && respErr.IsProtocolError() && strings.Contains(err.Error(), "Invalid or unsupported protocol version") { // unsupported protocol version // protocol renegotiation requires opening a new TCP connection err2 := newConn.Close() diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index d7bb7a67..00143218 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -357,6 +357,8 @@ func (c *cqlConn) PerformHandshake(version primitive.ProtocolVersion, ctx contex } } } + case *message.ProtocolError: + err = &ResponseError{Response: response} default: err = fmt.Errorf("expected AUTHENTICATE or READY, got %v", response.Body.Message) } diff --git a/proxy/pkg/zdmproxy/response.go b/proxy/pkg/zdmproxy/response.go index 9c531a4c..c328c3e2 100644 --- a/proxy/pkg/zdmproxy/response.go +++ b/proxy/pkg/zdmproxy/response.go @@ -1,6 +1,10 @@ package zdmproxy -import "github.com/datastax/go-cassandra-native-protocol/frame" +import ( + "fmt" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" +) type Response struct { responseFrame *frame.RawFrame @@ -37,3 +41,20 @@ func (r *Response) GetStreamId() int16 { return r.requestFrame.Header.StreamId } } + +type ResponseError struct { + Response *frame.Frame +} + +func (pre *ResponseError) Error() string { + return fmt.Sprintf("%v", pre.Response.Body.Message) +} + +func (pre *ResponseError) IsProtocolError() bool { + switch pre.Response.Body.Message.(type) { + case *message.ProtocolError: + return true + default: + return false + } +} From 3002a923938e7aeb06bfe1751b9c69b47b77c4f8 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 20 Jun 2024 11:25:43 +0200 Subject: [PATCH 06/29] Validation of protocol version --- integration-tests/connect_test.go | 71 ++++++++++++++++------ integration-tests/setup/testcluster.go | 2 +- proxy/pkg/config/config.go | 25 +++++++- proxy/pkg/config/config_test.go | 84 ++++++++++++++++++++++++++ proxy/pkg/zdmproxy/controlconn.go | 11 ++-- proxy/pkg/zdmproxy/cqlconn.go | 7 ++- 6 files changed, 173 insertions(+), 27 deletions(-) diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index c634cbfa..a066eed8 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -47,31 +47,64 @@ func TestGoCqlConnect(t *testing.T) { } func TestProtocolVersionNegotiation(t *testing.T) { - c := setup.NewTestConfig("", "") - c.ControlConnMaxProtocolVersion = 4 // configure unsupported protocol version - testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, true, false, 1, c, &simulacron.ClusterVersion{"2.1", "2.1"}) - require.Nil(t, err) - defer testSetup.Cleanup() + tests := []struct { + name string + clusterVersion string + controlConnMaxProtocolVersion string + negotiatedProtocolVersion primitive.ProtocolVersion + }{ + { + name: "Cluster2.1_MaxCCProtoVer4_NegotiatedProtoVer3", + clusterVersion: "2.1", + controlConnMaxProtocolVersion: "4", + negotiatedProtocolVersion: primitive.ProtocolVersion3, // protocol downgraded to V3, V4 is not supported + }, + { + name: "Cluster3.0_MaxCCProtoVer4_NegotiatedProtoVer4", + clusterVersion: "3.0", + controlConnMaxProtocolVersion: "4", + negotiatedProtocolVersion: primitive.ProtocolVersion4, + }, + { + name: "Cluster4.0_MaxCCProtoVer4_NegotiatedProtoVer4", + clusterVersion: "4.0", + controlConnMaxProtocolVersion: "4", + negotiatedProtocolVersion: primitive.ProtocolVersion4, + }, + } - // Connect to proxy as a "client" - proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := setup.NewTestConfig("", "") + c.ControlConnMaxProtocolVersion = tt.controlConnMaxProtocolVersion + testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, true, false, 1, c, + &simulacron.ClusterVersion{tt.clusterVersion, tt.clusterVersion}) + require.Nil(t, err) + defer testSetup.Cleanup() - if err != nil { - t.Fatal("Unable to connect to proxy session.") - } - defer proxy.Close() + // Connect to proxy as a "client" + proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3) - iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter() - result, err := iter.SliceMap() + if err != nil { + t.Fatal("Unable to connect to proxy session.") + } + defer proxy.Close() - if err != nil { - t.Fatal("query failed:", err) - } + cqlConn, _ := testSetup.Proxy.GetOriginControlConn().GetConnAndContactPoint() + negotiatedProto := cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) - require.Equal(t, 0, len(result)) + require.Equal(t, tt.negotiatedProtocolVersion, negotiatedProto) - // simulacron generates fake response metadata when queries aren't primed - require.Equal(t, "fake", iter.Columns()[0].Name) + iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter() + result, err := iter.SliceMap() + + if err != nil { + t.Fatal("query failed:", err) + } + + require.Equal(t, 0, len(result)) + }) + } } func TestMaxClientsThreshold(t *testing.T) { diff --git a/integration-tests/setup/testcluster.go b/integration-tests/setup/testcluster.go index 7791cab2..52613212 100644 --- a/integration-tests/setup/testcluster.go +++ b/integration-tests/setup/testcluster.go @@ -452,7 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config { conf.ReadMode = config.ReadModePrimaryOnly conf.SystemQueriesMode = config.SystemQueriesModeOrigin conf.AsyncHandshakeTimeoutMs = 4000 - conf.ControlConnMaxProtocolVersion = 3 + conf.ControlConnMaxProtocolVersion = "3" conf.ProxyRequestTimeoutMs = 10000 diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index c48b5e6e..4df23560 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -21,7 +21,7 @@ type Config struct { ReplaceCqlFunctions bool `default:"false" split_words:"true"` AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"` LogLevel string `default:"INFO" split_words:"true"` - ControlConnMaxProtocolVersion uint `default:"3" split_words:"true"` + ControlConnMaxProtocolVersion string `default:"3" split_words:"true"` // Numeric Cassandra OSS protocol version or Dse1 / Dse2 // Proxy Topology (also known as system.peers "virtualization") bucket @@ -283,6 +283,11 @@ func (c *Config) Validate() error { return err } + _, err = c.ParseControlConnMaxProtocolVersion() + if err != nil { + return err + } + return nil } @@ -337,6 +342,24 @@ func (c *Config) ParseReadMode() (common.ReadMode, error) { } } +func (c *Config) ParseControlConnMaxProtocolVersion() (uint, error) { + switch c.ControlConnMaxProtocolVersion { + case "Dse2": + return 0b_1_000010, nil + case "Dse1": + return 0b_1_000001, nil + } + ver, err := strconv.ParseUint(c.ControlConnMaxProtocolVersion, 10, 32) + if err != nil { + return 0, fmt.Errorf("could not parse control connection max protocol version, valid values are "+ + "2, 3, 4, Dse1, Dse2; original err: %w", err) + } + if ver < 2 || ver > 4 { + return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2") + } + return uint(ver), nil +} + func (c *Config) ParseLogLevel() (log.Level, error) { level, err := log.ParseLevel(strings.TrimSpace(c.LogLevel)) if err != nil { diff --git a/proxy/pkg/config/config_test.go b/proxy/pkg/config/config_test.go index 5265131b..6da7c431 100644 --- a/proxy/pkg/config/config_test.go +++ b/proxy/pkg/config/config_test.go @@ -93,3 +93,87 @@ func TestTargetConfig_WithHostnameButWithoutPort(t *testing.T) { require.Nil(t, err) require.Equal(t, 9042, c.TargetPort) } + +func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) { + defer clearAllEnvVars() + + // general setup + clearAllEnvVars() + setOriginCredentialsEnvVars() + setTargetCredentialsEnvVars() + setOriginContactPointsAndPortEnvVars() + + // test-specific setup + setTargetContactPointsAndPortEnvVars() + + conf, _ := New().ParseEnvVars() + + tests := []struct { + name string + controlConnMaxProtocolVersion string + parsedProtocolVersion uint + errorMessage string + }{ + { + name: "ParsedV2", + controlConnMaxProtocolVersion: "2", + parsedProtocolVersion: 2, + errorMessage: "", + }, + { + name: "ParsedV3", + controlConnMaxProtocolVersion: "3", + parsedProtocolVersion: 3, + errorMessage: "", + }, + { + name: "ParsedV4", + controlConnMaxProtocolVersion: "4", + parsedProtocolVersion: 4, + errorMessage: "", + }, + { + name: "ParsedDse1", + controlConnMaxProtocolVersion: "Dse1", + parsedProtocolVersion: 65, + errorMessage: "", + }, + { + name: "ParsedDse2", + controlConnMaxProtocolVersion: "Dse2", + parsedProtocolVersion: 66, + errorMessage: "", + }, + { + name: "UnsupportedCassandraV5", + controlConnMaxProtocolVersion: "5", + parsedProtocolVersion: 0, + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + }, + { + name: "UnsupportedCassandraV1", + controlConnMaxProtocolVersion: "1", + parsedProtocolVersion: 0, + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + }, + { + name: "InvalidValue", + controlConnMaxProtocolVersion: "Dsev123", + parsedProtocolVersion: 0, + errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conf.ControlConnMaxProtocolVersion = tt.controlConnMaxProtocolVersion + ver, err := conf.ParseControlConnMaxProtocolVersion() + if ver == 0 { + require.NotNil(t, err) + require.Contains(t, err.Error(), tt.errorMessage) + } else { + require.Equal(t, tt.parsedProtocolVersion, ver) + } + }) + } +} diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index 7fea4d5c..b2472439 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -125,7 +125,7 @@ func (cc *ControlConn) Start(wg *sync.WaitGroup, ctx context.Context) error { log.Infof("Received topology event from %v, refreshing topology.", cc.connConfig.GetClusterType()) - conn, _ := cc.getConnAndContactPoint() + conn, _ := cc.GetConnAndContactPoint() if conn == nil { log.Debugf("Topology refresh scheduled but the control connection isn't open. " + "Falling back to the connection where the event was received.") @@ -162,7 +162,7 @@ func (cc *ControlConn) Start(wg *sync.WaitGroup, ctx context.Context) error { cc.Close() } - conn, _ := cc.getConnAndContactPoint() + conn, _ := cc.GetConnAndContactPoint() if conn == nil { useContactPointsOnly := false if !lastOpenSuccessful { @@ -251,7 +251,7 @@ func (cc *ControlConn) ReadFailureCounter() int { } func (cc *ControlConn) Open(contactPointsOnly bool, ctx context.Context) (CqlConnection, error) { - oldConn, _ := cc.getConnAndContactPoint() + oldConn, _ := cc.GetConnAndContactPoint() if oldConn != nil { cc.Close() oldConn = nil @@ -321,7 +321,8 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( currentIndex := (firstEndpointIndex + i) % len(endpoints) endpoint = endpoints[currentIndex] - newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ControlConnMaxProtocolVersion, ctx) + maxProtoVer, _ := cc.conf.ParseControlConnMaxProtocolVersion() + newConn, err := cc.connAndNegotiateProtoVer(endpoint, maxProtoVer, ctx) if err == nil { newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) { @@ -678,7 +679,7 @@ func (cc *ControlConn) setConn(oldConn CqlConnection, newConn CqlConnection, new return cc.cqlConn, cc.currentContactPoint } -func (cc *ControlConn) getConnAndContactPoint() (CqlConnection, Endpoint) { +func (cc *ControlConn) GetConnAndContactPoint() (CqlConnection, Endpoint) { cc.cqlConnLock.Lock() conn := cc.cqlConn contactPoint := cc.currentContactPoint diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 00143218..8749516d 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -38,6 +38,7 @@ type CqlConnection interface { SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) SubscribeToProtocolEvents(ctx context.Context, eventTypes []primitive.EventType) error IsAuthEnabled() (bool, error) + GetProtocolVersion() *atomic.Value } // Not thread safe @@ -98,6 +99,7 @@ func NewCqlConnection( eventHandlerLock: &sync.Mutex{}, authEnabled: true, frameProcessor: NewStreamIdProcessor(NewInternalStreamIdMapper(conf.ProxyMaxStreamIds, nil)), + protocolVersion: &atomic.Value{}, } cqlConn.StartRequestLoop() cqlConn.StartResponseLoop() @@ -233,13 +235,16 @@ func (c *cqlConn) IsAuthEnabled() (bool, error) { return c.authEnabled, nil } +func (c *cqlConn) GetProtocolVersion() *atomic.Value { + return c.protocolVersion +} + func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx context.Context) error { authEnabled, err := c.PerformHandshake(version, ctx) if err != nil { return fmt.Errorf("failed to perform handshake: %w", err) } - c.protocolVersion = &atomic.Value{} c.protocolVersion.Store(version) c.initialized = true c.authEnabled = authEnabled From 4bf2cdd5d7ce87169007f4b03c5f6cd5650e5d9d Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 25 Jun 2024 11:01:09 +0200 Subject: [PATCH 07/29] Protocol V2 support --- proxy/pkg/zdmproxy/controlconn.go | 21 ++++++++++------ proxy/pkg/zdmproxy/cqlconn.go | 42 +++++++++++++++++++++++-------- proxy/pkg/zdmproxy/host.go | 13 +++++++--- 3 files changed, 54 insertions(+), 22 deletions(-) diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index b2472439..394c8ae7 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -378,7 +378,7 @@ func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoV cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err) return nil, err } - newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf) + newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf, protoVer) err = newConn.InitializeContext(protoVer, ctx) var respErr *ResponseError if err != nil && errors.As(err, &respErr) && respErr.IsProtocolError() && strings.Contains(err.Error(), "Invalid or unsupported protocol version") { @@ -436,6 +436,7 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] } localInfo, localHost, err := ParseSystemLocalResult(localQueryResult, cc.defaultPort) + // localHost may be nil, if we did not find the address in system.local table if err != nil { return nil, err } @@ -475,11 +476,13 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] } } - oldLocalhost, localHostExists := hostsById[localHost.HostId] - if localHostExists { - log.Warnf("Local host is also on the peers list: %v vs %v, ignoring the former one.", oldLocalhost, localHost) + if localHost != nil { + oldLocalhost, localHostExists := hostsById[localHost.HostId] + if localHostExists { + log.Warnf("Local host is also on the peers list: %v vs %v, ignoring the former one.", oldLocalhost, localHost) + } + hostsById[localHost.HostId] = localHost } - hostsById[localHost.HostId] = localHost orderedLocalHosts := make([]*Host, 0, len(hostsById)) for _, h := range hostsById { orderedLocalHosts = append(orderedLocalHosts, h) @@ -489,9 +492,11 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] currentDc := cc.datacenter cc.topologyLock.RUnlock() - orderedLocalHosts, currentDc, err = filterHosts(orderedLocalHosts, currentDc, cc.connConfig, localHost) - if err != nil { - return nil, err + if localHost != nil { + orderedLocalHosts, currentDc, err = filterHosts(orderedLocalHosts, currentDc, cc.connConfig, localHost) + if err != nil { + return nil, err + } } sort.Slice(orderedLocalHosts, func(i, j int) bool { diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 8749516d..c413cb50 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -19,10 +19,13 @@ import ( ) const ( - eventQueueLength = 2048 + eventQueueLength = 2048 + eventQueueLengthV2 = 128 - maxIncomingPending = 2048 - maxOutgoingPending = 2048 + maxIncomingPending = 2048 + maxIncomingPendingV2 = 128 + maxOutgoingPending = 2048 + maxOutgoingPendingV2 = 128 timeOutsThreshold = 1024 ) @@ -76,7 +79,7 @@ func NewCqlConnection( conn net.Conn, username string, password string, readTimeout time.Duration, writeTimeout time.Duration, - conf *config.Config) CqlConnection { + conf *config.Config, protoVer primitive.ProtocolVersion) CqlConnection { ctx, cFn := context.WithCancel(context.Background()) cqlConn := &cqlConn{ readTimeout: readTimeout, @@ -86,12 +89,13 @@ func NewCqlConnection( Username: username, Password: password, }, - initialized: false, - ctx: ctx, - cancelFn: cFn, - wg: &sync.WaitGroup{}, - outgoingCh: make(chan *frame.Frame, maxOutgoingPending), - eventsQueue: make(chan *frame.Frame, eventQueueLength), + initialized: false, + ctx: ctx, + cancelFn: cFn, + wg: &sync.WaitGroup{}, + // protoVer is the proposed protocol version using which we will try to establish connectivity + outgoingCh: make(chan *frame.Frame, maxOutgoingPendingRequests(protoVer)), + eventsQueue: make(chan *frame.Frame, maxEventsQueue(protoVer)), pendingOperations: make(map[int16]chan *frame.Frame), pendingOperationsLock: &sync.RWMutex{}, timedOutOperations: 0, @@ -107,6 +111,22 @@ func NewCqlConnection( return cqlConn } +func maxOutgoingPendingRequests(protocolVersion primitive.ProtocolVersion) int { + switch protocolVersion { + case primitive.ProtocolVersion2: + return maxOutgoingPendingV2 + } + return maxOutgoingPending +} + +func maxEventsQueue(protocolVersion primitive.ProtocolVersion) int { + switch protocolVersion { + case primitive.ProtocolVersion2: + return eventQueueLengthV2 + } + return eventQueueLength +} + func (c *cqlConn) SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) { c.eventHandlerLock.Lock() defer c.eventHandlerLock.Unlock() @@ -382,7 +402,7 @@ func (c *cqlConn) Query( queryMsg := &message.Query{ Query: cql, Options: &message.QueryOptions{ - Consistency: primitive.ConsistencyLevelLocalQuorum, + Consistency: primitive.ConsistencyLevelOne, }, } diff --git a/proxy/pkg/zdmproxy/host.go b/proxy/pkg/zdmproxy/host.go index 4033b1c0..b37e4753 100644 --- a/proxy/pkg/zdmproxy/host.go +++ b/proxy/pkg/zdmproxy/host.go @@ -64,9 +64,13 @@ func ParseSystemLocalResult(rs *ParsedRowSet, defaultPort int) (map[string]*opti return nil, nil, err } - host, err := parseHost(addr, port, row) - if err != nil { - return nil, nil, err + var host *Host + if addr != nil { + // could not resolve address from system.local table (e.g. not present in C* 2.0.0) + host, err = parseHost(addr, port, row) + if err != nil { + return nil, nil, err + } } sysLocalCols := map[string]*optionalColumn{ @@ -179,6 +183,9 @@ func ParseRpcAddress(isPeersV2 bool, row *ParsedRow, defaultPort int) (net.IP, i } else { addr = parseRpcAddressLocalOrPeersV1(row) } + if addr == nil { + return nil, -1, nil + } if addr.IsUnspecified() { peer, peerExists := row.GetByColumn("peer") From b8feb06a8b0f5073e825f0c90bee647dc50a91cc Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 26 Jun 2024 14:48:05 +0200 Subject: [PATCH 08/29] Upgrade go-cassandra-native-protocol library --- go.mod | 2 +- go.sum | 51 ++++++++++++++++ integration-tests/customhandler_test_utils.go | 4 +- integration-tests/functioncalls_test.go | 23 ++++---- integration-tests/prepared_statements_test.go | 58 +++++++++---------- proxy/pkg/config/config.go | 2 +- proxy/pkg/zdmproxy/clientconn.go | 2 +- proxy/pkg/zdmproxy/clienthandler.go | 18 +++--- proxy/pkg/zdmproxy/controlconn.go | 5 +- proxy/pkg/zdmproxy/cqlconn.go | 9 ++- proxy/pkg/zdmproxy/cqlparser.go | 11 ++-- .../cqlparser_adv_workloads_utils_test.go | 6 +- proxy/pkg/zdmproxy/cqlparser_test.go | 10 +++- proxy/pkg/zdmproxy/frameprocessor.go | 4 +- proxy/pkg/zdmproxy/host.go | 32 +++++++--- proxy/pkg/zdmproxy/nativeprotocol.go | 12 ++-- proxy/pkg/zdmproxy/parametermodifier_test.go | 6 +- proxy/pkg/zdmproxy/querymodifier.go | 8 +-- proxy/pkg/zdmproxy/querymodifier_test.go | 10 ++-- 19 files changed, 177 insertions(+), 96 deletions(-) diff --git a/go.mod b/go.mod index cd3ed9dc..ad8753b8 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20211106181442-e4c1a74c66bd - github.com/datastax/go-cassandra-native-protocol v0.0.0-20220525125956-6158d9e218b8 + github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e github.com/google/uuid v1.1.1 github.com/jpillora/backoff v1.0.0 diff --git a/go.sum b/go.sum index e28fd660..8d3e4fdf 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20211106181442-e4c1a74c66bd h1:fjJY1LimH0wVCvOHLX35SCX/MbWomAglET1H2kvz7xc= github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20211106181442-e4c1a74c66bd/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= +github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 h1:yL7+Jz0jTC6yykIK/Wh74gnTJnrGr5AyrNMXuA0gves= +github.com/antlr/antlr4/runtime/Go/antlr v1.4.10/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -14,9 +16,14 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4Yn github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/datastax/go-cassandra-native-protocol v0.0.0-20220525125956-6158d9e218b8 h1:NKLtNzC76ssf68VOenDAzMyQGg+QkxuD2QCubX+GvLk= github.com/datastax/go-cassandra-native-protocol v0.0.0-20220525125956-6158d9e218b8/go.mod h1:yFD0OKoVV9d1QW7Es58c1Gv6ijrqTGPcxgHv27wdC4Q= +github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d h1:UnPtAA8Ux3GvHLazSSUydERFuoQRyxHrB8puzXyjXIE= +github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d/go.mod h1:6FzirJfdffakAVqmHjwVfFkpru/gNbIazUOK5rIhndc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -27,20 +34,31 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e h1:SroDcndcOU9BVAduPf/PXihXoR2ZYTQYLXbupbqxAyQ= github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= +github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU= +github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= @@ -57,9 +75,17 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFB github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -69,32 +95,46 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/pierrec/lz4/v4 v4.0.3 h1:vNQKSVZNYUEAvRY9FaUXAF1XPbSOHJtDTiP41kzDz2E= github.com/pierrec/lz4/v4 v4.0.3/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.3.0 h1:miYCvYqFXtl/J9FIy8eNpBfYthAEFg+Ys0XyUVEcDsc= github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= +github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= +github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.1.0 h1:ElTg5tNp4DqfV7UQjDqv2+RJlNzsDtvNAWccbItceIE= github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.7.0 h1:L+1lyG48J1zAQXA3RBX/nG/B3gjlHq0zTt2tlbJLyCY= github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= +github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.8 h1:+fpWZdT24pJBiqJdAwYBjPSk+5YmQzYNPYzQsdzLkt8= github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= +github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.20.0 h1:38k9hgtUBdxFwE34yS8rTHmHBa4eN16E4DJlv177LNs= github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= @@ -102,9 +142,12 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -119,12 +162,20 @@ golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/integration-tests/customhandler_test_utils.go b/integration-tests/customhandler_test_utils.go index 7e440d6f..97fdc7ae 100644 --- a/integration-tests/customhandler_test_utils.go +++ b/integration-tests/customhandler_test_utils.go @@ -57,7 +57,7 @@ var ( releaseVersionColumn = &message.ColumnMetadata{Keyspace: "system", Table: "local", Name: "release_version", Type: datatype.Varchar} rpcAddressColumn = &message.ColumnMetadata{Keyspace: "system", Table: "local", Name: "rpc_address", Type: datatype.Inet} schemaVersionColumn = &message.ColumnMetadata{Keyspace: "system", Table: "local", Name: "schema_version", Type: datatype.Uuid} - tokensColumn = &message.ColumnMetadata{Keyspace: "system", Table: "local", Name: "tokens", Type: datatype.NewSetType(datatype.Varchar)} + tokensColumn = &message.ColumnMetadata{Keyspace: "system", Table: "local", Name: "tokens", Type: datatype.NewSet(datatype.Varchar)} ) // These columns are a subset of the total columns returned by OSS C* 3.11.2, and contain all the information that @@ -86,7 +86,7 @@ var ( releaseVersionPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "release_version", Type: datatype.Varchar} rpcAddressPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "rpc_address", Type: datatype.Inet} schemaVersionPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "schema_version", Type: datatype.Uuid} - tokensPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "tokens", Type: datatype.NewSetType(datatype.Varchar)} + tokensPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "tokens", Type: datatype.NewSet(datatype.Varchar)} ) // These columns are a subset of the total columns returned by OSS C* 3.11.2, and contain all the information that diff --git a/integration-tests/functioncalls_test.go b/integration-tests/functioncalls_test.go index 7cf04d76..96fbbfeb 100644 --- a/integration-tests/functioncalls_test.go +++ b/integration-tests/functioncalls_test.go @@ -854,7 +854,7 @@ func TestNowFunctionReplacementPreparedStatement(t *testing.T) { isReplacedNow: false, value: []int{11, 22, 33}, valueSimulacron: []int{11, 22, 33}, - dataType: datatype.NewListType(datatype.Int), + dataType: datatype.NewList(datatype.Int), simulacronType: "list", }, { @@ -880,7 +880,7 @@ func TestNowFunctionReplacementPreparedStatement(t *testing.T) { {1, 2, 3}, {2, 3, 4}, }, - dataType: datatype.NewListType(datatype.NewTupleType(datatype.Int, datatype.Int, datatype.Int)), + dataType: datatype.NewList(datatype.NewTuple(datatype.Int, datatype.Int, datatype.Int)), simulacronType: "list>", }, { @@ -2261,7 +2261,7 @@ func TestNowFunctionReplacementBatchStatement(t *testing.T) { } expectedBatchChildQueries = append(expectedBatchChildQueries, expectedBatchChildQuery) - var queryOrId interface{} + var batchChild *message.BatchChild if childStatement.prepared { when := simulacron.NewWhenQueryOptions() for _, p := range expectedChildQueryParams { @@ -2285,18 +2285,21 @@ func TestNowFunctionReplacementBatchStatement(t *testing.T) { require.Nil(t, err) prepared, ok := resp.Body.Message.(*message.PreparedResult) require.True(t, ok) - queryOrId = prepared.PreparedQueryId + batchChild = &message.BatchChild{ + Id: prepared.PreparedQueryId, + Values: positionalValues, + } validateForwardedPrepare(simulacronSetup.Origin, childStatement) validateForwardedPrepare(simulacronSetup.Target, childStatement) } else { - queryOrId = childStatement.originalQuery + batchChild = &message.BatchChild{ + Query: childStatement.originalQuery, + Values: positionalValues, + } } - batchChildStatements = append(batchChildStatements, &message.BatchChild{ - QueryOrId: queryOrId, - Values: positionalValues, - }) + batchChildStatements = append(batchChildStatements, batchChild) } batchMsg := &message.Batch{ @@ -2325,7 +2328,7 @@ func TestNowFunctionReplacementBatchStatement(t *testing.T) { actualStmt := matching[0].QueriesOrIds[idx] actualParams := matching[0].Values[idx] if childStatement.prepared { - b64ExpectedValue := base64.StdEncoding.EncodeToString(batchChildStatements[idx].QueryOrId.([]byte)) + b64ExpectedValue := base64.StdEncoding.EncodeToString(batchChildStatements[idx].Id) require.Equal(t, b64ExpectedValue, actualStmt, idx) } else { if enableNowReplacement { diff --git a/integration-tests/prepared_statements_test.go b/integration-tests/prepared_statements_test.go index cf8edc11..6e62a5e6 100644 --- a/integration-tests/prepared_statements_test.go +++ b/integration-tests/prepared_statements_test.go @@ -353,9 +353,9 @@ func TestPreparedIdReplacement(t *testing.T) { var batchPrepareMsg *message.Prepare var expectedBatchPrepareMsg *message.Prepare if test.batchQuery != "" { - batchPrepareMsg = prepareMsg.Clone().(*message.Prepare) + batchPrepareMsg = prepareMsg.DeepCopy() batchPrepareMsg.Query = test.batchQuery - expectedBatchPrepareMsg = batchPrepareMsg.Clone().(*message.Prepare) + expectedBatchPrepareMsg = batchPrepareMsg.DeepCopy() expectedBatchPrepareMsg.Query = test.expectedBatchQuery prepareResp, err = testSetup.Client.CqlConnection.SendAndReceive( frame.NewFrame(primitive.ProtocolVersion4, 10, batchPrepareMsg)) @@ -391,15 +391,15 @@ func TestPreparedIdReplacement(t *testing.T) { Type: primitive.BatchTypeLogged, Children: []*message.BatchChild{ { - QueryOrId: test.query, + Query: test.query, // the decoder uses empty slices instead of nil so this has to be initialized this way // so that the equality assertions work later in this test Values: make([]*primitive.Value, 0), }, { - QueryOrId: originBatchPreparedId, - Values: make([]*primitive.Value, 0), + Id: originBatchPreparedId, + Values: make([]*primitive.Value, 0), }, }, Consistency: primitive.ConsistencyLevelLocalQuorum, @@ -482,7 +482,7 @@ func TestPreparedIdReplacement(t *testing.T) { require.Equal(t, originPreparedId, originExecuteMessages[0].QueryId) if expectedOriginBatches > 0 { require.Equal(t, 2, len(originBatchMessages[0].Children)) - require.Equal(t, originBatchPreparedId, originBatchMessages[0].Children[1].QueryOrId) + require.Equal(t, originBatchPreparedId, originBatchMessages[0].Children[1].Id) } for _, targetExecute := range targetExecuteMessages { @@ -491,7 +491,7 @@ func TestPreparedIdReplacement(t *testing.T) { } if expectedTargetBatches > 0 { require.Equal(t, 2, len(targetBatchMessages[0].Children)) - require.Equal(t, targetBatchPreparedId, targetBatchMessages[0].Children[1].QueryOrId) + require.Equal(t, targetBatchPreparedId, targetBatchMessages[0].Children[1].Id) require.NotEqual(t, batchMsg, targetBatchMessages[0]) } @@ -508,8 +508,8 @@ func TestPreparedIdReplacement(t *testing.T) { require.NotEqual(t, len(executeMsg.Options.PositionalValues), len(originExecuteMessages[0].Options.PositionalValues)) // check if only the positional values are different, we test the parameter replacement in depth on other tests - modifiedOriginExecuteMsg := originExecuteMessages[0].Clone() - modifiedOriginExecuteMsg.(*message.Execute).Options.PositionalValues = executeMsg.Options.PositionalValues + modifiedOriginExecuteMsg := originExecuteMessages[0].DeepCopy() + modifiedOriginExecuteMsg.Options.PositionalValues = executeMsg.Options.PositionalValues require.Equal(t, executeMsg, modifiedOriginExecuteMsg) require.Equal(t, originExecuteMessages[0].Options, targetExecuteMessages[0].Options) } else { @@ -524,19 +524,19 @@ func TestPreparedIdReplacement(t *testing.T) { require.Equal(t, expectedBatchPrepareMsg, originPrepareMessages[1]) if test.expectedBatchPreparedStmtVariables != nil { - require.NotEqual(t, batchMsg.Children[0].QueryOrId, originBatchMessages[0].Children[0].QueryOrId) - require.NotEqual(t, batchMsg.Children[0].QueryOrId, targetBatchMessages[0].Children[0].QueryOrId) - require.Equal(t, originBatchMessages[0].Children[0].QueryOrId, targetBatchMessages[0].Children[0].QueryOrId) + require.NotEqual(t, batchMsg.Children[0].Query, originBatchMessages[0].Children[0].Query) + require.NotEqual(t, batchMsg.Children[0].Query, targetBatchMessages[0].Children[0].Query) + require.Equal(t, originBatchMessages[0].Children[0].Query, targetBatchMessages[0].Children[0].Query) require.Equal(t, 0, len(targetBatchMessages[0].Children[0].Values)) require.Equal(t, 0, len(originBatchMessages[0].Children[0].Values)) require.Equal(t, 0, len(batchMsg.Children[0].Values)) - require.Equal(t, batchMsg.Children[1].QueryOrId, originBatchMessages[0].Children[1].QueryOrId) - require.NotEqual(t, batchMsg.Children[1].QueryOrId, targetBatchMessages[0].Children[1].QueryOrId) - require.NotEqual(t, originBatchMessages[0].Children[1].QueryOrId, targetBatchMessages[0].Children[1].QueryOrId) - require.Equal(t, targetBatchPreparedId, targetBatchMessages[0].Children[1].QueryOrId) - require.Equal(t, originBatchPreparedId, originBatchMessages[0].Children[1].QueryOrId) - require.Equal(t, originBatchPreparedId, batchMsg.Children[1].QueryOrId) + require.Equal(t, batchMsg.Children[1].Query, originBatchMessages[0].Children[1].Query) + require.NotEqual(t, batchMsg.Children[1].Id, targetBatchMessages[0].Children[1].Id) + require.NotEqual(t, originBatchMessages[0].Children[1].Id, targetBatchMessages[0].Children[1].Id) + require.Equal(t, targetBatchPreparedId, targetBatchMessages[0].Children[1].Id) + require.Equal(t, originBatchPreparedId, originBatchMessages[0].Children[1].Id) + require.Equal(t, originBatchPreparedId, batchMsg.Children[1].Id) require.Equal(t, len(test.expectedBatchPreparedStmtVariables.Columns), len(targetBatchMessages[0].Children[1].Values)) require.Equal(t, len(test.expectedBatchPreparedStmtVariables.Columns), len(originBatchMessages[0].Children[1].Values)) require.Equal(t, 0, len(batchMsg.Children[1].Values)) @@ -546,8 +546,8 @@ func TestPreparedIdReplacement(t *testing.T) { } else { require.Equal(t, batchMsg, originBatchMessages[0]) require.NotEqual(t, batchMsg, targetBatchMessages[0]) - clonedBatchMsg := targetBatchMessages[0].Clone().(*message.Batch) - clonedBatchMsg.Children[1].QueryOrId = originBatchPreparedId + clonedBatchMsg := targetBatchMessages[0].DeepCopy() + clonedBatchMsg.Children[1].Id = originBatchPreparedId require.Equal(t, batchMsg, clonedBatchMsg) } } @@ -706,7 +706,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { var batchMsg *message.Batch var batchPrepareMsg *message.Prepare if test.batchQuery != "" { - batchPrepareMsg = prepareMsg.Clone().(*message.Prepare) + batchPrepareMsg = prepareMsg.DeepCopy() batchPrepareMsg.Query = test.batchQuery prepareResp, err = testSetup.Client.CqlConnection.SendAndReceive( frame.NewFrame(primitive.ProtocolVersion4, 10, batchPrepareMsg)) @@ -721,14 +721,14 @@ func TestUnpreparedIdReplacement(t *testing.T) { Type: primitive.BatchTypeLogged, Children: []*message.BatchChild{ { - QueryOrId: test.query, + Query: test.query, // the decoder uses empty slices instead of nil so this has to be initialized this way // so that the equality assertions work later in this test Values: make([]*primitive.Value, 0), }, { - QueryOrId: originBatchPreparedId, - Values: make([]*primitive.Value, 0), + Id: originBatchPreparedId, + Values: make([]*primitive.Value, 0), }, }, Consistency: primitive.ConsistencyLevelLocalQuorum, @@ -843,7 +843,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { if expectedTargetBatches > 0 { for _, batch := range targetBatchMessages { require.Equal(t, 2, len(batch.Children)) - require.Equal(t, targetBatchPreparedId, batch.Children[1].QueryOrId) + require.Equal(t, targetBatchPreparedId, batch.Children[1].Id) require.NotEqual(t, batchMsg, batch) } } @@ -1117,13 +1117,11 @@ func NewPreparedTestHandler( func checkIfPreparedIdMatches(batchMsg *message.Batch, preparedId []byte) (bool, []byte) { var batchPreparedId []byte for _, child := range batchMsg.Children { - switch queryOrId := child.QueryOrId.(type) { - case []byte: - batchPreparedId = queryOrId - if !bytes.Equal(queryOrId, preparedId) { + if child.Id != nil { + batchPreparedId = child.Id + if !bytes.Equal(child.Id, preparedId) { return false, batchPreparedId } - default: } } diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index 4df23560..a249cbf8 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -20,7 +20,7 @@ type Config struct { ReadMode string `default:"PRIMARY_ONLY" split_words:"true"` ReplaceCqlFunctions bool `default:"false" split_words:"true"` AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"` - LogLevel string `default:"INFO" split_words:"true"` + LogLevel string `default:"DEBUG" split_words:"true"` ControlConnMaxProtocolVersion string `default:"3" split_words:"true"` // Numeric Cassandra OSS protocol version or Dse1 / Dse2 // Proxy Topology (also known as system.peers "virtualization") bucket diff --git a/proxy/pkg/zdmproxy/clientconn.go b/proxy/pkg/zdmproxy/clientconn.go index b6bafe65..ff8e2100 100644 --- a/proxy/pkg/zdmproxy/clientconn.go +++ b/proxy/pkg/zdmproxy/clientconn.go @@ -187,7 +187,7 @@ func (cc *ClientConnector) listenForRequests() { cc.sendResponseToClient(protocolErrResponseFrame) continue } else if alreadySentProtocolErr != nil { - clonedProtocolErr := alreadySentProtocolErr.Clone() + clonedProtocolErr := alreadySentProtocolErr.DeepCopy() clonedProtocolErr.Header.StreamId = f.Header.StreamId cc.sendResponseToClient(clonedProtocolErr) continue diff --git a/proxy/pkg/zdmproxy/clienthandler.go b/proxy/pkg/zdmproxy/clienthandler.go index 1eb5e1dd..cedf7c83 100644 --- a/proxy/pkg/zdmproxy/clienthandler.go +++ b/proxy/pkg/zdmproxy/clienthandler.go @@ -891,7 +891,7 @@ func (ch *ClientHandler) processClientResponse( return nil, fmt.Errorf("invalid cluster type: %v", responseClusterType) } - newFrame = decodedFrame.Clone() + newFrame = decodedFrame.DeepCopy() newUnprepared := &message.Unprepared{ ErrorMessage: fmt.Sprintf("Prepared query with ID %s not found (either the query was not prepared "+ "on this host (maybe the host has been restarted?) or you have prepared too many queries and it has "+ @@ -945,7 +945,7 @@ func (ch *ClientHandler) processPreparedResponse( return nil, fmt.Errorf("replaced terms in the prepared statement but prepared result doesn't have variables metadata: %v", bodyMsg) } - newResponse = response.Clone() + newResponse = response.DeepCopy() newPreparedBody, ok := newResponse.Body.Message.(*message.PreparedResult) if !ok { return nil, fmt.Errorf("could not modify prepared result to remove generated parameters because "+ @@ -1655,7 +1655,7 @@ func (ch *ClientHandler) handleExecuteRequest( } replacementTimeUuids = ch.parameterModifier.generateTimeUuids(prepareRequestInfo) - newOriginRequest := clientRequest.Clone() + newOriginRequest := clientRequest.DeepCopy() _, err = ch.parameterModifier.AddValuesToExecuteFrame( newOriginRequest, prepareRequestInfo, preparedData.GetOriginVariablesMetadata(), replacementTimeUuids) if err != nil { @@ -1677,7 +1677,7 @@ func (ch *ClientHandler) handleExecuteRequest( return nil, nil, nil, fmt.Errorf("could not decode execute raw frame: %w", err) } - newTargetRequest := clientRequest.Clone() + newTargetRequest := clientRequest.DeepCopy() var newTargetExecuteMsg *message.Execute if len(replacedTerms) > 0 { if replacementTimeUuids == nil { @@ -1726,7 +1726,7 @@ func (ch *ClientHandler) handleBatchRequest( var newOriginRequest *frame.Frame var newOriginBatchMsg *message.Batch - newTargetRequest := decodedFrame.Clone() + newTargetRequest := decodedFrame.DeepCopy() newTargetBatchMsg, ok := newTargetRequest.Body.Message.(*message.Batch) if !ok { return nil, nil, fmt.Errorf("expected Batch but got %v instead", newTargetRequest.Body.Message.GetOpCode()) @@ -1736,7 +1736,7 @@ func (ch *ClientHandler) handleBatchRequest( prepareRequestInfo := preparedData.GetPrepareRequestInfo() if len(prepareRequestInfo.GetReplacedTerms()) > 0 { if newOriginRequest == nil { - newOriginRequest = decodedFrame.Clone() + newOriginRequest = decodedFrame.DeepCopy() newOriginBatchMsg, ok = newOriginRequest.Body.Message.(*message.Batch) if !ok { return nil, nil, fmt.Errorf("expected Batch but got %v instead", newOriginRequest.Body.Message.GetOpCode()) @@ -1754,8 +1754,8 @@ func (ch *ClientHandler) handleBatchRequest( } } - originalQueryId := newTargetBatchMsg.Children[stmtIdx].QueryOrId.([]byte) - newTargetBatchMsg.Children[stmtIdx].QueryOrId = preparedData.GetTargetPreparedId() + originalQueryId := newTargetBatchMsg.Children[stmtIdx].Id + newTargetBatchMsg.Children[stmtIdx].Id = preparedData.GetTargetPreparedId() log.Tracef("Replacing prepared ID %s within a BATCH with %s for target cluster.", hex.EncodeToString(originalQueryId), hex.EncodeToString(preparedData.GetTargetPreparedId())) } @@ -1803,7 +1803,7 @@ func (ch *ClientHandler) sendToAsyncConnector( } if sendAlsoToAsync { - asyncRequest = asyncRequest.Clone() // forwardToAsyncOnly requests don't need to be cloned because they are only sent to 1 connector + asyncRequest = asyncRequest.DeepCopy() // forwardToAsyncOnly requests don't need to be cloned because they are only sent to 1 connector } if isFireAndForget { diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index 394c8ae7..47fb09ce 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -378,7 +378,7 @@ func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoV cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err) return nil, err } - newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf, protoVer) + newConn := NewCqlConnection(endpoint, tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf, protoVer) err = newConn.InitializeContext(protoVer, ctx) var respErr *ResponseError if err != nil && errors.As(err, &respErr) && respErr.IsProtocolError() && strings.Contains(err.Error(), "Invalid or unsupported protocol version") { @@ -435,8 +435,7 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] return nil, fmt.Errorf("could not fetch information from system.local table: %w", err) } - localInfo, localHost, err := ParseSystemLocalResult(localQueryResult, cc.defaultPort) - // localHost may be nil, if we did not find the address in system.local table + localInfo, localHost, err := ParseSystemLocalResult(localQueryResult, conn.GetEndpoint(), cc.defaultPort) if err != nil { return nil, err } diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index c413cb50..929ec46d 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -31,6 +31,7 @@ const ( ) type CqlConnection interface { + GetEndpoint() Endpoint IsInitialized() bool InitializeContext(version primitive.ProtocolVersion, ctx context.Context) error SendAndReceive(request *frame.Frame, ctx context.Context) (*frame.Frame, error) @@ -48,6 +49,7 @@ type CqlConnection interface { type cqlConn struct { readTimeout time.Duration writeTimeout time.Duration + endpoint Endpoint conn net.Conn credentials *AuthCredentials initialized bool @@ -71,12 +73,16 @@ var ( StreamIdMismatchErr = errors.New("stream id of the response is different from the stream id of the request") ) +func (c *cqlConn) GetEndpoint() Endpoint { + return c.endpoint +} + func (c *cqlConn) String() string { return fmt.Sprintf("cqlConn{conn: %v}", c.conn.RemoteAddr().String()) } func NewCqlConnection( - conn net.Conn, + endpoint Endpoint, conn net.Conn, username string, password string, readTimeout time.Duration, writeTimeout time.Duration, conf *config.Config, protoVer primitive.ProtocolVersion) CqlConnection { @@ -84,6 +90,7 @@ func NewCqlConnection( cqlConn := &cqlConn{ readTimeout: readTimeout, writeTimeout: writeTimeout, + endpoint: endpoint, conn: conn, credentials: &AuthCredentials{ Username: username, diff --git a/proxy/pkg/zdmproxy/cqlparser.go b/proxy/pkg/zdmproxy/cqlparser.go index 56446ba9..937afe73 100644 --- a/proxy/pkg/zdmproxy/cqlparser.go +++ b/proxy/pkg/zdmproxy/cqlparser.go @@ -115,15 +115,13 @@ func buildRequestInfo( } preparedDataByStmtIdxMap := make(map[int]PreparedData) for childIdx, child := range batchMsg.Children { - switch queryOrId := child.QueryOrId.(type) { - case []byte: - preparedData, err := getPreparedData(psCache, mh, queryOrId, primitive.OpCodeBatch, decodedFrame) + if child.Id != nil { + preparedData, err := getPreparedData(psCache, mh, child.Id, primitive.OpCodeBatch, decodedFrame) if err != nil { return nil, err } else { preparedDataByStmtIdxMap[childIdx] = preparedData } - default: } } return NewBatchRequestInfo(preparedDataByStmtIdxMap), nil @@ -352,11 +350,10 @@ func (recv *frameDecodeContext) inspectStatements(currentKeyspace string, timeUu currentKeyspace = typedMsg.Keyspace } for idx, childStmt := range typedMsg.Children { - switch typedQueryOrId := childStmt.QueryOrId.(type) { - case string: + if len(childStmt.Query) > 0 { statementsQueryData = append( statementsQueryData, &statementQueryData{ - statementIndex: idx, queryData: inspectCqlQuery(typedQueryOrId, currentKeyspace, timeUuidGenerator)}) + statementIndex: idx, queryData: inspectCqlQuery(childStmt.Query, currentKeyspace, timeUuidGenerator)}) } } default: diff --git a/proxy/pkg/zdmproxy/cqlparser_adv_workloads_utils_test.go b/proxy/pkg/zdmproxy/cqlparser_adv_workloads_utils_test.go index 4ae50b02..46cefec9 100644 --- a/proxy/pkg/zdmproxy/cqlparser_adv_workloads_utils_test.go +++ b/proxy/pkg/zdmproxy/cqlparser_adv_workloads_utils_test.go @@ -39,6 +39,8 @@ func getGeneralParamsForTests(t *testing.T) params { } func buildQueryMessageForTests(queryString string) *message.Query { + var defaultTimestamp int64 = 1647023221311969 + var serialConsistency = primitive.ConsistencyLevelLocalSerial return &message.Query{ Query: queryString, Options: &message.QueryOptions{ @@ -49,8 +51,8 @@ func buildQueryMessageForTests(queryString string) *message.Query { PageSize: 5000, PageSizeInBytes: false, PagingState: nil, - SerialConsistency: &primitive.NillableConsistencyLevel{Value: primitive.ConsistencyLevelLocalSerial}, - DefaultTimestamp: &primitive.NillableInt64{Value: 1647023221311969}, + SerialConsistency: &serialConsistency, + DefaultTimestamp: &defaultTimestamp, Keyspace: "", NowInSeconds: nil, ContinuousPagingOptions: &message.ContinuousPagingOptions{ diff --git a/proxy/pkg/zdmproxy/cqlparser_test.go b/proxy/pkg/zdmproxy/cqlparser_test.go index a211abba..198cf822 100644 --- a/proxy/pkg/zdmproxy/cqlparser_test.go +++ b/proxy/pkg/zdmproxy/cqlparser_test.go @@ -182,7 +182,15 @@ func mockExecuteFrame(t *testing.T, preparedId string) *frame.RawFrame { } func mockBatch(t *testing.T, query interface{}) *frame.RawFrame { - batchMsg := &message.Batch{Children: []*message.BatchChild{{QueryOrId: query}}} + var child message.BatchChild + switch query.(type) { + case []byte: + child = message.BatchChild{Id: query.([]byte)} + default: + child = message.BatchChild{Query: query.(string)} + + } + batchMsg := &message.Batch{Children: []*message.BatchChild{&child}} return mockFrame(t, batchMsg, primitive.ProtocolVersion4) } diff --git a/proxy/pkg/zdmproxy/frameprocessor.go b/proxy/pkg/zdmproxy/frameprocessor.go index ee392a61..2beb62b1 100644 --- a/proxy/pkg/zdmproxy/frameprocessor.go +++ b/proxy/pkg/zdmproxy/frameprocessor.go @@ -84,7 +84,7 @@ func setRawFrameStreamId(f *frame.RawFrame, id int16) *frame.RawFrame { if f.Header.StreamId == id { return f } - newHeader := f.Header.Clone() + newHeader := f.Header.DeepCopy() newHeader.StreamId = id return &frame.RawFrame{ Header: newHeader, @@ -98,7 +98,7 @@ func setFrameStreamId(f *frame.Frame, id int16) *frame.Frame { if f.Header.StreamId == id { return f } - newHeader := f.Header.Clone() + newHeader := f.Header.DeepCopy() newHeader.StreamId = id return &frame.Frame{ Header: newHeader, diff --git a/proxy/pkg/zdmproxy/host.go b/proxy/pkg/zdmproxy/host.go index b37e4753..bc2168ca 100644 --- a/proxy/pkg/zdmproxy/host.go +++ b/proxy/pkg/zdmproxy/host.go @@ -7,6 +7,8 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" "net" + "strconv" + "strings" ) type Host struct { @@ -48,7 +50,7 @@ func (recv *Host) String() string { hex.EncodeToString(recv.HostId[:])) } -func ParseSystemLocalResult(rs *ParsedRowSet, defaultPort int) (map[string]*optionalColumn, *Host, error) { +func ParseSystemLocalResult(rs *ParsedRowSet, ccEndpoint Endpoint, defaultPort int) (map[string]*optionalColumn, *Host, error) { if len(rs.Rows) < 1 { return nil, nil, fmt.Errorf("could not parse system local query result: query returned %d rows", len(rs.Rows)) } @@ -60,17 +62,17 @@ func ParseSystemLocalResult(rs *ParsedRowSet, defaultPort int) (map[string]*opti row := rs.Rows[0] addr, port, err := ParseRpcAddress(false, row, defaultPort) + if addr == nil { + // could not resolve address from system.local table (e.g. not present in C* 2.0.0) + addr, port, err = ParseEndpoint(ccEndpoint) + } if err != nil { return nil, nil, err } - var host *Host - if addr != nil { - // could not resolve address from system.local table (e.g. not present in C* 2.0.0) - host, err = parseHost(addr, port, row) - if err != nil { - return nil, nil, err - } + host, err := parseHost(addr, port, row) + if err != nil { + return nil, nil, err } sysLocalCols := map[string]*optionalColumn{ @@ -222,6 +224,20 @@ func ParseRpcAddress(isPeersV2 bool, row *ParsedRow, defaultPort int) (net.IP, i return addr, rpcPort, nil } +func ParseEndpoint(endpoint Endpoint) (net.IP, int, error) { + socketEndpoint := endpoint.GetSocketEndpoint() + parts := strings.Split(socketEndpoint, ":") + if len(parts) != 2 { + return nil, -1, fmt.Errorf("invalid endpoint: %s", socketEndpoint) + } + addr := parts[0] + port, err := strconv.Atoi(parts[1]) + if err != nil { + return nil, -1, fmt.Errorf("invalid endpoint: %s", socketEndpoint) + } + return net.ParseIP(addr), port, nil +} + func parseRpcPortPeersV2(row *ParsedRow) (int, bool) { val, ok := row.GetByColumn("native_port") if ok && val != nil { diff --git a/proxy/pkg/zdmproxy/nativeprotocol.go b/proxy/pkg/zdmproxy/nativeprotocol.go index 8f3515ed..98b0dfe1 100644 --- a/proxy/pkg/zdmproxy/nativeprotocol.go +++ b/proxy/pkg/zdmproxy/nativeprotocol.go @@ -262,10 +262,10 @@ var ( storagePortColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "storage_port", Type: datatype.Int} storagePortSslColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "storage_port_ssl", Type: datatype.Int} thriftVersionColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "thrift_version", Type: datatype.Varchar} - tokensColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "tokens", Type: datatype.NewSetType(datatype.Varchar)} - truncatedAtColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "truncated_at", Type: datatype.NewMapType(datatype.Uuid, datatype.Blob)} + tokensColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "tokens", Type: datatype.NewSet(datatype.Varchar)} + truncatedAtColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "truncated_at", Type: datatype.NewMap(datatype.Uuid, datatype.Blob)} workloadColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "workload", Type: datatype.Varchar} - workloadsColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "workloads", Type: datatype.NewSetType(datatype.Varchar)} + workloadsColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "workloads", Type: datatype.NewSet(datatype.Varchar)} ) var systemLocalColumns = []*message.ColumnMetadata{ @@ -367,7 +367,7 @@ func columnFromSelector( } // we are assuming here that resultColumn always refers to an unaliased column because the cql grammar doesn't support alias recursion - aliasedColumn := resultColumn.Clone() + aliasedColumn := resultColumn.DeepCopy() aliasedColumn.Name = s.alias return aliasedColumn, isCountSelector, nil default: @@ -605,9 +605,9 @@ var ( serverIdPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "server_id", Type: datatype.Varchar} storagePortPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "storage_port", Type: datatype.Int} storagePortSslPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "storage_port_ssl", Type: datatype.Int} - tokensPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "tokens", Type: datatype.NewSetType(datatype.Varchar)} + tokensPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "tokens", Type: datatype.NewSet(datatype.Varchar)} workloadPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "workload", Type: datatype.Varchar} - workloadsPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "workloads", Type: datatype.NewSetType(datatype.Varchar)} + workloadsPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "workloads", Type: datatype.NewSet(datatype.Varchar)} ) var systemPeersColumns = []*message.ColumnMetadata{ diff --git a/proxy/pkg/zdmproxy/parametermodifier_test.go b/proxy/pkg/zdmproxy/parametermodifier_test.go index dc0f99cf..0cb85ade 100644 --- a/proxy/pkg/zdmproxy/parametermodifier_test.go +++ b/proxy/pkg/zdmproxy/parametermodifier_test.go @@ -25,7 +25,7 @@ func TestAddValuesToExecuteFrame_NoReplacedTerms(t *testing.T) { PkIndices: nil, Columns: nil, } - fClone := f.Clone() + fClone := f.DeepCopy() replacementTimeUuids := parameterModifier.generateTimeUuids(prepareRequestInfo) newMsg, err := parameterModifier.AddValuesToExecuteFrame(fClone, prepareRequestInfo, variablesMetadata, replacementTimeUuids) require.Same(t, fClone.Body.Message, newMsg) @@ -198,7 +198,7 @@ func TestAddValuesToExecuteFrame_PositionalValues(t *testing.T) { require.Nil(t, err) parameterModifier := NewParameterModifier(generator) queryOpts := &message.QueryOptions{PositionalValues: requestPosVals} - clonedQueryOpts := queryOpts.Clone() // we use this so that we keep the "original" request options + clonedQueryOpts := queryOpts.DeepCopy() // we use this so that we keep the "original" request options f := frame.NewFrame(primitive.ProtocolVersion4, 1, &message.Execute{ QueryId: nil, ResultMetadataId: nil, @@ -344,7 +344,7 @@ func TestAddValuesToExecuteFrame_NamedValues(t *testing.T) { require.Nil(t, err) parameterModifier := NewParameterModifier(generator) queryOpts := &message.QueryOptions{NamedValues: requestNamedVals} - clonedQueryOpts := queryOpts.Clone() // we use this so that we keep the "original" request options + clonedQueryOpts := queryOpts.DeepCopy() // we use this so that we keep the "original" request options f := frame.NewFrame(primitive.ProtocolVersion4, 1, &message.Execute{ QueryId: nil, ResultMetadataId: nil, diff --git a/proxy/pkg/zdmproxy/querymodifier.go b/proxy/pkg/zdmproxy/querymodifier.go index 91eb9c05..d2a0dd25 100644 --- a/proxy/pkg/zdmproxy/querymodifier.go +++ b/proxy/pkg/zdmproxy/querymodifier.go @@ -89,7 +89,7 @@ func (recv *QueryModifier) replaceQueryInBatchMessage( return decodedFrame, []*statementReplacedTerms{}, statementsQueryData, nil } - newFrame := decodedFrame.Clone() + newFrame := decodedFrame.DeepCopy() newBatchMsg, ok := newFrame.Body.Message.(*message.Batch) if !ok { return nil, nil, nil, fmt.Errorf("expected Batch in cloned frame but got %v instead", newFrame.Body.Message.GetOpCode()) @@ -100,7 +100,7 @@ func (recv *QueryModifier) replaceQueryInBatchMessage( return nil, nil, nil, fmt.Errorf("new query data statement index (%v) is greater or equal than "+ "number of batch child statements (%v)", newStmtQueryData.statementIndex, len(newBatchMsg.Children)) } - newBatchMsg.Children[newStmtQueryData.statementIndex].QueryOrId = newStmtQueryData.queryData.getQuery() + newBatchMsg.Children[newStmtQueryData.statementIndex].Query = newStmtQueryData.queryData.getQuery() } return newFrame, statementsReplacedTerms, newStatementsQueryData, nil @@ -117,7 +117,7 @@ func (recv *QueryModifier) replaceQueryInQueryMessage( return decodedFrame, []*statementReplacedTerms{}, statementsQueryData, nil } newQueryData, replacedTerms := stmtQueryData.queryData.replaceNowFunctionCallsWithLiteral() - newFrame := decodedFrame.Clone() + newFrame := decodedFrame.DeepCopy() newQueryMsg, ok := newFrame.Body.Message.(*message.Query) if !ok { return nil, nil, nil, fmt.Errorf("expected Query in cloned frame but got %v instead", newFrame.Body.Message.GetOpCode()) @@ -143,7 +143,7 @@ func (recv *QueryModifier) replaceQueryInPrepareMessage( } else { newQueryData, replacedTerms = stmtQueryData.queryData.replaceNowFunctionCallsWithPositionalBindMarkers() } - newFrame := decodedFrame.Clone() + newFrame := decodedFrame.DeepCopy() newPrepareMsg, ok := newFrame.Body.Message.(*message.Prepare) if !ok { return nil, nil, nil, fmt.Errorf("expected Prepare in cloned frame but got %v instead", newFrame.Body.Message.GetOpCode()) diff --git a/proxy/pkg/zdmproxy/querymodifier_test.go b/proxy/pkg/zdmproxy/querymodifier_test.go index d7634f84..da315367 100644 --- a/proxy/pkg/zdmproxy/querymodifier_test.go +++ b/proxy/pkg/zdmproxy/querymodifier_test.go @@ -123,7 +123,7 @@ func TestReplaceQueryString(t *testing.T) { {"OpCodeBatch Mixed Prepared and Simple", mockBatchWithChildren(t, []*message.BatchChild{ { - QueryOrId: "UPDATE blah SET a = ?, b = 123 " + + Query: "UPDATE blah SET a = ?, b = 123 " + "WHERE f[now()] = ? IF " + "g[123] IN (2, 3, ?, now(), ?, now()) AND " + "d IN ? AND " + @@ -132,12 +132,12 @@ func TestReplaceQueryString(t *testing.T) { Values: []*primitive.Value{}, // not used by the SUT (system under test) }, { - QueryOrId: []byte{0}, - Values: []*primitive.Value{}, // not used by the SUT + Id: []byte{0}, + Values: []*primitive.Value{}, // not used by the SUT }, { - QueryOrId: "DELETE FROM blah WHERE b = 123 AND a = now()", - Values: []*primitive.Value{}, // not used by the SUT + Query: "DELETE FROM blah WHERE b = 123 AND a = now()", + Values: []*primitive.Value{}, // not used by the SUT }}), []*statementReplacedTerms{ {statementIndex: 0, replacedTerms: []*term{ From 1d353fafe3c3e5ca343a0a3581ce3972a4810548 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 26 Jun 2024 15:11:15 +0200 Subject: [PATCH 09/29] Cleanup --- proxy/pkg/zdmproxy/controlconn.go | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index 47fb09ce..1c3e9621 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -475,13 +475,11 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] } } - if localHost != nil { - oldLocalhost, localHostExists := hostsById[localHost.HostId] - if localHostExists { - log.Warnf("Local host is also on the peers list: %v vs %v, ignoring the former one.", oldLocalhost, localHost) - } - hostsById[localHost.HostId] = localHost + oldLocalhost, localHostExists := hostsById[localHost.HostId] + if localHostExists { + log.Warnf("Local host is also on the peers list: %v vs %v, ignoring the former one.", oldLocalhost, localHost) } + hostsById[localHost.HostId] = localHost orderedLocalHosts := make([]*Host, 0, len(hostsById)) for _, h := range hostsById { orderedLocalHosts = append(orderedLocalHosts, h) @@ -491,11 +489,9 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] currentDc := cc.datacenter cc.topologyLock.RUnlock() - if localHost != nil { - orderedLocalHosts, currentDc, err = filterHosts(orderedLocalHosts, currentDc, cc.connConfig, localHost) - if err != nil { - return nil, err - } + orderedLocalHosts, currentDc, err = filterHosts(orderedLocalHosts, currentDc, cc.connConfig, localHost) + if err != nil { + return nil, err } sort.Slice(orderedLocalHosts, func(i, j int) bool { From 081bec0d7d027a5e90ca5e6ae9d04da09ded0d81 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 27 Jun 2024 14:16:27 +0200 Subject: [PATCH 10/29] Protocol V2 stubbed tests --- integration-tests/customhandler_test_utils.go | 50 +++++++-- integration-tests/protocolv2_test.go | 102 ++++++++++++++++++ 2 files changed, 142 insertions(+), 10 deletions(-) create mode 100644 integration-tests/protocolv2_test.go diff --git a/integration-tests/customhandler_test_utils.go b/integration-tests/customhandler_test_utils.go index 97fdc7ae..37de98ef 100644 --- a/integration-tests/customhandler_test_utils.go +++ b/integration-tests/customhandler_test_utils.go @@ -78,6 +78,19 @@ var systemLocalColumns = []*message.ColumnMetadata{ tokensColumn, } +var systemLocalColumnsV2 = []*message.ColumnMetadata{ + keyColumn, + clusterNameColumn, + cqlVersionColumn, + datacenterColumn, + hostIdColumn, + partitionerColumn, + rackColumn, + releaseVersionColumn, + schemaVersionColumn, + tokensColumn, +} + var ( peerColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "peer", Type: datatype.Inet} datacenterPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "data_center", Type: datatype.Varchar} @@ -112,13 +125,15 @@ var ( schemaVersionValue = message.Column{0xC0, 0xD1, 0xD2, 0x1E, 0xBB, 0x01, 0x41, 0x96, 0x86, 0xDB, 0xBC, 0x31, 0x7B, 0xC1, 0x79, 0x6A} ) -func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr net.Addr, version primitive.ProtocolVersion) message.Row { +func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr *net.Addr, version primitive.ProtocolVersion) message.Row { addrBuf := &bytes.Buffer{} - inetAddr := addr.(*net.TCPAddr).IP - if inetAddr.To4() != nil { - addrBuf.Write(inetAddr.To4()) - } else { - addrBuf.Write(inetAddr) + if addr != nil { + inetAddr := (*addr).(*net.TCPAddr).IP + if inetAddr.To4() != nil { + addrBuf.Write(inetAddr.To4()) + } else { + addrBuf.Write(inetAddr) + } } // emulate {'-9223372036854775808'} (entire ring) tokensBuf := &bytes.Buffer{} @@ -135,25 +150,40 @@ func systemLocalRow(cluster string, datacenter string, customPartitioner string, if customPartitioner != "" { partitionerValue = message.Column(customPartitioner) } + if addrBuf.Len() > 0 { + return message.Row{ + keyValue, + addrBuf.Bytes(), + message.Column(cluster), + cqlVersionValue, + message.Column(datacenter), + hostIdValue, + addrBuf.Bytes(), + partitionerValue, + rackValue, + releaseVersionValue, + addrBuf.Bytes(), + schemaVersionValue, + tokensBuf.Bytes(), + } + } return message.Row{ keyValue, - addrBuf.Bytes(), message.Column(cluster), cqlVersionValue, message.Column(datacenter), hostIdValue, - addrBuf.Bytes(), partitionerValue, rackValue, releaseVersionValue, - addrBuf.Bytes(), schemaVersionValue, tokensBuf.Bytes(), } } func fullSystemLocal(cluster string, datacenter string, customPartitioner string, request *frame.Frame, conn *client.CqlServerConnection) *frame.Frame { - systemLocalRow := systemLocalRow(cluster, datacenter, customPartitioner, conn.LocalAddr(), request.Header.Version) + localAddress := conn.LocalAddr() + systemLocalRow := systemLocalRow(cluster, datacenter, customPartitioner, &localAddress, request.Header.Version) msg := &message.RowsResult{ Metadata: &message.RowsMetadata{ ColumnCount: int32(len(systemLocalColumns)), diff --git a/integration-tests/protocolv2_test.go b/integration-tests/protocolv2_test.go new file mode 100644 index 00000000..8cba7509 --- /dev/null +++ b/integration-tests/protocolv2_test.go @@ -0,0 +1,102 @@ +package integration_tests + +import ( + "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/stretchr/testify/require" + "net" + "testing" +) + +func TestProtocolV2Basic(t *testing.T) { + originAddress := "127.0.0.2" + targetAddress := "127.0.0.3" + + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf.ControlConnMaxProtocolVersion = "2" + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewProtocolV2RequestHandler("origin", "dc1", "127.0.0.4") + targetRequestHandler := NewProtocolV2RequestHandler("target", "dc1", "127.0.0.5") + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, primitive.ProtocolVersion2) + require.Nil(t, err) + + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + + require.Nil(t, err) +} + +type ProtocolV2RequestHandler struct { + cluster string + datacenter string + peerIP string +} + +func NewProtocolV2RequestHandler(cluster string, datacenter string, peerIP string) *ProtocolV2RequestHandler { + return &ProtocolV2RequestHandler{ + cluster: cluster, + datacenter: datacenter, + peerIP: peerIP, + } +} + +func (recv *ProtocolV2RequestHandler) HandleRequest( + request *frame.Frame, + conn *client.CqlServerConnection, + ctx client.RequestHandlerContext) (response *frame.Frame) { + switch request.Body.Message.GetOpCode() { + case primitive.OpCodeStartup: + case primitive.OpCodeRegister: + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeQuery: + query := request.Body.Message.(*message.Query) + switch query.Query { + case "SELECT * FROM system.local": + // C* 2.0.0 does not store local endpoint details in system.local table + sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) + sysLocMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumnsV2)), + Columns: systemLocalColumnsV2, + }, + Data: message.RowSet{sysLocRow}, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) + case "SELECT * FROM system.peers": + sysPeerRow := systemPeersRow( + recv.datacenter, + &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, + primitive.ProtocolVersion2, + ) + sysPeeMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: int32(len(systemPeersColumns)), + Columns: systemPeersColumns, + }, + Data: message.RowSet{sysPeerRow}, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) + } + } + return nil +} From 65ce9a0243a63c8a674396ef7fb1375245470514 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 27 Jun 2024 16:13:58 +0200 Subject: [PATCH 11/29] Protocol V2 stubbed tests --- integration-tests/protocolv2_test.go | 85 +++++++++++++++++++++++++--- 1 file changed, 77 insertions(+), 8 deletions(-) diff --git a/integration-tests/protocolv2_test.go b/integration-tests/protocolv2_test.go index 8cba7509..65d1f460 100644 --- a/integration-tests/protocolv2_test.go +++ b/integration-tests/protocolv2_test.go @@ -1,7 +1,10 @@ package integration_tests import ( + "context" + "fmt" "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/datatype" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" @@ -11,13 +14,13 @@ import ( "testing" ) -func TestProtocolV2Basic(t *testing.T) { +func TestProtocolV2Connect(t *testing.T) { originAddress := "127.0.0.2" targetAddress := "127.0.0.3" serverConf := setup.NewTestConfig(originAddress, targetAddress) proxyConf := setup.NewTestConfig(originAddress, targetAddress) - proxyConf.ControlConnMaxProtocolVersion = "2" + proxyConf.ControlConnMaxProtocolVersion = "3" // simulate protocol downgrade to V2 testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) require.Nil(t, err) @@ -42,8 +45,51 @@ func TestProtocolV2Basic(t *testing.T) { if proxy != nil { defer proxy.Shutdown() } + require.Nil(t, err) +} + +func TestProtocolV2Query(t *testing.T) { + originAddress := "127.0.0.2" + targetAddress := "127.0.0.3" + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf.ControlConnMaxProtocolVersion = "2" + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewProtocolV2RequestHandler("origin", "dc1", "") + targetRequestHandler := NewProtocolV2RequestHandler("target", "dc1", "") + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, primitive.ProtocolVersion2) + require.Nil(t, err) + + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + require.Nil(t, err) + + cqlConn, err := testSetup.Client.CqlClient.Connect(context.Background()) + query := &message.Query{ + Query: "SELECT * FROM fakeks.faketb", + Options: &message.QueryOptions{Consistency: primitive.ConsistencyLevelOne}, + } + + response, err := cqlConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion2, 0, query)) + resultSet := response.Body.Message.(*message.RowsResult).Data + require.Equal(t, 1, len(resultSet)) } type ProtocolV2RequestHandler struct { @@ -66,6 +112,12 @@ func (recv *ProtocolV2RequestHandler) HandleRequest( ctx client.RequestHandlerContext) (response *frame.Frame) { switch request.Body.Message.GetOpCode() { case primitive.OpCodeStartup: + if request.Header.Version != primitive.ProtocolVersion2 { + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{ + ErrorMessage: fmt.Sprintf("Invalid or unsupported protocol version (%d)", request.Header.Version), + }) + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) case primitive.OpCodeRegister: return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) case primitive.OpCodeQuery: @@ -83,19 +135,36 @@ func (recv *ProtocolV2RequestHandler) HandleRequest( } return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) case "SELECT * FROM system.peers": - sysPeerRow := systemPeersRow( - recv.datacenter, - &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, - primitive.ProtocolVersion2, - ) + var sysPeerRows message.RowSet + if len(recv.peerIP) > 0 { + sysPeerRows = append(sysPeerRows, systemPeersRow( + recv.datacenter, + &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, + primitive.ProtocolVersion2, + )) + } sysPeeMsg := &message.RowsResult{ Metadata: &message.RowsMetadata{ ColumnCount: int32(len(systemPeersColumns)), Columns: systemPeersColumns, }, - Data: message.RowSet{sysPeerRow}, + Data: sysPeerRows, } return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) + case "SELECT * FROM fakeks.faketb": + sysLocMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: 2, + Columns: []*message.ColumnMetadata{ + {Keyspace: "fakeks", Table: "faketb", Name: "key", Type: datatype.Varchar}, + {Keyspace: "fakeks", Table: "faketb", Name: "value", Type: datatype.Uuid}, + }, + }, + Data: message.RowSet{ + message.Row{keyValue, hostIdValue}, + }, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) } } return nil From d8ba2db708e3f658c47a91dd17fcb0ce0e64a2de Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 27 Jun 2024 16:23:56 +0200 Subject: [PATCH 12/29] Update README --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3a243d28..a1512e37 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ containerized sandbox environment. ## Supported Protocol Versions -**ZDM Proxy supports protocol versions v3, v4, DSE_V1 and DSE_V2.** +**ZDM Proxy supports protocol versions v2, v3, v4, DSE_V1 and DSE_V2.** It technically doesn't support v5, but handles protocol negotiation so that the client application properly downgrades the protocol version to v4 if v5 is requested. This means that any client application using a recent driver that supports @@ -96,8 +96,7 @@ migration process. In practice this means that ZDM Proxy supports the following cluster versions (as Origin and / or Target): -- Apache Cassandra from 2.1+ up to (and including) Apache Cassandra 4.x. Apache Cassandra 2.0 support will be introduced -when protocol version v2 is supported. +- Apache Cassandra from 2.0+ up to (and including) Apache Cassandra 4.x. - DataStax Enterprise 4.8+. DataStax Enterprise 4.6 and 4.7 support will be introduced when protocol version v2 is supported. - DataStax Astra DB (both Serverless and Classic) From 58680ad6d7af83e5504f4678cd0b341a92f021ec Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 28 Jun 2024 07:53:33 +0200 Subject: [PATCH 13/29] Tidy dependencies --- go.sum | 55 +------------------------------------------------------ 1 file changed, 1 insertion(+), 54 deletions(-) diff --git a/go.sum b/go.sum index 8d3e4fdf..ff73502b 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,6 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20211106181442-e4c1a74c66bd h1:fjJY1LimH0wVCvOHLX35SCX/MbWomAglET1H2kvz7xc= github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20211106181442-e4c1a74c66bd/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= -github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 h1:yL7+Jz0jTC6yykIK/Wh74gnTJnrGr5AyrNMXuA0gves= -github.com/antlr/antlr4/runtime/Go/antlr v1.4.10/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -16,12 +14,7 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4Yn github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/datastax/go-cassandra-native-protocol v0.0.0-20220525125956-6158d9e218b8 h1:NKLtNzC76ssf68VOenDAzMyQGg+QkxuD2QCubX+GvLk= -github.com/datastax/go-cassandra-native-protocol v0.0.0-20220525125956-6158d9e218b8/go.mod h1:yFD0OKoVV9d1QW7Es58c1Gv6ijrqTGPcxgHv27wdC4Q= github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d h1:UnPtAA8Ux3GvHLazSSUydERFuoQRyxHrB8puzXyjXIE= github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d/go.mod h1:6FzirJfdffakAVqmHjwVfFkpru/gNbIazUOK5rIhndc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -34,31 +27,20 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e h1:SroDcndcOU9BVAduPf/PXihXoR2ZYTQYLXbupbqxAyQ= github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= -github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU= -github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= -github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= @@ -75,17 +57,9 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFB github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -95,59 +69,41 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/pierrec/lz4/v4 v4.0.3 h1:vNQKSVZNYUEAvRY9FaUXAF1XPbSOHJtDTiP41kzDz2E= github.com/pierrec/lz4/v4 v4.0.3/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.3.0 h1:miYCvYqFXtl/J9FIy8eNpBfYthAEFg+Ys0XyUVEcDsc= github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= -github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= -github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.1.0 h1:ElTg5tNp4DqfV7UQjDqv2+RJlNzsDtvNAWccbItceIE= github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= -github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.7.0 h1:L+1lyG48J1zAQXA3RBX/nG/B3gjlHq0zTt2tlbJLyCY= github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= -github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.8 h1:+fpWZdT24pJBiqJdAwYBjPSk+5YmQzYNPYzQsdzLkt8= github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.20.0 h1:38k9hgtUBdxFwE34yS8rTHmHBa4eN16E4DJlv177LNs= github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= -github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= -github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -162,24 +118,15 @@ golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From 7247e6446bbe697deedb393960e0a4916d0d4cc5 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 28 Jun 2024 08:37:20 +0200 Subject: [PATCH 14/29] Apply review comments --- integration-tests/customhandler_test_utils.go | 14 +++++---- integration-tests/protocolv2_test.go | 4 +-- proxy/pkg/config/config.go | 19 ++++++------ proxy/pkg/config/config_test.go | 29 ++++++++++++------- proxy/pkg/zdmproxy/controlconn.go | 10 +++---- proxy/pkg/zdmproxy/cqlconn.go | 13 ++------- 6 files changed, 45 insertions(+), 44 deletions(-) diff --git a/integration-tests/customhandler_test_utils.go b/integration-tests/customhandler_test_utils.go index 37de98ef..7101ca27 100644 --- a/integration-tests/customhandler_test_utils.go +++ b/integration-tests/customhandler_test_utils.go @@ -78,7 +78,10 @@ var systemLocalColumns = []*message.ColumnMetadata{ tokensColumn, } -var systemLocalColumnsV2 = []*message.ColumnMetadata{ +// These columns are a subset of the total columns returned by OSS C* 2.0.0, and contain all the information that +// drivers need in order to establish the cluster topology and determine its characteristics. Please note that RPC address +// column is not present. +var systemLocalColumnsProtocolV2 = []*message.ColumnMetadata{ keyColumn, clusterNameColumn, cqlVersionColumn, @@ -125,10 +128,10 @@ var ( schemaVersionValue = message.Column{0xC0, 0xD1, 0xD2, 0x1E, 0xBB, 0x01, 0x41, 0x96, 0x86, 0xDB, 0xBC, 0x31, 0x7B, 0xC1, 0x79, 0x6A} ) -func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr *net.Addr, version primitive.ProtocolVersion) message.Row { +func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr net.Addr, version primitive.ProtocolVersion) message.Row { addrBuf := &bytes.Buffer{} if addr != nil { - inetAddr := (*addr).(*net.TCPAddr).IP + inetAddr := addr.(*net.TCPAddr).IP if inetAddr.To4() != nil { addrBuf.Write(inetAddr.To4()) } else { @@ -150,7 +153,7 @@ func systemLocalRow(cluster string, datacenter string, customPartitioner string, if customPartitioner != "" { partitionerValue = message.Column(customPartitioner) } - if addrBuf.Len() > 0 { + if version >= primitive.ProtocolVersion3 { return message.Row{ keyValue, addrBuf.Bytes(), @@ -182,8 +185,7 @@ func systemLocalRow(cluster string, datacenter string, customPartitioner string, } func fullSystemLocal(cluster string, datacenter string, customPartitioner string, request *frame.Frame, conn *client.CqlServerConnection) *frame.Frame { - localAddress := conn.LocalAddr() - systemLocalRow := systemLocalRow(cluster, datacenter, customPartitioner, &localAddress, request.Header.Version) + systemLocalRow := systemLocalRow(cluster, datacenter, customPartitioner, conn.LocalAddr(), request.Header.Version) msg := &message.RowsResult{ Metadata: &message.RowsMetadata{ ColumnCount: int32(len(systemLocalColumns)), diff --git a/integration-tests/protocolv2_test.go b/integration-tests/protocolv2_test.go index 65d1f460..5b0f57f9 100644 --- a/integration-tests/protocolv2_test.go +++ b/integration-tests/protocolv2_test.go @@ -128,8 +128,8 @@ func (recv *ProtocolV2RequestHandler) HandleRequest( sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) sysLocMsg := &message.RowsResult{ Metadata: &message.RowsMetadata{ - ColumnCount: int32(len(systemLocalColumnsV2)), - Columns: systemLocalColumnsV2, + ColumnCount: int32(len(systemLocalColumnsProtocolV2)), + Columns: systemLocalColumnsProtocolV2, }, Data: message.RowSet{sysLocRow}, } diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index a249cbf8..9d602708 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -3,6 +3,7 @@ package config import ( "encoding/json" "fmt" + "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/datastax/zdm-proxy/proxy/pkg/common" "github.com/kelseyhightower/envconfig" log "github.com/sirupsen/logrus" @@ -342,22 +343,22 @@ func (c *Config) ParseReadMode() (common.ReadMode, error) { } } -func (c *Config) ParseControlConnMaxProtocolVersion() (uint, error) { - switch c.ControlConnMaxProtocolVersion { - case "Dse2": - return 0b_1_000010, nil - case "Dse1": - return 0b_1_000001, nil +func (c *Config) ParseControlConnMaxProtocolVersion() (primitive.ProtocolVersion, error) { + if strings.EqualFold(c.ControlConnMaxProtocolVersion, "DseV2") { + return primitive.ProtocolVersionDse2, nil + } + if strings.EqualFold(c.ControlConnMaxProtocolVersion, "DseV1") { + return primitive.ProtocolVersionDse1, nil } ver, err := strconv.ParseUint(c.ControlConnMaxProtocolVersion, 10, 32) if err != nil { return 0, fmt.Errorf("could not parse control connection max protocol version, valid values are "+ - "2, 3, 4, Dse1, Dse2; original err: %w", err) + "2, 3, 4, DseV1, DseV2; original err: %w", err) } if ver < 2 || ver > 4 { - return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2") + return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2") } - return uint(ver), nil + return primitive.ProtocolVersion(ver), nil } func (c *Config) ParseLogLevel() (log.Level, error) { diff --git a/proxy/pkg/config/config_test.go b/proxy/pkg/config/config_test.go index 6da7c431..cea6ce48 100644 --- a/proxy/pkg/config/config_test.go +++ b/proxy/pkg/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/stretchr/testify/require" "testing" ) @@ -111,56 +112,62 @@ func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) { tests := []struct { name string controlConnMaxProtocolVersion string - parsedProtocolVersion uint + parsedProtocolVersion primitive.ProtocolVersion errorMessage string }{ { name: "ParsedV2", controlConnMaxProtocolVersion: "2", - parsedProtocolVersion: 2, + parsedProtocolVersion: primitive.ProtocolVersion2, errorMessage: "", }, { name: "ParsedV3", controlConnMaxProtocolVersion: "3", - parsedProtocolVersion: 3, + parsedProtocolVersion: primitive.ProtocolVersion3, errorMessage: "", }, { name: "ParsedV4", controlConnMaxProtocolVersion: "4", - parsedProtocolVersion: 4, + parsedProtocolVersion: primitive.ProtocolVersion4, errorMessage: "", }, { name: "ParsedDse1", - controlConnMaxProtocolVersion: "Dse1", - parsedProtocolVersion: 65, + controlConnMaxProtocolVersion: "DseV1", + parsedProtocolVersion: primitive.ProtocolVersionDse1, errorMessage: "", }, { name: "ParsedDse2", - controlConnMaxProtocolVersion: "Dse2", - parsedProtocolVersion: 66, + controlConnMaxProtocolVersion: "DseV2", + parsedProtocolVersion: primitive.ProtocolVersionDse2, + errorMessage: "", + }, + { + name: "ParsedDse2CaseInsensitive", + controlConnMaxProtocolVersion: "dsev2", + parsedProtocolVersion: primitive.ProtocolVersionDse2, errorMessage: "", }, { name: "UnsupportedCassandraV5", controlConnMaxProtocolVersion: "5", parsedProtocolVersion: 0, - errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", }, { name: "UnsupportedCassandraV1", controlConnMaxProtocolVersion: "1", parsedProtocolVersion: 0, - errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", }, { name: "InvalidValue", controlConnMaxProtocolVersion: "Dsev123", parsedProtocolVersion: 0, - errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", }, } diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index 1c3e9621..4a2bc45c 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -361,16 +361,16 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( } conn = newConn - log.Infof("Successfully opened control connection to %v using endpoint %v.", - cc.connConfig.GetClusterType(), endpoint.String()) + log.Infof("Successfully opened control connection to %v using endpoint %v with %v.", + cc.connConfig.GetClusterType(), endpoint.String(), newConn.GetProtocolVersion().Load().(primitive.ProtocolVersion)) break } return conn, endpoint } -func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer uint, ctx context.Context) (CqlConnection, error) { - protoVer := primitive.ProtocolVersion(initialProtoVer) +func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer primitive.ProtocolVersion, ctx context.Context) (CqlConnection, error) { + protoVer := initialProtoVer for { tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false) if err != nil { @@ -389,7 +389,7 @@ func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoV log.Errorf("Failed to close cql connection: %v", err2) } protoVer = downgradeProtocol(protoVer) - log.Infof("Downgrading protocol version: %v", protoVer) + log.Debugf("Downgrading protocol version: %v", protoVer) if protoVer == 0 { // we cannot downgrade anymore return nil, err diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 929ec46d..1c26651d 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -19,8 +19,7 @@ import ( ) const ( - eventQueueLength = 2048 - eventQueueLengthV2 = 128 + eventQueueLength = 2048 maxIncomingPending = 2048 maxIncomingPendingV2 = 128 @@ -102,7 +101,7 @@ func NewCqlConnection( wg: &sync.WaitGroup{}, // protoVer is the proposed protocol version using which we will try to establish connectivity outgoingCh: make(chan *frame.Frame, maxOutgoingPendingRequests(protoVer)), - eventsQueue: make(chan *frame.Frame, maxEventsQueue(protoVer)), + eventsQueue: make(chan *frame.Frame, eventQueueLength), pendingOperations: make(map[int16]chan *frame.Frame), pendingOperationsLock: &sync.RWMutex{}, timedOutOperations: 0, @@ -126,14 +125,6 @@ func maxOutgoingPendingRequests(protocolVersion primitive.ProtocolVersion) int { return maxOutgoingPending } -func maxEventsQueue(protocolVersion primitive.ProtocolVersion) int { - switch protocolVersion { - case primitive.ProtocolVersion2: - return eventQueueLengthV2 - } - return eventQueueLength -} - func (c *cqlConn) SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) { c.eventHandlerLock.Lock() defer c.eventHandlerLock.Unlock() From 5efe8c2d14ed3636d0540ef71aacdaca034e3181 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 28 Jun 2024 09:28:48 +0200 Subject: [PATCH 15/29] Apply review comments --- integration-tests/prepared_statements_test.go | 52 ++++++++++++++++--- 1 file changed, 46 insertions(+), 6 deletions(-) diff --git a/integration-tests/prepared_statements_test.go b/integration-tests/prepared_statements_test.go index 6e62a5e6..53a50294 100644 --- a/integration-tests/prepared_statements_test.go +++ b/integration-tests/prepared_statements_test.go @@ -524,16 +524,16 @@ func TestPreparedIdReplacement(t *testing.T) { require.Equal(t, expectedBatchPrepareMsg, originPrepareMessages[1]) if test.expectedBatchPreparedStmtVariables != nil { - require.NotEqual(t, batchMsg.Children[0].Query, originBatchMessages[0].Children[0].Query) - require.NotEqual(t, batchMsg.Children[0].Query, targetBatchMessages[0].Children[0].Query) - require.Equal(t, originBatchMessages[0].Children[0].Query, targetBatchMessages[0].Children[0].Query) + batchChildNotEqual(t, batchMsg.Children[0], originBatchMessages[0].Children[0]) + batchChildNotEqual(t, batchMsg.Children[0], targetBatchMessages[0].Children[0]) + batchChildEqual(t, originBatchMessages[0].Children[0], targetBatchMessages[0].Children[0]) require.Equal(t, 0, len(targetBatchMessages[0].Children[0].Values)) require.Equal(t, 0, len(originBatchMessages[0].Children[0].Values)) require.Equal(t, 0, len(batchMsg.Children[0].Values)) - require.Equal(t, batchMsg.Children[1].Query, originBatchMessages[0].Children[1].Query) - require.NotEqual(t, batchMsg.Children[1].Id, targetBatchMessages[0].Children[1].Id) - require.NotEqual(t, originBatchMessages[0].Children[1].Id, targetBatchMessages[0].Children[1].Id) + batchChildEqual(t, batchMsg.Children[1], originBatchMessages[0].Children[1]) + batchChildNotEqual(t, batchMsg.Children[1], targetBatchMessages[0].Children[1]) + batchChildNotEqual(t, originBatchMessages[0].Children[1], targetBatchMessages[0].Children[1]) require.Equal(t, targetBatchPreparedId, targetBatchMessages[0].Children[1].Id) require.Equal(t, originBatchPreparedId, originBatchMessages[0].Children[1].Id) require.Equal(t, originBatchPreparedId, batchMsg.Children[1].Id) @@ -555,6 +555,46 @@ func TestPreparedIdReplacement(t *testing.T) { } } +func batchChildEqual(t *testing.T, child1 *message.BatchChild, child2 *message.BatchChild) { + id := false + if child1.Id != nil && child2.Id != nil { + id = true + require.Equal(t, child1.Id, child2.Id) + } else if child1.Id != nil || child2.Id != nil { + require.Fail(t, "unexpected id field presence: [%v], [%v]", child1.Id, child2.Id) + } + + query := false + if len(child1.Query) > 0 && len(child2.Query) > 0 { + query = true + require.Equal(t, child1.Query, child2.Query) + } else if len(child1.Query) > 0 || len(child2.Query) > 0 { + require.Fail(t, "unexpected query field presence: [%v], [%v]", child1.Query, child2.Query) + } + + require.True(t, id || query, "id or query fields should be present") +} + +func batchChildNotEqual(t *testing.T, child1 *message.BatchChild, child2 *message.BatchChild) { + id := false + if child1.Id != nil && child2.Id != nil { + id = true + require.NotEqual(t, child1.Id, child2.Id) + } else if child1.Id != nil || child2.Id != nil { + require.Fail(t, "unexpected query field presence: [%v], [%v]", child1.Id, child2.Id) + } + + query := false + if len(child1.Query) > 0 && len(child2.Query) > 0 { + query = true + require.NotEqual(t, child1.Query, child2.Query) + } else if len(child1.Query) > 0 || len(child2.Query) > 0 { + require.Fail(t, "unexpected query field presence: [%v], [%v]", child1.Query, child2.Query) + } + + require.True(t, id || query, "id or query fields should be present") +} + func TestUnpreparedIdReplacement(t *testing.T) { type test struct { name string From f8e95295bac857f045b30ef22adcfcb1a5824ce1 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 11 Jul 2024 10:48:51 +0200 Subject: [PATCH 16/29] Limit number of maximum stream IDs --- integration-tests/connect_test.go | 47 +++++++++++++++--------- integration-tests/utils/testutils.go | 9 +---- proxy/pkg/config/config.go | 2 +- proxy/pkg/zdmproxy/clienthandler.go | 36 +++++++++++++++--- proxy/pkg/zdmproxy/clienthandler_test.go | 34 +++++++++++++++++ proxy/pkg/zdmproxy/cqlconn.go | 24 ++++-------- 6 files changed, 105 insertions(+), 47 deletions(-) create mode 100644 proxy/pkg/zdmproxy/clienthandler_test.go diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index a066eed8..b5d6e217 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -3,7 +3,7 @@ package integration_tests import ( "bytes" "context" - client2 "github.com/datastax/go-cassandra-native-protocol/client" + cqlClient "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" @@ -82,27 +82,40 @@ func TestProtocolVersionNegotiation(t *testing.T) { require.Nil(t, err) defer testSetup.Cleanup() - // Connect to proxy as a "client" - proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3) + query := "SELECT * FROM test" + expectedRows := simulacron.NewRowsResult( + map[string]simulacron.DataType{ + "company": simulacron.DataTypeText, + }).WithRow(map[string]interface{}{ + "company": "TBD", + }) - if err != nil { - t.Fatal("Unable to connect to proxy session.") - } - defer proxy.Close() + err = testSetup.Origin.Prime(simulacron.WhenQuery( + query, + simulacron.NewWhenQueryOptions()). + ThenRowsSuccess(expectedRows)) + require.Nil(t, err) + + // Connect to proxy as a "client" + client := cqlClient.NewCqlClient("127.0.0.1:14002", nil) + cqlClientConn, err := client.ConnectAndInit(context.Background(), tt.negotiatedProtocolVersion, 0) + require.Nil(t, err) + defer cqlClientConn.Close() cqlConn, _ := testSetup.Proxy.GetOriginControlConn().GetConnAndContactPoint() negotiatedProto := cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) - require.Equal(t, tt.negotiatedProtocolVersion, negotiatedProto) - iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter() - result, err := iter.SliceMap() - + queryMsg := &message.Query{ + Query: "SELECT * FROM test", + Options: &message.QueryOptions{Consistency: primitive.ConsistencyLevelOne}, + } + rsp, err := cqlClientConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion3, 0, queryMsg)) if err != nil { t.Fatal("query failed:", err) } - require.Equal(t, 0, len(result)) + require.Equal(t, 1, len(rsp.Body.Message.(*message.RowsResult).Data)) }) } } @@ -179,8 +192,8 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) { require.Nil(t, err) defer testSetup.Cleanup() - testSetup.Origin.CqlServer.RequestHandlers = []client2.RequestHandler{client2.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {})} - testSetup.Target.CqlServer.RequestHandlers = []client2.RequestHandler{client2.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {})} + testSetup.Origin.CqlServer.RequestHandlers = []cqlClient.RequestHandler{cqlClient.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {})} + testSetup.Target.CqlServer.RequestHandlers = []cqlClient.RequestHandler{cqlClient.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {})} err = testSetup.Start(cfg, false, primitive.ProtocolVersion3) require.Nil(t, err) @@ -234,7 +247,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { enableHandlers := atomic.Value{} enableHandlers.Store(false) - rawHandler := func(request *frame.Frame, conn *client2.CqlServerConnection, ctx client2.RequestHandlerContext) (response []byte) { + rawHandler := func(request *frame.Frame, conn *cqlClient.CqlServerConnection, ctx cqlClient.RequestHandlerContext) (response []byte) { if enableHandlers.Load().(bool) && request.Header.Version == test.requestVersion { encodedFrame, err := createFrameWithUnsupportedVersion(test.returnedVersion, request.Header.StreamId, true) if err != nil { @@ -246,8 +259,8 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { return nil } - testSetup.Origin.CqlServer.RequestRawHandlers = []client2.RawRequestHandler{rawHandler} - testSetup.Target.CqlServer.RequestRawHandlers = []client2.RawRequestHandler{rawHandler} + testSetup.Origin.CqlServer.RequestRawHandlers = []cqlClient.RawRequestHandler{rawHandler} + testSetup.Target.CqlServer.RequestRawHandlers = []cqlClient.RawRequestHandler{rawHandler} err = testSetup.Start(cfg, false, primitive.ProtocolVersion4) require.Nil(t, err) diff --git a/integration-tests/utils/testutils.go b/integration-tests/utils/testutils.go index e0ca5edd..2c050ecd 100644 --- a/integration-tests/utils/testutils.go +++ b/integration-tests/utils/testutils.go @@ -116,9 +116,9 @@ func CheckMetricsEndpointResult(httpAddr string, success bool) error { return nil } -func ConnectToClusterUsingVersion(hostname string, username string, password string, port int, protoVersion int) (*gocql.Session, error) { +// ConnectToCluster is used to connect to source and destination clusters +func ConnectToCluster(hostname string, username string, password string, port int) (*gocql.Session, error) { cluster := NewCluster(hostname, username, password, port) - cluster.ProtoVersion = protoVersion session, err := cluster.CreateSession() log.Debugf("Connection established with Cluster: %s:%d", cluster.Hosts[0], cluster.Port) if err != nil { @@ -127,11 +127,6 @@ func ConnectToClusterUsingVersion(hostname string, username string, password str return session, nil } -// ConnectToCluster is used to connect to source and destination clusters -func ConnectToCluster(hostname string, username string, password string, port int) (*gocql.Session, error) { - return ConnectToClusterUsingVersion(hostname, username, password, port, 4) -} - // NewCluster initializes a ClusterConfig object with common settings func NewCluster(hostname string, username string, password string, port int) *gocql.ClusterConfig { cluster := gocql.NewCluster(hostname) diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index 9d602708..02f9f11d 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -22,7 +22,7 @@ type Config struct { ReplaceCqlFunctions bool `default:"false" split_words:"true"` AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"` LogLevel string `default:"DEBUG" split_words:"true"` - ControlConnMaxProtocolVersion string `default:"3" split_words:"true"` // Numeric Cassandra OSS protocol version or Dse1 / Dse2 + ControlConnMaxProtocolVersion string `default:"3" split_words:"true"` // Numeric Cassandra OSS protocol version or DseV1 / DseV2 // Proxy Topology (also known as system.peers "virtualization") bucket diff --git a/proxy/pkg/zdmproxy/clienthandler.go b/proxy/pkg/zdmproxy/clienthandler.go index cedf7c83..30b59bc1 100644 --- a/proxy/pkg/zdmproxy/clienthandler.go +++ b/proxy/pkg/zdmproxy/clienthandler.go @@ -168,9 +168,12 @@ func NewClientHandler( requestsDoneCtx, requestsDoneCancelFn := context.WithCancel(context.Background()) // Initialize stream id processors to manage the ids sent to the clusters - originFrameProcessor := newFrameProcessor(conf, nodeMetrics, ClusterConnectorTypeOrigin) - targetFrameProcessor := newFrameProcessor(conf, nodeMetrics, ClusterConnectorTypeTarget) - asyncFrameProcessor := newFrameProcessor(conf, nodeMetrics, ClusterConnectorTypeAsync) + originCCProtoVer := originControlConn.cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) + targetCCProtoVer := targetControlConn.cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) + streamIds := maxStreamIds(originCCProtoVer, targetCCProtoVer, conf) + originFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeOrigin) + targetFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeTarget) + asyncFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeAsync) closeFrameProcessors := func() { originFrameProcessor.Close() @@ -2230,7 +2233,7 @@ func GetNodeMetricsByClusterConnector(nodeMetrics *metrics.NodeMetrics, connecto } } -func newFrameProcessor(conf *config.Config, nodeMetrics *metrics.NodeMetrics, connectorType ClusterConnectorType) FrameProcessor { +func newFrameProcessor(maxStreamIds int, nodeMetrics *metrics.NodeMetrics, connectorType ClusterConnectorType) FrameProcessor { var streamIdsMetric metrics.Gauge connectorMetrics, err := GetNodeMetricsByClusterConnector(nodeMetrics, connectorType) if err != nil { @@ -2241,9 +2244,30 @@ func newFrameProcessor(conf *config.Config, nodeMetrics *metrics.NodeMetrics, co } var mapper StreamIdMapper if connectorType == ClusterConnectorTypeAsync { - mapper = NewInternalStreamIdMapper(conf.ProxyMaxStreamIds, streamIdsMetric) + mapper = NewInternalStreamIdMapper(maxStreamIds, streamIdsMetric) } else { - mapper = NewStreamIdMapper(conf.ProxyMaxStreamIds, streamIdsMetric) + mapper = NewStreamIdMapper(maxStreamIds, streamIdsMetric) } return NewStreamIdProcessor(mapper) } + +// Calculate maximum number of stream IDs. Take the oldest protocol version negotiated between two clusters +// and apply limit defined in proxy configuration. If origin or target cluster are still running protocol V2, +// we will limit maximum number of stream IDs to 128 on both clusters. Logic is based on Java driver version 3.x. +// Java driver 3.x was the last one supporting protocol V2. It establishes control connection first, and then +// uses negotiated protocol version to configure maximum number of stream IDs on node connections. Driver does NOT +// change the number of stream IDs on per node basis. +func maxStreamIds(originProtoVer primitive.ProtocolVersion, targetProtoVer primitive.ProtocolVersion, conf *config.Config) int { + maxSupported := maxOutgoingPending + protoVer := originProtoVer + if targetProtoVer < originProtoVer { + protoVer = targetProtoVer + } + if protoVer == primitive.ProtocolVersion2 { + maxSupported = maxOutgoingPendingV2 + } + if maxSupported < conf.ProxyMaxStreamIds { + return maxSupported + } + return conf.ProxyMaxStreamIds +} diff --git a/proxy/pkg/zdmproxy/clienthandler_test.go b/proxy/pkg/zdmproxy/clienthandler_test.go new file mode 100644 index 00000000..f24c2817 --- /dev/null +++ b/proxy/pkg/zdmproxy/clienthandler_test.go @@ -0,0 +1,34 @@ +package zdmproxy + +import ( + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/proxy/pkg/config" + "github.com/stretchr/testify/require" + "testing" +) + +func TestMaxStreamIds(t *testing.T) { + type args struct { + originProtoVer primitive.ProtocolVersion + targetProtoVer primitive.ProtocolVersion + config *config.Config + expectedMaxStreamIds int + } + tests := []struct { + name string + args args + expectedMaxStreamIds int + }{ + {"OriginV3_TargetV3_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion3, targetProtoVer: primitive.ProtocolVersion3, config: &config.Config{ProxyMaxStreamIds: 2048}}, 2048}, + {"OriginV3_TargetV4_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion3, targetProtoVer: primitive.ProtocolVersion4, config: &config.Config{ProxyMaxStreamIds: 2048}}, 2048}, + {"OriginV3_TargetV4_LowerConfig", args{originProtoVer: primitive.ProtocolVersion3, targetProtoVer: primitive.ProtocolVersion4, config: &config.Config{ProxyMaxStreamIds: 1024}}, 1024}, + {"OriginV2_TargetV3_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion2, targetProtoVer: primitive.ProtocolVersion3, config: &config.Config{ProxyMaxStreamIds: 2048}}, 128}, + {"OriginV2_TargetV2_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion2, targetProtoVer: primitive.ProtocolVersion2, config: &config.Config{ProxyMaxStreamIds: 2048}}, 128}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ids := maxStreamIds(tt.args.originProtoVer, tt.args.targetProtoVer, tt.args.config) + require.Equal(t, tt.expectedMaxStreamIds, ids) + }) + } +} diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 1c26651d..a9fdabb1 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -95,12 +95,11 @@ func NewCqlConnection( Username: username, Password: password, }, - initialized: false, - ctx: ctx, - cancelFn: cFn, - wg: &sync.WaitGroup{}, - // protoVer is the proposed protocol version using which we will try to establish connectivity - outgoingCh: make(chan *frame.Frame, maxOutgoingPendingRequests(protoVer)), + initialized: false, + ctx: ctx, + cancelFn: cFn, + wg: &sync.WaitGroup{}, + outgoingCh: make(chan *frame.Frame, maxOutgoingPending), eventsQueue: make(chan *frame.Frame, eventQueueLength), pendingOperations: make(map[int16]chan *frame.Frame), pendingOperationsLock: &sync.RWMutex{}, @@ -108,8 +107,9 @@ func NewCqlConnection( closed: false, eventHandlerLock: &sync.Mutex{}, authEnabled: true, - frameProcessor: NewStreamIdProcessor(NewInternalStreamIdMapper(conf.ProxyMaxStreamIds, nil)), - protocolVersion: &atomic.Value{}, + // protoVer is the proposed protocol version using which we will try to establish connectivity + frameProcessor: NewStreamIdProcessor(NewInternalStreamIdMapper(maxStreamIds(protoVer, protoVer, conf), nil)), + protocolVersion: &atomic.Value{}, } cqlConn.StartRequestLoop() cqlConn.StartResponseLoop() @@ -117,14 +117,6 @@ func NewCqlConnection( return cqlConn } -func maxOutgoingPendingRequests(protocolVersion primitive.ProtocolVersion) int { - switch protocolVersion { - case primitive.ProtocolVersion2: - return maxOutgoingPendingV2 - } - return maxOutgoingPending -} - func (c *cqlConn) SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) { c.eventHandlerLock.Lock() defer c.eventHandlerLock.Unlock() From f94e0373573b88dfe507accc8c496f59e8d1f85d Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 11 Jul 2024 11:13:47 +0200 Subject: [PATCH 17/29] Fix merge issues --- go.mod | 2 +- go.sum | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/go.mod b/go.mod index d433b080..28211350 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,6 @@ require ( github.com/rs/zerolog v1.20.0 github.com/sirupsen/logrus v1.6.0 github.com/stretchr/testify v1.8.0 - gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -35,4 +34,5 @@ require ( github.com/prometheus/procfs v0.0.8 // indirect golang.org/x/sys v0.3.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 170f51a0..387403a2 100644 --- a/go.sum +++ b/go.sum @@ -130,7 +130,6 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= From a5389e05fa266b05d4d52e747c61b74e7372f3d1 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 11 Jul 2024 11:17:55 +0200 Subject: [PATCH 18/29] Fix merge issues --- proxy/pkg/config/config.go | 1 + 1 file changed, 1 insertion(+) diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index 4fd2fa1b..72e35ceb 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -8,6 +8,7 @@ import ( "github.com/kelseyhightower/envconfig" def "github.com/mcuadros/go-defaults" log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" "net" "os" "strconv" From ce7179c1dce339b67e37910b727dd10b94de88be Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 11 Jul 2024 11:24:37 +0200 Subject: [PATCH 19/29] Fix merge issues --- proxy/pkg/config/config_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy/pkg/config/config_test.go b/proxy/pkg/config/config_test.go index 8c713e58..3e5ab34f 100644 --- a/proxy/pkg/config/config_test.go +++ b/proxy/pkg/config/config_test.go @@ -107,7 +107,7 @@ func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) { // test-specific setup setTargetContactPointsAndPortEnvVars() - conf, _ := New().ParseEnvVars() + conf := New().parseEnvVars() tests := []struct { name string From c2e46f1c5b30bdadd3cf752395c47734eae8795e Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 11 Jul 2024 11:29:39 +0200 Subject: [PATCH 20/29] Fix merge issues --- proxy/pkg/config/config_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/proxy/pkg/config/config_test.go b/proxy/pkg/config/config_test.go index 3e5ab34f..74eaa557 100644 --- a/proxy/pkg/config/config_test.go +++ b/proxy/pkg/config/config_test.go @@ -107,7 +107,9 @@ func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) { // test-specific setup setTargetContactPointsAndPortEnvVars() - conf := New().parseEnvVars() + conf := New() + err := conf.parseEnvVars() + require.Nil(t, err) tests := []struct { name string From 4bd37851aab796b44ddaebac3b00304b14cca26e Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 17 Jul 2024 15:50:54 +0200 Subject: [PATCH 21/29] New maximum stream IDs test --- README.md | 15 +- integration-tests/connect_test.go | 16 +- integration-tests/protocolv2_test.go | 171 -------------- integration-tests/protocolversions_test.go | 245 +++++++++++++++++++++ integration-tests/streamids_test.go | 210 ++++++++++++++++++ proxy/pkg/zdmproxy/clienthandler.go | 26 ++- proxy/pkg/zdmproxy/clusterconn.go | 5 +- proxy/pkg/zdmproxy/controlconn.go | 2 +- proxy/pkg/zdmproxy/cqlconn.go | 15 +- 9 files changed, 508 insertions(+), 197 deletions(-) delete mode 100644 integration-tests/protocolv2_test.go create mode 100644 integration-tests/protocolversions_test.go create mode 100644 integration-tests/streamids_test.go diff --git a/README.md b/README.md index a73ea1b6..723c005d 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,18 @@ It technically doesn't support v5, but handles protocol negotiation so that the the protocol version to v4 if v5 is requested. This means that any client application using a recent driver that supports protocol version v5 can be migrated using the ZDM Proxy (as long as it does not use v5-specific functionality). +ZDM Proxy requires origin and target clusters to have at least one protocol version in common. It is therefore not feasible +to configure Apache Cassandra 2.0 as origin and 3.x / 4.x as target. Below table displays protocol versions supported by +various C* versions: + +| Apache Cassandra | Protocol Version | +|------------------|------------------| +| 2.0 | V2 | +| 2.1 | V2, V3 | +| 2.2 | V3, V4 | +| 3.x | V3, V4 | +| 4.x | V3, V4, V5 | + --- :warning: **Thrift is not supported by ZDM Proxy.** If you are using a very old driver or cluster version that only supports Thrift then you need to change your client application to use CQL and potentially upgrade your cluster before starting the @@ -110,7 +122,8 @@ migration process. In practice this means that ZDM Proxy supports the following cluster versions (as Origin and / or Target): -- Apache Cassandra from 2.0+ up to (and including) Apache Cassandra 4.x. +- Apache Cassandra from 2.1+ up to (and including) Apache Cassandra 4.x. +- Apache Cassandra 2.0 up to 2.1. - DataStax Enterprise 4.8+. DataStax Enterprise 4.6 and 4.7 support will be introduced when protocol version v2 is supported. - DataStax Astra DB (both Serverless and Classic) diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index b5d6e217..8b31103d 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -46,7 +46,9 @@ func TestGoCqlConnect(t *testing.T) { require.Equal(t, "fake", iter.Columns()[0].Name) } -func TestProtocolVersionNegotiation(t *testing.T) { +// Simulacron-based test to make sure that we can handle invalid protocol error and downgrade +// used protocol on control connection. ORIGIN and TARGET are using the same C* version +func TestControlConnectionProtocolVersionNegotiation(t *testing.T) { tests := []struct { name string clusterVersion string @@ -63,13 +65,13 @@ func TestProtocolVersionNegotiation(t *testing.T) { name: "Cluster3.0_MaxCCProtoVer4_NegotiatedProtoVer4", clusterVersion: "3.0", controlConnMaxProtocolVersion: "4", - negotiatedProtocolVersion: primitive.ProtocolVersion4, + negotiatedProtocolVersion: primitive.ProtocolVersion4, // make sure that protocol negotiation does not fail if it is not actually needed }, { - name: "Cluster4.0_MaxCCProtoVer4_NegotiatedProtoVer4", - clusterVersion: "4.0", - controlConnMaxProtocolVersion: "4", - negotiatedProtocolVersion: primitive.ProtocolVersion4, + name: "Cluster3.0_MaxCCProtoVer3_NegotiatedProtoVer3", + clusterVersion: "3.0", + controlConnMaxProtocolVersion: "3", + negotiatedProtocolVersion: primitive.ProtocolVersion3, // protocol V3 applied as it is the maximum configured }, } @@ -103,7 +105,7 @@ func TestProtocolVersionNegotiation(t *testing.T) { defer cqlClientConn.Close() cqlConn, _ := testSetup.Proxy.GetOriginControlConn().GetConnAndContactPoint() - negotiatedProto := cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) + negotiatedProto := cqlConn.GetProtocolVersion() require.Equal(t, tt.negotiatedProtocolVersion, negotiatedProto) queryMsg := &message.Query{ diff --git a/integration-tests/protocolv2_test.go b/integration-tests/protocolv2_test.go deleted file mode 100644 index 5b0f57f9..00000000 --- a/integration-tests/protocolv2_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package integration_tests - -import ( - "context" - "fmt" - "github.com/datastax/go-cassandra-native-protocol/client" - "github.com/datastax/go-cassandra-native-protocol/datatype" - "github.com/datastax/go-cassandra-native-protocol/frame" - "github.com/datastax/go-cassandra-native-protocol/message" - "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/integration-tests/setup" - "github.com/stretchr/testify/require" - "net" - "testing" -) - -func TestProtocolV2Connect(t *testing.T) { - originAddress := "127.0.0.2" - targetAddress := "127.0.0.3" - - serverConf := setup.NewTestConfig(originAddress, targetAddress) - proxyConf := setup.NewTestConfig(originAddress, targetAddress) - proxyConf.ControlConnMaxProtocolVersion = "3" // simulate protocol downgrade to V2 - - testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) - require.Nil(t, err) - defer testSetup.Cleanup() - - originRequestHandler := NewProtocolV2RequestHandler("origin", "dc1", "127.0.0.4") - targetRequestHandler := NewProtocolV2RequestHandler("target", "dc1", "127.0.0.5") - - testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ - originRequestHandler.HandleRequest, - client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), - } - testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ - targetRequestHandler.HandleRequest, - client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), - } - - err = testSetup.Start(nil, false, primitive.ProtocolVersion2) - require.Nil(t, err) - - proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy - if proxy != nil { - defer proxy.Shutdown() - } - require.Nil(t, err) -} - -func TestProtocolV2Query(t *testing.T) { - originAddress := "127.0.0.2" - targetAddress := "127.0.0.3" - - serverConf := setup.NewTestConfig(originAddress, targetAddress) - proxyConf := setup.NewTestConfig(originAddress, targetAddress) - proxyConf.ControlConnMaxProtocolVersion = "2" - - testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) - require.Nil(t, err) - defer testSetup.Cleanup() - - originRequestHandler := NewProtocolV2RequestHandler("origin", "dc1", "") - targetRequestHandler := NewProtocolV2RequestHandler("target", "dc1", "") - - testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ - originRequestHandler.HandleRequest, - client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), - } - testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ - targetRequestHandler.HandleRequest, - client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), - } - - err = testSetup.Start(nil, false, primitive.ProtocolVersion2) - require.Nil(t, err) - - proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy - if proxy != nil { - defer proxy.Shutdown() - } - require.Nil(t, err) - - cqlConn, err := testSetup.Client.CqlClient.Connect(context.Background()) - query := &message.Query{ - Query: "SELECT * FROM fakeks.faketb", - Options: &message.QueryOptions{Consistency: primitive.ConsistencyLevelOne}, - } - - response, err := cqlConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion2, 0, query)) - resultSet := response.Body.Message.(*message.RowsResult).Data - require.Equal(t, 1, len(resultSet)) -} - -type ProtocolV2RequestHandler struct { - cluster string - datacenter string - peerIP string -} - -func NewProtocolV2RequestHandler(cluster string, datacenter string, peerIP string) *ProtocolV2RequestHandler { - return &ProtocolV2RequestHandler{ - cluster: cluster, - datacenter: datacenter, - peerIP: peerIP, - } -} - -func (recv *ProtocolV2RequestHandler) HandleRequest( - request *frame.Frame, - conn *client.CqlServerConnection, - ctx client.RequestHandlerContext) (response *frame.Frame) { - switch request.Body.Message.GetOpCode() { - case primitive.OpCodeStartup: - if request.Header.Version != primitive.ProtocolVersion2 { - return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{ - ErrorMessage: fmt.Sprintf("Invalid or unsupported protocol version (%d)", request.Header.Version), - }) - } - return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) - case primitive.OpCodeRegister: - return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) - case primitive.OpCodeQuery: - query := request.Body.Message.(*message.Query) - switch query.Query { - case "SELECT * FROM system.local": - // C* 2.0.0 does not store local endpoint details in system.local table - sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) - sysLocMsg := &message.RowsResult{ - Metadata: &message.RowsMetadata{ - ColumnCount: int32(len(systemLocalColumnsProtocolV2)), - Columns: systemLocalColumnsProtocolV2, - }, - Data: message.RowSet{sysLocRow}, - } - return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) - case "SELECT * FROM system.peers": - var sysPeerRows message.RowSet - if len(recv.peerIP) > 0 { - sysPeerRows = append(sysPeerRows, systemPeersRow( - recv.datacenter, - &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, - primitive.ProtocolVersion2, - )) - } - sysPeeMsg := &message.RowsResult{ - Metadata: &message.RowsMetadata{ - ColumnCount: int32(len(systemPeersColumns)), - Columns: systemPeersColumns, - }, - Data: sysPeerRows, - } - return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) - case "SELECT * FROM fakeks.faketb": - sysLocMsg := &message.RowsResult{ - Metadata: &message.RowsMetadata{ - ColumnCount: 2, - Columns: []*message.ColumnMetadata{ - {Keyspace: "fakeks", Table: "faketb", Name: "key", Type: datatype.Varchar}, - {Keyspace: "fakeks", Table: "faketb", Name: "value", Type: datatype.Uuid}, - }, - }, - Data: message.RowSet{ - message.Row{keyValue, hostIdValue}, - }, - } - return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) - } - } - return nil -} diff --git a/integration-tests/protocolversions_test.go b/integration-tests/protocolversions_test.go new file mode 100644 index 00000000..cd9bddd7 --- /dev/null +++ b/integration-tests/protocolversions_test.go @@ -0,0 +1,245 @@ +package integration_tests + +import ( + "context" + "fmt" + "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/datatype" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/stretchr/testify/require" + "net" + "slices" + "testing" +) + +// Test that proxy can establish connectivity with ORIGIN and TARGET +// clusters that support different set of protocol versions. Verify also that +// client driver can connect and successfully insert or query data. +func TestProtocolNegotiationDifferentClusters(t *testing.T) { + tests := []struct { + name string + proxyMaxProtoVer string + originProtoVer []primitive.ProtocolVersion + targetProtoVer []primitive.ProtocolVersion + clientProtoVer primitive.ProtocolVersion + failClientConnect bool + }{ + { + name: "OriginV2_TargetV2_ClientV2", + proxyMaxProtoVer: "2", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + name: "OriginV2_TargetV2_ClientV2_ProxyControlConnNegotiation", + proxyMaxProtoVer: "4", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + name: "OriginV2_TargetV23_ClientV2", + proxyMaxProtoVer: "3", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + name: "OriginV23_TargetV2_ClientV2", + proxyMaxProtoVer: "3", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + // most common setup with OSS Cassandra + name: "OriginV345_TargetV345_ClientV4", + proxyMaxProtoVer: "3", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + clientProtoVer: primitive.ProtocolVersion4, + }, + { + // most common setup with DSE + name: "OriginV345_TargetV34Dse1Dse2_ClientV4", + proxyMaxProtoVer: "3", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2}, + clientProtoVer: primitive.ProtocolVersion4, + }, + { + name: "OriginV2_TargetV3_ClientV2", + proxyMaxProtoVer: "3", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, + clientProtoVer: primitive.ProtocolVersion2, + // client connection should fail as there is no common protocol version between origin and target + failClientConnect: true, + }, + } + + originAddress := "127.0.1.1" + targetAddress := "127.0.1.2" + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + + queryInsert := &message.Query{ + Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters + } + querySelect := &message.Query{ + Query: "SELECT * FROM test_ks.test", + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxyConf.ControlConnMaxProtocolVersion = test.proxyMaxProtoVer + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewProtocolNegotiationRequestHandler("origin", "dc1", originAddress, test.originProtoVer) + targetRequestHandler := NewProtocolNegotiationRequestHandler("target", "dc1", targetAddress, test.targetProtoVer) + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, test.clientProtoVer) + require.Nil(t, err) + + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + require.Nil(t, err) + + cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), test.clientProtoVer, 0) + if test.failClientConnect { + require.NotNil(t, err) + return + } + require.Nil(t, err) + defer cqlConn.Close() + + response, err := cqlConn.SendAndReceive(frame.NewFrame(test.clientProtoVer, 0, queryInsert)) + require.Nil(t, err) + require.IsType(t, &message.VoidResult{}, response.Body.Message) + + response, err = cqlConn.SendAndReceive(frame.NewFrame(test.clientProtoVer, 0, querySelect)) + require.Nil(t, err) + resultSet := response.Body.Message.(*message.RowsResult).Data + require.Equal(t, 1, len(resultSet)) + }) + } +} + +type ProtocolNegotiationRequestHandler struct { + cluster string + datacenter string + peerIP string + protocolVersions []primitive.ProtocolVersion // accepted protocol versions + // store negotiated protocol versions by socket port number + // protocol version negotiated by proxy on control connections can be different from the one + // used by client driver with ORIGIN and TARGET nodes. In the scenario 'OriginV2_TargetV23_ClientV2', proxy + // will establish control connection with ORIGIN using version 2, and TARGET with version 3. + // Protocol version applied on client connections with TARGET will be different - V2. + negotiatedProtoVer map[int]primitive.ProtocolVersion // negotiated protocol version on different sockets +} + +func NewProtocolNegotiationRequestHandler(cluster string, datacenter string, peerIP string, + protocolVersion []primitive.ProtocolVersion) *ProtocolNegotiationRequestHandler { + return &ProtocolNegotiationRequestHandler{ + cluster: cluster, + datacenter: datacenter, + peerIP: peerIP, + protocolVersions: protocolVersion, + negotiatedProtoVer: make(map[int]primitive.ProtocolVersion), + } +} + +func (recv *ProtocolNegotiationRequestHandler) HandleRequest( + request *frame.Frame, + conn *client.CqlServerConnection, + ctx client.RequestHandlerContext) (response *frame.Frame) { + port := conn.RemoteAddr().(*net.TCPAddr).Port + negotiatedProtoVer := recv.negotiatedProtoVer[port] + if !slices.Contains(recv.protocolVersions, request.Header.Version) || (negotiatedProtoVer != 0 && negotiatedProtoVer != request.Header.Version) { + // server does not support given protocol version, or it was not the one negotiated + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{ + ErrorMessage: fmt.Sprintf("Invalid or unsupported protocol version (%d)", request.Header.Version), + }) + } + switch request.Body.Message.GetOpCode() { + case primitive.OpCodeStartup: + recv.negotiatedProtoVer[port] = request.Header.Version + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeRegister: + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeQuery: + query := request.Body.Message.(*message.Query) + switch query.Query { + case "SELECT * FROM system.local": + // C* 2.0.0 does not store local endpoint details in system.local table + sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) + metadata := &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumns)), + Columns: systemLocalColumns, + } + if negotiatedProtoVer == primitive.ProtocolVersion2 { + metadata = &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumnsProtocolV2)), + Columns: systemLocalColumnsProtocolV2, + } + } + sysLocMsg := &message.RowsResult{ + Metadata: metadata, + Data: message.RowSet{sysLocRow}, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) + case "SELECT * FROM system.peers": + var sysPeerRows message.RowSet + if len(recv.peerIP) > 0 { + sysPeerRows = append(sysPeerRows, systemPeersRow( + recv.datacenter, + &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, + negotiatedProtoVer, + )) + } + sysPeeMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: int32(len(systemPeersColumns)), + Columns: systemPeersColumns, + }, + Data: sysPeerRows, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) + case "SELECT * FROM test_ks.test": + qryMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: 2, + Columns: []*message.ColumnMetadata{ + {Keyspace: "test_ks", Table: "test", Name: "key", Type: datatype.Varchar}, + {Keyspace: "test_ks", Table: "test", Name: "value", Type: datatype.Uuid}, + }, + }, + Data: message.RowSet{ + message.Row{keyValue, hostIdValue}, + }, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, qryMsg) + case "INSERT INTO test_ks.test(key, value) VALUES(1, '1')": + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.VoidResult{}) + } + } + return nil +} diff --git a/integration-tests/streamids_test.go b/integration-tests/streamids_test.go new file mode 100644 index 00000000..3a781fc7 --- /dev/null +++ b/integration-tests/streamids_test.go @@ -0,0 +1,210 @@ +package integration_tests + +import ( + "context" + "fmt" + "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/utils" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "net" + "strings" + "sync" + "testing" + "time" +) + +// Test sending more concurrent, async request than allowed stream IDs. +// Origin and target clusters are stubbed and will return protocol error +// if we notice greater stream ID value than expected. We cannot easily test +// exceeding 127 stream IDs allowed in protocol V2, because clients will +// fail serializing the frame +func TestMaxStreamIds(t *testing.T) { + originAddress := "127.0.1.1" + targetAddress := "127.0.1.2" + originProtoVer := primitive.ProtocolVersion2 + targetProtoVer := primitive.ProtocolVersion2 + requestCount := 20 + maxStreamIdsConf := 10 + maxStreamIdsExpected := 10 + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + + queryInsert := &message.Query{ + Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters + } + + buffer := utils.CreateLogHooks(log.WarnLevel, log.ErrorLevel) + defer log.StandardLogger().ReplaceHooks(make(log.LevelHooks)) + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewMaxStreamIdsRequestHandler("origin", "dc1", originAddress, maxStreamIdsExpected) + targetRequestHandler := NewProtocolNegotiationRequestHandler("target", "dc1", targetAddress, []primitive.ProtocolVersion{targetProtoVer}) + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, originProtoVer) + require.Nil(t, err) + + proxyConf.ProxyMaxStreamIds = maxStreamIdsConf + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + require.Nil(t, err) + + testSetup.Client.CqlClient.MaxInFlight = 127 // set to 127, otherwise we fail to serialize in protocol + cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), originProtoVer, 0) + require.Nil(t, err) + defer cqlConn.Close() + + remainingRequests := requestCount + + for j := 0; j < 10; j++ { + var responses []client.InFlightRequest + for i := 0; i < remainingRequests; i++ { + inFlightReq, err := cqlConn.Send(frame.NewFrame(originProtoVer, 0, queryInsert)) + require.Nil(t, err) + responses = append(responses, inFlightReq) + } + + for _, response := range responses { + select { + case msg := <-response.Incoming(): + if response.Err() != nil { + t.Fatalf(response.Err().Error()) + } + switch msg.Body.Message.(type) { + case *message.VoidResult: + // expected, we have received successful response + remainingRequests-- + case *message.Overloaded: + // client received overloaded message due to insufficient stream ID pool, retry the request + default: + t.Fatalf(response.Err().Error()) + } + } + } + + if remainingRequests == 0 { + break + } + } + + require.True(t, strings.Contains(buffer.String(), "no stream id available")) + + require.True(t, len(originRequestHandler.usedStreamIdsPerConn) >= 1) + for _, idMap := range originRequestHandler.usedStreamIdsPerConn { + maxId := int16(0) + for streamId, _ := range idMap { + if streamId > maxId { + maxId = streamId + } + } + require.True(t, maxId < int16(maxStreamIdsExpected)) + } +} + +type MaxStreamIdsRequestHandler struct { + lock sync.Mutex + cluster string + datacenter string + peerIP string + maxStreamIds int + usedStreamIdsPerConn map[int]map[int16]bool +} + +func NewMaxStreamIdsRequestHandler(cluster string, datacenter string, peerIP string, maxStreamIds int) *MaxStreamIdsRequestHandler { + return &MaxStreamIdsRequestHandler{ + cluster: cluster, + datacenter: datacenter, + peerIP: peerIP, + maxStreamIds: maxStreamIds, + usedStreamIdsPerConn: make(map[int]map[int16]bool), + } +} + +func (recv *MaxStreamIdsRequestHandler) HandleRequest( + request *frame.Frame, + conn *client.CqlServerConnection, + ctx client.RequestHandlerContext) (response *frame.Frame) { + port := conn.RemoteAddr().(*net.TCPAddr).Port + + switch request.Body.Message.GetOpCode() { + case primitive.OpCodeStartup: + case primitive.OpCodeRegister: + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeQuery: + query := request.Body.Message.(*message.Query) + switch query.Query { + case "SELECT * FROM system.local": + // C* 2.0.0 does not store local endpoint details in system.local table + sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) + metadata := &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumns)), + Columns: systemLocalColumns, + } + if request.Header.Version == primitive.ProtocolVersion2 { + metadata = &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumnsProtocolV2)), + Columns: systemLocalColumnsProtocolV2, + } + } + sysLocMsg := &message.RowsResult{ + Metadata: metadata, + Data: message.RowSet{sysLocRow}, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) + case "SELECT * FROM system.peers": + var sysPeerRows message.RowSet + if len(recv.peerIP) > 0 { + sysPeerRows = append(sysPeerRows, systemPeersRow( + recv.datacenter, + &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, + request.Header.Version, + )) + } + sysPeeMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: int32(len(systemPeersColumns)), + Columns: systemPeersColumns, + }, + Data: sysPeerRows, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) + case "INSERT INTO test_ks.test(key, value) VALUES(1, '1')": + recv.lock.Lock() + usedStreamIdsMap := recv.usedStreamIdsPerConn[port] + if usedStreamIdsMap == nil { + usedStreamIdsMap = make(map[int16]bool) + recv.usedStreamIdsPerConn[port] = usedStreamIdsMap + } + usedStreamIdsMap[request.Header.StreamId] = true + recv.lock.Unlock() + + time.Sleep(5 * time.Millisecond) // introduce some delay so that stream IDs are not released immediately + + if len(usedStreamIdsMap) > recv.maxStreamIds { + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{ + ErrorMessage: fmt.Sprintf("Too many stream IDs used (%d)", len(usedStreamIdsMap)), + }) + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.VoidResult{}) + } + } + return nil +} diff --git a/proxy/pkg/zdmproxy/clienthandler.go b/proxy/pkg/zdmproxy/clienthandler.go index 30b59bc1..bfd91ba2 100644 --- a/proxy/pkg/zdmproxy/clienthandler.go +++ b/proxy/pkg/zdmproxy/clienthandler.go @@ -168,8 +168,8 @@ func NewClientHandler( requestsDoneCtx, requestsDoneCancelFn := context.WithCancel(context.Background()) // Initialize stream id processors to manage the ids sent to the clusters - originCCProtoVer := originControlConn.cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) - targetCCProtoVer := targetControlConn.cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) + originCCProtoVer := originControlConn.cqlConn.GetProtocolVersion() + targetCCProtoVer := targetControlConn.cqlConn.GetProtocolVersion() streamIds := maxStreamIds(originCCProtoVer, targetCCProtoVer, conf) originFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeOrigin) targetFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeTarget) @@ -1485,17 +1485,27 @@ func (ch *ClientHandler) executeRequest( case forwardToBoth: log.Tracef("Forwarding request with opcode %v for stream %v to %v and %v", f.Header.OpCode, f.Header.StreamId, common.ClusterTypeOrigin, common.ClusterTypeTarget) - ch.originCassandraConnector.sendRequestToCluster(originRequest) - ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + sendErr := ch.originCassandraConnector.sendRequestToCluster(originRequest) + if sendErr != nil { + ch.clientConnector.sendOverloadedToClient(frameContext.frame) + } else { + ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + } case forwardToOrigin: log.Tracef("Forwarding request with opcode %v for stream %v to %v", f.Header.OpCode, f.Header.StreamId, common.ClusterTypeOrigin) - ch.originCassandraConnector.sendRequestToCluster(originRequest) + sendErr := ch.originCassandraConnector.sendRequestToCluster(originRequest) + if sendErr != nil { + ch.clientConnector.sendOverloadedToClient(frameContext.frame) + } ch.targetCassandraConnector.sendHeartbeat(startupFrameVersion, ch.conf.HeartbeatIntervalMs) case forwardToTarget: log.Tracef("Forwarding request with opcode %v for stream %v to %v", f.Header.OpCode, f.Header.StreamId, common.ClusterTypeTarget) - ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + sendErr := ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + if sendErr != nil { + ch.clientConnector.sendOverloadedToClient(frameContext.frame) + } ch.originCassandraConnector.sendHeartbeat(startupFrameVersion, ch.conf.HeartbeatIntervalMs) case forwardToAsyncOnly: default: @@ -2258,13 +2268,13 @@ func newFrameProcessor(maxStreamIds int, nodeMetrics *metrics.NodeMetrics, conne // uses negotiated protocol version to configure maximum number of stream IDs on node connections. Driver does NOT // change the number of stream IDs on per node basis. func maxStreamIds(originProtoVer primitive.ProtocolVersion, targetProtoVer primitive.ProtocolVersion, conf *config.Config) int { - maxSupported := maxOutgoingPending + maxSupported := maxStreamIdsV3 protoVer := originProtoVer if targetProtoVer < originProtoVer { protoVer = targetProtoVer } if protoVer == primitive.ProtocolVersion2 { - maxSupported = maxOutgoingPendingV2 + maxSupported = maxStreamIdsV2 } if maxSupported < conf.ProxyMaxStreamIds { return maxSupported diff --git a/proxy/pkg/zdmproxy/clusterconn.go b/proxy/pkg/zdmproxy/clusterconn.go index a1a35e93..21bea2c1 100644 --- a/proxy/pkg/zdmproxy/clusterconn.go +++ b/proxy/pkg/zdmproxy/clusterconn.go @@ -394,17 +394,18 @@ func (cc *ClusterConnector) handleAsyncResponse(response *frame.RawFrame) *frame return nil } -func (cc *ClusterConnector) sendRequestToCluster(frame *frame.RawFrame) { +func (cc *ClusterConnector) sendRequestToCluster(frame *frame.RawFrame) error { var err error if cc.frameProcessor != nil { frame, err = cc.frameProcessor.AssignUniqueId(frame) } if err != nil { log.Errorf("[%v] Couldn't assign stream id to frame %v: %v", string(cc.connectorType), frame.Header.OpCode, err) - return + return err } else { cc.writeCoalescer.Enqueue(frame) } + return nil } func (cc *ClusterConnector) validateAsyncStateForRequest(frame *frame.RawFrame) bool { diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index 4a2bc45c..17fd8048 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -362,7 +362,7 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( conn = newConn log.Infof("Successfully opened control connection to %v using endpoint %v with %v.", - cc.connConfig.GetClusterType(), endpoint.String(), newConn.GetProtocolVersion().Load().(primitive.ProtocolVersion)) + cc.connConfig.GetClusterType(), endpoint.String(), newConn.GetProtocolVersion()) break } diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index a9fdabb1..92a047f4 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -21,10 +21,11 @@ import ( const ( eventQueueLength = 2048 - maxIncomingPending = 2048 - maxIncomingPendingV2 = 128 - maxOutgoingPending = 2048 - maxOutgoingPendingV2 = 128 + maxIncomingPending = 2048 + maxOutgoingPending = 2048 + + maxStreamIdsV3 = 2048 + maxStreamIdsV2 = 127 timeOutsThreshold = 1024 ) @@ -41,7 +42,7 @@ type CqlConnection interface { SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) SubscribeToProtocolEvents(ctx context.Context, eventTypes []primitive.EventType) error IsAuthEnabled() (bool, error) - GetProtocolVersion() *atomic.Value + GetProtocolVersion() primitive.ProtocolVersion } // Not thread safe @@ -245,8 +246,8 @@ func (c *cqlConn) IsAuthEnabled() (bool, error) { return c.authEnabled, nil } -func (c *cqlConn) GetProtocolVersion() *atomic.Value { - return c.protocolVersion +func (c *cqlConn) GetProtocolVersion() primitive.ProtocolVersion { + return c.protocolVersion.Load().(primitive.ProtocolVersion) } func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx context.Context) error { From 7722bf1b507aaeab941f9afbb9254639d7a7362d Mon Sep 17 00:00:00 2001 From: Auto Gofmt Date: Wed, 17 Jul 2024 13:51:20 +0000 Subject: [PATCH 22/29] Automated gofmt changes --- integration-tests/streamids_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration-tests/streamids_test.go b/integration-tests/streamids_test.go index 3a781fc7..9e47e880 100644 --- a/integration-tests/streamids_test.go +++ b/integration-tests/streamids_test.go @@ -110,7 +110,7 @@ func TestMaxStreamIds(t *testing.T) { require.True(t, len(originRequestHandler.usedStreamIdsPerConn) >= 1) for _, idMap := range originRequestHandler.usedStreamIdsPerConn { maxId := int16(0) - for streamId, _ := range idMap { + for streamId := range idMap { if streamId > maxId { maxId = streamId } From 1dabb2385d26a831114b0f89ea10f5f80305ac66 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 17 Jul 2024 16:03:23 +0200 Subject: [PATCH 23/29] Cleanup --- proxy/pkg/zdmproxy/clienthandler_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/proxy/pkg/zdmproxy/clienthandler_test.go b/proxy/pkg/zdmproxy/clienthandler_test.go index f24c2817..b7fd0ce5 100644 --- a/proxy/pkg/zdmproxy/clienthandler_test.go +++ b/proxy/pkg/zdmproxy/clienthandler_test.go @@ -22,8 +22,8 @@ func TestMaxStreamIds(t *testing.T) { {"OriginV3_TargetV3_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion3, targetProtoVer: primitive.ProtocolVersion3, config: &config.Config{ProxyMaxStreamIds: 2048}}, 2048}, {"OriginV3_TargetV4_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion3, targetProtoVer: primitive.ProtocolVersion4, config: &config.Config{ProxyMaxStreamIds: 2048}}, 2048}, {"OriginV3_TargetV4_LowerConfig", args{originProtoVer: primitive.ProtocolVersion3, targetProtoVer: primitive.ProtocolVersion4, config: &config.Config{ProxyMaxStreamIds: 1024}}, 1024}, - {"OriginV2_TargetV3_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion2, targetProtoVer: primitive.ProtocolVersion3, config: &config.Config{ProxyMaxStreamIds: 2048}}, 128}, - {"OriginV2_TargetV2_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion2, targetProtoVer: primitive.ProtocolVersion2, config: &config.Config{ProxyMaxStreamIds: 2048}}, 128}, + {"OriginV2_TargetV3_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion2, targetProtoVer: primitive.ProtocolVersion3, config: &config.Config{ProxyMaxStreamIds: 2048}}, 127}, + {"OriginV2_TargetV2_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion2, targetProtoVer: primitive.ProtocolVersion2, config: &config.Config{ProxyMaxStreamIds: 2048}}, 127}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 2615ce91d80274637a416e15853f773d9e768d8f Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 18 Jul 2024 10:50:49 +0200 Subject: [PATCH 24/29] Use DSEv2 as default max protocol version --- integration-tests/setup/testcluster.go | 2 +- integration-tests/streamids_test.go | 2 +- proxy/pkg/config/config.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/integration-tests/setup/testcluster.go b/integration-tests/setup/testcluster.go index 52613212..55ee74f7 100644 --- a/integration-tests/setup/testcluster.go +++ b/integration-tests/setup/testcluster.go @@ -452,7 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config { conf.ReadMode = config.ReadModePrimaryOnly conf.SystemQueriesMode = config.SystemQueriesModeOrigin conf.AsyncHandshakeTimeoutMs = 4000 - conf.ControlConnMaxProtocolVersion = "3" + conf.ControlConnMaxProtocolVersion = "DseV2" conf.ProxyRequestTimeoutMs = 10000 diff --git a/integration-tests/streamids_test.go b/integration-tests/streamids_test.go index 9e47e880..97eb9431 100644 --- a/integration-tests/streamids_test.go +++ b/integration-tests/streamids_test.go @@ -23,7 +23,7 @@ import ( // if we notice greater stream ID value than expected. We cannot easily test // exceeding 127 stream IDs allowed in protocol V2, because clients will // fail serializing the frame -func TestMaxStreamIds(t *testing.T) { +func TestLimitStreamIdsGeneration(t *testing.T) { originAddress := "127.0.1.1" targetAddress := "127.0.1.2" originProtoVer := primitive.ProtocolVersion2 diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index 72e35ceb..cc62842b 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -25,7 +25,7 @@ type Config struct { ReplaceCqlFunctions bool `default:"false" split_words:"true" yaml:"replace_cql_functions"` AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true" yaml:"async_handshake_timeout_ms"` LogLevel string `default:"INFO" split_words:"true" yaml:"log_level"` - ControlConnMaxProtocolVersion string `default:"3" split_words:"true" yaml:"control_conn_max_protocol_version"` // Numeric Cassandra OSS protocol version or DseV1 / DseV2 + ControlConnMaxProtocolVersion string `default:"DseV2" split_words:"true" yaml:"control_conn_max_protocol_version"` // Numeric Cassandra OSS protocol version or DseV1 / DseV2 // Proxy Topology (also known as system.peers "virtualization") bucket From 21eebb68ef2bf691e7fc42c258726b7120280b31 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 18 Jul 2024 11:18:27 +0200 Subject: [PATCH 25/29] More various protocol version tests --- integration-tests/protocolversions_test.go | 121 ++++++++++++++------- 1 file changed, 79 insertions(+), 42 deletions(-) diff --git a/integration-tests/protocolversions_test.go b/integration-tests/protocolversions_test.go index cd9bddd7..4da2980d 100644 --- a/integration-tests/protocolversions_test.go +++ b/integration-tests/protocolversions_test.go @@ -20,65 +20,92 @@ import ( // client driver can connect and successfully insert or query data. func TestProtocolNegotiationDifferentClusters(t *testing.T) { tests := []struct { - name string - proxyMaxProtoVer string - originProtoVer []primitive.ProtocolVersion - targetProtoVer []primitive.ProtocolVersion - clientProtoVer primitive.ProtocolVersion - failClientConnect bool + name string + proxyMaxProtoVer string + proxyOriginContConnVer primitive.ProtocolVersion + proxyTargetContConnVer primitive.ProtocolVersion + originProtoVer []primitive.ProtocolVersion + targetProtoVer []primitive.ProtocolVersion + clientProtoVer primitive.ProtocolVersion + failClientConnect bool + failProxyStartup bool }{ { - name: "OriginV2_TargetV2_ClientV2", - proxyMaxProtoVer: "2", - originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, - targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, - clientProtoVer: primitive.ProtocolVersion2, + name: "OriginV2_TargetV2_ClientV2", + proxyMaxProtoVer: "2", + proxyOriginContConnVer: primitive.ProtocolVersion2, + proxyTargetContConnVer: primitive.ProtocolVersion2, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, }, { - name: "OriginV2_TargetV2_ClientV2_ProxyControlConnNegotiation", - proxyMaxProtoVer: "4", - originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, - targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, - clientProtoVer: primitive.ProtocolVersion2, + name: "OriginV2_TargetV2_ClientV2_ProxyControlConnNegotiation", + proxyMaxProtoVer: "4", + proxyOriginContConnVer: primitive.ProtocolVersion2, + proxyTargetContConnVer: primitive.ProtocolVersion2, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, }, { - name: "OriginV2_TargetV23_ClientV2", - proxyMaxProtoVer: "3", - originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, - targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, - clientProtoVer: primitive.ProtocolVersion2, + name: "OriginV2_TargetV23_ClientV2", + proxyMaxProtoVer: "3", + proxyOriginContConnVer: primitive.ProtocolVersion2, + proxyTargetContConnVer: primitive.ProtocolVersion3, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, + clientProtoVer: primitive.ProtocolVersion2, }, { - name: "OriginV23_TargetV2_ClientV2", - proxyMaxProtoVer: "3", - originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, - targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, - clientProtoVer: primitive.ProtocolVersion2, + name: "OriginV23_TargetV2_ClientV2", + proxyMaxProtoVer: "3", + proxyOriginContConnVer: primitive.ProtocolVersion3, + proxyTargetContConnVer: primitive.ProtocolVersion2, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, }, { // most common setup with OSS Cassandra - name: "OriginV345_TargetV345_ClientV4", - proxyMaxProtoVer: "3", - originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, - targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, - clientProtoVer: primitive.ProtocolVersion4, + name: "OriginV345_TargetV345_ClientV4", + proxyMaxProtoVer: "DseV2", + proxyOriginContConnVer: primitive.ProtocolVersion4, + proxyTargetContConnVer: primitive.ProtocolVersion4, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + clientProtoVer: primitive.ProtocolVersion4, }, { // most common setup with DSE - name: "OriginV345_TargetV34Dse1Dse2_ClientV4", - proxyMaxProtoVer: "3", - originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, - targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2}, - clientProtoVer: primitive.ProtocolVersion4, + name: "OriginV345_TargetV34Dse1Dse2_ClientV4", + proxyMaxProtoVer: "DseV2", + proxyOriginContConnVer: primitive.ProtocolVersion4, + proxyTargetContConnVer: primitive.ProtocolVersionDse2, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2}, + clientProtoVer: primitive.ProtocolVersion4, }, { - name: "OriginV2_TargetV3_ClientV2", - proxyMaxProtoVer: "3", - originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, - targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, - clientProtoVer: primitive.ProtocolVersion2, + name: "OriginV2_TargetV3_ClientV2", + proxyMaxProtoVer: "3", + proxyOriginContConnVer: primitive.ProtocolVersion2, + proxyTargetContConnVer: primitive.ProtocolVersion3, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, + clientProtoVer: primitive.ProtocolVersion2, // client connection should fail as there is no common protocol version between origin and target failClientConnect: true, + }, { + name: "OriginV3_TargetV3_ClientV3_Too_Low_Proto_Configured", + proxyMaxProtoVer: "2", + proxyOriginContConnVer: primitive.ProtocolVersion3, + proxyTargetContConnVer: primitive.ProtocolVersion3, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, + clientProtoVer: primitive.ProtocolVersion2, + // client proxy startup, because configured protocol version is too low + failProxyStartup: true, }, } @@ -121,7 +148,12 @@ func TestProtocolNegotiationDifferentClusters(t *testing.T) { if proxy != nil { defer proxy.Shutdown() } - require.Nil(t, err) + if test.failProxyStartup { + require.NotNil(t, err) + return + } else { + require.Nil(t, err) + } cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), test.clientProtoVer, 0) if test.failClientConnect { @@ -139,6 +171,11 @@ func TestProtocolNegotiationDifferentClusters(t *testing.T) { require.Nil(t, err) resultSet := response.Body.Message.(*message.RowsResult).Data require.Equal(t, 1, len(resultSet)) + + proxyCqlConn, _ := proxy.GetOriginControlConn().GetConnAndContactPoint() + require.Equal(t, test.proxyOriginContConnVer, proxyCqlConn.GetProtocolVersion()) + proxyCqlConn, _ = proxy.GetTargetControlConn().GetConnAndContactPoint() + require.Equal(t, test.proxyTargetContConnVer, proxyCqlConn.GetProtocolVersion()) }) } } From 5ea7eeff8ddce562c2d7382bbbe4114121679058 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 19 Jul 2024 13:33:19 +0200 Subject: [PATCH 26/29] Stream ID verification --- integration-tests/connect_test.go | 28 ++++++--- integration-tests/streamids_test.go | 47 ++++++++++++++ proxy/pkg/zdmproxy/clientconn.go | 18 +++--- proxy/pkg/zdmproxy/clienthandler.go | 75 +++++++++++++---------- proxy/pkg/zdmproxy/clienthandler_test.go | 2 +- proxy/pkg/zdmproxy/clusterconn.go | 8 ++- proxy/pkg/zdmproxy/cqlconn.go | 2 +- proxy/pkg/zdmproxy/streamidmapper.go | 68 +++++++++++++++----- proxy/pkg/zdmproxy/streamidmapper_test.go | 7 ++- 9 files changed, 181 insertions(+), 74 deletions(-) diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index 8b31103d..b687ae44 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -159,20 +159,23 @@ func TestMaxClientsThreshold(t *testing.T) { func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) { tests := []struct { - name string - requestVersion primitive.ProtocolVersion - expectedVersion primitive.ProtocolVersion - errExpected string + name string + requestVersion primitive.ProtocolVersion + negotiatedVersion string + expectedVersion primitive.ProtocolVersion + errExpected string }{ { "request v5, response v4", primitive.ProtocolVersion5, + "4", primitive.ProtocolVersion4, "Invalid or unsupported protocol version (5)", }, { "request v1, response v4", primitive.ProtocolVersion(0x1), + "4", primitive.ProtocolVersion4, "Invalid or unsupported protocol version (1)", }, @@ -189,6 +192,7 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) { defer zerolog.SetGlobalLevel(oldZeroLogLevel) cfg := setup.NewTestConfig("127.0.1.1", "127.0.1.2") + cfg.ControlConnMaxProtocolVersion = test.negotiatedVersion cfg.LogLevel = "TRACE" // saw 1 test failure here once but logs didn't show enough info testSetup, err := setup.NewCqlServerTestSetup(t, cfg, false, false, false) require.Nil(t, err) @@ -218,16 +222,18 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) { func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { type test struct { - name string - requestVersion primitive.ProtocolVersion - returnedVersion primitive.ProtocolVersion - expectedVersion primitive.ProtocolVersion - errExpected string + name string + requestVersion primitive.ProtocolVersion + negotiatedVersion string + returnedVersion primitive.ProtocolVersion + expectedVersion primitive.ProtocolVersion + errExpected string } tests := []*test{ { "DSE_V2 request, v5 returned, v4 expected", primitive.ProtocolVersionDse2, + "4", primitive.ProtocolVersion5, primitive.ProtocolVersion4, "Invalid or unsupported protocol version (5)", @@ -235,6 +241,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { { "DSE_V2 request, v1 returned, v4 expected", primitive.ProtocolVersionDse2, + "4", primitive.ProtocolVersion(0x01), primitive.ProtocolVersion4, "Invalid or unsupported protocol version (1)", @@ -242,6 +249,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { } runTestFunc := func(t *testing.T, test *test, cfg *config.Config) { + cfg.ControlConnMaxProtocolVersion = test.negotiatedVersion // simulate what version was negotiated on control connection testSetup, err := setup.NewCqlServerTestSetup(t, cfg, false, false, false) require.Nil(t, err) defer testSetup.Cleanup() @@ -299,7 +307,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { } func createFrameWithUnsupportedVersion(version primitive.ProtocolVersion, streamId int16, isResponse bool) ([]byte, error) { - mostSimilarVersion := primitive.ProtocolVersion4 + mostSimilarVersion := version if version > primitive.ProtocolVersionDse2 { mostSimilarVersion = primitive.ProtocolVersionDse2 } else if version < primitive.ProtocolVersion2 { diff --git a/integration-tests/streamids_test.go b/integration-tests/streamids_test.go index 97eb9431..4796cc86 100644 --- a/integration-tests/streamids_test.go +++ b/integration-tests/streamids_test.go @@ -119,6 +119,53 @@ func TestLimitStreamIdsGeneration(t *testing.T) { } } +func TestFailOnNegativeStreamIDsFromClient(t *testing.T) { + originAddress := "127.0.1.1" + targetAddress := "127.0.1.2" + originProtoVer := primitive.ProtocolVersion2 + targetProtoVer := primitive.ProtocolVersion2 + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + + queryInsert := &message.Query{ + Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters + } + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewMaxStreamIdsRequestHandler("origin", "dc1", originAddress, 100) + targetRequestHandler := NewProtocolNegotiationRequestHandler("target", "dc1", targetAddress, []primitive.ProtocolVersion{targetProtoVer}) + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, originProtoVer) + require.Nil(t, err) + + proxyConf.ProxyMaxStreamIds = 100 + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + require.Nil(t, err) + + cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), originProtoVer, 0) + require.Nil(t, err) + defer cqlConn.Close() + + response, _ := cqlConn.SendAndReceive(frame.NewFrame(originProtoVer, -1, queryInsert)) + require.IsType(t, response.Body.Message, &message.ProtocolError{}) + require.Equal(t, "negative stream id: -1", response.Body.Message.(*message.ProtocolError).ErrorMessage) +} + type MaxStreamIdsRequestHandler struct { lock sync.Mutex cluster string diff --git a/proxy/pkg/zdmproxy/clientconn.go b/proxy/pkg/zdmproxy/clientconn.go index ff8e2100..33e8b66d 100644 --- a/proxy/pkg/zdmproxy/clientconn.go +++ b/proxy/pkg/zdmproxy/clientconn.go @@ -56,6 +56,8 @@ type ClientConnector struct { readScheduler *Scheduler shutdownRequestCtx context.Context + + minProtoVer primitive.ProtocolVersion } func NewClientConnector( @@ -71,7 +73,8 @@ func NewClientConnector( readScheduler *Scheduler, writeScheduler *Scheduler, shutdownRequestCtx context.Context, - clientHandlerShutdownRequestCancelFn context.CancelFunc) *ClientConnector { + clientHandlerShutdownRequestCancelFn context.CancelFunc, + minProtoVer primitive.ProtocolVersion) *ClientConnector { return &ClientConnector{ connection: connection, @@ -97,6 +100,7 @@ func NewClientConnector( readScheduler: readScheduler, shutdownRequestCtx: shutdownRequestCtx, clientHandlerShutdownRequestCancelFn: clientHandlerShutdownRequestCancelFn, + minProtoVer: minProtoVer, } } @@ -176,7 +180,7 @@ func (cc *ClientConnector) listenForRequests() { for cc.clientHandlerContext.Err() == nil { f, err := readRawFrame(bufferedReader, connectionAddr, cc.clientHandlerContext) - protocolErrResponseFrame, err, _ := checkProtocolError(f, err, protocolErrOccurred, ClientConnectorLogPrefix) + protocolErrResponseFrame, err, _ := checkProtocolError(f, cc.minProtoVer, err, protocolErrOccurred, ClientConnectorLogPrefix) if err != nil { handleConnectionError( err, cc.clientHandlerContext, cc.clientHandlerCancelFunc, ClientConnectorLogPrefix, "reading", connectionAddr) @@ -224,7 +228,7 @@ func (cc *ClientConnector) sendOverloadedToClient(request *frame.RawFrame) { } } -func checkProtocolError(f *frame.RawFrame, connErr error, protocolErrorOccurred bool, prefix string) (protocolErrResponse *frame.RawFrame, fatalErr error, errorCode int8) { +func checkProtocolError(f *frame.RawFrame, protoVer primitive.ProtocolVersion, connErr error, protocolErrorOccurred bool, prefix string) (protocolErrResponse *frame.RawFrame, fatalErr error, errorCode int8) { var protocolErrMsg *message.ProtocolError var streamId int16 var logMsg string @@ -244,7 +248,7 @@ func checkProtocolError(f *frame.RawFrame, connErr error, protocolErrorOccurred if !protocolErrorOccurred { log.Debugf("[%v] %v Returning a protocol error to the client to force a downgrade: %v.", prefix, logMsg, protocolErrMsg) } - rawProtocolErrResponse, err := generateProtocolErrorResponseFrame(streamId, protocolErrMsg) + rawProtocolErrResponse, err := generateProtocolErrorResponseFrame(streamId, protoVer, protocolErrMsg) if err != nil { return nil, fmt.Errorf("could not generate protocol error response raw frame (%v): %v", protocolErrMsg, err), -1 } else { @@ -255,10 +259,8 @@ func checkProtocolError(f *frame.RawFrame, connErr error, protocolErrorOccurred } } -func generateProtocolErrorResponseFrame(streamId int16, protocolErrMsg *message.ProtocolError) (*frame.RawFrame, error) { - // ideally we would use the maximum version between the versions used by both control connections if - // control connections implemented protocol version negotiation - response := frame.NewFrame(primitive.ProtocolVersion4, streamId, protocolErrMsg) +func generateProtocolErrorResponseFrame(streamId int16, protoVer primitive.ProtocolVersion, protocolErrMsg *message.ProtocolError) (*frame.RawFrame, error) { + response := frame.NewFrame(protoVer, streamId, protocolErrMsg) rawResponse, err := defaultCodec.ConvertToRawFrame(response) if err != nil { return nil, err diff --git a/proxy/pkg/zdmproxy/clienthandler.go b/proxy/pkg/zdmproxy/clienthandler.go index bfd91ba2..498e767c 100644 --- a/proxy/pkg/zdmproxy/clienthandler.go +++ b/proxy/pkg/zdmproxy/clienthandler.go @@ -16,6 +16,7 @@ import ( log "github.com/sirupsen/logrus" "net" "sort" + "strings" "sync" "sync/atomic" "time" @@ -170,10 +171,16 @@ func NewClientHandler( // Initialize stream id processors to manage the ids sent to the clusters originCCProtoVer := originControlConn.cqlConn.GetProtocolVersion() targetCCProtoVer := targetControlConn.cqlConn.GetProtocolVersion() - streamIds := maxStreamIds(originCCProtoVer, targetCCProtoVer, conf) - originFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeOrigin) - targetFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeTarget) - asyncFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeAsync) + // Calculate maximum number of stream IDs. Take the oldest protocol version negotiated between two clusters + // and apply limit defined in proxy configuration. If origin or target cluster are still running protocol V2, + // we will limit maximum number of stream IDs to 127 on both clusters. Logic is based on Java driver version 3.x. + // Java driver 3.x was the last one supporting protocol V2. It establishes control connection first, and then + // uses negotiated protocol version to configure maximum number of stream IDs on node connections. Driver does NOT + // change the number of stream IDs on per node basis. Maximum stream ID is calculated while creating stream ID mapper. + minimalProtoVer := minProtoVer(originCCProtoVer, targetCCProtoVer) + originFrameProcessor := newFrameProcessor(minimalProtoVer, conf, nodeMetrics, ClusterConnectorTypeOrigin) + targetFrameProcessor := newFrameProcessor(minimalProtoVer, conf, nodeMetrics, ClusterConnectorTypeTarget) + asyncFrameProcessor := newFrameProcessor(minimalProtoVer, conf, nodeMetrics, ClusterConnectorTypeAsync) closeFrameProcessors := func() { originFrameProcessor.Close() @@ -200,7 +207,7 @@ func NewClientHandler( originConnector, err := NewClusterConnector( originCassandraConnInfo, conf, psCache, nodeMetrics, localClientHandlerWg, clientHandlerRequestWg, clientHandlerContext, clientHandlerCancelFunc, respChannel, readScheduler, writeScheduler, requestsDoneCtx, - false, nil, handshakeDone, originFrameProcessor) + false, nil, handshakeDone, originFrameProcessor, originCCProtoVer) if err != nil { clientHandlerCancelFunc() return nil, err @@ -209,7 +216,7 @@ func NewClientHandler( targetConnector, err := NewClusterConnector( targetCassandraConnInfo, conf, psCache, nodeMetrics, localClientHandlerWg, clientHandlerRequestWg, clientHandlerContext, clientHandlerCancelFunc, respChannel, readScheduler, writeScheduler, requestsDoneCtx, - false, nil, handshakeDone, targetFrameProcessor) + false, nil, handshakeDone, targetFrameProcessor, targetCCProtoVer) if err != nil { clientHandlerCancelFunc() return nil, err @@ -227,7 +234,7 @@ func NewClientHandler( asyncConnector, err = NewClusterConnector( asyncConnInfo, conf, psCache, nodeMetrics, localClientHandlerWg, clientHandlerRequestWg, clientHandlerContext, clientHandlerCancelFunc, respChannel, readScheduler, writeScheduler, requestsDoneCtx, - true, asyncPendingRequests, handshakeDone, asyncFrameProcessor) + true, asyncPendingRequests, handshakeDone, asyncFrameProcessor, originCCProtoVer) if err != nil { log.Errorf("Could not create async cluster connector to %s, async requests will not be forwarded: %s", asyncConnInfo.connConfig.GetClusterType(), err.Error()) asyncConnector = nil @@ -263,7 +270,8 @@ func NewClientHandler( readScheduler, writeScheduler, clientHandlerShutdownRequestContext, - clientHandlerShutdownRequestCancelFn), + clientHandlerShutdownRequestCancelFn, + minProtoVer(originCCProtoVer, targetCCProtoVer)), asyncConnector: asyncConnector, originCassandraConnector: originConnector, @@ -1487,7 +1495,7 @@ func (ch *ClientHandler) executeRequest( f.Header.OpCode, f.Header.StreamId, common.ClusterTypeOrigin, common.ClusterTypeTarget) sendErr := ch.originCassandraConnector.sendRequestToCluster(originRequest) if sendErr != nil { - ch.clientConnector.sendOverloadedToClient(frameContext.frame) + ch.handleRequestSendFailure(sendErr, frameContext) } else { ch.targetCassandraConnector.sendRequestToCluster(targetRequest) } @@ -1496,7 +1504,7 @@ func (ch *ClientHandler) executeRequest( f.Header.OpCode, f.Header.StreamId, common.ClusterTypeOrigin) sendErr := ch.originCassandraConnector.sendRequestToCluster(originRequest) if sendErr != nil { - ch.clientConnector.sendOverloadedToClient(frameContext.frame) + ch.handleRequestSendFailure(sendErr, frameContext) } ch.targetCassandraConnector.sendHeartbeat(startupFrameVersion, ch.conf.HeartbeatIntervalMs) case forwardToTarget: @@ -1504,7 +1512,7 @@ func (ch *ClientHandler) executeRequest( f.Header.OpCode, f.Header.StreamId, common.ClusterTypeTarget) sendErr := ch.targetCassandraConnector.sendRequestToCluster(targetRequest) if sendErr != nil { - ch.clientConnector.sendOverloadedToClient(frameContext.frame) + ch.handleRequestSendFailure(sendErr, frameContext) } ch.originCassandraConnector.sendHeartbeat(startupFrameVersion, ch.conf.HeartbeatIntervalMs) case forwardToAsyncOnly: @@ -1526,6 +1534,20 @@ func (ch *ClientHandler) executeRequest( overallRequestStartTime, requestTimeout) } +func (ch *ClientHandler) handleRequestSendFailure(err error, frameContext *frameDecodeContext) { + if strings.Contains(err.Error(), "no stream id available") { + ch.clientConnector.sendOverloadedToClient(frameContext.frame) + } else if strings.Contains(err.Error(), "negative stream id") { + responseMessage := &message.ProtocolError{ErrorMessage: err.Error()} + responseFrame, err := generateProtocolErrorResponseFrame( + frameContext.frame.Header.StreamId, frameContext.frame.Header.Version, responseMessage) + if err != nil { + log.Errorf("could not generate protocol error response raw frame (%v): %v", responseMessage, err) + } + ch.clientConnector.sendResponseToClient(responseFrame) + } +} + func (ch *ClientHandler) handleInterceptedRequest( requestInfo RequestInfo, frameContext *frameDecodeContext, currentKeyspace string) (*frame.RawFrame, error) { @@ -2243,7 +2265,8 @@ func GetNodeMetricsByClusterConnector(nodeMetrics *metrics.NodeMetrics, connecto } } -func newFrameProcessor(maxStreamIds int, nodeMetrics *metrics.NodeMetrics, connectorType ClusterConnectorType) FrameProcessor { +func newFrameProcessor(protoVer primitive.ProtocolVersion, config *config.Config, nodeMetrics *metrics.NodeMetrics, + connectorType ClusterConnectorType) FrameProcessor { var streamIdsMetric metrics.Gauge connectorMetrics, err := GetNodeMetricsByClusterConnector(nodeMetrics, connectorType) if err != nil { @@ -2254,30 +2277,16 @@ func newFrameProcessor(maxStreamIds int, nodeMetrics *metrics.NodeMetrics, conne } var mapper StreamIdMapper if connectorType == ClusterConnectorTypeAsync { - mapper = NewInternalStreamIdMapper(maxStreamIds, streamIdsMetric) + mapper = NewInternalStreamIdMapper(protoVer, config, streamIdsMetric) } else { - mapper = NewStreamIdMapper(maxStreamIds, streamIdsMetric) + mapper = NewStreamIdMapper(protoVer, config, streamIdsMetric) } return NewStreamIdProcessor(mapper) } -// Calculate maximum number of stream IDs. Take the oldest protocol version negotiated between two clusters -// and apply limit defined in proxy configuration. If origin or target cluster are still running protocol V2, -// we will limit maximum number of stream IDs to 128 on both clusters. Logic is based on Java driver version 3.x. -// Java driver 3.x was the last one supporting protocol V2. It establishes control connection first, and then -// uses negotiated protocol version to configure maximum number of stream IDs on node connections. Driver does NOT -// change the number of stream IDs on per node basis. -func maxStreamIds(originProtoVer primitive.ProtocolVersion, targetProtoVer primitive.ProtocolVersion, conf *config.Config) int { - maxSupported := maxStreamIdsV3 - protoVer := originProtoVer - if targetProtoVer < originProtoVer { - protoVer = targetProtoVer - } - if protoVer == primitive.ProtocolVersion2 { - maxSupported = maxStreamIdsV2 - } - if maxSupported < conf.ProxyMaxStreamIds { - return maxSupported - } - return conf.ProxyMaxStreamIds +func minProtoVer(version1 primitive.ProtocolVersion, version2 primitive.ProtocolVersion) primitive.ProtocolVersion { + if version1 < version2 { + return version1 + } + return version2 } diff --git a/proxy/pkg/zdmproxy/clienthandler_test.go b/proxy/pkg/zdmproxy/clienthandler_test.go index b7fd0ce5..b5452557 100644 --- a/proxy/pkg/zdmproxy/clienthandler_test.go +++ b/proxy/pkg/zdmproxy/clienthandler_test.go @@ -27,7 +27,7 @@ func TestMaxStreamIds(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ids := maxStreamIds(tt.args.originProtoVer, tt.args.targetProtoVer, tt.args.config) + ids := maxStreamIds(minProtoVer(tt.args.originProtoVer, tt.args.targetProtoVer), tt.args.config) require.Equal(t, tt.expectedMaxStreamIds, ids) }) } diff --git a/proxy/pkg/zdmproxy/clusterconn.go b/proxy/pkg/zdmproxy/clusterconn.go index 21bea2c1..deeeaa45 100644 --- a/proxy/pkg/zdmproxy/clusterconn.go +++ b/proxy/pkg/zdmproxy/clusterconn.go @@ -75,6 +75,8 @@ type ClusterConnector struct { lastHeartbeatTime *atomic.Value lastHeartbeatLock sync.Mutex + + ccProtoVer primitive.ProtocolVersion } func NewClusterConnectionInfo(connConfig ConnectionConfig, endpointConfig Endpoint, isOriginCassandra bool) *ClusterConnectionInfo { @@ -101,7 +103,8 @@ func NewClusterConnector( asyncConnector bool, asyncPendingRequests *pendingRequests, handshakeDone *atomic.Value, - frameProcessor FrameProcessor) (*ClusterConnector, error) { + frameProcessor FrameProcessor, + ccProtoVer primitive.ProtocolVersion) (*ClusterConnector, error) { var connectorType ClusterConnectorType var clusterType common.ClusterType @@ -181,6 +184,7 @@ func NewClusterConnector( asyncPendingRequests: asyncPendingRequests, handshakeDone: handshakeDone, lastHeartbeatTime: lastHeartbeatTime, + ccProtoVer: ccProtoVer, }, nil } @@ -247,7 +251,7 @@ func (cc *ClusterConnector) runResponseListeningLoop() { protocolErrOccurred := false for { response, err := readRawFrame(bufferedReader, connectionAddr, cc.clusterConnContext) - protocolErrResponseFrame, err, errCode := checkProtocolError(response, err, protocolErrOccurred, string(cc.connectorType)) + protocolErrResponseFrame, err, errCode := checkProtocolError(response, cc.ccProtoVer, err, protocolErrOccurred, string(cc.connectorType)) if err != nil { handleConnectionError( diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 92a047f4..16d9ad4e 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -109,7 +109,7 @@ func NewCqlConnection( eventHandlerLock: &sync.Mutex{}, authEnabled: true, // protoVer is the proposed protocol version using which we will try to establish connectivity - frameProcessor: NewStreamIdProcessor(NewInternalStreamIdMapper(maxStreamIds(protoVer, protoVer, conf), nil)), + frameProcessor: NewStreamIdProcessor(NewInternalStreamIdMapper(protoVer, conf, nil)), protocolVersion: &atomic.Value{}, } cqlConn.StartRequestLoop() diff --git a/proxy/pkg/zdmproxy/streamidmapper.go b/proxy/pkg/zdmproxy/streamidmapper.go index 02a78d21..652e9ff6 100644 --- a/proxy/pkg/zdmproxy/streamidmapper.go +++ b/proxy/pkg/zdmproxy/streamidmapper.go @@ -2,7 +2,10 @@ package zdmproxy import ( "fmt" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/datastax/zdm-proxy/proxy/pkg/metrics" + "math" "sync" ) @@ -17,30 +20,35 @@ type StreamIdMapper interface { type streamIdMapper struct { sync.Mutex - idMapper map[int16]int16 - clusterIds chan int16 - metrics metrics.Gauge + idMapper map[int16]int16 + clusterIds chan int16 + metrics metrics.Gauge + protocolVersion primitive.ProtocolVersion } type internalStreamIdMapper struct { - clusterIds chan int16 - metrics metrics.Gauge + clusterIds chan int16 + metrics metrics.Gauge + protocolVersion primitive.ProtocolVersion } // NewInternalStreamIdMapper is used to assign unique ids to frames that have no initial stream id defined, such as // CQL queries initiated by the proxy or ASYNC requests. -func NewInternalStreamIdMapper(maxStreamIds int, metrics metrics.Gauge) StreamIdMapper { - streamIdsQueue := make(chan int16, maxStreamIds) - for i := int16(0); i < int16(maxStreamIds); i++ { +func NewInternalStreamIdMapper(protocolVersion primitive.ProtocolVersion, config *config.Config, metrics metrics.Gauge) StreamIdMapper { + maximumStreamIds := maxStreamIds(protocolVersion, config) + streamIdsQueue := make(chan int16, maximumStreamIds) + for i := int16(0); i < int16(maximumStreamIds); i++ { streamIdsQueue <- i } return &internalStreamIdMapper{ - clusterIds: streamIdsQueue, - metrics: metrics, + protocolVersion: protocolVersion, + clusterIds: streamIdsQueue, + metrics: metrics, } } func (csid *internalStreamIdMapper) GetNewIdFor(_ int16) (int16, error) { + // do not validate provided stream ID select { case id := <-csid.clusterIds: if csid.metrics != nil { @@ -73,20 +81,25 @@ func (csid *internalStreamIdMapper) Close() { } } -func NewStreamIdMapper(maxStreamIds int, metrics metrics.Gauge) StreamIdMapper { +func NewStreamIdMapper(protocolVersion primitive.ProtocolVersion, config *config.Config, metrics metrics.Gauge) StreamIdMapper { + maximumStreamIds := maxStreamIds(protocolVersion, config) idMapper := make(map[int16]int16) - streamIdsQueue := make(chan int16, maxStreamIds) - for i := int16(0); i < int16(maxStreamIds); i++ { + streamIdsQueue := make(chan int16, maximumStreamIds) + for i := int16(0); i < int16(maximumStreamIds); i++ { streamIdsQueue <- i } return &streamIdMapper{ - idMapper: idMapper, - clusterIds: streamIdsQueue, - metrics: metrics, + protocolVersion: protocolVersion, + idMapper: idMapper, + clusterIds: streamIdsQueue, + metrics: metrics, } } func (sim *streamIdMapper) GetNewIdFor(streamId int16) (int16, error) { + if err := validateStreamId(sim.protocolVersion, streamId); err != nil { + return -1, err + } select { case id := <-sim.clusterIds: if sim.metrics != nil { @@ -134,3 +147,26 @@ func (sim *streamIdMapper) Close() { sim.metrics.Subtract(cap(sim.clusterIds) - len(sim.clusterIds)) } } + +func maxStreamIds(protoVer primitive.ProtocolVersion, conf *config.Config) int { + maxSupported := maxStreamIdsV3 + if protoVer == primitive.ProtocolVersion2 { + maxSupported = maxStreamIdsV2 + } + if maxSupported < conf.ProxyMaxStreamIds { + return maxSupported + } + return conf.ProxyMaxStreamIds +} + +func validateStreamId(version primitive.ProtocolVersion, streamId int16) error { + if version < primitive.ProtocolVersion3 { + if streamId > math.MaxInt8 || streamId < math.MinInt8 { + return fmt.Errorf("stream id out of range for %v: %v", version, streamId) + } + } + if streamId < 0 { + return fmt.Errorf("negative stream id: %v", streamId) + } + return nil +} diff --git a/proxy/pkg/zdmproxy/streamidmapper_test.go b/proxy/pkg/zdmproxy/streamidmapper_test.go index 92d04792..c361e8af 100644 --- a/proxy/pkg/zdmproxy/streamidmapper_test.go +++ b/proxy/pkg/zdmproxy/streamidmapper_test.go @@ -1,20 +1,21 @@ package zdmproxy import ( + "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/stretchr/testify/require" "sync" "testing" ) func TestStreamIdMapper(t *testing.T) { - var mapper = NewStreamIdMapper(2048, nil) + var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, 2048, nil) var syntheticId, _ = mapper.GetNewIdFor(1000) var originalId, _ = mapper.ReleaseId(syntheticId) require.Equal(t, int16(1000), originalId) } func BenchmarkStreamIdMapper(b *testing.B) { - var mapper = NewStreamIdMapper(2048, nil) + var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, 2048, nil) for i := 0; i < b.N; i++ { var originalId = int16(i) var syntheticId, _ = mapper.GetNewIdFor(originalId) @@ -28,7 +29,7 @@ func TestConcurrentStreamIdMapper(t *testing.T) { var wg = sync.WaitGroup{} wg.Add(concurrency) for i := 0; i < concurrency; i++ { - var mapper = NewStreamIdMapper(2048, nil) + var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, 2048, nil) getAndReleaseIds(t, mapper, int16(i), requestCount, &wg) } wg.Wait() From 6e4b20d601ce25d4ea151aeb58561b1f350f53b7 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 19 Jul 2024 13:35:55 +0200 Subject: [PATCH 27/29] Documentation --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 723c005d..70fa60fd 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ various C* versions: |------------------|------------------| | 2.0 | V2 | | 2.1 | V2, V3 | -| 2.2 | V3, V4 | +| 2.2 | V2, V3, V4 | | 3.x | V3, V4 | | 4.x | V3, V4, V5 | @@ -122,8 +122,7 @@ migration process. In practice this means that ZDM Proxy supports the following cluster versions (as Origin and / or Target): -- Apache Cassandra from 2.1+ up to (and including) Apache Cassandra 4.x. -- Apache Cassandra 2.0 up to 2.1. +- Apache Cassandra from 2.0+ up to (and including) Apache Cassandra 4.x. (although both clusters have to support a common protocol version as mentioned above). - DataStax Enterprise 4.8+. DataStax Enterprise 4.6 and 4.7 support will be introduced when protocol version v2 is supported. - DataStax Astra DB (both Serverless and Classic) From bca996c8c7b1f10a2a6465fda10195b8bedb9efb Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 19 Jul 2024 13:49:33 +0200 Subject: [PATCH 28/29] Fix build --- proxy/pkg/zdmproxy/streamidmapper_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/proxy/pkg/zdmproxy/streamidmapper_test.go b/proxy/pkg/zdmproxy/streamidmapper_test.go index c361e8af..d72e5f59 100644 --- a/proxy/pkg/zdmproxy/streamidmapper_test.go +++ b/proxy/pkg/zdmproxy/streamidmapper_test.go @@ -2,20 +2,21 @@ package zdmproxy import ( "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/stretchr/testify/require" "sync" "testing" ) func TestStreamIdMapper(t *testing.T) { - var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, 2048, nil) + var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, &config.Config{ProxyMaxStreamIds: 2048}, nil) var syntheticId, _ = mapper.GetNewIdFor(1000) var originalId, _ = mapper.ReleaseId(syntheticId) require.Equal(t, int16(1000), originalId) } func BenchmarkStreamIdMapper(b *testing.B) { - var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, 2048, nil) + var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, &config.Config{ProxyMaxStreamIds: 2048}, nil) for i := 0; i < b.N; i++ { var originalId = int16(i) var syntheticId, _ = mapper.GetNewIdFor(originalId) @@ -29,7 +30,7 @@ func TestConcurrentStreamIdMapper(t *testing.T) { var wg = sync.WaitGroup{} wg.Add(concurrency) for i := 0; i < concurrency; i++ { - var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, 2048, nil) + var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, &config.Config{ProxyMaxStreamIds: 2048}, nil) getAndReleaseIds(t, mapper, int16(i), requestCount, &wg) } wg.Wait() From 3a18a1118eded1e150a8eb37e48f35230d54111f Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Mon, 22 Jul 2024 19:29:28 +0200 Subject: [PATCH 29/29] Cleanup --- proxy/pkg/zdmproxy/clienthandler.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/proxy/pkg/zdmproxy/clienthandler.go b/proxy/pkg/zdmproxy/clienthandler.go index 498e767c..066acf0a 100644 --- a/proxy/pkg/zdmproxy/clienthandler.go +++ b/proxy/pkg/zdmproxy/clienthandler.go @@ -1543,8 +1543,9 @@ func (ch *ClientHandler) handleRequestSendFailure(err error, frameContext *frame frameContext.frame.Header.StreamId, frameContext.frame.Header.Version, responseMessage) if err != nil { log.Errorf("could not generate protocol error response raw frame (%v): %v", responseMessage, err) + } else { + ch.clientConnector.sendResponseToClient(responseFrame) } - ch.clientConnector.sendResponseToClient(responseFrame) } }