6
6
"context"
7
7
"crypto/rsa"
8
8
"crypto/sha256"
9
+ "crypto/tls"
9
10
"crypto/x509"
10
11
"encoding/base64"
11
12
"encoding/binary"
@@ -66,6 +67,8 @@ type AuthHandler struct {
66
67
Auths map [string ]map [string ]string
67
68
Authenticators []Authenticator
68
69
ACLs []ACLRecord
70
+ HaveCertAuth bool
71
+ certRoleCAs map [string ][]* x509.Certificate
69
72
}
70
73
71
74
const jwtParams = `(cookie|header|query)=([A-Za-z0-9_-]+)`
@@ -126,8 +129,100 @@ func sshKeysToPEM(in []byte) (out []byte) {
126
129
return
127
130
}
128
131
132
+ func (ah * AuthHandler ) ClientAuthHostsCAs () (hostMap map [string ][]* x509.Certificate ) {
133
+ acls := ah .ACLs
134
+ if acls == nil {
135
+ acls = []ACLRecord {{RolesToCheck : map [string ]interface {}{}}}
136
+ for role := range ah .certRoleCAs {
137
+ acls [0 ].RolesToCheck [role ] = true
138
+ }
139
+ }
140
+ for _ , acl := range acls {
141
+ haveCertRoles := false
142
+ caCerts := []* x509.Certificate {}
143
+ for role := range acl .RolesToCheck {
144
+ if roleCerts , have := ah .certRoleCAs [role ]; have {
145
+ haveCertRoles = true
146
+ if caCerts == nil || roleCerts == nil {
147
+ caCerts = nil
148
+ } else {
149
+ caCerts = append (caCerts , roleCerts ... )
150
+ }
151
+ }
152
+ }
153
+ if ! haveCertRoles {
154
+ continue
155
+ }
156
+ if hostMap == nil {
157
+ hostMap = make (map [string ][]* x509.Certificate )
158
+ }
159
+ hosts := acl .Hosts
160
+ if hosts == nil {
161
+ hosts = map [string ]bool {"*" : true }
162
+ }
163
+ for host := range hosts {
164
+ if caCerts == nil {
165
+ hostMap [host ] = nil
166
+ } else if curList , have := hostMap [host ]; curList != nil || ! have {
167
+ hostMap [host ] = append (curList , caCerts ... )
168
+ }
169
+ }
170
+ }
171
+ return
172
+ }
173
+
174
+ func (ah * AuthHandler ) ConfigureServerTLSConfig (cfg * tls.Config ) (hostToCAsMap map [string ][]* x509.Certificate ) {
175
+ if ! ah .HaveCertAuth {
176
+ return
177
+ }
178
+ hostToCAsMap = ah .ClientAuthHostsCAs ()
179
+ globalCAs , haveGlobal := hostToCAsMap ["*" ]
180
+ if haveGlobal && (len (hostToCAsMap ) == 1 || globalCAs == nil ) {
181
+ cfg .ClientAuth = tls .RequestClientCert
182
+ if globalCAs != nil {
183
+ cfg .ClientCAs = x509 .NewCertPool ()
184
+ for _ , cert := range globalCAs {
185
+ cfg .ClientCAs .AddCert (cert )
186
+ }
187
+ return
188
+ }
189
+ }
190
+
191
+ cfg .GetConfigForClient = func (chi * tls.ClientHelloInfo ) (ret * tls.Config , err error ) {
192
+ hostCAs , haveHost := hostToCAsMap [chi .ServerName ]
193
+ if haveHost || haveGlobal {
194
+ ret = cfg .Clone ()
195
+ ret .ClientAuth = tls .RequestClientCert
196
+ if (haveHost && hostCAs == nil ) || (haveGlobal && globalCAs == nil ) {
197
+ return
198
+ }
199
+ ret .ClientCAs = x509 .NewCertPool ()
200
+ if haveHost {
201
+ for _ , cert := range hostCAs {
202
+ ret .ClientCAs .AddCert (cert )
203
+ }
204
+ }
205
+ if haveGlobal {
206
+ for _ , cert := range globalCAs {
207
+ ret .ClientCAs .AddCert (cert )
208
+ }
209
+ }
210
+ }
211
+ return
212
+ }
213
+ return
214
+ }
215
+
129
216
// AddAuth : add authentication method to identify role(s)
130
217
func (ah * AuthHandler ) AddAuth (method , check , name string ) {
218
+ switch method {
219
+ case "Cert" , "CertBy" , "CertKeyHash" :
220
+ ah .HaveCertAuth = true
221
+ if ah .certRoleCAs == nil {
222
+ ah .certRoleCAs = make (map [string ][]* x509.Certificate )
223
+ }
224
+ }
225
+
131
226
if ah .Auths == nil {
132
227
ah .Auths = make (map [string ]map [string ]string )
133
228
}
@@ -146,6 +241,9 @@ func (ah *AuthHandler) AddAuth(method, check, name string) {
146
241
147
242
switch method {
148
243
case "CertKeyHash" :
244
+ for _ , role := range strings .Split (name , "+" ) {
245
+ ah .certRoleCAs [role ] = nil
246
+ }
149
247
if strings .HasPrefix (check , "file:" ) {
150
248
fileName := check [5 :]
151
249
data , err := os .ReadFile (fileName )
@@ -224,6 +322,24 @@ func (ah *AuthHandler) AddAuth(method, check, name string) {
224
322
}
225
323
logf (nil , logLevelInfo , "Read %d certificates from %#v for role %#v" , nrDone , fileName , name )
226
324
return
325
+ } else if method == "CertBy" {
326
+ certBytes , err := hex .DecodeString (check )
327
+ if err != nil {
328
+ logf (nil , logLevelFatal , "decoding hex: %s: %#v" , check )
329
+ }
330
+ cert , err := x509 .ParseCertificate (certBytes )
331
+ if err != nil {
332
+ logf (nil , logLevelFatal , "parsing cert: %s: %#v" , check )
333
+ }
334
+ for _ , role := range strings .Split (name , "+" ) {
335
+ if pool , have := ah .certRoleCAs [role ]; ! have || pool != nil {
336
+ ah .certRoleCAs [role ] = append (pool , cert )
337
+ }
338
+ }
339
+ } else {
340
+ for _ , role := range strings .Split (name , "+" ) {
341
+ ah .certRoleCAs [role ] = nil
342
+ }
227
343
}
228
344
case "JWTSecret" , "JWTFilePat" :
229
345
logf (nil , logLevelWarning , "DEPRECATED: please use JWT auth method instead of %#v" , method )
0 commit comments