Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

120: Support protocol v2 #119

Merged
merged 30 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9257cbd
ZDM-71: Introduce protocol negotiation
lukasz-antoniak Jun 17, 2024
72e5518
ZDM-71: Introduce protocol negotiation
lukasz-antoniak Jun 18, 2024
9de161b
Allow to run with specific Simulacron cluster version
lukasz-antoniak Jun 18, 2024
f60bd28
Allow to run with specific Simulacron cluster version
lukasz-antoniak Jun 18, 2024
a244c8b
Better classification of ProtocolError
lukasz-antoniak Jun 18, 2024
3002a92
Validation of protocol version
lukasz-antoniak Jun 20, 2024
4bf2cdd
Protocol V2 support
lukasz-antoniak Jun 25, 2024
b8feb06
Upgrade go-cassandra-native-protocol library
lukasz-antoniak Jun 26, 2024
1d353fa
Cleanup
lukasz-antoniak Jun 26, 2024
081bec0
Protocol V2 stubbed tests
lukasz-antoniak Jun 27, 2024
65ce9a0
Protocol V2 stubbed tests
lukasz-antoniak Jun 27, 2024
d8ba2db
Update README
lukasz-antoniak Jun 27, 2024
58680ad
Tidy dependencies
lukasz-antoniak Jun 28, 2024
7247e64
Apply review comments
lukasz-antoniak Jun 28, 2024
5efe8c2
Apply review comments
lukasz-antoniak Jun 28, 2024
f8e9529
Limit number of maximum stream IDs
lukasz-antoniak Jul 11, 2024
a918eaf
Merge branch 'main' of github.com:datastax/zdm-proxy into ZDM-71
lukasz-antoniak Jul 11, 2024
f94e037
Fix merge issues
lukasz-antoniak Jul 11, 2024
a5389e0
Fix merge issues
lukasz-antoniak Jul 11, 2024
ce7179c
Fix merge issues
lukasz-antoniak Jul 11, 2024
c2e46f1
Fix merge issues
lukasz-antoniak Jul 11, 2024
4bd3785
New maximum stream IDs test
lukasz-antoniak Jul 17, 2024
7722bf1
Automated gofmt changes
Jul 17, 2024
1dabb23
Cleanup
lukasz-antoniak Jul 17, 2024
2615ce9
Use DSEv2 as default max protocol version
lukasz-antoniak Jul 18, 2024
21eebb6
More various protocol version tests
lukasz-antoniak Jul 18, 2024
5ea7eef
Stream ID verification
lukasz-antoniak Jul 19, 2024
6e4b20d
Documentation
lukasz-antoniak Jul 19, 2024
bca996c
Fix build
lukasz-antoniak Jul 19, 2024
3a18a11
Cleanup
lukasz-antoniak Jul 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integration-tests/asyncreads_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func TestAsyncReadsRequestTypes(t *testing.T) {
}

testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(
t, false, false, 1, nil)
t, false, false, 1, nil, nil)
require.Nil(t, err)
defer testSetup.Cleanup()

Expand Down
29 changes: 29 additions & 0 deletions integration-tests/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/datastax/go-cassandra-native-protocol/primitive"
"github.com/datastax/zdm-proxy/integration-tests/client"
"github.com/datastax/zdm-proxy/integration-tests/setup"
"github.com/datastax/zdm-proxy/integration-tests/simulacron"
"github.com/datastax/zdm-proxy/integration-tests/utils"
"github.com/datastax/zdm-proxy/proxy/pkg/config"
"github.com/rs/zerolog"
Expand Down Expand Up @@ -45,6 +46,34 @@ func TestGoCqlConnect(t *testing.T) {
require.Equal(t, "fake", iter.Columns()[0].Name)
}

func TestProtocolVersionNegotiation(t *testing.T) {
c := setup.NewTestConfig("", "")
lukasz-antoniak marked this conversation as resolved.
Show resolved Hide resolved
c.ControlConnMaxProtocolVersion = 4 // configure unsupported protocol version
testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, true, false, 1, c, &simulacron.ClusterVersion{"2.1", "2.1"})
require.Nil(t, err)
defer testSetup.Cleanup()

// Connect to proxy as a "client"
proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3)

if err != nil {
t.Fatal("Unable to connect to proxy session.")
}
defer proxy.Close()

iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter()
result, err := iter.SliceMap()

if err != nil {
t.Fatal("query failed:", err)
}

require.Equal(t, 0, len(result))

// simulacron generates fake response metadata when queries aren't primed
require.Equal(t, "fake", iter.Columns()[0].Name)
}

func TestMaxClientsThreshold(t *testing.T) {
maxClients := 10
goCqlConnectionsPerHost := 1
Expand Down
11 changes: 6 additions & 5 deletions integration-tests/setup/testcluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,22 @@ func NewSimulacronTestSetupWithSession(t *testing.T, createProxy bool, createSes
}

func NewSimulacronTestSetupWithSessionAndConfig(t *testing.T, createProxy bool, createSession bool, config *config.Config) (*SimulacronTestSetup, error) {
return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, 1, config)
return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, 1, config, nil)
}

func NewSimulacronTestSetupWithSessionAndNodes(t *testing.T, createProxy bool, createSession bool, nodes int) (*SimulacronTestSetup, error) {
return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, nodes, nil)
return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, nodes, nil, nil)
}

func NewSimulacronTestSetupWithSessionAndNodesAndConfig(t *testing.T, createProxy bool, createSession bool, nodes int, config *config.Config) (*SimulacronTestSetup, error) {
func NewSimulacronTestSetupWithSessionAndNodesAndConfig(t *testing.T, createProxy bool, createSession bool, nodes int, config *config.Config, version *simulacron.ClusterVersion) (*SimulacronTestSetup, error) {
if !env.RunMockTests {
t.Skip("Skipping Simulacron tests, RUN_MOCKTESTS is set false")
}
origin, err := simulacron.GetNewCluster(createSession, nodes)
origin, err := simulacron.GetNewCluster(createSession, nodes, version)
if err != nil {
log.Panic("simulacron origin startup failed: ", err)
}
target, err := simulacron.GetNewCluster(createSession, nodes)
target, err := simulacron.GetNewCluster(createSession, nodes, version)
if err != nil {
log.Panic("simulacron target startup failed: ", err)
}
Expand Down Expand Up @@ -452,6 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config {
conf.ReadMode = config.ReadModePrimaryOnly
conf.SystemQueriesMode = config.SystemQueriesModeOrigin
conf.AsyncHandshakeTimeoutMs = 4000
conf.ControlConnMaxProtocolVersion = 3

conf.ProxyRequestTimeoutMs = 10000

Expand Down
4 changes: 2 additions & 2 deletions integration-tests/simulacron/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ func (baseSimulacron *baseSimulacron) GetId() string {
return baseSimulacron.id
}

func GetNewCluster(startSession bool, numberOfNodes int) (*Cluster, error) {
func GetNewCluster(startSession bool, numberOfNodes int, version *ClusterVersion) (*Cluster, error) {
process, err := GetOrCreateGlobalSimulacronProcess()

if err != nil {
return nil, err
}

cluster, createErr := process.Create(startSession, numberOfNodes)
cluster, createErr := process.Create(startSession, numberOfNodes, version)

if createErr != nil {
return nil, createErr
Expand Down
12 changes: 10 additions & 2 deletions integration-tests/simulacron/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ type ClusterData struct {
Datacenters []*DatacenterData `json:"data_centers"`
}

type ClusterVersion struct {
Cassandra string
Dse string
}

type DatacenterData struct {
Id int `json:"id"`
Nodes []*NodeData `json:"nodes"`
Expand All @@ -31,11 +36,14 @@ type NodeData struct {

const createUrl = "/cluster?data_centers=%s&cassandra_version=%s&dse_version=%s&name=%s&activity_log=%s&num_tokens=%d"

func (process *Process) Create(startSession bool, numberOfNodes int) (*Cluster, error) {
func (process *Process) Create(startSession bool, numberOfNodes int, version *ClusterVersion) (*Cluster, error) {
if version == nil {
version = &ClusterVersion{env.CassandraVersion, env.DseVersion}
}
name := "test_" + uuid.New().String()
resp, err := process.execHttp(
"POST",
fmt.Sprintf(createUrl, strconv.FormatInt(int64(numberOfNodes), 10), env.CassandraVersion, env.DseVersion, name, "true", 1),
fmt.Sprintf(createUrl, strconv.FormatInt(int64(numberOfNodes), 10), version.Cassandra, version.Dse, name, "true", 1),
nil)

if err != nil {
Expand Down
9 changes: 7 additions & 2 deletions integration-tests/utils/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ func CheckMetricsEndpointResult(httpAddr string, success bool) error {
return nil
}

// ConnectToCluster is used to connect to source and destination clusters
func ConnectToCluster(hostname string, username string, password string, port int) (*gocql.Session, error) {
func ConnectToClusterUsingVersion(hostname string, username string, password string, port int, protoVersion int) (*gocql.Session, error) {
cluster := NewCluster(hostname, username, password, port)
cluster.ProtoVersion = protoVersion
lukasz-antoniak marked this conversation as resolved.
Show resolved Hide resolved
session, err := cluster.CreateSession()
log.Debugf("Connection established with Cluster: %s:%d", cluster.Hosts[0], cluster.Port)
if err != nil {
Expand All @@ -127,6 +127,11 @@ func ConnectToCluster(hostname string, username string, password string, port in
return session, nil
}

// ConnectToCluster is used to connect to source and destination clusters
func ConnectToCluster(hostname string, username string, password string, port int) (*gocql.Session, error) {
return ConnectToClusterUsingVersion(hostname, username, password, port, 4)
}

// NewCluster initializes a ClusterConfig object with common settings
func NewCluster(hostname string, username string, password string, port int) *gocql.ClusterConfig {
cluster := gocql.NewCluster(hostname)
Expand Down
11 changes: 6 additions & 5 deletions proxy/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ type Config struct {

// Global bucket

PrimaryCluster string `default:"ORIGIN" split_words:"true"`
ReadMode string `default:"PRIMARY_ONLY" split_words:"true"`
ReplaceCqlFunctions bool `default:"false" split_words:"true"`
AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"`
LogLevel string `default:"INFO" split_words:"true"`
PrimaryCluster string `default:"ORIGIN" split_words:"true"`
ReadMode string `default:"PRIMARY_ONLY" split_words:"true"`
ReplaceCqlFunctions bool `default:"false" split_words:"true"`
AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"`
LogLevel string `default:"INFO" split_words:"true"`
ControlConnMaxProtocolVersion uint `default:"3" split_words:"true"`
lukasz-antoniak marked this conversation as resolved.
Show resolved Hide resolved

// Proxy Topology (also known as system.peers "virtualization") bucket

Expand Down
70 changes: 56 additions & 14 deletions proxy/pkg/zdmproxy/controlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zdmproxy

import (
"context"
"errors"
"fmt"
"github.com/datastax/go-cassandra-native-protocol/frame"
"github.com/datastax/go-cassandra-native-protocol/message"
Expand Down Expand Up @@ -58,7 +59,6 @@ type ControlConn struct {

const ProxyVirtualRack = "rack0"
const ProxyVirtualPartitioner = "org.apache.cassandra.dht.Murmur3Partitioner"
const ccProtocolVersion = primitive.ProtocolVersion3
const ccWriteTimeout = 5 * time.Second
const ccReadTimeout = 10 * time.Second

Expand Down Expand Up @@ -320,15 +320,9 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) (

currentIndex := (firstEndpointIndex + i) % len(endpoints)
endpoint = endpoints[currentIndex]
tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false)
if err != nil {
log.Warnf("Failed to open control connection to %v using endpoint %v: %v",
cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err)
continue
}

newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf)
err = newConn.InitializeContext(ccProtocolVersion, ctx)
newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ControlConnMaxProtocolVersion, ctx)

lukasz-antoniak marked this conversation as resolved.
Show resolved Hide resolved
if err == nil {
newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) {
switch f.Body.Message.(type) {
Expand All @@ -355,9 +349,11 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) (
log.Warnf("Error while initializing a new cql connection for the control connection of %v: %v",
cc.connConfig.GetClusterType(), err)
}
err2 := newConn.Close()
if err2 != nil {
log.Errorf("Failed to close cql connection: %v", err2)
if newConn != nil {
err2 := newConn.Close()
if err2 != nil {
log.Errorf("Failed to close cql connection: %v", err2)
}
}

continue
Expand All @@ -372,6 +368,52 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) (
return conn, endpoint
}

func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer uint, ctx context.Context) (CqlConnection, error) {
protoVer := primitive.ProtocolVersion(initialProtoVer)
for {
tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false)
if err != nil {
log.Warnf("Failed to open control connection to %v using endpoint %v: %v",
cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err)
return nil, err
}
newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf)
err = newConn.InitializeContext(protoVer, ctx)
var respErr *ResponseError
if err != nil && errors.As(err, &respErr) && respErr.IsProtocolError() && strings.Contains(err.Error(), "Invalid or unsupported protocol version") {
// unsupported protocol version
// protocol renegotiation requires opening a new TCP connection
err2 := newConn.Close()
if err2 != nil {
log.Errorf("Failed to close cql connection: %v", err2)
}
protoVer = downgradeProtocol(protoVer)
log.Infof("Downgrading protocol version: %v", protoVer)
lukasz-antoniak marked this conversation as resolved.
Show resolved Hide resolved
if protoVer == 0 {
// we cannot downgrade anymore
return nil, err
}
continue // retry lower protocol version
} else {
return newConn, err // we may have successfully established connection or faced other error
}
}
}

func downgradeProtocol(version primitive.ProtocolVersion) primitive.ProtocolVersion {
switch version {
case primitive.ProtocolVersionDse2:
return primitive.ProtocolVersionDse1
case primitive.ProtocolVersionDse1:
return primitive.ProtocolVersion4
case primitive.ProtocolVersion4:
return primitive.ProtocolVersion3
case primitive.ProtocolVersion3:
return primitive.ProtocolVersion2
}
return 0
}

func (cc *ControlConn) Close() {
cc.cqlConnLock.Lock()
conn := cc.cqlConn
Expand All @@ -387,7 +429,7 @@ func (cc *ControlConn) Close() {
}

func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]*Host, error) {
localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), ccProtocolVersion, ctx)
localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), ctx)
if err != nil {
return nil, fmt.Errorf("could not fetch information from system.local table: %w", err)
}
Expand All @@ -410,7 +452,7 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]
}
}

peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), ccProtocolVersion, ctx)
peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), ctx)
if err != nil {
return nil, fmt.Errorf("could not fetch information from system.peers table: %w", err)
}
Expand Down
19 changes: 14 additions & 5 deletions proxy/pkg/zdmproxy/cqlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -32,7 +33,7 @@ type CqlConnection interface {
SendAndReceive(request *frame.Frame, ctx context.Context) (*frame.Frame, error)
Close() error
Execute(msg message.Message, ctx context.Context) (message.Message, error)
Query(cql string, genericTypeCodec *GenericTypeCodec, version primitive.ProtocolVersion, ctx context.Context) (*ParsedRowSet, error)
Query(cql string, genericTypeCodec *GenericTypeCodec, ctx context.Context) (*ParsedRowSet, error)
SendHeartbeat(ctx context.Context) error
SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection))
SubscribeToProtocolEvents(ctx context.Context, eventTypes []primitive.EventType) error
Expand All @@ -59,6 +60,7 @@ type cqlConn struct {
eventHandlerLock *sync.Mutex
authEnabled bool
frameProcessor FrameProcessor
protocolVersion *atomic.Value
}

var (
Expand Down Expand Up @@ -237,6 +239,8 @@ func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx conte
return fmt.Errorf("failed to perform handshake: %w", err)
}

c.protocolVersion = &atomic.Value{}
lukasz-antoniak marked this conversation as resolved.
Show resolved Hide resolved
c.protocolVersion.Store(version)
c.initialized = true
c.authEnabled = authEnabled
return nil
Expand Down Expand Up @@ -353,6 +357,8 @@ func (c *cqlConn) PerformHandshake(version primitive.ProtocolVersion, ctx contex
}
}
}
case *message.ProtocolError:
err = &ResponseError{Response: response}
default:
err = fmt.Errorf("expected AUTHENTICATE or READY, got %v", response.Body.Message)
}
Expand All @@ -367,15 +373,16 @@ func (c *cqlConn) PerformHandshake(version primitive.ProtocolVersion, ctx contex
}

func (c *cqlConn) Query(
cql string, genericTypeCodec *GenericTypeCodec, version primitive.ProtocolVersion, ctx context.Context) (*ParsedRowSet, error) {
cql string, genericTypeCodec *GenericTypeCodec, ctx context.Context) (*ParsedRowSet, error) {
queryMsg := &message.Query{
Query: cql,
Options: &message.QueryOptions{
Consistency: primitive.ConsistencyLevelLocalQuorum,
},
}

queryFrame := frame.NewFrame(ccProtocolVersion, -1, queryMsg)
version := c.protocolVersion.Load().(primitive.ProtocolVersion)
queryFrame := frame.NewFrame(version, -1, queryMsg)
var rowSet *ParsedRowSet
for {
localResponse, err := c.SendAndReceive(queryFrame, ctx)
Expand Down Expand Up @@ -429,7 +436,8 @@ func (c *cqlConn) Query(
}

func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Message, error) {
queryFrame := frame.NewFrame(ccProtocolVersion, -1, msg)
version := c.protocolVersion.Load().(primitive.ProtocolVersion)
queryFrame := frame.NewFrame(version, -1, msg)
localResponse, err := c.SendAndReceive(queryFrame, ctx)
if err != nil {
return nil, err
Expand All @@ -440,7 +448,8 @@ func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Mes

func (c *cqlConn) SendHeartbeat(ctx context.Context) error {
optionsMsg := &message.Options{}
heartBeatFrame := frame.NewFrame(ccProtocolVersion, -1, optionsMsg)
version := c.protocolVersion.Load().(primitive.ProtocolVersion)
heartBeatFrame := frame.NewFrame(version, -1, optionsMsg)

response, err := c.SendAndReceive(heartBeatFrame, ctx)
if err != nil {
Expand Down
Loading
Loading