diff --git a/auth/client/iam/client.go b/auth/client/iam/client.go index 657200e8f..1dc1d9b05 100644 --- a/auth/client/iam/client.go +++ b/auth/client/iam/client.go @@ -23,11 +23,15 @@ import ( "context" "encoding/json" "fmt" + "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/nuts-foundation/nuts-node/crypto" + "github.com/nuts-foundation/nuts-node/vdr/resolver" "io" "net/http" "net/url" "strings" + "time" "github.com/nuts-foundation/go-did/vc" "github.com/nuts-foundation/nuts-node/auth/log" @@ -38,8 +42,9 @@ import ( // HTTPClient holds the server address and other basic settings for the http client type HTTPClient struct { - strictMode bool - httpClient core.HTTPRequestDoer + strictMode bool + keyResolver resolver.KeyResolver + httpClient core.HTTPRequestDoer } // OAuthAuthorizationServerMetadata retrieves the OAuth authorization server metadata for the given oauth issuer. @@ -232,10 +237,8 @@ func (hb HTTPClient) OpenIDConfiguration(ctx context.Context, issuerURL string) return nil, err } var configuration oauth.OpenIDConfiguration - request, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL.String(), nil) - if err != nil { - return nil, err - } + // url already checked + request, _ := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL.String(), nil) response, err := hb.httpClient.Do(request.WithContext(ctx)) if err != nil { return nil, fmt.Errorf("failed to call endpoint: %w", err) @@ -247,8 +250,8 @@ func (hb HTTPClient) OpenIDConfiguration(ctx context.Context, issuerURL string) if data, err = core.LimitedReadAll(response.Body); err != nil { return nil, fmt.Errorf("unable to read response: %w", err) } - // todo check kid against something? get keys from somewhere? (issuerURL to keys) - token, err := jwt.Parse(data, jwt.WithVerify(false)) + // kid is checked against did resolver + token, err := jwt.Parse(data, jwt.WithKeyProvider(hb.KeyProvider()), jwt.WithAcceptableSkew(5*time.Second)) if err != nil { return nil, fmt.Errorf("unable to parse response: %w", err) } @@ -259,13 +262,28 @@ func (hb HTTPClient) OpenIDConfiguration(ctx context.Context, issuerURL string) // hack, broken iat claims["iat"] = token.IssuedAt().Unix() asJSON, _ := json.Marshal(claims) - println("TOKEN ", string(asJSON)) if err = json.Unmarshal(asJSON, &configuration); err != nil { return nil, fmt.Errorf("unable to unmarshal response: %w", err) } return &configuration, err } +func (hb HTTPClient) KeyProvider() jws.KeyProviderFunc { + return func(context context.Context, keySink jws.KeySink, signature *jws.Signature, message *jws.Message) error { + keyID := signature.ProtectedHeaders().KeyID() + publicKey, err := hb.keyResolver.ResolveKeyByID(keyID, nil, resolver.AssertionMethod) + if err != nil { + return fmt.Errorf("failed to resolve key (kid=%s): %w", keyID, err) + } + alg, err := crypto.SignatureAlgorithm(publicKey) + if err != nil { + return fmt.Errorf("failed to resolve key (kid=%s): %w", keyID, err) + } + keySink.Key(alg, publicKey) + return nil + } +} + // CredentialRequest represents ths request to fetch a credential, the JSON object holds the proof as // CredentialRequestProof. type CredentialRequest struct { diff --git a/auth/client/iam/client_test.go b/auth/client/iam/client_test.go index a2e1bdf4a..3c37deed7 100644 --- a/auth/client/iam/client_test.go +++ b/auth/client/iam/client_test.go @@ -20,10 +20,21 @@ package iam import ( "context" + "crypto" + "crypto/ecdsa" + "encoding/json" + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/nuts-foundation/go-did/did" + "github.com/nuts-foundation/nuts-node/audit" + nutsCrypto "github.com/nuts-foundation/nuts-node/crypto" + test2 "github.com/nuts-foundation/nuts-node/crypto/test" + "github.com/nuts-foundation/nuts-node/vdr/resolver" "net/http" "net/http/httptest" "net/url" "testing" + "time" ssi "github.com/nuts-foundation/go-did" "github.com/nuts-foundation/go-did/vc" @@ -173,6 +184,83 @@ func TestHTTPClient_ClientMetadata(t *testing.T) { }) } +func TestHTTPClient_OpenIDConfiguration(t *testing.T) { + ctx := context.Background() + configuration := oauth.OpenIDConfiguration{ + Issuer: "issuer", + } + + // create jwt + createToken := func(t *testing.T, client *HTTPClient) string { + testKey := client.keyResolver.(testKeyResolver).key + claims := make(map[string]interface{}) + asJson, _ := json.Marshal(configuration) + _ = json.Unmarshal(asJson, &claims) + alg, _ := nutsCrypto.SignatureAlgorithm(testKey.Public()) + headers := map[string]interface{}{jws.AlgorithmKey: alg, jws.KeyIDKey: "test"} + token, err := nutsCrypto.SignJWT(audit.TestContext(), testKey, alg, claims, headers) + require.NoError(t, err) + return token + } + + t.Run("ok", func(t *testing.T) { + handler := http2.Handler{StatusCode: http.StatusOK} + tlsServer, client := testServerAndClient(t, &handler) + handler.ResponseData = createToken(t, client) + + response, err := client.OpenIDConfiguration(ctx, tlsServer.URL) + + require.NoError(t, err) + require.NotNil(t, response) + assert.Equal(t, configuration, *response) + require.NotNil(t, handler.Request) + }) + t.Run("error - invalid url", func(t *testing.T) { + handler := http2.Handler{StatusCode: http.StatusOK} + _, client := testServerAndClient(t, &handler) + handler.ResponseData = createToken(t, client) + + _, err := client.OpenIDConfiguration(ctx, ":") + + require.Error(t, err) + assert.EqualError(t, err, "parse \":\": missing protocol scheme") + }) + t.Run("error - error return", func(t *testing.T) { + handler := http2.Handler{StatusCode: http.StatusInternalServerError} + tlsServer, client := testServerAndClient(t, &handler) + + response, err := client.OpenIDConfiguration(ctx, tlsServer.URL) + + require.Error(t, err) + require.Nil(t, response) + assert.EqualError(t, err, "server returned HTTP 500 (expected: 200)") + }) + t.Run("error - not a signed jwt", func(t *testing.T) { + handler := http2.Handler{StatusCode: http.StatusOK, ResponseData: ""} + tlsServer, client := testServerAndClient(t, &handler) + + response, err := client.OpenIDConfiguration(ctx, tlsServer.URL) + + require.Error(t, err) + require.Nil(t, response) + assert.EqualError(t, err, "unable to parse response: failed to parse jws: invalid byte sequence") + }) + t.Run("error - unknown key", func(t *testing.T) { + otherClient := &HTTPClient{ + keyResolver: newTestKeyResolver(), + } + handler := http2.Handler{StatusCode: http.StatusOK} + tlsServer, client := testServerAndClient(t, &handler) + handler.ResponseData = createToken(t, otherClient) + + response, err := client.OpenIDConfiguration(ctx, tlsServer.URL) + + require.Error(t, err) + require.Nil(t, response) + assert.EqualError(t, err, "unable to parse response: could not verify message using any of the signatures or keys") + }) +} + func TestHTTPClient_PostError(t *testing.T) { redirectReturn := oauth.Redirect{ RedirectURI: "http://test.test", @@ -299,13 +387,6 @@ func TestHTTPClient_RequestObjectPost(t *testing.T) { }) } -func testServerAndClient(t *testing.T, handler http.Handler) (*httptest.Server, *HTTPClient) { - tlsServer := http2.TestTLSServer(t, handler) - return tlsServer, &HTTPClient{ - httpClient: tlsServer.Client(), - } -} - func TestHTTPClient_doGet(t *testing.T) { t.Run("error - non 200 return value", func(t *testing.T) { handler := http2.Handler{StatusCode: http.StatusBadRequest} @@ -333,3 +414,31 @@ func TestHTTPClient_doGet(t *testing.T) { assert.Error(t, err) }) } + +func newTestKeyResolver() resolver.KeyResolver { + return testKeyResolver{ + kid: uuid.NewString(), + key: test2.GenerateECKey(), + } +} + +type testKeyResolver struct { + kid string + key *ecdsa.PrivateKey +} + +func (t testKeyResolver) ResolveKeyByID(keyID string, validAt *time.Time, relationType resolver.RelationType) (crypto.PublicKey, error) { + return t.key.Public(), nil +} + +func (t testKeyResolver) ResolveKey(id did.DID, validAt *time.Time, relationType resolver.RelationType) (string, crypto.PublicKey, error) { + return t.kid, t.key.Public(), nil +} + +func testServerAndClient(t *testing.T, handler http.Handler) (*httptest.Server, *HTTPClient) { + tlsServer := http2.TestTLSServer(t, handler) + return tlsServer, &HTTPClient{ + httpClient: tlsServer.Client(), + keyResolver: newTestKeyResolver(), + } +} diff --git a/auth/client/iam/openid4vp_test.go b/auth/client/iam/openid4vp_test.go index 80378694e..db9fb9df8 100644 --- a/auth/client/iam/openid4vp_test.go +++ b/auth/client/iam/openid4vp_test.go @@ -234,6 +234,27 @@ func TestIAMClient_AuthorizationServerMetadata(t *testing.T) { }) } +func TestIAMClient_OpenIDConfiguration(t *testing.T) { + t.Run("ok", func(t *testing.T) { + ctx := createClientServerTestContext(t) + + metadata, err := ctx.client.OpenIDConfiguration(context.Background(), ctx.tlsServer.URL) + + require.NoError(t, err) + require.NotNil(t, metadata) + assert.Equal(t, *ctx.authzServerMetadata, *metadata) + }) + t.Run("error - failed to get metadata", func(t *testing.T) { + ctx := createClientServerTestContext(t) + ctx.metadata = nil + + _, err := ctx.client.OpenIDConfiguration(context.Background(), ctx.tlsServer.URL) + + require.Error(t, err) + assert.EqualError(t, err, "failed to retrieve remote OAuth Authorization Server metadata: server returned HTTP 404 (expected: 200)") + }) +} + func TestRelyingParty_RequestRFC021AccessToken(t *testing.T) { const subjectID = "subby" primaryWalletDID := did.MustParseDID("did:primary:123")