Skip to content

Commit acae954

Browse files
Load --grpc_auth_static_client_creds file once (#15030)
Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com>
1 parent 9df1763 commit acae954

File tree

3 files changed

+203
-12
lines changed

3 files changed

+203
-12
lines changed

go/vt/grpcclient/client_auth_static.go

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,35 @@ import (
2020
"context"
2121
"encoding/json"
2222
"os"
23+
"os/signal"
24+
"sync"
25+
"syscall"
2326

2427
"google.golang.org/grpc"
2528
"google.golang.org/grpc/credentials"
29+
30+
"vitess.io/vitess/go/vt/servenv"
2631
)
2732

2833
var (
2934
credsFile string // registered as --grpc_auth_static_client_creds in RegisterFlags
3035
// StaticAuthClientCreds implements client interface to be able to WithPerRPCCredentials
3136
_ credentials.PerRPCCredentials = (*StaticAuthClientCreds)(nil)
37+
38+
clientCreds *StaticAuthClientCreds
39+
clientCredsCancel context.CancelFunc
40+
clientCredsErr error
41+
clientCredsMu sync.Mutex
42+
clientCredsSigChan chan os.Signal
3243
)
3344

34-
// StaticAuthClientCreds holder for client credentials
45+
// StaticAuthClientCreds holder for client credentials.
3546
type StaticAuthClientCreds struct {
3647
Username string
3748
Password string
3849
}
3950

40-
// GetRequestMetadata gets the request metadata as a map from StaticAuthClientCreds
51+
// GetRequestMetadata gets the request metadata as a map from StaticAuthClientCreds.
4152
func (c *StaticAuthClientCreds) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
4253
return map[string]string{
4354
"username": c.Username,
@@ -47,30 +58,82 @@ func (c *StaticAuthClientCreds) GetRequestMetadata(context.Context, ...string) (
4758

4859
// RequireTransportSecurity indicates whether the credentials requires transport security.
4960
// Given that people can use this with or without TLS, at the moment we are not enforcing
50-
// transport security
61+
// transport security.
5162
func (c *StaticAuthClientCreds) RequireTransportSecurity() bool {
5263
return false
5364
}
5465

5566
// AppendStaticAuth optionally appends static auth credentials if provided.
5667
func AppendStaticAuth(opts []grpc.DialOption) ([]grpc.DialOption, error) {
57-
if credsFile == "" {
58-
return opts, nil
59-
}
60-
data, err := os.ReadFile(credsFile)
68+
creds, err := getStaticAuthCreds()
6169
if err != nil {
6270
return nil, err
6371
}
64-
clientCreds := &StaticAuthClientCreds{}
65-
err = json.Unmarshal(data, clientCreds)
72+
if creds != nil {
73+
grpcCreds := grpc.WithPerRPCCredentials(creds)
74+
opts = append(opts, grpcCreds)
75+
}
76+
return opts, nil
77+
}
78+
79+
// ResetStaticAuth resets the static auth credentials.
80+
func ResetStaticAuth() {
81+
clientCredsMu.Lock()
82+
defer clientCredsMu.Unlock()
83+
if clientCredsCancel != nil {
84+
clientCredsCancel()
85+
clientCredsCancel = nil
86+
}
87+
clientCreds = nil
88+
clientCredsErr = nil
89+
}
90+
91+
// getStaticAuthCreds returns the static auth creds and error.
92+
func getStaticAuthCreds() (*StaticAuthClientCreds, error) {
93+
clientCredsMu.Lock()
94+
defer clientCredsMu.Unlock()
95+
if credsFile != "" && clientCreds == nil {
96+
var ctx context.Context
97+
ctx, clientCredsCancel = context.WithCancel(context.Background())
98+
go handleClientCredsSignals(ctx)
99+
clientCreds, clientCredsErr = loadStaticAuthCredsFromFile(credsFile)
100+
}
101+
return clientCreds, clientCredsErr
102+
}
103+
104+
// handleClientCredsSignals handles signals to reload client creds.
105+
func handleClientCredsSignals(ctx context.Context) {
106+
for {
107+
select {
108+
case <-ctx.Done():
109+
return
110+
case <-clientCredsSigChan:
111+
if newCreds, err := loadStaticAuthCredsFromFile(credsFile); err == nil {
112+
clientCredsMu.Lock()
113+
clientCreds = newCreds
114+
clientCredsErr = err
115+
clientCredsMu.Unlock()
116+
}
117+
}
118+
}
119+
}
120+
121+
// loadStaticAuthCredsFromFile loads static auth credentials from a file.
122+
func loadStaticAuthCredsFromFile(path string) (*StaticAuthClientCreds, error) {
123+
data, err := os.ReadFile(path)
66124
if err != nil {
67125
return nil, err
68126
}
69-
creds := grpc.WithPerRPCCredentials(clientCreds)
70-
opts = append(opts, creds)
71-
return opts, nil
127+
creds := &StaticAuthClientCreds{}
128+
err = json.Unmarshal(data, creds)
129+
return creds, err
72130
}
73131

74132
func init() {
133+
servenv.OnInit(func() {
134+
clientCredsSigChan = make(chan os.Signal, 1)
135+
signal.Notify(clientCredsSigChan, syscall.SIGHUP)
136+
_, _ = getStaticAuthCreds() // preload static auth credentials
137+
})
75138
RegisterGRPCDialOptions(AppendStaticAuth)
76139
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
Copyright 2024 The Vitess Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package grpcclient
18+
19+
import (
20+
"errors"
21+
"fmt"
22+
"os"
23+
"reflect"
24+
"syscall"
25+
"testing"
26+
"time"
27+
28+
"github.com/stretchr/testify/assert"
29+
"google.golang.org/grpc"
30+
)
31+
32+
func TestAppendStaticAuth(t *testing.T) {
33+
{
34+
clientCreds = nil
35+
clientCredsErr = nil
36+
opts, err := AppendStaticAuth([]grpc.DialOption{})
37+
assert.Nil(t, err)
38+
assert.Len(t, opts, 0)
39+
}
40+
{
41+
clientCreds = nil
42+
clientCredsErr = errors.New("test err")
43+
opts, err := AppendStaticAuth([]grpc.DialOption{})
44+
assert.NotNil(t, err)
45+
assert.Len(t, opts, 0)
46+
}
47+
{
48+
clientCreds = &StaticAuthClientCreds{Username: "test", Password: "123456"}
49+
clientCredsErr = nil
50+
opts, err := AppendStaticAuth([]grpc.DialOption{})
51+
assert.Nil(t, err)
52+
assert.Len(t, opts, 1)
53+
}
54+
}
55+
56+
func TestGetStaticAuthCreds(t *testing.T) {
57+
tmp, err := os.CreateTemp("", t.Name())
58+
assert.Nil(t, err)
59+
defer os.Remove(tmp.Name())
60+
credsFile = tmp.Name()
61+
clientCredsSigChan = make(chan os.Signal, 1)
62+
63+
// load old creds
64+
fmt.Fprint(tmp, `{"Username": "old", "Password": "123456"}`)
65+
ResetStaticAuth()
66+
creds, err := getStaticAuthCreds()
67+
assert.Nil(t, err)
68+
assert.Equal(t, &StaticAuthClientCreds{Username: "old", Password: "123456"}, creds)
69+
70+
// write new creds to the same file
71+
_ = tmp.Truncate(0)
72+
_, _ = tmp.Seek(0, 0)
73+
fmt.Fprint(tmp, `{"Username": "new", "Password": "123456789"}`)
74+
75+
// test the creds did not change yet
76+
creds, err = getStaticAuthCreds()
77+
assert.Nil(t, err)
78+
assert.Equal(t, &StaticAuthClientCreds{Username: "old", Password: "123456"}, creds)
79+
80+
// test SIGHUP signal triggers reload
81+
credsOld := creds
82+
clientCredsSigChan <- syscall.SIGHUP
83+
timeoutChan := time.After(time.Second * 10)
84+
for {
85+
select {
86+
case <-timeoutChan:
87+
assert.Fail(t, "timed out waiting for SIGHUP reload of static auth creds")
88+
return
89+
default:
90+
// confirm new creds get loaded
91+
creds, err = getStaticAuthCreds()
92+
if reflect.DeepEqual(creds, credsOld) {
93+
continue // not changed yet
94+
}
95+
assert.Nil(t, err)
96+
assert.Equal(t, &StaticAuthClientCreds{Username: "new", Password: "123456789"}, creds)
97+
return
98+
}
99+
}
100+
}
101+
102+
func TestLoadStaticAuthCredsFromFile(t *testing.T) {
103+
{
104+
f, err := os.CreateTemp("", t.Name())
105+
if !assert.Nil(t, err) {
106+
assert.FailNowf(t, "cannot create temp file: %s", err.Error())
107+
}
108+
defer os.Remove(f.Name())
109+
fmt.Fprint(f, `{
110+
"Username": "test",
111+
"Password": "correct horse battery staple"
112+
}`)
113+
if !assert.Nil(t, err) {
114+
assert.FailNowf(t, "cannot read auth file: %s", err.Error())
115+
}
116+
117+
creds, err := loadStaticAuthCredsFromFile(f.Name())
118+
assert.Nil(t, err)
119+
assert.Equal(t, "test", creds.Username)
120+
assert.Equal(t, "correct horse battery staple", creds.Password)
121+
}
122+
{
123+
_, err := loadStaticAuthCredsFromFile(`does-not-exist`)
124+
assert.NotNil(t, err)
125+
}
126+
}

go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ func TestGRPCVTGateConnAuth(t *testing.T) {
108108
fs := pflag.NewFlagSet("", pflag.ContinueOnError)
109109
grpcclient.RegisterFlags(fs)
110110

111+
grpcclient.ResetStaticAuth()
111112
err = fs.Parse([]string{
112113
"--grpc_auth_static_client_creds",
113114
f.Name(),
@@ -148,6 +149,7 @@ func TestGRPCVTGateConnAuth(t *testing.T) {
148149
fs = pflag.NewFlagSet("", pflag.ContinueOnError)
149150
grpcclient.RegisterFlags(fs)
150151

152+
grpcclient.ResetStaticAuth()
151153
err = fs.Parse([]string{
152154
"--grpc_auth_static_client_creds",
153155
f.Name(),

0 commit comments

Comments
 (0)