Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gerardsn committed May 13, 2024
1 parent 66b1d52 commit 6e49453
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 55 deletions.
18 changes: 9 additions & 9 deletions auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,16 @@ func (r Wrapper) HandleAuthorizeRequest(ctx context.Context, request HandleAutho
}

// handleAuthorizeRequest handles calls to the authorization endpoint for starting an authorization code flow.
// ownDID must be validated by the caller
// The caller must ensure ownDID is actually owned by this node.
func (r Wrapper) handleAuthorizeRequest(ctx context.Context, ownDID did.DID, request url.URL) (HandleAuthorizeRequestResponseObject, error) {
// parse and validate as JAR (RFC9101, JWT Authorization Request)
authzParams, err := r.jar.Parse(ctx, ownDID, request.Query())
requestObject, err := r.jar.Parse(ctx, ownDID, request.Query())
if err != nil {
// already an oauth.OAuth2Error
return nil, err
}

switch authzParams.get(oauth.ResponseTypeParam) {
switch requestObject.get(oauth.ResponseTypeParam) {
case responseTypeCode:
// Options:
// - Regular authorization code flow for EHR data access through access token, authentication of end-user using OpenID4VP.
Expand All @@ -349,10 +349,10 @@ func (r Wrapper) handleAuthorizeRequest(ctx context.Context, ownDID did.DID, req
// when client_id is a did:web, it is a cloud/server wallet
// otherwise it's a normal registered client which we do not support yet
// Note: this is the user facing OpenID4VP flow with a "vp_token" responseType, the demo uses the "vp_token id_token" responseType
clientId := authzParams.get(oauth.ClientIDParam)
clientId := requestObject.get(oauth.ClientIDParam)
if strings.HasPrefix(clientId, "did:web:") {
// client is a cloud wallet with user
return r.handleAuthorizeRequestFromHolder(ctx, ownDID, authzParams)
return r.handleAuthorizeRequestFromHolder(ctx, ownDID, requestObject)
} else {
return nil, oauth.OAuth2Error{
Code: oauth.InvalidRequest,
Expand All @@ -369,10 +369,10 @@ func (r Wrapper) handleAuthorizeRequest(ctx context.Context, ownDID did.DID, req
if strings.HasPrefix(request.String(), "openid4vp:") {
walletOwnerType = pe.WalletOwnerUser
}
return r.handleAuthorizeRequestFromVerifier(ctx, ownDID, authzParams, walletOwnerType)
return r.handleAuthorizeRequestFromVerifier(ctx, ownDID, requestObject, walletOwnerType)
default:
// TODO: This should be a redirect?
redirectURI, _ := url.Parse(authzParams.get(oauth.RedirectURIParam))
redirectURI, _ := url.Parse(requestObject.get(oauth.RedirectURIParam))
return nil, oauth.OAuth2Error{
Code: oauth.UnsupportedResponseType,
RedirectURI: redirectURI,
Expand All @@ -395,7 +395,7 @@ func (r Wrapper) GetRequestJWT(ctx context.Context, request GetRequestJWTRequest
if ro.Client.String() != request.Did {
return nil, oauth.OAuth2Error{
Code: oauth.InvalidRequest,
Description: "request object not found",
Description: "client_id does not match request",
InternalError: errors.New("DID does not match client_id for requestID"),
}
}
Expand Down Expand Up @@ -440,7 +440,7 @@ func (r Wrapper) PostRequestJWT(ctx context.Context, request PostRequestJWTReque
if ro.Client.String() != request.Did {
return nil, oauth.OAuth2Error{
Code: oauth.InvalidRequest,
Description: "request object not found",
Description: "client_id does not match request",
InternalError: errors.New("DID does not match client_id for requestID"),
}
}
Expand Down
143 changes: 126 additions & 17 deletions auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -942,15 +942,14 @@ func TestWrapper_StatusList(t *testing.T) {
}

func TestWrapper_GetRequestJWT(t *testing.T) {
cont := context.Background()
requestID := "thisID"
expectedToken := "validToken"
t.Run("ok", func(t *testing.T) {
cont := context.Background()
requestID := "thisID"
expectedToken := "validToken"
ro := jar{}.Create(webDID, &holderDID, func(claims map[string]string) {})

ctx := newTestClient(t)
ctx.jar.EXPECT().Sign(cont, ro.Claims).Return(expectedToken, nil)
ro := jar{}.Create(webDID, &holderDID, func(claims map[string]string) {})
require.NoError(t, ctx.client.authzRequestObjectStore().Put(requestID, ro))
ctx.jar.EXPECT().Sign(cont, ro.Claims).Return(expectedToken, nil)

response, err := ctx.client.GetRequestJWT(cont, GetRequestJWTRequestObject{Did: webDID.String(), Id: requestID})

Expand All @@ -968,15 +967,126 @@ func TestWrapper_GetRequestJWT(t *testing.T) {
assert.Nil(t, response)
assert.EqualError(t, err, "invalid_request - request object not found")
})
t.Run("error - clientID does not match request", func(t *testing.T) {
ctx := newTestClient(t)
ro := jar{}.Create(webDID, &holderDID, func(claims map[string]string) {})
require.NoError(t, ctx.client.authzRequestObjectStore().Put(requestID, ro))

response, err := ctx.client.GetRequestJWT(cont, GetRequestJWTRequestObject{Did: holderDID.String(), Id: requestID})

assert.Nil(t, response)
assert.EqualError(t, err, "invalid_request - DID does not match client_id for requestID - client_id does not match request")
})
t.Run("error - wrong request_uri_method used", func(t *testing.T) {
ctx := newTestClient(t)
ro := jar{}.Create(webDID, &holderDID, func(claims map[string]string) {})
ro.RequestURIMethod = "post"
require.NoError(t, ctx.client.authzRequestObjectStore().Put(requestID, ro))

response, err := ctx.client.GetRequestJWT(cont, GetRequestJWTRequestObject{Did: webDID.String(), Id: requestID})

assert.Nil(t, response)
assert.EqualError(t, err, "invalid_request - wrong 'request_uri_method' authorization server or wallet probably does not support 'request_uri_method' - used request_uri_method 'get' on a 'post' request_uri")
})
t.Run("error - signing failed", func(t *testing.T) {
ctx := newTestClient(t)
ro := jar{}.Create(webDID, &holderDID, func(claims map[string]string) {})
require.NoError(t, ctx.client.authzRequestObjectStore().Put(requestID, ro))
ctx.jar.EXPECT().Sign(cont, ro.Claims).Return("", errors.New("fail"))

response, err := ctx.client.GetRequestJWT(cont, GetRequestJWTRequestObject{Did: webDID.String(), Id: requestID})

assert.Nil(t, response)
assert.EqualError(t, err, "server_error - fail - failed to sign authorization RequestObject")
})
}

func TestWrapper_PostRequestJWT(t *testing.T) {
//ctx := newTestClient(t)
//
//response, err := ctx.client.PostRequestJWT(nil, PostRequestJWTRequestObject{Id: "unknownID"})
//
//assert.Nil(t, response)
//assert.EqualError(t, err, "invalid_request - not implemented")
cont := context.Background()
requestID := "thisID"
expectedToken := "validToken"
newReqObj := func(issuer, nonce string) jarRequest {
ro := jar{}.Create(webDID, nil, func(claims map[string]string) {})
if issuer != "" {
ro.Claims[jwt.AudienceKey] = issuer
}
if nonce != "" {
ro.Claims[oauth.WalletNonceParam] = nonce
}
return ro
}
t.Run("ok", func(t *testing.T) {
ctx := newTestClient(t)
ro := newReqObj("https://self-issued.me/v2", "")
require.NoError(t, ctx.client.authzRequestObjectStore().Put(requestID, ro))
ctx.jar.EXPECT().Sign(cont, ro.Claims).Return(expectedToken, nil)

response, err := ctx.client.PostRequestJWT(cont, PostRequestJWTRequestObject{Did: webDID.String(), Id: requestID})

assert.NoError(t, err)
assert.Equal(t, PostRequestJWT200ApplicationoauthAuthzReqJwtResponse{
Body: bytes.NewReader([]byte(expectedToken)),
ContentLength: 10,
}, response)
})
t.Run("ok - with metadata and nonce", func(t *testing.T) {
wallet_nonce := "wallet_nonce"
ctx := newTestClient(t)
ro := newReqObj("mario", wallet_nonce)
require.NoError(t, ctx.client.authzRequestObjectStore().Put(requestID, ro))
ctx.jar.EXPECT().Sign(cont, ro.Claims).Return(expectedToken, nil)
body := PostRequestJWTFormdataRequestBody(PostRequestJWTFormdataBody{
WalletMetadata: &oauth.AuthorizationServerMetadata{Issuer: "mario"},
WalletNonce: &wallet_nonce,
})

response, err := ctx.client.PostRequestJWT(cont, PostRequestJWTRequestObject{Did: webDID.String(), Id: requestID, Body: &body})

assert.NoError(t, err)
assert.Equal(t, PostRequestJWT200ApplicationoauthAuthzReqJwtResponse{
Body: bytes.NewReader([]byte(expectedToken)),
ContentLength: 10,
}, response)
})
t.Run("error - not found", func(t *testing.T) {
ctx := newTestClient(t)

response, err := ctx.client.PostRequestJWT(nil, PostRequestJWTRequestObject{Id: "unknownID"})

assert.Nil(t, response)
assert.EqualError(t, err, "invalid_request - request object not found")
})
t.Run("error - clientID does not match request", func(t *testing.T) {
ctx := newTestClient(t)
require.NoError(t, ctx.client.authzRequestObjectStore().Put(requestID, newReqObj("", "")))

response, err := ctx.client.PostRequestJWT(cont, PostRequestJWTRequestObject{Did: holderDID.String(), Id: requestID})

assert.Nil(t, response)
assert.EqualError(t, err, "invalid_request - DID does not match client_id for requestID - client_id does not match request")
})
t.Run("error - wrong request_uri_method used", func(t *testing.T) {
ctx := newTestClient(t)
ro := newReqObj("", "")
ro.RequestURIMethod = "get"
require.NoError(t, ctx.client.authzRequestObjectStore().Put(requestID, ro))

response, err := ctx.client.PostRequestJWT(cont, PostRequestJWTRequestObject{Did: webDID.String(), Id: requestID})

assert.Nil(t, response)
assert.EqualError(t, err, "invalid_request - used request_uri_method 'post' on a 'get' request_uri")
})
t.Run("error - signing failed", func(t *testing.T) {
ctx := newTestClient(t)
ro := newReqObj("https://self-issued.me/v2", "")
require.NoError(t, ctx.client.authzRequestObjectStore().Put(requestID, ro))
ctx.jar.EXPECT().Sign(cont, ro.Claims).Return("", errors.New("fail"))

response, err := ctx.client.PostRequestJWT(cont, PostRequestJWTRequestObject{Did: webDID.String(), Id: requestID})

assert.Nil(t, response)
assert.EqualError(t, err, "server_error - fail - failed to sign authorization RequestObject")
})
}

func TestWrapper_CreateAuthorizationRequest(t *testing.T) {
Expand Down Expand Up @@ -1014,22 +1124,21 @@ func TestWrapper_CreateAuthorizationRequest(t *testing.T) {
require.NoError(t, ctx.client.authzRequestObjectStore().Get(requestURIID, &jarReq))
assert.Equal(t, expectedJarReq, jarReq)
})
t.Run("ok - RequireSignedRequestObject=false", func(t *testing.T) {
t.Run("ok - no server -> RequireSignedRequestObject=false", func(t *testing.T) {
var expectedJarReq jarRequest
ctx := newTestClient(t)
ctx.iamClient.EXPECT().AuthorizationServerMetadata(gomock.Any(), serverDID).Return(&oauth.AuthorizationServerMetadata{AuthorizationEndpoint: serverMetadata.AuthorizationEndpoint}, nil)
ctx.jar.EXPECT().Create(clientDID, &serverDID, gomock.Any()).DoAndReturn(func(client did.DID, server *did.DID, modifier requestObjectModifier) jarRequest {
ctx.jar.EXPECT().Create(clientDID, nil, gomock.Any()).DoAndReturn(func(client did.DID, server *did.DID, modifier requestObjectModifier) jarRequest {
expectedJarReq = createJarRequest(client, server, modifier)
assert.Equal(t, "value", expectedJarReq.Claims.get("custom"))
return expectedJarReq
})

redirectURL, err := ctx.client.CreateAuthorizationRequest(context.Background(), clientDID, &serverDID, modifier)
redirectURL, err := ctx.client.CreateAuthorizationRequest(context.Background(), clientDID, nil, modifier)

assert.NoError(t, err)
assert.Equal(t, "value", redirectURL.Query().Get("custom"))
assert.Equal(t, clientDID.String(), redirectURL.Query().Get(oauth.ClientIDParam))
assert.Equal(t, "get", redirectURL.Query().Get(oauth.RequestURIMethodParam))
assert.Equal(t, "post", redirectURL.Query().Get(oauth.RequestURIMethodParam))
assert.NotEmpty(t, redirectURL.Query().Get(oauth.RequestURIParam))
})
t.Run("error - missing authorization endpoint", func(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion auth/api/iam/jar.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,15 @@ func (j jar) Parse(ctx context.Context, ownDID did.DID, q url.Values) (oauthPara
case "post":
md, err := authorizationServerMetadata(ownDID)
if err != nil {
// DB error
return nil, err
}
rawRequestObject, err = j.auth.IAMClient().RequestObject(ctx, requestURI, "post", md)
if err != nil {
return nil, oauth.OAuth2Error{Code: oauth.InvalidRequestURI, Description: "failed to get Request Object", InternalError: err}
}
default:
return nil, oauth.OAuth2Error{Code: oauth.InvalidRequest, Description: "unsupported request_uri_method"}
return nil, oauth.OAuth2Error{Code: oauth.InvalidRequestURIMethod, Description: "unsupported request_uri_method"}
}
} else {
// require_signed_request_object is true, so we reject anything that isn't
Expand Down
93 changes: 75 additions & 18 deletions auth/api/iam/jar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,61 @@ func TestJar_Parse(t *testing.T) {
require.NoError(t, err)
token := string(bytes)
ctx := newJarTestCtx(t)
t.Run("ok - 'request_uri'", func(t *testing.T) {
ctx.iamClient.EXPECT().RequestObject(context.Background(), "request_uri", "get", nil).Return(token, nil)
ctx.keyResolver.EXPECT().ResolveKeyByID(key.KID(), nil, resolver.AssertionMethod).Return(key.Public(), nil)
t.Run("request_uri_method", func(t *testing.T) {
t.Run("ok - get", func(t *testing.T) {
ctx.iamClient.EXPECT().RequestObject(context.Background(), "request_uri", "get", nil).Return(token, nil)
ctx.keyResolver.EXPECT().ResolveKeyByID(key.KID(), nil, resolver.AssertionMethod).Return(key.Public(), nil)

res, err := ctx.jar.Parse(context.Background(), verifierDID,
map[string][]string{
oauth.ClientIDParam: {holderDID.String()},
oauth.RequestURIParam: {"request_uri"},
})
res, err := ctx.jar.Parse(context.Background(), verifierDID,
map[string][]string{
oauth.ClientIDParam: {holderDID.String()},
oauth.RequestURIParam: {"request_uri"},
oauth.RequestURIMethodParam: {"get"},
})

assert.NoError(t, err)
require.NotNil(t, res)
assert.NoError(t, err)
require.NotNil(t, res)
})
t.Run("ok - param not supported", func(t *testing.T) {
ctx.iamClient.EXPECT().RequestObject(context.Background(), "request_uri", "get", nil).Return(token, nil)
ctx.keyResolver.EXPECT().ResolveKeyByID(key.KID(), nil, resolver.AssertionMethod).Return(key.Public(), nil)

res, err := ctx.jar.Parse(context.Background(), verifierDID,
map[string][]string{
oauth.ClientIDParam: {holderDID.String()},
oauth.RequestURIParam: {"request_uri"},
oauth.RequestURIMethodParam: {""},
})

assert.NoError(t, err)
require.NotNil(t, res)
})
t.Run("ok - post", func(t *testing.T) {
md, _ := authorizationServerMetadata(verifierDID)
ctx.iamClient.EXPECT().RequestObject(context.Background(), "request_uri", "post", md).Return(token, nil)
ctx.keyResolver.EXPECT().ResolveKeyByID(key.KID(), nil, resolver.AssertionMethod).Return(key.Public(), nil)

res, err := ctx.jar.Parse(context.Background(), verifierDID,
map[string][]string{
oauth.ClientIDParam: {holderDID.String()},
oauth.RequestURIParam: {"request_uri"},
oauth.RequestURIMethodParam: {"post"},
})

assert.NoError(t, err)
require.NotNil(t, res)
})
t.Run("error - unsupported method", func(t *testing.T) {
res, err := ctx.jar.Parse(context.Background(), verifierDID,
map[string][]string{
oauth.ClientIDParam: {holderDID.String()},
oauth.RequestURIParam: {"request_uri"},
oauth.RequestURIMethodParam: {"invalid"},
})

assert.EqualError(t, err, "invalid_request_uri_method - unsupported request_uri_method")
assert.Nil(t, res)
})
})
t.Run("ok - 'request'", func(t *testing.T) {
ctx.keyResolver.EXPECT().ResolveKeyByID(key.KID(), nil, resolver.AssertionMethod).Return(key.Public(), nil)
Expand All @@ -138,15 +181,29 @@ func TestJar_Parse(t *testing.T) {
assert.NoError(t, err)
require.NotNil(t, res)
})
t.Run("error - server error", func(t *testing.T) {
ctx.iamClient.EXPECT().RequestObject(context.Background(), "request_uri", "get", nil).Return("", errors.New("server error"))
res, err := ctx.jar.Parse(context.Background(), verifierDID,
map[string][]string{
oauth.RequestURIParam: {"request_uri"},
})
t.Run("server error", func(t *testing.T) {
t.Run("get", func(t *testing.T) {
ctx.iamClient.EXPECT().RequestObject(context.Background(), "request_uri", "get", nil).Return("", errors.New("server error"))
res, err := ctx.jar.Parse(context.Background(), verifierDID,
map[string][]string{
oauth.RequestURIParam: {"request_uri"},
})

requireOAuthError(t, err, oauth.InvalidRequestURI, "failed to get Request Object")
assert.Nil(t, res)
requireOAuthError(t, err, oauth.InvalidRequestURI, "failed to get Request Object")
assert.Nil(t, res)
})
t.Run("post", func(t *testing.T) {
md, _ := authorizationServerMetadata(verifierDID)
ctx.iamClient.EXPECT().RequestObject(context.Background(), "request_uri", "post", md).Return("", errors.New("server error"))
res, err := ctx.jar.Parse(context.Background(), verifierDID,
map[string][]string{
oauth.RequestURIParam: {"request_uri"},
oauth.RequestURIMethodParam: {"post"},
})

requireOAuthError(t, err, oauth.InvalidRequestURI, "failed to get Request Object")
assert.Nil(t, res)
})
})
t.Run("error - both 'request' and 'request_uri'", func(t *testing.T) {
res, err := ctx.jar.Parse(context.Background(), verifierDID,
Expand Down
Loading

0 comments on commit 6e49453

Please sign in to comment.