Skip to content

Commit 66bcd97

Browse files
authored
VCR: Deduplicate AccessToken functions in IAM client (#3077)
1 parent f262272 commit 66bcd97

File tree

17 files changed

+193
-173
lines changed

17 files changed

+193
-173
lines changed

auth/api/auth/v1/api_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ func TestWrapper_CreateAccessToken(t *testing.T) {
598598
params := CreateAccessTokenRequest{GrantType: "urn:ietf:params:oauth:grant-type:jwt-bearer", Assertion: validJwt}
599599

600600
in800000 := 800000
601-
pkgResponse := &oauth2.TokenResponse{AccessToken: "foo", ExpiresIn: &in800000}
601+
pkgResponse := oauth2.NewTokenResponse("foo", "Bearer", in800000, "")
602602
ctx.authzServerMock.EXPECT().CreateAccessToken(gomock.Any(), services.CreateAccessTokenRequest{RawJwtBearerToken: validJwt}).Return(pkgResponse, nil)
603603

604604
expectedResponse := CreateAccessToken200JSONResponse{

auth/api/iam/api.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ func (r Wrapper) RetrieveAccessToken(_ context.Context, request RetrieveAccessTo
220220
if err != nil {
221221
return nil, err
222222
}
223-
if token.Status != nil && *token.Status == oauth.AccessTokenRequestStatusPending {
223+
if token.Get("status") == oauth.AccessTokenRequestStatusPending {
224224
// return pending status
225225
return RetrieveAccessToken200JSONResponse(token), nil
226226
}
@@ -625,10 +625,10 @@ func (r Wrapper) RequestUserAccessToken(ctx context.Context, request RequestUser
625625
if err != nil {
626626
return nil, err
627627
}
628-
status := oauth.AccessTokenRequestStatusPending
629-
err = r.accessTokenClientStore().Put(sessionID, TokenResponse{
630-
Status: &status,
631-
})
628+
tokenResponse := (&TokenResponse{}).With("status", oauth.AccessTokenRequestStatusPending)
629+
if err = r.accessTokenClientStore().Put(sessionID, tokenResponse); err != nil {
630+
return nil, err
631+
}
632632

633633
// generate a link to the redirect endpoint
634634
webURL, err := createOAuth2BaseURL(*requestHolder)
@@ -786,12 +786,13 @@ func (r Wrapper) CallbackOid4vciCredentialIssuance(ctx context.Context, request
786786
log.Logger().WithError(err).Error("cannot fetch the right endpoints")
787787
return nil, withCallbackURI(oauthError(oauth.ServerError, fmt.Sprintf("cannot fetch the right endpoints: %s", err.Error())), oid4vciSession.remoteRedirectUri())
788788
}
789-
response, err := r.auth.IAMClient().AccessTokenOid4vci(ctx, holderDid.String(), tokenEndpoint, oid4vciSession.RedirectUri, code, &pkceParams.Verifier)
789+
response, err := r.auth.IAMClient().AccessToken(ctx, code, *issuerDid, oid4vciSession.RedirectUri, *holderDid, pkceParams.Verifier)
790790
if err != nil {
791791
log.Logger().WithError(err).Errorf("error while fetching the access_token from endpoint: %s", tokenEndpoint)
792792
return nil, withCallbackURI(oauthError(oauth.AccessDenied, fmt.Sprintf("error while fetching the access_token from endpoint: %s, error: %s", tokenEndpoint, err.Error())), oid4vciSession.remoteRedirectUri())
793793
}
794-
proofJWT, err := r.proofJwt(ctx, *holderDid, *issuerDid, response.CNonce)
794+
cNonce := response.Get(oauth.CNonceParam)
795+
proofJWT, err := r.proofJwt(ctx, *holderDid, *issuerDid, &cNonce)
795796
if err != nil {
796797
log.Logger().WithError(err).Error("error while building proof")
797798
return nil, withCallbackURI(oauthError(oauth.ServerError, fmt.Sprintf("error while fetching the credential from endpoint %s, error: %s", credentialEndpoint, err.Error())), oid4vciSession.remoteRedirectUri())

auth/api/iam/api_test.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ func TestWrapper_Callback(t *testing.T) {
588588
var tokenResponse TokenResponse
589589
err = ctx.client.accessTokenClientStore().Get(token, &tokenResponse)
590590
require.NoError(t, err)
591-
assert.Equal(t, oauth.AccessTokenRequestStatusActive, *tokenResponse.Status)
591+
assert.Equal(t, oauth.AccessTokenRequestStatusActive, tokenResponse.Get("status"))
592592
assert.Equal(t, "access", tokenResponse.AccessToken)
593593
})
594594
t.Run("unknown did", func(t *testing.T) {
@@ -904,7 +904,7 @@ func TestWrapper_RequestUserAccessToken(t *testing.T) {
904904
var tokenResponse TokenResponse
905905
require.NotNil(t, redirectResponse.SessionId)
906906
err = ctx.client.accessTokenClientStore().Get(redirectResponse.SessionId, &tokenResponse)
907-
assert.Equal(t, oauth.AccessTokenRequestStatusPending, *tokenResponse.Status)
907+
assert.Equal(t, oauth.AccessTokenRequestStatusPending, tokenResponse.Get("status"))
908908
})
909909
t.Run("preauthorized_user", func(t *testing.T) {
910910
ctx := newTestClient(t)
@@ -1288,19 +1288,15 @@ func TestWrapper_CallbackOid4vciCredentialIssuance(t *testing.T) {
12881288
IssuerTokenEndpoint: tokenEndpoint,
12891289
IssuerCredentialEndpoint: credEndpoint,
12901290
}
1291-
tokenResponse := oauth.Oid4vciTokenResponse{
1292-
AccessToken: accessToken,
1293-
TokenType: "Bearer",
1294-
CNonce: &cNonce,
1295-
}
1291+
tokenResponse := oauth.NewTokenResponse(accessToken, "Bearer", 0, "").With("c_nonce", cNonce)
12961292
credentialResponse := iam.CredentialResponse{
12971293
Format: "jwt_vc",
12981294
Credential: verifiableCredential.Raw(),
12991295
}
13001296
t.Run("ok", func(t *testing.T) {
13011297
ctx := newTestClient(t)
13021298
ctx.client.storageEngine.GetSessionDatabase().GetStore(15*time.Minute, "oid4vci").Put(state, &session)
1303-
ctx.iamClient.EXPECT().AccessTokenOid4vci(nil, holderDID.String(), tokenEndpoint, redirectURI, code, &pkceParams.Verifier).Return(&tokenResponse, nil)
1299+
ctx.iamClient.EXPECT().AccessToken(nil, code, issuerDID, redirectURI, holderDID, pkceParams.Verifier).Return(tokenResponse, nil)
13041300
ctx.keyResolver.EXPECT().ResolveKey(holderDID, nil, resolver.NutsSigningKeyType).Return(ssi.MustParseURI("kid"), nil, nil)
13051301
ctx.jwtSigner.EXPECT().SignJWT(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return("signed-proof", nil)
13061302
ctx.iamClient.EXPECT().VerifiableCredentials(nil, credEndpoint, accessToken, "signed-proof").Return(&credentialResponse, nil)
@@ -1352,7 +1348,7 @@ func TestWrapper_CallbackOid4vciCredentialIssuance(t *testing.T) {
13521348
t.Run("fail_access_token", func(t *testing.T) {
13531349
ctx := newTestClient(t)
13541350
ctx.client.storageEngine.GetSessionDatabase().GetStore(15*time.Minute, "oid4vci").Put(state, &session)
1355-
ctx.iamClient.EXPECT().AccessTokenOid4vci(nil, holderDID.String(), tokenEndpoint, redirectURI, code, &pkceParams.Verifier).Return(nil, errors.New("FAIL"))
1351+
ctx.iamClient.EXPECT().AccessToken(nil, code, issuerDID, redirectURI, holderDID, pkceParams.Verifier).Return(nil, errors.New("FAIL"))
13561352

13571353
callback, err := ctx.client.CallbackOid4vciCredentialIssuance(nil, CallbackOid4vciCredentialIssuanceRequestObject{
13581354
Params: CallbackOid4vciCredentialIssuanceParams{
@@ -1368,7 +1364,7 @@ func TestWrapper_CallbackOid4vciCredentialIssuance(t *testing.T) {
13681364
t.Run("fail_credential_response", func(t *testing.T) {
13691365
ctx := newTestClient(t)
13701366
require.NoError(t, ctx.client.storageEngine.GetSessionDatabase().GetStore(15*time.Minute, "oid4vci").Put(state, &session))
1371-
ctx.iamClient.EXPECT().AccessTokenOid4vci(nil, holderDID.String(), tokenEndpoint, redirectURI, code, &pkceParams.Verifier).Return(&tokenResponse, nil)
1367+
ctx.iamClient.EXPECT().AccessToken(nil, code, issuerDID, redirectURI, holderDID, pkceParams.Verifier).Return(tokenResponse, nil)
13721368
ctx.keyResolver.EXPECT().ResolveKey(holderDID, nil, resolver.NutsSigningKeyType).Return(ssi.MustParseURI("kid"), nil, nil)
13731369
ctx.jwtSigner.EXPECT().SignJWT(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return("signed-proof", nil)
13741370
ctx.iamClient.EXPECT().VerifiableCredentials(nil, credEndpoint, accessToken, "signed-proof").Return(nil, errors.New("FAIL"))
@@ -1387,7 +1383,7 @@ func TestWrapper_CallbackOid4vciCredentialIssuance(t *testing.T) {
13871383
t.Run("fail_verify", func(t *testing.T) {
13881384
ctx := newTestClient(t)
13891385
require.NoError(t, ctx.client.storageEngine.GetSessionDatabase().GetStore(15*time.Minute, "oid4vci").Put(state, &session))
1390-
ctx.iamClient.EXPECT().AccessTokenOid4vci(nil, holderDID.String(), tokenEndpoint, redirectURI, code, &pkceParams.Verifier).Return(&tokenResponse, nil)
1386+
ctx.iamClient.EXPECT().AccessToken(nil, code, issuerDID, redirectURI, holderDID, pkceParams.Verifier).Return(tokenResponse, nil)
13911387
ctx.keyResolver.EXPECT().ResolveKey(holderDID, nil, resolver.NutsSigningKeyType).Return(ssi.MustParseURI("kid"), nil, nil)
13921388
ctx.jwtSigner.EXPECT().SignJWT(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return("signed-proof", nil)
13931389
ctx.iamClient.EXPECT().VerifiableCredentials(nil, credEndpoint, accessToken, "signed-proof").Return(&credentialResponse, nil)
@@ -1406,7 +1402,7 @@ func TestWrapper_CallbackOid4vciCredentialIssuance(t *testing.T) {
14061402
t.Run("error - key not found", func(t *testing.T) {
14071403
ctx := newTestClient(t)
14081404
require.NoError(t, ctx.client.storageEngine.GetSessionDatabase().GetStore(15*time.Minute, "oid4vci").Put(state, &session))
1409-
ctx.iamClient.EXPECT().AccessTokenOid4vci(nil, holderDID.String(), tokenEndpoint, redirectURI, code, &pkceParams.Verifier).Return(&tokenResponse, nil)
1405+
ctx.iamClient.EXPECT().AccessToken(nil, code, issuerDID, redirectURI, holderDID, pkceParams.Verifier).Return(tokenResponse, nil)
14101406
ctx.keyResolver.EXPECT().ResolveKey(holderDID, nil, resolver.NutsSigningKeyType).Return(ssi.URI{}, nil, resolver.ErrKeyNotFound)
14111407

14121408
callback, err := ctx.client.CallbackOid4vciCredentialIssuance(nil, CallbackOid4vciCredentialIssuanceRequestObject{
@@ -1423,7 +1419,7 @@ func TestWrapper_CallbackOid4vciCredentialIssuance(t *testing.T) {
14231419
t.Run("error - signature failure", func(t *testing.T) {
14241420
ctx := newTestClient(t)
14251421
require.NoError(t, ctx.client.storageEngine.GetSessionDatabase().GetStore(15*time.Minute, "oid4vci").Put(state, &session))
1426-
ctx.iamClient.EXPECT().AccessTokenOid4vci(nil, holderDID.String(), tokenEndpoint, redirectURI, code, &pkceParams.Verifier).Return(&tokenResponse, nil)
1422+
ctx.iamClient.EXPECT().AccessToken(nil, code, issuerDID, redirectURI, holderDID, pkceParams.Verifier).Return(tokenResponse, nil)
14271423
ctx.keyResolver.EXPECT().ResolveKey(holderDID, nil, resolver.NutsSigningKeyType).Return(ssi.MustParseURI("kid"), nil, nil)
14281424
ctx.jwtSigner.EXPECT().SignJWT(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return("", errors.New("signature failed"))
14291425

auth/api/iam/openid4vp.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,7 @@ func (r Wrapper) handleCallback(ctx context.Context, request CallbackRequestObje
689689
return nil, withCallbackURI(oauthError(oauth.ServerError, fmt.Sprintf("failed to retrieve access token: %s", err.Error())), appCallbackURI)
690690
}
691691
// update TokenResponse using session.SessionID
692-
statusActive := oauth.AccessTokenRequestStatusActive
693-
tokenResponse.Status = &statusActive
692+
tokenResponse = tokenResponse.With("status", oauth.AccessTokenRequestStatusActive)
694693
if err = r.accessTokenClientStore().Put(oauthSession.SessionID, tokenResponse); err != nil {
695694
return nil, withCallbackURI(oauthError(oauth.ServerError, fmt.Sprintf("failed to store access token: %s", err.Error())), appCallbackURI)
696695
}

auth/client/iam/client.go

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -211,51 +211,6 @@ func (hb HTTPClient) OpenIdCredentialIssuerMetadata(ctx context.Context, webDID
211211
return &metadata, err
212212
}
213213

214-
func (hb HTTPClient) AccessTokenOid4vci(ctx context.Context, presentationDefinitionURL url.URL, data url.Values) (*oauth.Oid4vciTokenResponse, error) {
215-
// create a POST request with x-www-form-urlencoded body
216-
request, err := http.NewRequestWithContext(ctx, http.MethodPost, presentationDefinitionURL.String(), strings.NewReader(data.Encode()))
217-
request.Header.Add("Accept", "application/json")
218-
request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
219-
if err != nil {
220-
return nil, err
221-
}
222-
response, err := hb.httpClient.Do(request.WithContext(ctx))
223-
if err != nil {
224-
return nil, fmt.Errorf("failed to call endpoint: %w", err)
225-
}
226-
if err = core.TestResponseCode(http.StatusOK, response); err != nil {
227-
// check for oauth error
228-
if innerErr := core.TestResponseCode(http.StatusBadRequest, response); innerErr != nil {
229-
// a non oauth error, the response body could contain a lot of stuff. We'll log and return the entire error
230-
log.Logger().Debugf("authorization server token endpoint returned non oauth error (statusCode=%d)", response.StatusCode)
231-
return nil, err
232-
}
233-
httpErr := err.(core.HttpError)
234-
oauthError := oauth.OAuth2Error{}
235-
if err := json.Unmarshal(httpErr.ResponseBody, &oauthError); err != nil {
236-
return nil, fmt.Errorf("unable to unmarshal OAuth error response: %w", err)
237-
}
238-
239-
return nil, oauthError
240-
}
241-
242-
var responseData []byte
243-
if responseData, err = core.LimitedReadAll(response.Body); err != nil {
244-
return nil, fmt.Errorf("unable to read response: %w", err)
245-
}
246-
247-
var token oauth.Oid4vciTokenResponse
248-
if err = json.Unmarshal(responseData, &token); err != nil {
249-
// Cut off the response body to 100 characters max to prevent logging of large responses
250-
responseBodyString := string(responseData)
251-
if len(responseBodyString) > 100 {
252-
responseBodyString = responseBodyString[:100] + "...(clipped)"
253-
}
254-
return nil, fmt.Errorf("unable to unmarshal response: %w, %s", err, string(responseData))
255-
}
256-
return &token, nil
257-
}
258-
259214
// CredentialRequest represents ths request to fetch a credential, the JSON object holds the proof as
260215
// CredentialRequestProof.
261216
type CredentialRequest struct {

auth/client/iam/interface.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ import (
2828

2929
// Client defines OpenID4VP client methods using the IAM OpenAPI Spec.
3030
type Client interface {
31-
// AccessToken requests an access token at the oauth2 token endpoint.
31+
// AccessToken requests an access token at the OAuth2 Token Endpoint.
32+
// The token endpoint can be a regular OAuth2 token endpoint or OpenID4VCI-related endpoint.
33+
// The response will be unmarshalled into the given tokenResponseOut parameter.
3234
AccessToken(ctx context.Context, code string, verifier did.DID, callbackURI string, clientID did.DID, codeVerifier string) (*oauth.TokenResponse, error)
3335
// AuthorizationServerMetadata returns the metadata of the remote wallet.
3436
AuthorizationServerMetadata(ctx context.Context, webdid did.DID) (*oauth.AuthorizationServerMetadata, error)
@@ -47,8 +49,6 @@ type Client interface {
4749

4850
OpenIdCredentialIssuerMetadata(ctx context.Context, webDID did.DID) (*oauth.OpenIDCredentialIssuerMetadata, error)
4951

50-
AccessTokenOid4vci(ctx context.Context, clientId string, tokenEndpoint string, redirectUri string, code string, pkceCodeVerifier *string) (*oauth.Oid4vciTokenResponse, error)
51-
5252
VerifiableCredentials(ctx context.Context, credentialEndpoint string, accessToken string, proofJWT string) (*CredentialResponse, error)
5353
// RequestObject is returned from the authorization request's 'request_uri' defined in RFC9101.
5454
RequestObject(ctx context.Context, requestURI string) (string, error)

auth/client/iam/mock.go

Lines changed: 0 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

auth/client/iam/openid4vp.go

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -231,28 +231,6 @@ func (c *OpenID4VPClient) OpenIdCredentialIssuerMetadata(ctx context.Context, we
231231
return rsp, nil
232232
}
233233

234-
func (c *OpenID4VPClient) AccessTokenOid4vci(ctx context.Context, clientId string, tokenEndpoint string, redirectUri string, code string, pkceCodeVerifier *string) (*oauth.Oid4vciTokenResponse, error) {
235-
iamClient := c.httpClient
236-
data := url.Values{}
237-
data.Set("client_id", clientId)
238-
data.Set(oauth.GrantTypeParam, oauth.AuthorizationCodeGrantType)
239-
data.Set(oauth.CodeParam, code)
240-
data.Set("redirect_uri", redirectUri)
241-
if pkceCodeVerifier != nil {
242-
data.Set("code_verifier", *pkceCodeVerifier)
243-
}
244-
presentationDefinitionURL, err := url.Parse(tokenEndpoint)
245-
if err != nil {
246-
return nil, err
247-
}
248-
249-
rsp, err := iamClient.AccessTokenOid4vci(ctx, *presentationDefinitionURL, data)
250-
if err != nil {
251-
return nil, fmt.Errorf("remote server: failed to retrieve an access_token: %w", err)
252-
}
253-
return rsp, nil
254-
}
255-
256234
func (c *OpenID4VPClient) VerifiableCredentials(ctx context.Context, credentialEndpoint string, accessToken string, proofJWT string) (*CredentialResponse, error) {
257235
iamClient := c.httpClient
258236
rsp, err := iamClient.VerifiableCredentials(ctx, credentialEndpoint, accessToken, proofJWT)

auth/client/iam/openid4vp_test.go

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -539,31 +539,7 @@ func TestIAMClient_OpenIdCredentialIssuerMetadata(t *testing.T) {
539539
})
540540

541541
}
542-
func TestIAMClient_AccessTokenOid4vci(t *testing.T) {
543-
code := "code"
544-
redirectUri := "https://test.test/callback"
545-
pkceCodeVerifier := "verifier"
546-
547-
t.Run("ok", func(t *testing.T) {
548-
ctx := createClientServerTestContext(t)
549-
550-
response, err := ctx.client.AccessTokenOid4vci(context.Background(), ctx.verifierDID.String(), ctx.openIDConfigurationMetadata.TokenEndpoint, redirectUri, code, &pkceCodeVerifier)
551-
552-
require.NoError(t, err)
553-
require.NotNil(t, response)
554-
assert.Equal(t, "token", response.AccessToken)
555-
assert.Equal(t, "bearer", response.TokenType)
556-
})
557-
t.Run("error - failed to get access token", func(t *testing.T) {
558-
ctx := createClientServerTestContext(t)
559-
ctx.token = nil
560542

561-
response, err := ctx.client.AccessTokenOid4vci(context.Background(), ctx.verifierDID.String(), ctx.openIDConfigurationMetadata.TokenEndpoint, redirectUri, code, &pkceCodeVerifier)
562-
563-
assert.EqualError(t, err, "remote server: failed to retrieve an access_token: server returned HTTP 404 (expected: 200)")
564-
assert.Nil(t, response)
565-
})
566-
}
567543
func TestIAMClient_VerifiableCredentials(t *testing.T) {
568544
//walletDID := did.MustParseDID("did:web:test.test:iam:123")
569545
accessToken := "code"

auth/oauth/test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (C) 2024 Nuts community
3+
*
4+
* This program is free software: you can redistribute it and/or modify
5+
* it under the terms of the GNU General Public License as published by
6+
* the Free Software Foundation, either version 3 of the License, or
7+
* (at your option) any later version.
8+
*
9+
* This program is distributed in the hope that it will be useful,
10+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
* GNU General Public License for more details.
13+
*
14+
* You should have received a copy of the GNU General Public License
15+
* along with this program. If not, see <https://www.gnu.org/licenses/>.
16+
*
17+
*/
18+
19+
package oauth
20+
21+
// NewTokenResponse is a convenience function for creating a TokenResponse with the given parameters.
22+
// expires_in and scope are only set if they are passed a valid value.
23+
func NewTokenResponse(accessToken, tokenType string, expiresIn int, scope string) *TokenResponse {
24+
tr := &TokenResponse{
25+
AccessToken: accessToken,
26+
TokenType: tokenType,
27+
}
28+
if expiresIn > 0 {
29+
tr.ExpiresIn = &expiresIn
30+
}
31+
if scope != "" {
32+
tr.Scope = &scope
33+
}
34+
return tr
35+
}

0 commit comments

Comments
 (0)