Skip to content

Commit

Permalink
Refactor and merge Callbacks (#3151)
Browse files Browse the repository at this point in the history
* refactor and merge Callbacks

* cleanup + fix proof JWT

* pr feedback
  • Loading branch information
gerardsn authored May 31, 2024
1 parent 0568158 commit 219beee
Show file tree
Hide file tree
Showing 14 changed files with 406 additions and 600 deletions.
58 changes: 51 additions & 7 deletions auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,19 +240,63 @@ func (r Wrapper) HandleTokenRequest(ctx context.Context, request HandleTokenRequ
}

func (r Wrapper) Callback(ctx context.Context, request CallbackRequestObject) (CallbackResponseObject, error) {
// check id in path
_, err := r.toOwnedDID(ctx, request.Did)
// validate request
// check did in path
ownDID, err := r.toOwnedDID(ctx, request.Did)
if err != nil {
// this is an OAuthError already, will be rendered as 400 but that's fine (for now) for an illegal id
return nil, err
}
// check if state is present and resolves to a client state
if request.Params.State == nil || *request.Params.State == "" {
// without state it is an invalid request, but try to provide as much useful information as possible
if request.Params.Error != nil && *request.Params.Error != "" {
callbackError := callbackRequestToError(request, nil)
callbackError.InternalError = errors.New("missing state parameter")
return nil, callbackError
}
return nil, oauthError(oauth.InvalidRequest, "missing state parameter")
}
oauthSession := new(OAuthSession)
if err = r.oauthClientStateStore().Get(*request.Params.State, oauthSession); err != nil {
return nil, oauthError(oauth.InvalidRequest, "invalid or expired state", err)
}
if !ownDID.Equals(*oauthSession.OwnDID) {
// TODO: this is a manipulated request, add error logging?
return nil, withCallbackURI(oauthError(oauth.InvalidRequest, "session DID does not match request"), oauthSession.redirectURI())
}

// if error is present, redirect error back to application initiating the flow
if request.Params.Error != nil && *request.Params.Error != "" {
return nil, callbackRequestToError(request, oauthSession.redirectURI())
}

// if error is present, delegate call to error handler
if request.Params.Error != nil {
return r.handleCallbackError(request)
// check if code is present
if request.Params.Code == nil || *request.Params.Code == "" {
return nil, withCallbackURI(oauthError(oauth.InvalidRequest, "missing code parameter"), oauthSession.redirectURI())
}

return r.handleCallback(ctx, request)
// continue flow
switch oauthSession.ClientFlow {
case credentialRequestClientFlow:
return r.handleOpenID4VCICallback(ctx, *request.Params.Code, oauthSession)
case accessTokenRequestClientFlow:
return r.handleCallback(ctx, *request.Params.Code, oauthSession)
default:
// programming error, should never happen
return nil, withCallbackURI(oauthError(oauth.ServerError, "unknown client flow for callback: '"+oauthSession.ClientFlow+"'"), oauthSession.redirectURI())
}
}

// callbackRequestToError should only be used if request.params.Error is present
func callbackRequestToError(request CallbackRequestObject, redirectURI *url.URL) oauth.OAuth2Error {
requestErr := oauth.OAuth2Error{
Code: oauth.ErrorCode(*request.Params.Error),
RedirectURI: redirectURI,
}
if request.Params.ErrorDescription != nil {
requestErr.Description = *request.Params.ErrorDescription
}
return requestErr
}

func (r Wrapper) RetrieveAccessToken(_ context.Context, request RetrieveAccessTokenRequestObject) (RetrieveAccessTokenResponseObject, error) {
Expand Down
144 changes: 122 additions & 22 deletions auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,11 @@ func TestWrapper_HandleAuthorizeRequest(t *testing.T) {
// handleAuthorizeRequestFromVerifier
_ = ctx.client.storageEngine.GetSessionDatabase().GetStore(oAuthFlowTimeout, oauthClientStateKey...).Put("state", OAuthSession{
// this is the state from the holder that was stored at the creation of the first authorization request to the verifier
ClientID: holderDID.String(),
Scope: "test",
OwnDID: &holderDID,
ClientState: "state",
RedirectURI: "https://example.com/iam/holder/cb",
ResponseType: "code",
ClientID: holderDID.String(),
Scope: "test",
OwnDID: &holderDID,
ClientState: "state",
RedirectURI: "https://example.com/iam/holder/cb",
})
callCtx, _ := user.CreateTestSession(requestContext(nil), holderDID)
clientMetadata := oauth.OAuthClientMetadata{VPFormats: oauth.DefaultOpenIDSupportedFormats()}
Expand Down Expand Up @@ -454,31 +453,40 @@ func TestWrapper_Callback(t *testing.T) {
errorDescription := "error description"
state := "state"
token := "token"
redirectURI, parseErr := url.Parse("https://example.com/iam/holder/cb")
require.NoError(t, parseErr)

session := OAuthSession{
ClientFlow: "access_token_request",
SessionID: "token",
OwnDID: &holderDID,
RedirectURI: "https://example.com/iam/holder/cb",
VerifierDID: &verifierDID,
RedirectURI: redirectURI.String(),
OtherDID: &verifierDID,
TokenEndpoint: "https://example.com/token",
}

t.Run("ok - error flow", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
putState(ctx, "state", session)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
Params: CallbackParams{
State: &state,
Error: &errorCode,
ErrorDescription: &errorDescription,
},
})

require.NoError(t, err)
assert.Equal(t, "https://example.com/iam/holder/cb?error=error&error_description=error+description", res.(Callback302Response).Headers.Location)
var oauthErr oauth.OAuth2Error
require.ErrorAs(t, err, &oauthErr)
assert.Equal(t, oauth.OAuth2Error{
Code: oauth.ErrorCode(errorCode),
Description: errorDescription,
RedirectURI: redirectURI,
}, err)
assert.Nil(t, res)
})
t.Run("ok - success flow", func(t *testing.T) {
ctx := newTestClient(t)
Expand All @@ -487,11 +495,11 @@ func TestWrapper_Callback(t *testing.T) {
putState(ctx, "state", withDPoP)
putToken(ctx, token)
codeVerifier := getState(ctx, state).PKCEParams.Verifier
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil).Times(2)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:123/callback", holderDID, codeVerifier, true).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:holder/callback", holderDID, codeVerifier, true).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
Expand All @@ -511,21 +519,22 @@ func TestWrapper_Callback(t *testing.T) {
t.Run("ok - no DPoP", func(t *testing.T) {
ctx := newTestClient(t)
_ = ctx.client.oauthClientStateStore().Put(state, OAuthSession{
ClientFlow: "access_token_request",
OwnDID: &holderDID,
PKCEParams: generatePKCEParams(),
RedirectURI: "https://example.com/iam/holder/cb",
SessionID: "token",
UseDPoP: false,
VerifierDID: &verifierDID,
OtherDID: &verifierDID,
TokenEndpoint: session.TokenEndpoint,
})
putToken(ctx, token)
codeVerifier := getState(ctx, state).PKCEParams.Verifier
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil).Times(2)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:123/callback", holderDID, codeVerifier, false).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:holder/callback", holderDID, codeVerifier, false).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
Expand All @@ -535,17 +544,108 @@ func TestWrapper_Callback(t *testing.T) {
require.NoError(t, err)
assert.NotNil(t, res)
})
t.Run("unknown did", func(t *testing.T) {
t.Run("err - unknown did", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(false, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(false, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
})

assert.EqualError(t, err, "DID document not managed by this node")
assert.Nil(t, res)
})
t.Run("err - did mismatch", func(t *testing.T) {
ctx := newTestClient(t)
putState(ctx, "state", session)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
},
})

assert.Nil(t, res)
requireOAuthError(t, err, oauth.InvalidRequest, "session DID does not match request")

})
t.Run("err - missing state", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
},
})

requireOAuthError(t, err, oauth.InvalidRequest, "missing state parameter")
})
t.Run("err - error flow but missing state", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
Error: &errorCode,
ErrorDescription: &errorDescription,
},
})

requireOAuthError(t, err, oauth.ErrorCode(errorCode), errorDescription)
assert.EqualError(t, err, "error - missing state parameter - error description")
})
t.Run("err - expired state/session", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
},
})

requireOAuthError(t, err, oauth.InvalidRequest, "invalid or expired state")
})
t.Run("err - missing code", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
putState(ctx, "state", session)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
State: &state,
},
})

requireOAuthError(t, err, oauth.InvalidRequest, "missing code parameter")
})
t.Run("err - unknown flow", func(t *testing.T) {
ctx := newTestClient(t)
_ = ctx.client.oauthClientStateStore().Put(state, OAuthSession{
ClientFlow: "",
OwnDID: &holderDID,
})
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
},
})

requireOAuthError(t, err, oauth.ServerError, "unknown client flow for callback: ''")
})
}

func TestWrapper_RetrieveAccessToken(t *testing.T) {
Expand Down
Loading

0 comments on commit 219beee

Please sign in to comment.