diff --git a/token/cache.go b/token/cache.go new file mode 100644 index 0000000..a9ce818 --- /dev/null +++ b/token/cache.go @@ -0,0 +1,56 @@ +// Copyright (C) 2024 vcs contributors +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this program. If not, see +// . +// +// SPDX-License-Identifier: LGPL-3.0 + +package token + +import ( + "sync" + + "github.com/jaredallard/vcs" + "github.com/jaredallard/vcs/token/internal/shared" +) + +// tokenCache is a cache of tokens that have been fetched from the +// user's machine. +type tokenCache struct { + // tokensMu is a mutex to protect the tokens map. + tokensMu sync.RWMutex + + // tokens is a map of VCS provider to their respective token. + tokens map[vcs.Provider]*shared.Token +} + +// Get returns a token from the cache if it exists. +func (c *tokenCache) Get(provider vcs.Provider) (*shared.Token, bool) { + c.tokensMu.RLock() + defer c.tokensMu.RUnlock() + + t, ok := c.tokens[provider] + return t, ok +} + +// Set sets a token in the cache. +func (c *tokenCache) Set(provider vcs.Provider, token *shared.Token) { + c.tokensMu.Lock() + defer c.tokensMu.Unlock() + + c.tokens[provider] = token +} + +// cache is the global token cache. +var cache = &tokenCache{tokens: make(map[vcs.Provider]*shared.Token)} diff --git a/token/internal/shared/shared.go b/token/internal/shared/shared.go index a8fabf6..1bce86a 100644 --- a/token/internal/shared/shared.go +++ b/token/internal/shared/shared.go @@ -22,6 +22,7 @@ package shared import ( "strings" + "time" ) // Token is a VCS token that can be used for API access. @@ -29,6 +30,9 @@ import ( // Do not use the 'shared.Token' type, instead use [token.Token] which // is an alias to this type. type Token struct { + // FetchedAt is the time that the token was fetched at. + FetchedAt time.Time + // Value is the token value. Value string @@ -56,6 +60,15 @@ func (t *Token) String() string { return t.Value } +// Clone returns a deep clone of the token. +func (t *Token) Clone() *Token { + return &Token{ + FetchedAt: t.FetchedAt, + Value: t.Value, + Type: t.Type, + } +} + // Provider is an interface for VCS providers to implement to provide a // token from a user's machine. type Provider interface { diff --git a/token/token.go b/token/token.go index 7c602e9..228d057 100644 --- a/token/token.go +++ b/token/token.go @@ -24,6 +24,7 @@ import ( "context" "errors" "fmt" + "time" "github.com/jaredallard/vcs" "github.com/jaredallard/vcs/token/internal/github" @@ -56,16 +57,66 @@ func (errs ErrNoToken) Error() string { return errors.Join(errs...).Error() } +// Options contains options for the [Fetch] function. +type Options struct { + // AllowUnauthenticated allows for an empty token to be returned if + // no token is found. + AllowUnauthenticated bool + + // UseGlobalCache allows for the use of a global cache for tokens. If + // set to true, the token will be cached globally (all instances of + // this library). Otherwise, the token will always be fetched. + // + // Defaults to true. + // + // Note: When using [shared.Token], the value will never change. + // Caching refers only to function calls provided by this package + // (e.g., [Fetch]). + UseGlobalCache *bool +} + // Fetch returns a valid token from one of the configured credential // providers. If no token is found, ErrNoToken is returned. // -// If allowUnauthenticated is true, then an empty token is returned if -// no token is found. -func Fetch(_ context.Context, vcsp vcs.Provider, allowUnauthenticated bool) (*shared.Token, error) { +// allowUnauthenticated is DEPRECATED and will be removed in a future +// release. Use the Options struct instead, setting AllowUnauthenticated +// to true/false. +// +// optss is a variadic argument only to avoid a breaking change. Only +// one option struct is allowed, an error will be returned if more than +// one is provided. +func Fetch(_ context.Context, vcsp vcs.Provider, allowUnauthenticated bool, optss ...*Options) (*shared.Token, error) { if _, ok := defaultProviders[vcsp]; !ok { return nil, fmt.Errorf("unknown VCS provider %q", vcsp) } + var opts Options + if len(optss) == 1 { + if optss[0] != nil { + opts = *optss[0] + } + } else if len(optss) > 1 { + return nil, fmt.Errorf("too many options provided") + } + + // Support the older API. + if allowUnauthenticated { + opts.AllowUnauthenticated = true + } + + // If UseGlobalCache is not set, default to true. + if opts.UseGlobalCache == nil { + b := true + opts.UseGlobalCache = &b + } + + if *opts.UseGlobalCache { + t, ok := cache.Get(vcsp) + if ok { + return t.Clone(), nil + } + } + var token *shared.Token errs := []error{} for _, p := range defaultProviders[vcsp] { @@ -83,11 +134,18 @@ func Fetch(_ context.Context, vcsp vcs.Provider, allowUnauthenticated bool) (*sh } } if token == nil { - if allowUnauthenticated { - return &shared.Token{}, nil + if !opts.AllowUnauthenticated { + return nil, ErrNoToken(errs) } - return nil, ErrNoToken(errs) + // Set an empty token since we're allowing unauthenticated access. + token = &shared.Token{} } + + // Set when the token was fetched and store it in the cache for + // possibly other calls to use. + token.FetchedAt = time.Now() + cache.Set(vcsp, token) + return token, nil } diff --git a/token/token_test.go b/token/token_test.go index c0c9759..6581ac5 100644 --- a/token/token_test.go +++ b/token/token_test.go @@ -6,11 +6,19 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/jaredallard/vcs" "github.com/jaredallard/vcs/token" "gotest.tools/v3/assert" ) +// ignoreTime is a [cmp.Option] that ignores time.Time values when +// comparing them, always returning true. +var ignoreTime = cmp.Comparer(func(_, _ time.Time) bool { + // Times are random, so ignore them. + return true +}) + // TestCanGetToken ensures that [token.Fetch] calls the underlying // provider to get the token. func TestCanGetToken(t *testing.T) { @@ -18,5 +26,24 @@ func TestCanGetToken(t *testing.T) { authToken, err := token.Fetch(context.Background(), vcs.ProviderGithub, false) assert.NilError(t, err) assert.Assert(t, authToken != nil, "expected a token to be returned") - assert.DeepEqual(t, authToken, &token.Token{Value: os.Getenv("GITHUB_TOKEN")}) + assert.DeepEqual(t, authToken, &token.Token{Value: os.Getenv("GITHUB_TOKEN")}, ignoreTime) +} + +// TestCanGetCachedToken ensures that [token.Fetch] returns the same +// token when called multiple times and caching is enabled. +func TestCanGetCachedToken(t *testing.T) { + bfalse := false + t.Setenv("GITHUB_TOKEN", time.Now().String()) + + originalToken, err := token.Fetch(context.Background(), vcs.ProviderGithub, false, &token.Options{UseGlobalCache: &bfalse}) + assert.NilError(t, err) + assert.Assert(t, originalToken != nil, "expected a token to be returned") + assert.DeepEqual(t, originalToken, &token.Token{Value: os.Getenv("GITHUB_TOKEN")}, ignoreTime) + assert.Equal(t, originalToken.FetchedAt.IsZero(), false) // should not be zero + + // Fetch again, should return the same token. + newToken, err := token.Fetch(context.Background(), vcs.ProviderGithub, false) + assert.NilError(t, err) + assert.Assert(t, newToken != nil, "expected a token to be returned") + assert.DeepEqual(t, newToken, &token.Token{FetchedAt: originalToken.FetchedAt, Value: os.Getenv("GITHUB_TOKEN")}) }