From 8692d20d8dda0eecb2072348a6704a8a1185ae91 Mon Sep 17 00:00:00 2001 From: Bohdan Siryk Date: Fri, 28 Nov 2025 10:14:22 +0200 Subject: [PATCH] Protocol version negotiation doesn't work if server replies with stream id different than 0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, protocol negotiation didn't work properly when C* was responding with stream id different from 0. This patch changes the way protocol negotiation works. Instead of parsing a supported protocol version from C* error response, the driver tries to connect with each supported protocol starting from the latest. Patch by Bohdan Siryk; Reviewed by João Reis for CASSGO-98 --- CHANGELOG.md | 1 + Makefile | 2 +- conn.go | 51 ++++++- conn_test.go | 78 ++++++++-- control.go | 98 +++++-------- control_test.go | 35 ----- errors.go | 10 ++ frame.go | 15 +- protocol_negotiation_test.go | 267 +++++++++++++++++++++++++++++++++++ 9 files changed, 442 insertions(+), 115 deletions(-) create mode 100644 protocol_negotiation_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index cd327073b..872ebdd1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Prevent panic with queries during session init (CASSGO-92) - Return correct values from RowData (CASSGO-95) - Prevent setting a compression flag in a frame header when native proto v5 is being used (CASSGO-98) +- Use protocol downgrading approach during protocol negotiation (CASSGO-97) ## [2.0.0] diff --git a/Makefile b/Makefile index 56ed015db..c4f6df212 100644 --- a/Makefile +++ b/Makefile @@ -100,7 +100,7 @@ test-integration-auth: .prepare-cassandra-cluster test-unit: @echo "Run unit tests" @go clean -testcache - go test -tags unit -timeout=5m -race ./... + go test -v -tags unit -timeout=5m -race ./... check: .prepare-golangci @echo "Build" diff --git a/conn.go b/conn.go index 40044565d..a9bb4f5ab 100644 --- a/conn.go +++ b/conn.go @@ -378,6 +378,13 @@ func (s *startupCoordinator) setupConn(ctx context.Context) error { select { case err := <-startupErr: if err != nil { + if s.checkProtocolRelatedError(err) { + return &unsupportedProtocolVersionError{ + err: err, + hostInfo: s.conn.host, + version: protoVersion(s.conn.version), + } + } return err } case <-ctx.Done(): @@ -387,6 +394,38 @@ func (s *startupCoordinator) setupConn(ctx context.Context) error { return nil } +// Checks if the error is protocol related and should be retried during startup. +// It returns the frame that caused the error and whether the error should be retried. +func (s *startupCoordinator) checkProtocolRelatedError(err error) bool { + var unwrappedFrame frame + + var protocolErr *protocolError + if !errors.As(err, &protocolErr) { + var errFrame errorFrame + if !errors.As(err, &errFrame) { + return false + } else { + unwrappedFrame = errFrame + } + } else { + unwrappedFrame = protocolErr.frame + } + + switch frame := unwrappedFrame.(type) { + case *supportedFrame: + // We can receive a supportedFrame wrapped in protocolError from Conn.recv if the host responds to a 0 stream id. + // If we receive a supportedFrame then we know that the host is not compatible with the protocol version, but it is reachable, so we can retry + return true + case errorFrame: + // If we receive an errorFrame with codes ErrCodeProtocol or ErrCodeServer, + // then we should try to downgrade a protocol version, so do not skip the host + return frame.code == ErrCodeProtocol || frame.code == ErrCodeServer + default: + // In any other case we should not retry as it means the host is not reachable or some other error happened + return false + } +} + func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder, startupCompleted *atomic.Bool) (frame, error) { select { case s.frameTicker <- struct{}{}: @@ -408,12 +447,14 @@ func (s *startupCoordinator) options(ctx context.Context, startupCompleted *atom return err } - supported, ok := frame.(*supportedFrame) - if !ok { - return NewErrProtocol("Unknown type of response to startup frame: %T", frame) + switch frame := frame.(type) { + case *supportedFrame: + return s.startup(ctx, frame.supported, startupCompleted) + case error: + return frame + default: + return NewErrProtocol("Unknown type of response to startup frame: %T (frame=%s)", frame, frame.String()) } - - return s.startup(ctx, supported.supported, startupCompleted) } func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string, startupCompleted *atomic.Bool) error { diff --git a/conn_test.go b/conn_test.go index 60e4a2a8a..ad4e66e54 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1054,6 +1054,9 @@ type newTestServerOpts struct { addr string protocol uint8 recvHook func(*framer) + + customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error + dontFailOnProtocolMismatch bool } func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestServer { @@ -1078,6 +1081,9 @@ func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestS cancel: cancel, onRecv: nts.recvHook, + + customRequestHandler: nts.customRequestHandler, + dontFailOnProtocolMismatch: nts.dontFailOnProtocolMismatch, } go srv.closeWatch() @@ -1142,6 +1148,10 @@ type TestServer struct { // onRecv is a hook point for tests, called in receive loop. onRecv func(*framer) + + // customRequestHandler allows overriding the default request handling for testing purposes. + customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error + dontFailOnProtocolMismatch bool } func (srv *TestServer) closeWatch() { @@ -1162,9 +1172,26 @@ func (srv *TestServer) serve() { } go func(conn net.Conn) { + var startupCompleted bool + var useProtoV5 bool + defer conn.Close() for !srv.isClosed() { - framer, err := srv.readFrame(conn) + var reader io.Reader = conn + + if useProtoV5 && startupCompleted { + frame, _, err := readUncompressedSegment(conn) + if err != nil { + if errors.Is(err, io.EOF) { + return + } + srv.errorLocked(err) + return + } + reader = bytes.NewReader(frame) + } + + framer, err := srv.readFrame(reader) if err != nil { if err == io.EOF { return @@ -1177,7 +1204,7 @@ func (srv *TestServer) serve() { srv.onRecv(framer) } - go srv.process(conn, framer) + srv.process(conn, framer, &useProtoV5, &startupCompleted) } }(conn) } @@ -1215,13 +1242,22 @@ func (srv *TestServer) errorLocked(err interface{}) { srv.t.Error(err) } -func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { +func (srv *TestServer) process(conn net.Conn, reqFrame *framer, useProtoV5, startupCompleted *bool) { head := reqFrame.header if head == nil { srv.errorLocked("process frame with a nil header") return } - respFrame := newFramer(nil, reqFrame.proto, GlobalTypes) + respFrame := newFramer(nil, byte(head.version), GlobalTypes) + + if srv.customRequestHandler != nil { + if err := srv.customRequestHandler(srv, reqFrame, respFrame); err != nil { + srv.errorLocked(err) + return + } + // Dont like this but... + goto finish + } switch head.op { case opStartup: @@ -1412,26 +1448,46 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { respFrame.writeString("not supported") } - respFrame.buf[0] = srv.protocol | 0x80 +finish: + + respFrame.buf[0] |= 0x80 if err := respFrame.finish(); err != nil { srv.errorLocked(err) } - if err := respFrame.writeTo(conn); err != nil { - srv.errorLocked(err) + if *useProtoV5 && *startupCompleted { + segment, err := newUncompressedSegment(respFrame.buf, true) + if err == nil { + _, err = conn.Write(segment) + } + if err != nil { + srv.errorLocked(err) + return + } + } else { + if err := respFrame.writeTo(conn); err != nil { + srv.errorLocked(err) + } + + if reqFrame.header.op == opStartup { + *startupCompleted = true + if head.version == protoVersion5 { + *useProtoV5 = true + } + } } } -func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { +func (srv *TestServer) readFrame(reader io.Reader) (*framer, error) { buf := make([]byte, srv.headerSize) - head, err := readHeader(conn, buf) + head, err := readHeader(reader, buf) if err != nil { return nil, err } framer := newFramer(nil, srv.protocol, GlobalTypes) - err = framer.readFrame(conn, &head) + err = framer.readFrame(reader, &head) if err != nil { return nil, err } @@ -1439,7 +1495,7 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { // should be a request frame if head.version.response() { return nil, fmt.Errorf("expected to read a request frame got version: %v", head.version) - } else if head.version.version() != srv.protocol { + } else if !srv.dontFailOnProtocolMismatch && head.version.version() != srv.protocol { return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.version.version()) } diff --git a/control.go b/control.go index e59acb402..c518ba685 100644 --- a/control.go +++ b/control.go @@ -32,7 +32,6 @@ import ( "math/rand" "net" "os" - "regexp" "strconv" "sync" "sync/atomic" @@ -202,56 +201,9 @@ func shuffleHosts(hosts []*HostInfo) []*HostInfo { return shuffled } -// this is going to be version dependant and a nightmare to maintain :( -var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`) -var betaProtocolRe = regexp.MustCompile(`Beta version of the protocol used \(.*\), but USE_BETA flag is unset`) - -func parseProtocolFromError(err error) int { - errStr := err.Error() - - var errProtocol ErrProtocol - if errors.As(err, &errProtocol) { - err = errProtocol.error - } - - // I really wish this had the actual info in the error frame... - matches := betaProtocolRe.FindAllStringSubmatch(errStr, -1) - if len(matches) == 1 { - var protoErr *protocolError - if errors.As(err, &protoErr) { - version := protoErr.frame.Header().version.version() - if version > 0 { - return int(version - 1) - } - } - return 0 - } - - matches = protocolSupportRe.FindAllStringSubmatch(errStr, -1) - if len(matches) != 1 || len(matches[0]) != 2 { - var protoErr *protocolError - if errors.As(err, &protoErr) { - return int(protoErr.frame.Header().version.version()) - } - return 0 - } - - max, err := strconv.Atoi(matches[0][1]) - if err != nil { - return 0 - } - - return max -} - -const highestProtocolVersionSupported = 5 - func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { hosts = shuffleHosts(hosts) - connCfg := *c.session.connCfg - connCfg.ProtoVersion = highestProtocolVersionSupported - handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) { // we should never get here, but if we do it means we connected to a // host successfully which means our attempted protocol version worked @@ -261,30 +213,56 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { }) var err error + var proto int for _, host := range hosts { - var conn *Conn - conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler) + proto, err = c.tryProtocolVersionsForHost(host, handler) + if err == nil { + return proto, nil + } + + c.session.logger.Debug("Failed to discover protocol version for host.", + NewLogFieldIP("host_addr", host.ConnectAddress()), + NewLogFieldError("err", err)) + } + + return 0, err +} + +func (c *controlConn) tryProtocolVersionsForHost(host *HostInfo, handler ConnErrorHandler) (int, error) { + connCfg := *c.session.connCfg + + var triedVersions []int + + for proto := highestProtocolVersionSupported; proto >= lowestProtocolVersionSupported; proto-- { + connCfg.ProtoVersion = proto + + conn, err := c.session.dial(c.session.ctx, host, &connCfg, handler) if conn != nil { conn.Close() } if err == nil { - c.session.logger.Debug("Discovered protocol version using host.", - NewLogFieldInt("protocol_version", connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID())) - return connCfg.ProtoVersion, nil + return proto, nil } - if proto := parseProtocolFromError(err); proto > 0 { - c.session.logger.Debug("Discovered protocol version using host after parsing protocol error.", - NewLogFieldInt("protocol_version", proto), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID())) - return proto, nil + var unsupportedErr *unsupportedProtocolVersionError + if errors.As(err, &unsupportedErr) { + // the host does not support this protocol version, try a lower version + c.session.logger.Debug("Failed to connect to host during protocol negotiation.", + NewLogFieldIP("host_addr", host.ConnectAddress()), + NewLogFieldInt("proto_version", proto), + NewLogFieldError("err", err)) + triedVersions = append(triedVersions, connCfg.ProtoVersion) + continue } - c.session.logger.Debug("Failed to discover protocol version using host.", - NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()), NewLogFieldError("err", err)) + c.session.logger.Debug("Error connecting to host during protocol negotiation.", + NewLogFieldIP("host_addr", host.ConnectAddress()), + NewLogFieldError("err", err)) + return 0, err } - return 0, err + return 0, fmt.Errorf("gocql: failed to discover protocol version for host %s, tried versions: %v", host.ConnectAddress(), triedVersions) } func (c *controlConn) connect(hosts []*HostInfo, sessionInit bool) error { diff --git a/control_test.go b/control_test.go index 9f83ec955..7d9311a68 100644 --- a/control_test.go +++ b/control_test.go @@ -57,38 +57,3 @@ func TestHostInfo_Lookup(t *testing.T) { } } } - -func TestParseProtocol(t *testing.T) { - tests := [...]struct { - err error - proto int - }{ - { - err: &protocolError{ - frame: errorFrame{ - code: 0x10, - message: "Invalid or unsupported protocol version (5); the lowest supported version is 3 and the greatest is 4", - }, - }, - proto: 4, - }, - { - err: &protocolError{ - frame: errorFrame{ - frameHeader: frameHeader{ - version: 0x83, - }, - code: 0x10, - message: "Invalid or unsupported protocol version: 5", - }, - }, - proto: 3, - }, - } - - for i, test := range tests { - if proto := parseProtocolFromError(test.err); proto != test.proto { - t.Errorf("%d: exepcted proto %d got %d", i, test.proto, proto) - } - } -} diff --git a/errors.go b/errors.go index 2d1c2205d..4305f78fd 100644 --- a/errors.go +++ b/errors.go @@ -244,3 +244,13 @@ type RequestErrCASWriteUnknown struct { Received int BlockFor int } + +type unsupportedProtocolVersionError struct { + hostInfo *HostInfo + version protoVersion + err error +} + +func (e unsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version %d for host %s", e.version, e.hostInfo.ConnectAddress()) +} diff --git a/frame.go b/frame.go index e86c538c8..0316032e5 100644 --- a/frame.go +++ b/frame.go @@ -65,12 +65,13 @@ func NamedValue(name string, value interface{}) interface{} { const ( protoDirectionMask = 0x80 protoVersionMask = 0x7F - protoVersion1 = 0x01 - protoVersion2 = 0x02 protoVersion3 = 0x03 protoVersion4 = 0x04 protoVersion5 = 0x05 + lowestProtocolVersionSupported = protoVersion3 + highestProtocolVersionSupported = protoVersion5 + maxFrameSize = 256 * 1024 * 1024 maxSegmentPayloadSize = 0x1FFFF @@ -422,7 +423,7 @@ func readHeader(r io.Reader, p []byte) (head frameHeader, err error) { version := p[0] & protoVersionMask - if version < protoVersion3 || version > protoVersion5 { + if version < lowestProtocolVersionSupported || version > highestProtocolVersionSupported { return frameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version) } @@ -2370,6 +2371,14 @@ func (f *framer) writeStringMap(m map[string]string) { } } +func (f *framer) writeStringMultiMap(m map[string][]string) { + f.writeShort(uint16(len(m))) + for k, v := range m { + f.writeString(k) + f.writeStringList(v) + } +} + func (f *framer) writeBytesMap(m map[string][]byte) { f.writeShort(uint16(len(m))) for k, v := range m { diff --git a/protocol_negotiation_test.go b/protocol_negotiation_test.go new file mode 100644 index 000000000..567c74e36 --- /dev/null +++ b/protocol_negotiation_test.go @@ -0,0 +1,267 @@ +//go:build all || unit +// +build all unit + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gocql + +import ( + "context" + "encoding/binary" + "fmt" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type requestHandlerForProtocolNegotiationTest struct { + supportedProtocolVersions []protoVersion + supportedBetaProtocols []protoVersion + + // forces stream id to 0 + forceZeroStreamID bool + + forceCloseConnection bool +} + +func (r *requestHandlerForProtocolNegotiationTest) supportsBetaProtocol(version protoVersion) bool { + return slices.Contains(r.supportedBetaProtocols, version) +} + +func (r *requestHandlerForProtocolNegotiationTest) supportsProtocol(version protoVersion) bool { + return slices.Contains(r.supportedProtocolVersions, version) +} + +func (r *requestHandlerForProtocolNegotiationTest) hasBetaFlag(header *frameHeader) bool { + return header.flags&flagBetaProtocol == flagBetaProtocol +} + +func (r *requestHandlerForProtocolNegotiationTest) createBetaFlagUnsetProtocolErrorMessage(version protoVersion) string { + return fmt.Sprintf("Beta version of the protocol used (%d/v%d-beta), but USE_BETA flag is unset", version, version) +} + +func (r *requestHandlerForProtocolNegotiationTest) handle(_ *TestServer, reqFrame, respFrame *framer) error { + if r.forceCloseConnection { + return fmt.Errorf("NEGOTIATION TEST: forcing close connection") + } + + stream := reqFrame.header.stream + + // If a client uses beta protocol, but the USE_BETA flag is not set, we respond with an error + if r.supportsBetaProtocol(reqFrame.header.version) && !r.hasBetaFlag(reqFrame.header) { + if r.forceZeroStreamID { + stream = 0 + } + respFrame.writeHeader(0, opError, stream) + respFrame.writeInt(ErrCodeProtocol) + respFrame.writeString(r.createBetaFlagUnsetProtocolErrorMessage(reqFrame.header.version)) + return nil + } + + // if a client uses an unsupported protocol version, we respond with an error + if !r.supportsProtocol(reqFrame.header.version) { + if r.forceZeroStreamID { + stream = 0 + } + respFrame.writeHeader(0, opError, stream) + respFrame.writeInt(ErrCodeProtocol) + respFrame.writeString(fmt.Sprintf("NEGOTIATION TEST: Unsupported protocol version %d", reqFrame.header.version)) + return nil + } + + switch reqFrame.header.op { + case opStartup, opRegister: + respFrame.writeHeader(0, opReady, stream) + case opOptions: + // Emulating C* behavior. + // If a client uses an unsupported protocol version, C* responds with supported versions to 0 stream id. + // If a client uses a beta protocol version, but the USE_BETA flag is not set, C* responds with supported versions to 0 stream id. + if r.forceZeroStreamID && !(r.supportsProtocol(reqFrame.header.version) || r.supportsBetaProtocol(reqFrame.header.version) && !r.hasBetaFlag(reqFrame.header)) { + stream = 0 + } + respFrame.writeHeader(0, opSupported, stream) + var supportedVersionsWithDesc []string + for _, supportedVersion := range r.supportedProtocolVersions { + supportedVersionsWithDesc = append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d", supportedVersion, supportedVersion)) + } + for _, betaProtocol := range r.supportedBetaProtocols { + supportedVersionsWithDesc = append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d-beta", betaProtocol, betaProtocol)) + } + supported := map[string][]string{ + "PROTOCOL_VERSIONS": supportedVersionsWithDesc, + } + respFrame.writeStringMultiMap(supported) + case opQuery: + respFrame.writeHeader(0, opResult, stream) + respFrame.writeInt(resultKindRows) + respFrame.writeInt(int32(flagGlobalTableSpec)) + respFrame.writeInt(1) + respFrame.writeString("system") + respFrame.writeString("local") + respFrame.writeString("rack") + respFrame.writeShort(uint16(TypeVarchar)) + respFrame.writeInt(1) + respFrame.writeInt(int32(len("rack-1"))) + respFrame.writeString("rack-1") + case opPrepare: + // This doesn't really make any sense, but it's enough to test the protocol negotiation + respFrame.writeHeader(0, opResult, stream) + respFrame.writeInt(resultKindPrepared) + // + respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 111)) + if respFrame.proto >= protoVersion5 { + respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 222)) + } + // + respFrame.writeInt(0) // + respFrame.writeInt(0) // + if reqFrame.header.version >= protoVersion4 { + respFrame.writeInt(0) // + } + // + respFrame.writeInt(int32(flagGlobalTableSpec)) // + respFrame.writeInt(1) // + // + respFrame.writeString("system") + respFrame.writeString("keyspaces") + // + respFrame.writeString("col0") // + respFrame.writeShort(uint16(TypeBoolean)) // + case opExecute: + // This doesn't really make any sense, but it's enough to test the protocol negotiation + respFrame.writeHeader(0, opResult, stream) + respFrame.writeInt(resultKindRows) + // + respFrame.writeInt(0) // + respFrame.writeInt(0) // + // + respFrame.writeInt(0) + } + + return nil +} + +func mockedErrorCodeHandler(errorCode int) func(*TestServer, *framer, *framer) error { + return func(_ *TestServer, reqFrame *framer, respFrame *framer) error { + reqFrame.writeHeader(0, opError, reqFrame.header.stream) + reqFrame.writeInt(int32(errorCode)) + reqFrame.writeString(fmt.Sprintf("NEGOTIATION TEST: Error code %d", errorCode)) + return nil + } +} + +func TestProtocolNegotiation(t *testing.T) { + testCases := []struct { + name string + supportedVersions []protoVersion + supportedBetaVersions []protoVersion + expectedVersion protoVersion + expectedErrorMsg string + + forceZeroStreamID bool + overrideHost string + + requestHandler func(*TestServer, *framer, *framer) error + }{ + { + name: "all supported versions", + supportedVersions: []protoVersion{protoVersion3, protoVersion4, protoVersion5}, + expectedVersion: protoVersion5, + }, + { + name: "v5-beta is supported", + supportedVersions: []protoVersion{protoVersion3, protoVersion4}, + supportedBetaVersions: []protoVersion{protoVersion5}, + expectedVersion: protoVersion4, + }, + { + name: "v5 is unsupported", + supportedVersions: []protoVersion{protoVersion3, protoVersion4}, + expectedVersion: protoVersion4, + }, + { + name: "all supported versions / 0 stream id", + supportedVersions: []protoVersion{protoVersion3, protoVersion4, protoVersion5}, + expectedVersion: protoVersion5, + forceZeroStreamID: true, + }, + { + name: "v5-beta is supported / 0 stream id", + supportedVersions: []protoVersion{protoVersion3, protoVersion4}, + supportedBetaVersions: []protoVersion{protoVersion5}, + expectedVersion: protoVersion4, + forceZeroStreamID: true, + }, + { + name: "v5 is unsupported / 0 stream id", + supportedVersions: []protoVersion{protoVersion3, protoVersion4}, + expectedVersion: protoVersion4, + forceZeroStreamID: true, + }, + { + name: "wrong host addr", + expectedErrorMsg: "unable to discover protocol version", + overrideHost: "1.2.3.4", // totally wrong addr to get network related error + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + handler := &requestHandlerForProtocolNegotiationTest{ + supportedProtocolVersions: tc.supportedVersions, + supportedBetaProtocols: tc.supportedBetaVersions, + forceZeroStreamID: tc.forceZeroStreamID, + } + + srv := newTestServerOpts{ + addr: "127.0.0.1:0", + protocol: 5, + customRequestHandler: handler.handle, + dontFailOnProtocolMismatch: true, + }.newServer(t, context.Background()) + + go srv.serve() + defer srv.Stop() + + cluster := NewCluster(srv.Address) + if tc.overrideHost != "" { + cluster.Hosts = []string{tc.overrideHost} + } + + cluster.Compressor = nil + cluster.ProtoVersion = 0 + cluster.Logger = NewLogger(LogLevelDebug) + cluster.ConnectTimeout = time.Second * 2 + cluster.Timeout = time.Second * 2 + cluster.DisableInitialHostLookup = true + + s, err := cluster.CreateSession() + switch { + case tc.expectedErrorMsg != "": + require.Error(t, err) + require.ErrorContains(t, err, tc.expectedErrorMsg) + default: + require.NoError(t, err) + require.Equal(t, tc.expectedVersion, protoVersion(s.cfg.ProtoVersion)) + } + }) + } +}