@@ -20,24 +20,35 @@ import (
20
20
"context"
21
21
"encoding/json"
22
22
"os"
23
+ "os/signal"
24
+ "sync"
25
+ "syscall"
23
26
24
27
"google.golang.org/grpc"
25
28
"google.golang.org/grpc/credentials"
29
+
30
+ "vitess.io/vitess/go/vt/servenv"
26
31
)
27
32
28
33
var (
29
34
credsFile string // registered as --grpc_auth_static_client_creds in RegisterFlags
30
35
// StaticAuthClientCreds implements client interface to be able to WithPerRPCCredentials
31
36
_ credentials.PerRPCCredentials = (* StaticAuthClientCreds )(nil )
37
+
38
+ clientCreds * StaticAuthClientCreds
39
+ clientCredsCancel context.CancelFunc
40
+ clientCredsErr error
41
+ clientCredsMu sync.Mutex
42
+ clientCredsSigChan chan os.Signal
32
43
)
33
44
34
- // StaticAuthClientCreds holder for client credentials
45
+ // StaticAuthClientCreds holder for client credentials.
35
46
type StaticAuthClientCreds struct {
36
47
Username string
37
48
Password string
38
49
}
39
50
40
- // GetRequestMetadata gets the request metadata as a map from StaticAuthClientCreds
51
+ // GetRequestMetadata gets the request metadata as a map from StaticAuthClientCreds.
41
52
func (c * StaticAuthClientCreds ) GetRequestMetadata (context.Context , ... string ) (map [string ]string , error ) {
42
53
return map [string ]string {
43
54
"username" : c .Username ,
@@ -47,30 +58,82 @@ func (c *StaticAuthClientCreds) GetRequestMetadata(context.Context, ...string) (
47
58
48
59
// RequireTransportSecurity indicates whether the credentials requires transport security.
49
60
// Given that people can use this with or without TLS, at the moment we are not enforcing
50
- // transport security
61
+ // transport security.
51
62
func (c * StaticAuthClientCreds ) RequireTransportSecurity () bool {
52
63
return false
53
64
}
54
65
55
66
// AppendStaticAuth optionally appends static auth credentials if provided.
56
67
func AppendStaticAuth (opts []grpc.DialOption ) ([]grpc.DialOption , error ) {
57
- if credsFile == "" {
58
- return opts , nil
59
- }
60
- data , err := os .ReadFile (credsFile )
68
+ creds , err := getStaticAuthCreds ()
61
69
if err != nil {
62
70
return nil , err
63
71
}
64
- clientCreds := & StaticAuthClientCreds {}
65
- err = json .Unmarshal (data , clientCreds )
72
+ if creds != nil {
73
+ grpcCreds := grpc .WithPerRPCCredentials (creds )
74
+ opts = append (opts , grpcCreds )
75
+ }
76
+ return opts , nil
77
+ }
78
+
79
+ // ResetStaticAuth resets the static auth credentials.
80
+ func ResetStaticAuth () {
81
+ clientCredsMu .Lock ()
82
+ defer clientCredsMu .Unlock ()
83
+ if clientCredsCancel != nil {
84
+ clientCredsCancel ()
85
+ clientCredsCancel = nil
86
+ }
87
+ clientCreds = nil
88
+ clientCredsErr = nil
89
+ }
90
+
91
+ // getStaticAuthCreds returns the static auth creds and error.
92
+ func getStaticAuthCreds () (* StaticAuthClientCreds , error ) {
93
+ clientCredsMu .Lock ()
94
+ defer clientCredsMu .Unlock ()
95
+ if credsFile != "" && clientCreds == nil {
96
+ var ctx context.Context
97
+ ctx , clientCredsCancel = context .WithCancel (context .Background ())
98
+ go handleClientCredsSignals (ctx )
99
+ clientCreds , clientCredsErr = loadStaticAuthCredsFromFile (credsFile )
100
+ }
101
+ return clientCreds , clientCredsErr
102
+ }
103
+
104
+ // handleClientCredsSignals handles signals to reload client creds.
105
+ func handleClientCredsSignals (ctx context.Context ) {
106
+ for {
107
+ select {
108
+ case <- ctx .Done ():
109
+ return
110
+ case <- clientCredsSigChan :
111
+ if newCreds , err := loadStaticAuthCredsFromFile (credsFile ); err == nil {
112
+ clientCredsMu .Lock ()
113
+ clientCreds = newCreds
114
+ clientCredsErr = err
115
+ clientCredsMu .Unlock ()
116
+ }
117
+ }
118
+ }
119
+ }
120
+
121
+ // loadStaticAuthCredsFromFile loads static auth credentials from a file.
122
+ func loadStaticAuthCredsFromFile (path string ) (* StaticAuthClientCreds , error ) {
123
+ data , err := os .ReadFile (path )
66
124
if err != nil {
67
125
return nil , err
68
126
}
69
- creds := grpc . WithPerRPCCredentials ( clientCreds )
70
- opts = append ( opts , creds )
71
- return opts , nil
127
+ creds := & StaticAuthClientCreds {}
128
+ err = json . Unmarshal ( data , creds )
129
+ return creds , err
72
130
}
73
131
74
132
func init () {
133
+ servenv .OnInit (func () {
134
+ clientCredsSigChan = make (chan os.Signal , 1 )
135
+ signal .Notify (clientCredsSigChan , syscall .SIGHUP )
136
+ _ , _ = getStaticAuthCreds () // preload static auth credentials
137
+ })
75
138
RegisterGRPCDialOptions (AppendStaticAuth )
76
139
}
0 commit comments