diff --git a/auth/client/iam/client.go b/auth/client/iam/client.go index bfcff021df..0020fea550 100644 --- a/auth/client/iam/client.go +++ b/auth/client/iam/client.go @@ -136,7 +136,7 @@ func (hb HTTPClient) RequestObjectByGet(ctx context.Context, requestURI string) return "", httpErr } - data, err := core.LimitedReadAll(response.Body) + data, err := io.ReadAll(response.Body) if err != nil { return "", fmt.Errorf("unable to read response: %w", err) } @@ -161,7 +161,7 @@ func (hb HTTPClient) RequestObjectByPost(ctx context.Context, requestURI string, return "", httpErr } - data, err := core.LimitedReadAll(response.Body) + data, err := io.ReadAll(response.Body) if err != nil { return "", fmt.Errorf("unable to read response: %w", err) } @@ -206,7 +206,7 @@ func (hb HTTPClient) AccessToken(ctx context.Context, tokenEndpoint string, data } var responseData []byte - if responseData, err = core.LimitedReadAll(response.Body); err != nil { + if responseData, err = io.ReadAll(response.Body); err != nil { return token, fmt.Errorf("unable to read response: %w", err) } if err = json.Unmarshal(responseData, &token); err != nil { @@ -271,7 +271,7 @@ func (hb HTTPClient) OpenIDConfiguration(ctx context.Context, issuerURL string) return nil, httpErr } var data []byte - if data, err = core.LimitedReadAll(response.Body); err != nil { + if data, err = io.ReadAll(response.Body); err != nil { return nil, fmt.Errorf("unable to read response: %w", err) } // kid is checked against did resolver @@ -407,7 +407,7 @@ func (hb HTTPClient) doRequest(ctx context.Context, request *http.Request, targe var data []byte - if data, err = core.LimitedReadAll(response.Body); err != nil { + if data, err = io.ReadAll(response.Body); err != nil { return fmt.Errorf("unable to read response: %w", err) } if err = json.Unmarshal(data, &target); err != nil { diff --git a/auth/services/oauth/relying_party.go b/auth/services/oauth/relying_party.go index 1b9b7819f2..754f55bf11 100644 --- a/auth/services/oauth/relying_party.go +++ b/auth/services/oauth/relying_party.go @@ -22,12 +22,11 @@ import ( "context" "crypto/tls" "fmt" - "github.com/lestrrat-go/jwx/v2/jwt" - "net/http" "net/url" "strings" "time" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/nuts-node/auth/api/auth/v1/client" "github.com/nuts-foundation/nuts-node/auth/oauth" @@ -35,6 +34,7 @@ import ( "github.com/nuts-foundation/nuts-node/core" nutsCrypto "github.com/nuts-foundation/nuts-node/crypto" "github.com/nuts-foundation/nuts-node/didman" + strictHttp "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/vcr/credential" "github.com/nuts-foundation/nuts-node/vcr/holder" "github.com/nuts-foundation/nuts-node/vdr/resolver" @@ -110,12 +110,7 @@ func (s *relyingParty) RequestRFC003AccessToken(ctx context.Context, jwtGrantTok if s.strictMode && strings.ToLower(authorizationServerEndpoint.Scheme) != "https" { return nil, fmt.Errorf("authorization server endpoint must be HTTPS when in strict mode: %s", authorizationServerEndpoint.String()) } - httpClient := &http.Client{} - if s.httpClientTLS != nil { - httpClient.Transport = &http.Transport{ - TLSClientConfig: s.httpClientTLS, - } - } + httpClient := strictHttp.NewWithTLSConfig(s.httpClientTimeout, s.httpClientTLS) authClient, err := client.NewHTTPClient("", s.httpClientTimeout, client.WithHTTPClient(httpClient), client.WithRequestEditorFn(core.UserAgentRequestEditor)) if err != nil { return nil, fmt.Errorf("unable to create HTTP client: %w", err) diff --git a/core/http_client.go b/core/http_client.go index dc9aff03bb..53dbc01918 100644 --- a/core/http_client.go +++ b/core/http_client.go @@ -31,21 +31,6 @@ import ( // If the response body is longer than this, it will be truncated. const HttpResponseBodyLogClipAt = 200 -// DefaultMaxHttpResponseSize is a default maximum size of an HTTP response body that will be read. -// Very large or unbounded HTTP responses can cause denial-of-service, so it's good to limit how much data is read. -// This of course heavily depends on the use case, but 1MB is a reasonable default. -const DefaultMaxHttpResponseSize = 1024 * 1024 - -// LimitedReadAll reads the given reader until the DefaultMaxHttpResponseSize is reached. -// It returns an error if more data is available than DefaultMaxHttpResponseSize. -func LimitedReadAll(reader io.Reader) ([]byte, error) { - result, err := io.ReadAll(io.LimitReader(reader, DefaultMaxHttpResponseSize+1)) - if len(result) > DefaultMaxHttpResponseSize { - return nil, fmt.Errorf("data to read exceeds max. safety limit of %d bytes", DefaultMaxHttpResponseSize) - } - return result, err -} - // HttpError describes an error returned when invoking a remote server. type HttpError struct { error @@ -63,7 +48,7 @@ func TestResponseCode(expectedStatusCode int, response *http.Response) error { // It logs using the given logger, unless nil is passed. func TestResponseCodeWithLog(expectedStatusCode int, response *http.Response, log *logrus.Entry) error { if response.StatusCode != expectedStatusCode { - responseData, _ := LimitedReadAll(response.Body) + responseData, _ := io.ReadAll(response.Body) if log != nil { // Cut off the response body to 100 characters max to prevent logging of large responses responseBodyString := string(responseData) @@ -104,16 +89,18 @@ func (w httpRequestDoerAdapter) Do(req *http.Request) (*http.Response, error) { return w.fn(req) } -// CreateHTTPClient creates a new HTTP client with the given client configuration. +// CreateHTTPInternalClient creates a new HTTP client with the given client configuration. +// This client is to be used for internal API calls (CMDs and such) // The result HTTPRequestDoer can be supplied to OpenAPI generated clients for executing requests. // This does not use the generated client options for e.g. authentication, // because each generated OpenAPI client reimplements the client options using structs, // which makes them incompatible with each other, making it impossible to use write generic client code for common traits like authorization. // If the given authorization token builder is non-nil, it calls it and passes the resulting token as bearer token with requests. -func CreateHTTPClient(cfg ClientConfig, generator AuthorizationTokenGenerator) (HTTPRequestDoer, error) { +func CreateHTTPInternalClient(cfg ClientConfig, generator AuthorizationTokenGenerator) (HTTPRequestDoer, error) { var result *httpRequestDoerAdapter client := &http.Client{} client.Timeout = cfg.Timeout + result = &httpRequestDoerAdapter{ fn: client.Do, } @@ -149,9 +136,9 @@ func CreateHTTPClient(cfg ClientConfig, generator AuthorizationTokenGenerator) ( return result, nil } -// MustCreateHTTPClient is like CreateHTTPClient but panics if it returns an error. -func MustCreateHTTPClient(cfg ClientConfig, generator AuthorizationTokenGenerator) HTTPRequestDoer { - client, err := CreateHTTPClient(cfg, generator) +// MustCreateInternalHTTPClient is like CreateHTTPInternalClient but panics if it returns an error. +func MustCreateInternalHTTPClient(cfg ClientConfig, generator AuthorizationTokenGenerator) HTTPRequestDoer { + client, err := CreateHTTPInternalClient(cfg, generator) if err != nil { panic(err) } diff --git a/core/http_client_test.go b/core/http_client_test.go index 891364b279..1bc9e31835 100644 --- a/core/http_client_test.go +++ b/core/http_client_test.go @@ -44,7 +44,7 @@ func TestHTTPClient(t *testing.T) { t.Run("no auth token", func(t *testing.T) { authToken = "" - client, err := CreateHTTPClient(ClientConfig{}, nil) + client, err := CreateHTTPInternalClient(ClientConfig{}, nil) require.NoError(t, err) req, _ := stdHttp.NewRequest(stdHttp.MethodGet, server.URL, nil) @@ -56,7 +56,7 @@ func TestHTTPClient(t *testing.T) { }) t.Run("with auth token", func(t *testing.T) { authToken = "" - client, err := CreateHTTPClient(ClientConfig{ + client, err := CreateHTTPInternalClient(ClientConfig{ Token: "test", }, nil) require.NoError(t, err) @@ -69,7 +69,7 @@ func TestHTTPClient(t *testing.T) { assert.Equal(t, "Bearer test", authToken) }) t.Run("with custom token builder", func(t *testing.T) { - client, err := CreateHTTPClient(ClientConfig{}, newLegacyTokenGenerator("test")) + client, err := CreateHTTPInternalClient(ClientConfig{}, newLegacyTokenGenerator("test")) require.NoError(t, err) req, _ := stdHttp.NewRequest(stdHttp.MethodGet, server.URL, nil) @@ -80,7 +80,7 @@ func TestHTTPClient(t *testing.T) { assert.Equal(t, "Bearer test", authToken) }) t.Run("with errored token builder", func(t *testing.T) { - client, err := CreateHTTPClient(ClientConfig{}, newErrorTokenBuilder()) + client, err := CreateHTTPInternalClient(ClientConfig{}, newErrorTokenBuilder()) require.NoError(t, err) req, _ := stdHttp.NewRequest(stdHttp.MethodGet, server.URL, nil) @@ -162,20 +162,3 @@ func newErrorTokenBuilder() AuthorizationTokenGenerator { return "", errors.New("error") } } - -func TestLimitedReadAll(t *testing.T) { - t.Run("less than limit", func(t *testing.T) { - data := strings.Repeat("a", 10) - result, err := LimitedReadAll(strings.NewReader(data)) - - assert.NoError(t, err) - assert.Equal(t, []byte(data), result) - }) - t.Run("more than limit", func(t *testing.T) { - data := strings.Repeat("a", DefaultMaxHttpResponseSize+1) - result, err := LimitedReadAll(strings.NewReader(data)) - - assert.EqualError(t, err, "data to read exceeds max. safety limit of 1048576 bytes") - assert.Nil(t, result) - }) -} diff --git a/crypto/storage/external/client.go b/crypto/storage/external/client.go index e810b8a2df..3bda879270 100644 --- a/crypto/storage/external/client.go +++ b/crypto/storage/external/client.go @@ -30,6 +30,7 @@ import ( "github.com/nuts-foundation/nuts-node/core" "github.com/nuts-foundation/nuts-node/crypto/storage/spi" "github.com/nuts-foundation/nuts-node/crypto/util" + safeHttp "github.com/nuts-foundation/nuts-node/http/client" ) // StorageType is the name of this storage type, used in health check reports and configuration. @@ -82,7 +83,7 @@ func NewAPIClient(config Config) (spi.Storage, error) { if _, err := url.ParseRequestURI(config.Address); err != nil { return nil, err } - client, _ := NewClientWithResponses(config.Address, WithHTTPClient(&http.Client{Timeout: config.Timeout})) + client, _ := NewClientWithResponses(config.Address, WithHTTPClient(safeHttp.New(config.Timeout))) return &APIClient{httpClient: client}, nil } diff --git a/didman/api/v1/client.go b/didman/api/v1/client.go index 6b2b1cd877..9a888b8bf1 100644 --- a/didman/api/v1/client.go +++ b/didman/api/v1/client.go @@ -32,7 +32,7 @@ type HTTPClient struct { } func (hb HTTPClient) client() ClientInterface { - response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateHTTPClient(hb.ClientConfig, hb.TokenGenerator))) + response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateInternalHTTPClient(hb.ClientConfig, hb.TokenGenerator))) if err != nil { panic(err) } diff --git a/http/client/caching.go b/http/client/caching.go index 5226e4674c..f9ff63f7da 100644 --- a/http/client/caching.go +++ b/http/client/caching.go @@ -32,8 +32,8 @@ import ( // DefaultCachingTransport is a http.RoundTripper that can be used as a default transport for HTTP clients. // If caching is enabled, it will cache responses according to RFC 7234. -// If caching is disabled, it will behave like http.DefaultTransport. -var DefaultCachingTransport = http.DefaultTransport +// If caching is disabled, it will behave like our safe http.DefaultTransport. +var DefaultCachingTransport http.RoundTripper // maxCacheTime is the maximum time responses are cached. // Even if the server responds with a longer cache time, responses are never cached longer than maxCacheTime. diff --git a/http/client/client.go b/http/client/client.go index 5ec00bef13..072e4da6f7 100644 --- a/http/client/client.go +++ b/http/client/client.go @@ -19,23 +19,59 @@ package client import ( + "bytes" "crypto/tls" "errors" + "fmt" + "github.com/nuts-foundation/nuts-node/core" + "io" "net/http" "time" ) +// SafeHttpTransport is a http.Transport that can be used as a default transport for HTTP clients. +var SafeHttpTransport *http.Transport + func init() { - httpTransport := http.DefaultTransport.(*http.Transport) - if httpTransport.TLSClientConfig == nil { - httpTransport.TLSClientConfig = &tls.Config{} + SafeHttpTransport = http.DefaultTransport.(*http.Transport).Clone() + if SafeHttpTransport.TLSClientConfig == nil { + SafeHttpTransport.TLSClientConfig = &tls.Config{} } - httpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12 + SafeHttpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12 + // to prevent slow responses from public clients to have significant impact (default was unlimited) + SafeHttpTransport.MaxConnsPerHost = 5 + + DefaultCachingTransport = SafeHttpTransport } // StrictMode is a flag that can be set to true to enable strict mode for the HTTP client. var StrictMode bool +// DefaultMaxHttpResponseSize is a default maximum size of an HTTP response body that will be read. +// Very large or unbounded HTTP responses can cause denial-of-service, so it's good to limit how much data is read. +// This of course heavily depends on the use case, but 1MB is a reasonable default. +const DefaultMaxHttpResponseSize = 1024 * 1024 + +// limitedReadAll reads the given reader until the DefaultMaxHttpResponseSize is reached. +// It returns an error if more data is available than DefaultMaxHttpResponseSize. +func limitedReadAll(reader io.Reader) ([]byte, error) { + result, err := io.ReadAll(io.LimitReader(reader, DefaultMaxHttpResponseSize+1)) + if len(result) > DefaultMaxHttpResponseSize { + return nil, fmt.Errorf("data to read exceeds max. safety limit of %d bytes", DefaultMaxHttpResponseSize) + } + return result, err +} + +// New creates a new HTTP client with the given timeout. +func New(timeout time.Duration) *StrictHTTPClient { + return &StrictHTTPClient{ + client: &http.Client{ + Transport: SafeHttpTransport, + Timeout: timeout, + }, + } +} + // NewWithCache creates a new HTTP client with the given timeout. // It uses the DefaultCachingTransport as the underlying transport. func NewWithCache(timeout time.Duration) *StrictHTTPClient { @@ -51,7 +87,7 @@ func NewWithCache(timeout time.Duration) *StrictHTTPClient { // It copies the http.DefaultTransport and sets the TLSClientConfig to the given tls.Config. // As such, it can't be used in conjunction with the CachingRoundTripper. func NewWithTLSConfig(timeout time.Duration, tlsConfig *tls.Config) *StrictHTTPClient { - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := SafeHttpTransport.Clone() transport.TLSClientConfig = tlsConfig return &StrictHTTPClient{ client: &http.Client{ @@ -69,5 +105,17 @@ func (s *StrictHTTPClient) Do(req *http.Request) (*http.Response, error) { if StrictMode && req.URL.Scheme != "https" { return nil, errors.New("strictmode is enabled, but request is not over HTTPS") } - return s.client.Do(req) + req.Header.Set("User-Agent", core.UserAgent()) + result, err := s.client.Do(req) + if err != nil { + return nil, err + } + if result.Body != nil { + body, err := limitedReadAll(result.Body) + if err != nil { + return nil, err + } + result.Body = io.NopCloser(bytes.NewReader(body)) + } + return result, nil } diff --git a/http/client/client_test.go b/http/client/client_test.go index aad3b7bd1b..76d5c3d401 100644 --- a/http/client/client_test.go +++ b/http/client/client_test.go @@ -20,8 +20,14 @@ package client import ( "crypto/tls" + "fmt" "github.com/stretchr/testify/assert" - stdHttp "net/http" + "github.com/stretchr/testify/require" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" "testing" "time" ) @@ -34,7 +40,7 @@ func TestStrictHTTPClient(t *testing.T) { StrictMode = true client := NewWithCache(time.Second) - httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + httpRequest, _ := http.NewRequest("GET", "http://example.com", nil) _, err := client.Do(httpRequest) assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") @@ -46,7 +52,7 @@ func TestStrictHTTPClient(t *testing.T) { StrictMode = false client := NewWithCache(time.Second) - httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + httpRequest, _ := http.NewRequest("GET", "http://example.com", nil) _, err := client.Do(httpRequest) assert.NoError(t, err) @@ -60,7 +66,7 @@ func TestStrictHTTPClient(t *testing.T) { StrictMode = true client := NewWithCache(time.Second) - httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + httpRequest, _ := http.NewRequest("GET", "http://example.com", nil) _, err := client.Do(httpRequest) assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") @@ -70,7 +76,7 @@ func TestStrictHTTPClient(t *testing.T) { client := NewWithTLSConfig(time.Second, &tls.Config{ InsecureSkipVerify: true, }) - ts := client.client.Transport.(*stdHttp.Transport) + ts := client.client.Transport.(*http.Transport) assert.True(t, ts.TLSClientConfig.InsecureSkipVerify) }) }) @@ -80,10 +86,114 @@ func TestStrictHTTPClient(t *testing.T) { StrictMode = true client := NewWithCache(time.Second) - httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + httpRequest, _ := http.NewRequest("GET", "http://example.com", nil) _, err := client.Do(httpRequest) assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") assert.Equal(t, 0, rt.invocations) }) } + +func TestLimitedReadAll(t *testing.T) { + t.Run("less than limit", func(t *testing.T) { + data := strings.Repeat("a", 10) + result, err := limitedReadAll(strings.NewReader(data)) + + assert.NoError(t, err) + assert.Equal(t, []byte(data), result) + }) + t.Run("more than limit", func(t *testing.T) { + data := strings.Repeat("a", DefaultMaxHttpResponseSize+1) + result, err := limitedReadAll(strings.NewReader(data)) + + assert.EqualError(t, err, "data to read exceeds max. safety limit of 1048576 bytes") + assert.Nil(t, result) + }) +} + +func TestMaxConns(t *testing.T) { + oldStrictMode := StrictMode + StrictMode = false + t.Cleanup(func() { StrictMode = oldStrictMode }) + // Our safe http Transport has MaxConnsPerHost = 5 + // if we request 6 resources multiple times, we expect a max connection usage of 5 + + // counter for the number of concurrent requests + var counter atomic.Int32 + + // create a test server with 6 different url handlers + handler := http.NewServeMux() + createHandler := func(id int) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + counter.Add(1) + assert.True(t, counter.Load() < 6) + _, _ = w.Write([]byte(fmt.Sprintf("%d", id))) + time.Sleep(time.Millisecond) // to allow for some parallel requests + counter.Add(-1) + } + } + handler.HandleFunc("/1", createHandler(1)) + handler.HandleFunc("/2", createHandler(2)) + handler.HandleFunc("/3", createHandler(3)) + handler.HandleFunc("/4", createHandler(4)) + handler.HandleFunc("/5", createHandler(5)) + handler.HandleFunc("/6", createHandler(6)) + + server := httptest.NewServer(handler) + defer server.Close() + client := New(time.Second) + + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + request, _ := http.NewRequest("GET", fmt.Sprintf("%s/%d", server.URL, i%6), nil) + _, _ = client.Do(request) + }() + } + + wg.Wait() +} + +func TestCaching(t *testing.T) { + oldStrictMode := StrictMode + StrictMode = false + t.Cleanup(func() { StrictMode = oldStrictMode }) + // counter for the number of concurrent requests + var total atomic.Int32 + + // create a test server with 6 different url handlers + handler := http.NewServeMux() + createHandler := func(id int) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + total.Add(1) + w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", 5)) + _, _ = w.Write([]byte(fmt.Sprintf("%d", id))) + } + } + handler.HandleFunc("/1", createHandler(1)) + + server := httptest.NewServer(handler) + defer server.Close() + DefaultCachingTransport = NewCachingTransport(SafeHttpTransport, 1024*1024) + client := NewWithCache(time.Second) + + // fill cache + request, _ := http.NewRequest("GET", fmt.Sprintf("%s/1", server.URL), nil) + _, err := client.Do(request) + require.NoError(t, err) + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req, _ := http.NewRequest("GET", fmt.Sprintf("%s/1", server.URL), nil) + _, _ = client.Do(req) + }() + } + wg.Wait() + + assert.Equal(t, int32(1), total.Load()) +} diff --git a/http/engine.go b/http/engine.go index fa3c1f6739..790f3da4b7 100644 --- a/http/engine.go +++ b/http/engine.go @@ -99,7 +99,7 @@ func (h *Engine) configureClient(serverConfig core.ServerConfig) { client.StrictMode = serverConfig.Strictmode // Configure the HTTP caching client, if enabled. Set it to http.DefaultTransport so it can be used by any subsystem. if h.config.ResponseCacheSize > 0 { - client.DefaultCachingTransport = client.NewCachingTransport(http.DefaultTransport, h.config.ResponseCacheSize) + client.DefaultCachingTransport = client.NewCachingTransport(client.SafeHttpTransport, h.config.ResponseCacheSize) } } diff --git a/network/api/v1/client.go b/network/api/v1/client.go index 66469d2651..a1d06a172a 100644 --- a/network/api/v1/client.go +++ b/network/api/v1/client.go @@ -129,7 +129,7 @@ func (hb HTTPClient) Reprocess(contentType string) error { } func (hb HTTPClient) client() ClientInterface { - response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateHTTPClient(hb.ClientConfig, hb.TokenGenerator))) + response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateInternalHTTPClient(hb.ClientConfig, hb.TokenGenerator))) if err != nil { panic(err) } diff --git a/pki/denylist.go b/pki/denylist.go index 1f836b9a0e..e1a3039fb1 100644 --- a/pki/denylist.go +++ b/pki/denylist.go @@ -219,6 +219,7 @@ func (b *denylistImpl) Subscribe(f func()) { // download retrieves and parses the denylist func (b *denylistImpl) download() ([]byte, error) { // Make an HTTP GET request for the denylist URL + // We do not use our safe http client here since we're downloading from our own resource httpClient := http.Client{Timeout: syncTimeout} response, err := httpClient.Get(b.url) if err != nil { diff --git a/pki/validator.go b/pki/validator.go index 3043668a3d..e8e70dbd09 100644 --- a/pki/validator.go +++ b/pki/validator.go @@ -88,6 +88,7 @@ func newRevocationList(cert *x509.Certificate) *revocationList { // newValidator returns a new PKI (crl/denylist) validator. func newValidator(config Config) (*validator, error) { + // we do not use our safe http client here since we're downloading from a trusted resource return newValidatorWithHTTPClient(config, &http.Client{Timeout: syncTimeout}) } diff --git a/vcr/api/vcr/v2/client.go b/vcr/api/vcr/v2/client.go index bb77afb380..cbf3b208e3 100644 --- a/vcr/api/vcr/v2/client.go +++ b/vcr/api/vcr/v2/client.go @@ -36,7 +36,7 @@ type HTTPClient struct { } func (hb HTTPClient) client() ClientInterface { - response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateHTTPClient(hb.ClientConfig, hb.TokenGenerator))) + response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateInternalHTTPClient(hb.ClientConfig, hb.TokenGenerator))) if err != nil { panic(err) } diff --git a/vcr/openid4vci/identifiers.go b/vcr/openid4vci/identifiers.go index ffacf3cb3c..37c2f4bed3 100644 --- a/vcr/openid4vci/identifiers.go +++ b/vcr/openid4vci/identifiers.go @@ -23,6 +23,7 @@ import ( "fmt" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/nuts-node/core" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/vcr/log" "github.com/nuts-foundation/nuts-node/vdr/resolver" "net/http" @@ -143,12 +144,8 @@ func (t tlsIdentifierResolver) resolveFromCertificate(id did.DID) (string, error } // Resolve URLs - httpTransport := http.DefaultTransport.(*http.Transport).Clone() - httpTransport.TLSClientConfig = t.config - httpClient := &http.Client{ - Timeout: 5 * time.Second, - Transport: httpTransport, - } + httpClient := client.NewWithTLSConfig(5*time.Second, t.config) + for _, candidateURL := range candidateURLs { issuerIdentifier := core.JoinURLPaths(candidateURL, "n2n", "identity", url.PathEscape(id.String())) err := t.testIdentifier(issuerIdentifier, httpClient) @@ -161,9 +158,13 @@ func (t tlsIdentifierResolver) resolveFromCertificate(id did.DID) (string, error return "", nil } -func (t tlsIdentifierResolver) testIdentifier(issuerIdentifier string, httpClient *http.Client) error { +func (t tlsIdentifierResolver) testIdentifier(issuerIdentifier string, httpClient core.HTTPRequestDoer) error { metadataURL := core.JoinURLPaths(issuerIdentifier, CredentialIssuerMetadataWellKnownPath) - httpResponse, err := httpClient.Head(metadataURL) + request, err := http.NewRequest(http.MethodHead, metadataURL, nil) + if err != nil { + return err + } + httpResponse, err := httpClient.Do(request) if err != nil { return err } diff --git a/vcr/openid4vci/issuer_client.go b/vcr/openid4vci/issuer_client.go index 91ab1eafd9..c355aa96d5 100644 --- a/vcr/openid4vci/issuer_client.go +++ b/vcr/openid4vci/issuer_client.go @@ -28,6 +28,7 @@ import ( "github.com/nuts-foundation/nuts-node/auth/oauth" "github.com/nuts-foundation/nuts-node/core" "github.com/nuts-foundation/nuts-node/vcr/log" + "io" "net/http" "net/http/httptrace" "net/url" @@ -167,7 +168,7 @@ func httpDo(httpClient core.HTTPRequestDoer, httpRequest *http.Request, result i return fmt.Errorf("http request error: %w", err) } defer httpResponse.Body.Close() - responseBody, err := core.LimitedReadAll(httpResponse.Body) + responseBody, err := io.ReadAll(httpResponse.Body) if err != nil { return fmt.Errorf("read error (%s): %w", httpRequest.URL, err) } diff --git a/vcr/revocation/statuslist2021_verifier.go b/vcr/revocation/statuslist2021_verifier.go index a1eeb89a21..4071ac54cd 100644 --- a/vcr/revocation/statuslist2021_verifier.go +++ b/vcr/revocation/statuslist2021_verifier.go @@ -22,6 +22,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "strconv" "time" @@ -198,7 +199,7 @@ func (cs *StatusList2021) download(statusListCredential string) (*vc.VerifiableC Debug("Failed to close response body") } }() - body, err := core.LimitedReadAll(res.Body) // default minimum size is 16kb (PII entropy), so 1mb is already unlikely + body, err := io.ReadAll(res.Body) if res.StatusCode > 299 || err != nil { return nil, errors.Join(fmt.Errorf("fetching StatusList2021Credential from '%s' failed", statusListCredential), err) } diff --git a/vcr/test/openid4vci_integration_test.go b/vcr/test/openid4vci_integration_test.go index 769d8e12be..fc7f877323 100644 --- a/vcr/test/openid4vci_integration_test.go +++ b/vcr/test/openid4vci_integration_test.go @@ -81,13 +81,10 @@ func TestOpenID4VCIHappyFlow(t *testing.T) { } func TestOpenID4VCIConnectionReuse(t *testing.T) { - // default http.Transport has MaxConnsPerHost=100, - // but we need to adjust it to something lower, so we can assert connection reuse - const maxConnsPerHost = 2 + // Our safe http Transport has MaxConnsPerHost = 5 // for 2 http.Transport instance (one for issuer, one for wallet), - // so we expect max maxConnsPerHost*2 connections in total. - const maxExpectedConnCount = maxConnsPerHost * 2 - http.DefaultTransport.(*http.Transport).MaxConnsPerHost = maxConnsPerHost + // so we expect max 10 connections in total. + const maxExpectedConnCount = 10 ctx := audit.TestContext() _, baseURL, system := node.StartServer(t) @@ -117,7 +114,7 @@ func TestOpenID4VCIConnectionReuse(t *testing.T) { }, } - const numCreds = 10 + const numCreds = 12 errChan := make(chan error, numCreds) wg := sync.WaitGroup{} for i := 0; i < numCreds; i++ { @@ -149,7 +146,7 @@ func TestOpenID4VCIConnectionReuse(t *testing.T) { } assert.Empty(t, errs, "error issuing credential") for host, v := range newConns { - assert.LessOrEqualf(t, v, maxExpectedConnCount, "number of created HTTP connections should be at most %d for host %s", maxConnsPerHost, host) + assert.LessOrEqualf(t, v, maxExpectedConnCount, "number of created HTTP connections should be at most %d for host %s", 5, host) } } diff --git a/vdr/api/v1/client.go b/vdr/api/v1/client.go index 93db80fb60..66f76eeefe 100644 --- a/vdr/api/v1/client.go +++ b/vdr/api/v1/client.go @@ -36,7 +36,7 @@ type HTTPClient struct { } func (hb HTTPClient) client() ClientInterface { - response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateHTTPClient(hb.ClientConfig, hb.TokenGenerator))) + response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateInternalHTTPClient(hb.ClientConfig, hb.TokenGenerator))) if err != nil { panic(err) } diff --git a/vdr/api/v2/client.go b/vdr/api/v2/client.go index 2d944f0443..4fd3b7b0f2 100644 --- a/vdr/api/v2/client.go +++ b/vdr/api/v2/client.go @@ -34,7 +34,7 @@ type HTTPClient struct { } func (hb HTTPClient) client() ClientInterface { - response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateHTTPClient(hb.ClientConfig, hb.TokenGenerator))) + response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateInternalHTTPClient(hb.ClientConfig, hb.TokenGenerator))) if err != nil { panic(err) } diff --git a/vdr/didweb/web.go b/vdr/didweb/web.go index 50d9fe50fc..25f78ca10a 100644 --- a/vdr/didweb/web.go +++ b/vdr/didweb/web.go @@ -25,6 +25,7 @@ import ( "github.com/nuts-foundation/nuts-node/core" "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/vdr/resolver" + "io" "mime" "net/http" "time" @@ -37,16 +38,13 @@ var _ resolver.DIDResolver = (*Resolver)(nil) // Resolver is a DID resolver for the did:web method. type Resolver struct { - HttpClient *http.Client + HttpClient core.HTTPRequestDoer } // NewResolver creates a new did:web Resolver with default TLS configuration. func NewResolver() *Resolver { return &Resolver{ - HttpClient: &http.Client{ - Transport: client.DefaultCachingTransport, - Timeout: 5 * time.Second, - }, + HttpClient: client.NewWithCache(5 * time.Second), } } @@ -68,11 +66,14 @@ func (w Resolver) Resolve(id did.DID, _ *resolver.ResolveMetadata) (*did.Documen targetURL := baseURL.String() // TODO: Support DNS over HTTPS (DOH), https://www.rfc-editor.org/rfc/rfc8484 - httpResponse, err := w.HttpClient.Get(targetURL) + request, err := http.NewRequest(http.MethodGet, targetURL, nil) + if err != nil { + return nil, nil, err + } + httpResponse, err := w.HttpClient.Do(request) if err != nil { return nil, nil, fmt.Errorf("did:web HTTP error: %w", err) } - defer httpResponse.Body.Close() if !(httpResponse.StatusCode >= 200 && httpResponse.StatusCode < 300) { return nil, nil, fmt.Errorf("did:web non-ok HTTP status: %s", httpResponse.Status) } @@ -98,7 +99,7 @@ func (w Resolver) Resolve(id did.DID, _ *resolver.ResolveMetadata) (*did.Documen } // Read document - data, err := core.LimitedReadAll(httpResponse.Body) + data, err := io.ReadAll(httpResponse.Body) if err != nil { return nil, nil, fmt.Errorf("did:web HTTP response read error: %w", err) } diff --git a/vdr/didweb/web_test.go b/vdr/didweb/web_test.go index 519464be39..b887de4096 100644 --- a/vdr/didweb/web_test.go +++ b/vdr/didweb/web_test.go @@ -20,7 +20,6 @@ package didweb import ( "github.com/nuts-foundation/go-did/did" - "github.com/nuts-foundation/nuts-node/http/client" http2 "github.com/nuts-foundation/nuts-node/test/http" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -55,10 +54,6 @@ const didDocTemplate = ` func TestResolver_NewResolver(t *testing.T) { resolver := NewResolver() assert.NotNil(t, resolver.HttpClient) - - t.Run("it uses cached transport", func(t *testing.T) { - assert.Same(t, client.DefaultCachingTransport, resolver.HttpClient.Transport) - }) } func TestResolver_Resolve(t *testing.T) {