diff --git a/descope/internal/auth/auth_test.go b/descope/internal/auth/auth_test.go index f7e18097..657f24e6 100644 --- a/descope/internal/auth/auth_test.go +++ b/descope/internal/auth/auth_test.go @@ -958,6 +958,7 @@ func TestExchangeAccessKey(t *testing.T) { func TestExchangeAccessKeyWithLoginOptions(t *testing.T) { response := map[string]any{} err := utils.Unmarshal([]byte(mockAuthSessionBody), &response) + require.NoError(t, err) a, err := newTestAuth(nil, helpers.DoOkWithBody(func(r *http.Request) { req := map[string]any{} require.NoError(t, helpers.ReadBody(r, &req)) @@ -971,7 +972,7 @@ func TestExchangeAccessKeyWithLoginOptions(t *testing.T) { require.True(t, found) require.EqualValues(t, "v1", d) }, response)) - + require.NoError(t, err) loginOptions := &descope.AccessKeyLoginOptions{ CustomClaims: map[string]any{"k1": "v1"}, } diff --git a/descope/internal/auth/jwt.go b/descope/internal/auth/jwt.go index d6b76028..1281874e 100644 --- a/descope/internal/auth/jwt.go +++ b/descope/internal/auth/jwt.go @@ -3,30 +3,36 @@ package auth import ( "context" "path" - - "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jws" + "sync/atomic" "github.com/descope/go-sdk/descope" "github.com/descope/go-sdk/descope/api" "github.com/descope/go-sdk/descope/internal/utils" "github.com/descope/go-sdk/descope/logger" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" ) type provider struct { client *api.Client conf *AuthParams providedKey jwk.Key - keySet map[string]jwk.Key + keySet atomic.Value } func newProvider(client *api.Client, conf *AuthParams) *provider { - return &provider{client: client, conf: conf, keySet: make(map[string]jwk.Key)} + ks := atomic.Value{} + ks.Store(map[string]jwk.Key{}) + return &provider{client: client, conf: conf, keySet: ks} +} + +func (p *provider) keySetMap() map[string]jwk.Key { + return p.keySet.Load().(map[string]jwk.Key) } func (p *provider) publicKeyExists() bool { - return len(p.keySet) > 0 || p.providedKey != nil + return len(p.keySetMap()) > 0 || p.providedKey != nil } func (p *provider) selectKey(sink jws.KeySink, key jwk.Key) error { @@ -50,7 +56,7 @@ func (p *provider) selectKey(sink jws.KeySink, key jwk.Key) error { func (p *provider) requestKeys() error { projectID := p.conf.ProjectID keysWrapper := map[string][]map[string]interface{}{} - _, err := p.client.DoGetRequest(nil, path.Join(api.Routes.GetKeys(), projectID), &api.HTTPRequest{ResBodyObj: &keysWrapper}, "") + _, err := p.client.DoGetRequest(context.Background(), path.Join(api.Routes.GetKeys(), projectID), &api.HTTPRequest{ResBodyObj: &keysWrapper}, "") if err != nil { return err } @@ -79,7 +85,7 @@ func (p *provider) requestKeys() error { } logger.LogDebug("Refresh keys set with %d key(s)", len(tempKeySet)) - p.keySet = tempKeySet + p.keySet.Store(tempKeySet) return nil } @@ -119,10 +125,10 @@ func (p *provider) findKey(kid string) (jwk.Key, error) { return nil, err } - key, ok := p.keySet[kid] + key, ok := p.keySetMap()[kid] if !ok { err := descope.ErrPublicKey.WithMessage("Required public key does not exist in key set") - logger.LogInfo("Required public key does not exist in key set (key set size [%d])", len(p.keySet)) + logger.LogInfo("Required public key does not exist in key set (key set size [%d])", len(p.keySetMap())) return nil, err } @@ -131,7 +137,7 @@ func (p *provider) findKey(kid string) (jwk.Key, error) { func (p *provider) FetchKeys(_ context.Context, sink jws.KeySink, sig *jws.Signature, _ *jws.Message) error { wantedKid := sig.ProtectedHeaders().KeyID() - v, ok := p.keySet[wantedKid] + v, ok := p.keySetMap()[wantedKid] if !ok { logger.LogDebug("Key was not found, looking for key id [%s]", wantedKid) if key, err := p.findKey(wantedKid); key != nil {