diff --git a/backend.go b/backend.go index 775cba57..2d12c931 100644 --- a/backend.go +++ b/backend.go @@ -7,6 +7,7 @@ import ( "fmt" "strings" "sync" + "time" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" @@ -27,6 +28,18 @@ var ( // when adding new alias name sources make sure to update the corresponding FieldSchema description in path_role.go aliasNameSources = []string{aliasNameSourceSAUid, aliasNameSourceSAName} errInvalidAliasNameSource = fmt.Errorf(`invalid alias_name_source, must be one of: %s`, strings.Join(aliasNameSources, ", ")) + + // jwtReloadPeriod is the time period how often the in-memory copy of local + // service account token can be used, before reading it again from disk. + // + // The value is selected according to recommendation in Kubernetes 1.21 changelog: + // "Clients should reload the token from disk periodically (once per minute + // is recommended) to ensure they continue to use a valid token." + jwtReloadPeriod = 1 * time.Minute + + // caReloadPeriod is the time period how often the in-memory copy of local + // CA cert can be used, before reading it again from disk. + caReloadPeriod = 1 * time.Hour ) // kubeAuthBackend implements logical.Backend @@ -38,6 +51,19 @@ type kubeAuthBackend struct { // review. Mocks should only be used in tests. reviewFactory tokenReviewFactory + // localSATokenReader caches the service account token in memory. + // It periodically reloads the token to support token rotation/renewal. + // Local token is used when running in a pod with following configuration + // - token_reviewer_jwt is not set + // - disable_local_ca_jwt is false + localSATokenReader *cachingFileReader + + // localCACertReader contains the local CA certificate. Local CA certificate is + // used when running in a pod with following configuration + // - kubernetes_ca_cert is not set + // - disable_local_ca_jwt is false + localCACertReader *cachingFileReader + l sync.RWMutex } @@ -51,7 +77,10 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, } func Backend() *kubeAuthBackend { - b := &kubeAuthBackend{} + b := &kubeAuthBackend{ + localSATokenReader: newCachingFileReader(localJWTPath, jwtReloadPeriod, time.Now), + localCACertReader: newCachingFileReader(localCACertPath, caReloadPeriod, time.Now), + } b.Backend = &framework.Backend{ AuthRenew: b.pathLoginRenew(), @@ -80,7 +109,8 @@ func Backend() *kubeAuthBackend { return b } -// config takes a storage object and returns a kubeConfig object +// config takes a storage object and returns a kubeConfig object. +// It does not return local token and CA file which are specific to the pod we run in. func (b *kubeAuthBackend) config(ctx context.Context, s logical.Storage) (*kubeConfig, error) { raw, err := s.Get(ctx, configPath) if err != nil { @@ -107,6 +137,8 @@ func (b *kubeAuthBackend) config(ctx context.Context, s logical.Storage) (*kubeC return conf, nil } +// loadConfig fetches the kubeConfig from storage and optionally decorates it with +// local token and CA certificate. func (b *kubeAuthBackend) loadConfig(ctx context.Context, s logical.Storage) (*kubeConfig, error) { config, err := b.config(ctx, s) if err != nil { @@ -115,6 +147,30 @@ func (b *kubeAuthBackend) loadConfig(ctx context.Context, s logical.Storage) (*k if config == nil { return nil, errors.New("could not load backend configuration") } + + // Nothing more to do if loading local CA cert and JWT token is disabled. + if config.DisableLocalCAJwt { + return config, nil + } + + // Read local JWT token unless it was not stored in config. + if config.TokenReviewerJWT == "" { + config.TokenReviewerJWT, err = b.localSATokenReader.ReadFile() + if err != nil { + // Ignore error: make best effort trying to load local JWT, + // otherwise the JWT submitted in login payload will be used. + b.Logger().Debug("failed to read local service account token, will use client token", "error", err) + } + } + + // Read local CA cert unless it was stored in config. + if config.CACert == "" { + config.CACert, err = b.localCACertReader.ReadFile() + if err != nil { + return nil, err + } + } + return config, nil } diff --git a/caching_file_reader.go b/caching_file_reader.go new file mode 100644 index 00000000..d18445fc --- /dev/null +++ b/caching_file_reader.go @@ -0,0 +1,68 @@ +package kubeauth + +import ( + "io/ioutil" + "sync" + "time" +) + +// cachingFileReader reads a file and keeps an in-memory copy of it, until the +// copy is considered stale. Next ReadFile() after expiry will re-read the file from disk. +type cachingFileReader struct { + // path is the file path to the cached file. + path string + + // ttl is the time-to-live duration when cached file is considered stale + ttl time.Duration + + // cache is the buffer holding the in-memory copy of the file. + cache cachedFile + + l sync.RWMutex + + // currentTime is a function that returns the current local time. + // Normally set to time.Now but it can be overwritten by test cases to manipulate time. + currentTime func() time.Time +} + +type cachedFile struct { + // buf is the buffer holding the in-memory copy of the file. + buf string + + // expiry is the time when the cached copy is considered stale and must be re-read. + expiry time.Time +} + +func newCachingFileReader(path string, ttl time.Duration, currentTime func() time.Time) *cachingFileReader { + return &cachingFileReader{ + path: path, + ttl: ttl, + currentTime: currentTime, + } +} + +func (r *cachingFileReader) ReadFile() (string, error) { + // Fast path requiring read lock only: file is already in memory and not stale. + r.l.RLock() + now := r.currentTime() + cache := r.cache + r.l.RUnlock() + if now.Before(cache.expiry) { + return cache.buf, nil + } + + // Slow path: read the file from disk. + r.l.Lock() + defer r.l.Unlock() + + buf, err := ioutil.ReadFile(r.path) + if err != nil { + return "", err + } + r.cache = cachedFile{ + buf: string(buf), + expiry: now.Add(r.ttl), + } + + return r.cache.buf, nil +} diff --git a/caching_file_reader_test.go b/caching_file_reader_test.go new file mode 100644 index 00000000..ba282510 --- /dev/null +++ b/caching_file_reader_test.go @@ -0,0 +1,65 @@ +package kubeauth + +import ( + "io/ioutil" + "os" + "testing" + "time" +) + +func TestCachingFileReader(t *testing.T) { + content1 := "before" + content2 := "after" + + // Create temporary file. + f, err := ioutil.TempFile("", "testfile") + if err != nil { + t.Error(err) + } + f.Close() + defer os.Remove(f.Name()) + + currentTime := time.Now() + + r := newCachingFileReader(f.Name(), 1*time.Minute, + func() time.Time { + return currentTime + }) + + // Write initial content to file and check that we can read it. + ioutil.WriteFile(f.Name(), []byte(content1), 0644) + got, err := r.ReadFile() + if err != nil { + t.Error(err) + } + if got != content1 { + t.Errorf("got '%s', expected '%s'", got, content1) + } + + // Write new content to the file. + ioutil.WriteFile(f.Name(), []byte(content2), 0644) + + // Advance simulated time, but not enough for cache to expire. + currentTime = currentTime.Add(30 * time.Second) + + // Read again and check we still got the old cached content. + got, err = r.ReadFile() + if err != nil { + t.Error(err) + } + if got != content1 { + t.Errorf("got '%s', expected '%s'", got, content1) + } + + // Advance simulated time for cache to expire. + currentTime = currentTime.Add(30 * time.Second) + + // Read again and check that we got the new content. + got, err = r.ReadFile() + if err != nil { + t.Error(err) + } + if got != content2 { + t.Errorf("got '%s', expected '%s'", got, content2) + } +} diff --git a/path_config.go b/path_config.go index b593dede..d61b5d93 100644 --- a/path_config.go +++ b/path_config.go @@ -7,14 +7,13 @@ import ( "crypto/x509" "encoding/pem" "errors" - "io/ioutil" "github.com/briankassouf/jose/jws" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" ) -var ( +const ( localCACertPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" localJWTPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" ) @@ -126,30 +125,13 @@ func (b *kubeAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Requ } disableLocalJWT := data.Get("disable_local_ca_jwt").(bool) - localCACert := []byte{} - localTokenReviewer := []byte{} - if !disableLocalJWT { - localCACert, _ = ioutil.ReadFile(localCACertPath) - localTokenReviewer, _ = ioutil.ReadFile(localJWTPath) - } pemList := data.Get("pem_keys").([]string) caCert := data.Get("kubernetes_ca_cert").(string) issuer := data.Get("issuer").(string) disableIssValidation := data.Get("disable_iss_validation").(bool) - if len(pemList) == 0 && len(caCert) == 0 { - if len(localCACert) > 0 { - caCert = string(localCACert) - } else { - return logical.ErrorResponse("one of pem_keys or kubernetes_ca_cert must be set"), nil - } - } - tokenReviewer := data.Get("token_reviewer_jwt").(string) - if !disableLocalJWT && len(tokenReviewer) == 0 && len(localTokenReviewer) > 0 { - tokenReviewer = string(localTokenReviewer) - } - if len(tokenReviewer) > 0 { + if tokenReviewer != "" { // Validate it's a JWT _, err := jws.ParseJWT([]byte(tokenReviewer)) if err != nil { @@ -157,6 +139,10 @@ func (b *kubeAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Requ } } + if disableLocalJWT && caCert == "" { + return logical.ErrorResponse("kubernetes_ca_cert must be given when disable_local_ca_jwt is true"), nil + } + config := &kubeConfig{ PublicKeys: make([]interface{}, len(pemList)), PEMKeys: pemList, diff --git a/path_config_test.go b/path_config_test.go index c49ed6a4..e5cd86bb 100644 --- a/path_config_test.go +++ b/path_config_test.go @@ -6,13 +6,40 @@ import ( "os" "reflect" "testing" + "time" "github.com/hashicorp/vault/sdk/logical" ) +func setupLocalFiles(t *testing.T, b logical.Backend) func() { + cert, err := ioutil.TempFile("", "ca.crt") + if err != nil { + t.Fatal(err) + } + cert.WriteString(testLocalCACert) + cert.Close() + + token, err := ioutil.TempFile("", "token") + if err != nil { + t.Fatal(err) + } + token.WriteString(testLocalJWT) + token.Close() + b.(*kubeAuthBackend).localCACertReader = newCachingFileReader(cert.Name(), caReloadPeriod, time.Now) + b.(*kubeAuthBackend).localSATokenReader = newCachingFileReader(token.Name(), jwtReloadPeriod, time.Now) + + return func() { + os.Remove(cert.Name()) + os.Remove(token.Name()) + } +} + func TestConfig_Read(t *testing.T) { b, storage := getBackend(t) + cleanup := setupLocalFiles(t, b) + defer cleanup() + data := map[string]interface{}{ "pem_keys": []string{testRSACert, testECCert}, "kubernetes_host": "host", @@ -54,6 +81,9 @@ func TestConfig_Read(t *testing.T) { func TestConfig(t *testing.T) { b, storage := getBackend(t) + cleanup := setupLocalFiles(t, b) + defer cleanup() + // test no certificate data := map[string]interface{}{ "kubernetes_host": "host", @@ -67,11 +97,8 @@ func TestConfig(t *testing.T) { } resp, err := b.HandleRequest(context.Background(), req) - if resp == nil || !resp.IsError() { - t.Fatal("expected error") - } - if resp.Error().Error() != "one of pem_keys or kubernetes_ca_cert must be set" { - t.Fatalf("got unexpected error: %v", resp.Error()) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) } // test no host @@ -331,24 +358,16 @@ func TestConfig(t *testing.T) { } func TestConfig_LocalCaJWT(t *testing.T) { - b, storage := getBackend(t) - - // write "local" CA and JWT, and override local path vars - caFile := writeToTempFile(t, testLocalCACert) - localCACertPath = caFile - defer os.Remove(caFile) - jwtFile := writeToTempFile(t, testLocalJWT) - localJWTPath = jwtFile - defer os.Remove(jwtFile) - testCases := map[string]struct { - config map[string]interface{} - expected *kubeConfig + config map[string]interface{} + setupInClusterFiles bool + expected *kubeConfig }{ "no CA or JWT, default to local": { config: map[string]interface{}{ "kubernetes_host": "host", }, + setupInClusterFiles: true, expected: &kubeConfig{ PublicKeys: []interface{}{}, PEMKeys: []string{}, @@ -364,6 +383,7 @@ func TestConfig_LocalCaJWT(t *testing.T) { "kubernetes_host": "host", "kubernetes_ca_cert": testCACert, }, + setupInClusterFiles: true, expected: &kubeConfig{ PublicKeys: []interface{}{}, PEMKeys: []string{}, @@ -379,6 +399,7 @@ func TestConfig_LocalCaJWT(t *testing.T) { "kubernetes_host": "host", "token_reviewer_jwt": jwtData, }, + setupInClusterFiles: true, expected: &kubeConfig{ PublicKeys: []interface{}{}, PEMKeys: []string{}, @@ -409,6 +430,13 @@ func TestConfig_LocalCaJWT(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { + b, storage := getBackend(t) + + if tc.setupInClusterFiles { + cleanup := setupLocalFiles(t, b) + defer cleanup() + } + req := &logical.Request{ Operation: logical.CreateOperation, Path: configPath, @@ -421,7 +449,7 @@ func TestConfig_LocalCaJWT(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - conf, err := b.(*kubeAuthBackend).config(context.Background(), storage) + conf, err := b.(*kubeAuthBackend).loadConfig(context.Background(), storage) if err != nil { t.Fatal(err) } @@ -433,18 +461,84 @@ func TestConfig_LocalCaJWT(t *testing.T) { } } -func writeToTempFile(t *testing.T, contents string) string { - t.Helper() +func TestConfig_LocalJWTRenewal(t *testing.T) { + b, storage := getBackend(t) - f, err := ioutil.TempFile("", "test") + cleanup := setupLocalFiles(t, b) + defer cleanup() + + // Create temp file that will be used as token. + f, err := ioutil.TempFile("", "renewed-token") if err != nil { - t.Fatalf("Failure to create test file: %s", err) + t.Error(err) } - _, err = f.WriteString(contents) - if err != nil { - t.Fatalf("Failure to write test file: %s", err) + f.Close() + defer os.Remove(f.Name()) + + currentTime := time.Now() + + b.(*kubeAuthBackend).localSATokenReader = newCachingFileReader(f.Name(), jwtReloadPeriod, func() time.Time { + return currentTime + }) + + token1 := "before-renewal" + token2 := "after-renewal" + + // Write initial token to the temp file. + ioutil.WriteFile(f.Name(), []byte(token1), 0644) + + data := map[string]interface{}{ + "kubernetes_host": "host", } - return f.Name() + req := &logical.Request{ + Operation: logical.CreateOperation, + Path: configPath, + Storage: storage, + Data: data, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Loading the config will load the initial token file from disk. + conf, err := b.(*kubeAuthBackend).loadConfig(context.Background(), storage) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Check that we loaded the initial token. + if conf.TokenReviewerJWT != token1 { + t.Fatalf("got unexpected JWT: expected %#v\n got %#v\n", token1, conf.TokenReviewerJWT) + } + + // Write new value to the token file to simulate renewal. + ioutil.WriteFile(f.Name(), []byte(token2), 0644) + + // Load again to check we still got the old cached token from memory. + conf, err = b.(*kubeAuthBackend).loadConfig(context.Background(), storage) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if conf.TokenReviewerJWT != token1 { + t.Fatalf("got unexpected JWT: expected %#v\n got %#v\n", token1, conf.TokenReviewerJWT) + } + + // Advance simulated time for cache to expire + currentTime = currentTime.Add(1 * time.Minute) + + // Load again and check we the new renewed token from disk. + conf, err = b.(*kubeAuthBackend).loadConfig(context.Background(), storage) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if conf.TokenReviewerJWT != token2 { + t.Fatalf("got unexpected JWT: expected %#v\n got %#v\n", token2, conf.TokenReviewerJWT) + } + } var testLocalCACert string = `-----BEGIN CERTIFICATE----- diff --git a/path_login.go b/path_login.go index 51f1fcf6..ec8d8a1b 100644 --- a/path_login.go +++ b/path_login.go @@ -180,6 +180,9 @@ func (b *kubeAuthBackend) aliasLookahead(ctx context.Context, req *logical.Reque return resp, nil } + b.l.RLock() + defer b.l.RUnlock() + role, err := b.role(ctx, req.Storage, roleName) if err != nil { return nil, err