diff --git a/clickhouse_options.go b/clickhouse_options.go index b457364e6e..58079e8d28 100644 --- a/clickhouse_options.go +++ b/clickhouse_options.go @@ -164,6 +164,10 @@ type Options struct { scheme string ReadTimeout time.Duration + + // ClientTCPProtocolVersion specifies the custom protocol revision, as defined in lib/proto/const.go + // if not specified, the latest supported protocol revision, proto.DBMS_TCP_PROTOCOL_VERSION , is used. + ClientTCPProtocolVersion uint64 } func (o *Options) fromDSN(in string) error { @@ -399,5 +403,9 @@ func (o Options) setDefaults() *Options { o.Addr = []string{"localhost:8123"} } } + if o.ClientTCPProtocolVersion == 0 { + o.ClientTCPProtocolVersion = ClientTCPProtocolVersion + } + return &o } diff --git a/conn.go b/conn.go index 64e14db20a..a13c668df4 100644 --- a/conn.go +++ b/conn.go @@ -91,6 +91,10 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er compressor = compress.NewWriter(compress.LevelZero, compress.None) } + if opt.ClientTCPProtocolVersion < proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO || opt.ClientTCPProtocolVersion > proto.DBMS_TCP_PROTOCOL_VERSION { + return nil, fmt.Errorf("unsupported protocol revision") + } + var ( connect = &connect{ id: num, @@ -99,7 +103,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er debugf: debugf, buffer: new(chproto.Buffer), reader: chproto.NewReader(conn), - revision: ClientTCPProtocolVersion, + revision: opt.ClientTCPProtocolVersion, structMap: &structMap{}, compression: compression, connectedAt: time.Now(), diff --git a/conn_handshake.go b/conn_handshake.go index 34e65df8bc..271c998992 100644 --- a/conn_handshake.go +++ b/conn_handshake.go @@ -37,7 +37,7 @@ func (c *connect) handshake(auth Auth) error { { c.buffer.PutByte(proto.ClientHello) handshake := &proto.ClientHandshake{ - ProtocolVersion: ClientTCPProtocolVersion, + ProtocolVersion: c.revision, ClientName: c.opt.ClientInfo.String(), ClientVersion: proto.Version{ClientVersionMajor, ClientVersionMinor, ClientVersionPatch}, //nolint:govet } @@ -60,7 +60,7 @@ func (c *connect) handshake(auth Auth) error { case proto.ServerException: return c.exception() case proto.ServerHello: - if err := c.server.Decode(c.reader); err != nil { + if err := c.server.Decode(c.reader, c.revision); err != nil { return err } case proto.ServerEndOfStream: diff --git a/conn_send_query.go b/conn_send_query.go index 8897a8c768..dc05305781 100644 --- a/conn_send_query.go +++ b/conn_send_query.go @@ -27,7 +27,7 @@ func (c *connect) sendQuery(body string, o *QueryOptions) error { c.debugf("[send query] compression=%q %s", c.compression, body) c.buffer.PutByte(proto.ClientQuery) q := proto.Query{ - ClientTCPProtocolVersion: ClientTCPProtocolVersion, + ClientTCPProtocolVersion: c.revision, ClientName: c.opt.ClientInfo.String(), ClientVersion: proto.Version{ClientVersionMajor, ClientVersionMinor, ClientVersionPatch}, //nolint:govet ID: o.queryID, diff --git a/lib/proto/handshake.go b/lib/proto/handshake.go index 6ee620905c..c880d565a6 100644 --- a/lib/proto/handshake.go +++ b/lib/proto/handshake.go @@ -85,7 +85,7 @@ func CheckMinVersion(constraint Version, version Version) bool { return true } -func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) { +func (srv *ServerHandshake) Decode(reader *chproto.Reader, clientRevision uint64) (err error) { if srv.Name, err = reader.Str(); err != nil { return fmt.Errorf("could not read server name: %v", err) } @@ -98,7 +98,8 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) { if srv.Revision, err = reader.UVarInt(); err != nil { return fmt.Errorf("could not read server revision: %v", err) } - if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE { + rev := min(clientRevision, srv.Revision) + if rev >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE { name, err := reader.Str() if err != nil { return fmt.Errorf("could not read server timezone: %v", err) @@ -107,12 +108,12 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) { return fmt.Errorf("could not load time location: %v", err) } } - if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME { + if rev >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME { if srv.DisplayName, err = reader.Str(); err != nil { return fmt.Errorf("could not read server display name: %v", err) } } - if srv.Revision >= DBMS_MIN_REVISION_WITH_VERSION_PATCH { + if rev >= DBMS_MIN_REVISION_WITH_VERSION_PATCH { if srv.Version.Patch, err = reader.UVarInt(); err != nil { return fmt.Errorf("could not read server patch: %v", err) } diff --git a/tests/std/conn_test.go b/tests/std/conn_test.go index a6d2cd47e2..73f93ffb2e 100644 --- a/tests/std/conn_test.go +++ b/tests/std/conn_test.go @@ -31,6 +31,8 @@ import ( "time" "github.com/ClickHouse/clickhouse-go/v2" + "github.com/ClickHouse/clickhouse-go/v2/lib/driver" + "github.com/ClickHouse/clickhouse-go/v2/lib/proto" clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -231,6 +233,80 @@ func TestStdConnector(t *testing.T) { require.NoError(t, err) } +func TestCustomProtocolRevision(t *testing.T) { + env, err := GetStdTestEnvironment() + require.NoError(t, err) + useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false")) + require.NoError(t, err) + port := env.Port + var tlsConfig *tls.Config + if useSSL { + port = env.SslPort + tlsConfig = &tls.Config{} + } + baseOpts := clickhouse.Options{ + Addr: []string{fmt.Sprintf("%s:%d", env.Host, port)}, + Auth: clickhouse.Auth{ + Database: "default", + Username: env.Username, + Password: env.Password, + }, + Compression: &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }, + TLS: tlsConfig, + } + t.Run("unsupported proto versions", func(t *testing.T) { + badOpts := baseOpts + badOpts.ClientTCPProtocolVersion = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO - 1 + conn, _ := clickhouse.Open(&badOpts) + require.NotNil(t, conn) + err = conn.Ping(t.Context()) + require.Error(t, err) + badOpts.ClientTCPProtocolVersion = proto.DBMS_TCP_PROTOCOL_VERSION + 1 + conn, _ = clickhouse.Open(&badOpts) + require.NotNil(t, conn) + err = conn.Ping(t.Context()) + require.Error(t, err) + }) + + t.Run("minimal proto version", func(t *testing.T) { + opts := baseOpts + opts.ClientTCPProtocolVersion = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO + conn, err := clickhouse.Open(&opts) + require.NoError(t, err) + require.NotNil(t, conn) + err = conn.Ping(t.Context()) + require.NoError(t, err) + + defer func() { + _ = conn.Exec(t.Context(), "DROP TABLE insert_example") + }() + err = conn.Exec(t.Context(), "DROP TABLE IF EXISTS insert_example") + + err = conn.Exec(t.Context(), ` + CREATE TABLE insert_example ( + Col1 UInt64 + ) Engine = MergeTree() ORDER BY tuple() + `) + require.NoError(t, err) + var batch driver.Batch + batch, err = conn.PrepareBatch(t.Context(), "INSERT INTO insert_example (Col1)") + require.NoError(t, err) + require.NoError(t, batch.Append(10)) + require.NoError(t, batch.Send()) + + rows, err := conn.Query(t.Context(), "SELECT Col1 FROM insert_example") + require.NoError(t, err) + count := 0 + for rows.Next() { + count++ + } + assert.Equal(t, 1, count) + }) + +} + func TestBlockBufferSize(t *testing.T) { env, err := GetStdTestEnvironment() require.NoError(t, err)