diff --git a/go/vt/grpcclient/client_auth_static_test.go b/go/vt/grpcclient/client_auth_static_test.go index e14ace527d1..325a3f6042c 100644 --- a/go/vt/grpcclient/client_auth_static_test.go +++ b/go/vt/grpcclient/client_auth_static_test.go @@ -17,7 +17,6 @@ limitations under the License. package grpcclient import ( - "errors" "fmt" "os" "reflect" @@ -26,39 +25,80 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" ) +func init() { + clientCredsSigChan = make(chan os.Signal, 1) +} + func TestAppendStaticAuth(t *testing.T) { - { - clientCreds = nil - clientCredsErr = nil - opts, err := AppendStaticAuth([]grpc.DialOption{}) - assert.Nil(t, err) - assert.Len(t, opts, 0) + oldCredsFile := credsFile + opts := []grpc.DialOption{ + grpc.EmptyDialOption{}, } - { - clientCreds = nil - clientCredsErr = errors.New("test err") - opts, err := AppendStaticAuth([]grpc.DialOption{}) - assert.NotNil(t, err) - assert.Len(t, opts, 0) + + tests := []struct { + name string + cFile string + expectedLen int + expectedErr string + }{ + { + name: "creds file not set", + expectedLen: 1, + }, + { + name: "non-existent creds file", + cFile: "./testdata/unknown.json", + expectedErr: "open ./testdata/unknown.json: no such file or directory", + }, + { + name: "valid creds file", + cFile: "./testdata/credsFile.json", + expectedLen: 2, + }, + { + name: "invalid creds file", + cFile: "./testdata/invalid.json", + expectedErr: "unexpected end of JSON input", + }, } - { - clientCreds = &StaticAuthClientCreds{Username: "test", Password: "123456"} - clientCredsErr = nil - opts, err := AppendStaticAuth([]grpc.DialOption{}) - assert.Nil(t, err) - assert.Len(t, opts, 1) + + for _, tt := range tests { + t.Run(tt.cFile, func(t *testing.T) { + defer func() { + credsFile = oldCredsFile + }() + + if tt.cFile != "" { + credsFile = tt.cFile + } + dialOpts, err := AppendStaticAuth(opts) + if tt.expectedErr == "" { + require.NoError(t, err) + require.Equal(t, tt.expectedLen, len(dialOpts)) + } else { + require.ErrorContains(t, err, tt.expectedErr) + } + ResetStaticAuth() + require.Nil(t, clientCredsCancel) + }) } } func TestGetStaticAuthCreds(t *testing.T) { + oldCredsFile := credsFile + defer func() { + ResetStaticAuth() + credsFile = oldCredsFile + }() tmp, err := os.CreateTemp("", t.Name()) assert.Nil(t, err) defer os.Remove(tmp.Name()) credsFile = tmp.Name() - clientCredsSigChan = make(chan os.Signal, 1) + ResetStaticAuth() // load old creds fmt.Fprint(tmp, `{"Username": "old", "Password": "123456"}`) diff --git a/go/vt/grpcclient/client_test.go b/go/vt/grpcclient/client_test.go index edc6d9be98c..40b03bef2f6 100644 --- a/go/vt/grpcclient/client_test.go +++ b/go/vt/grpcclient/client_test.go @@ -18,10 +18,13 @@ package grpcclient import ( "context" + "os" "strings" "testing" "time" + "github.com/spf13/pflag" + "github.com/stretchr/testify/require" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc" @@ -68,3 +71,33 @@ func TestDialErrors(t *testing.T) { } } } + +func TestRegisterGRPCClientFlags(t *testing.T) { + oldArgs := os.Args + defer func() { + os.Args = oldArgs + }() + + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + RegisterFlags(fs) + + // Test current values + require.Equal(t, 10*time.Second, keepaliveTime) + require.Equal(t, 10*time.Second, keepaliveTimeout) + require.Equal(t, 0, initialWindowSize) + require.Equal(t, 0, initialConnWindowSize) + require.Equal(t, "", compression) + require.Equal(t, "", credsFile) + + // Test setting flags from command-line arguments + os.Args = []string{"test", "--grpc_keepalive_time=5s", "--grpc_keepalive_timeout=5s", "--grpc_initial_conn_window_size=10", "--grpc_initial_window_size=10", "--grpc_compression=not-snappy", "--grpc_auth_static_client_creds=tempfile"} + err := fs.Parse(os.Args[1:]) + require.NoError(t, err) + + require.Equal(t, 5*time.Second, keepaliveTime) + require.Equal(t, 5*time.Second, keepaliveTimeout) + require.Equal(t, 10, initialWindowSize) + require.Equal(t, 10, initialConnWindowSize) + require.Equal(t, "not-snappy", compression) + require.Equal(t, "tempfile", credsFile) +} diff --git a/go/vt/grpcclient/glogger_test.go b/go/vt/grpcclient/glogger_test.go new file mode 100644 index 00000000000..6b394ff7ef9 --- /dev/null +++ b/go/vt/grpcclient/glogger_test.go @@ -0,0 +1,87 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package grpcclient + +import ( + "io" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func captureOutput(t *testing.T, f func()) string { + oldVal := os.Stderr + t.Cleanup(func() { + // Ensure reset even if deferred function panics + os.Stderr = oldVal + }) + + r, w, err := os.Pipe() + require.NoError(t, err) + + os.Stderr = w + + f() + + err = w.Close() + require.NoError(t, err) + + got, err := io.ReadAll(r) + require.NoError(t, err) + + return string(got) +} + +func TestGlogger(t *testing.T) { + gl := glogger{} + + output := captureOutput(t, func() { + gl.Warning("warning") + }) + require.Contains(t, output, "warning") + + output = captureOutput(t, func() { + gl.Warningln("warningln") + }) + require.Contains(t, output, "warningln\n") + + output = captureOutput(t, func() { + gl.Warningf("formatted %s", "warning") + }) + require.Contains(t, output, "formatted warning") + +} + +func TestGloggerError(t *testing.T) { + gl := glogger{} + + output := captureOutput(t, func() { + gl.Error("error message") + }) + require.Contains(t, output, "error message") + + output = captureOutput(t, func() { + gl.Errorln("error message line") + }) + require.Contains(t, output, "error message line\n") + + output = captureOutput(t, func() { + gl.Errorf("this is a %s error message", "formatted") + }) + require.Contains(t, output, "this is a formatted error message") +} diff --git a/go/vt/grpcclient/snappy_test.go b/go/vt/grpcclient/snappy_test.go new file mode 100644 index 00000000000..41d205bf04d --- /dev/null +++ b/go/vt/grpcclient/snappy_test.go @@ -0,0 +1,62 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package grpcclient + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +func TestCompressDecompress(t *testing.T) { + snappComp := SnappyCompressor{} + writer, err := snappComp.Compress(&bytes.Buffer{}) + require.NoError(t, err) + require.NotEmpty(t, writer) + + reader, err := snappComp.Decompress(&bytes.Buffer{}) + require.NoError(t, err) + require.NotEmpty(t, reader) +} + +func TestAppendCompression(t *testing.T) { + oldCompression := compression + defer func() { + compression = oldCompression + }() + + dialOpts := []grpc.DialOption{} + dialOpts, err := appendCompression(dialOpts) + require.NoError(t, err) + require.Equal(t, 0, len(dialOpts)) + + // Change the compression to snappy + compression = "snappy" + + dialOpts, err = appendCompression(dialOpts) + require.NoError(t, err) + require.Equal(t, 1, len(dialOpts)) + + // Change the compression to some unknown value + compression = "unknown" + + dialOpts, err = appendCompression(dialOpts) + require.NoError(t, err) + require.Equal(t, 1, len(dialOpts)) +} diff --git a/go/vt/grpcclient/testdata/credsFile.json b/go/vt/grpcclient/testdata/credsFile.json new file mode 100644 index 00000000000..e036126f78e --- /dev/null +++ b/go/vt/grpcclient/testdata/credsFile.json @@ -0,0 +1,4 @@ +{ + "Username": "test-user", + "Password": "test-pass" +} \ No newline at end of file diff --git a/go/vt/grpcclient/testdata/invalid.json b/go/vt/grpcclient/testdata/invalid.json new file mode 100644 index 00000000000..81750b96f9d --- /dev/null +++ b/go/vt/grpcclient/testdata/invalid.json @@ -0,0 +1 @@ +{ \ No newline at end of file