Skip to content

Commit

Permalink
secure outgoing http client with max connections
Browse files Browse the repository at this point in the history
  • Loading branch information
woutslakhorst committed Oct 21, 2024
1 parent c8a740d commit 901fea3
Show file tree
Hide file tree
Showing 22 changed files with 229 additions and 107 deletions.
10 changes: 5 additions & 5 deletions auth/client/iam/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (hb HTTPClient) RequestObjectByGet(ctx context.Context, requestURI string)
return "", httpErr
}

data, err := core.LimitedReadAll(response.Body)
data, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("unable to read response: %w", err)
}
Expand All @@ -161,7 +161,7 @@ func (hb HTTPClient) RequestObjectByPost(ctx context.Context, requestURI string,
return "", httpErr
}

data, err := core.LimitedReadAll(response.Body)
data, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("unable to read response: %w", err)
}
Expand Down Expand Up @@ -206,7 +206,7 @@ func (hb HTTPClient) AccessToken(ctx context.Context, tokenEndpoint string, data
}

var responseData []byte
if responseData, err = core.LimitedReadAll(response.Body); err != nil {
if responseData, err = io.ReadAll(response.Body); err != nil {
return token, fmt.Errorf("unable to read response: %w", err)
}
if err = json.Unmarshal(responseData, &token); err != nil {
Expand Down Expand Up @@ -271,7 +271,7 @@ func (hb HTTPClient) OpenIDConfiguration(ctx context.Context, issuerURL string)
return nil, httpErr
}
var data []byte
if data, err = core.LimitedReadAll(response.Body); err != nil {
if data, err = io.ReadAll(response.Body); err != nil {
return nil, fmt.Errorf("unable to read response: %w", err)
}
// kid is checked against did resolver
Expand Down Expand Up @@ -407,7 +407,7 @@ func (hb HTTPClient) doRequest(ctx context.Context, request *http.Request, targe

var data []byte

if data, err = core.LimitedReadAll(response.Body); err != nil {
if data, err = io.ReadAll(response.Body); err != nil {
return fmt.Errorf("unable to read response: %w", err)
}
if err = json.Unmarshal(data, &target); err != nil {
Expand Down
11 changes: 3 additions & 8 deletions auth/services/oauth/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ import (
"context"
"crypto/tls"
"fmt"
"github.com/lestrrat-go/jwx/v2/jwt"
"net/http"
"net/url"
"strings"
"time"

"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/nuts-foundation/go-did/did"
"github.com/nuts-foundation/nuts-node/auth/api/auth/v1/client"
"github.com/nuts-foundation/nuts-node/auth/oauth"
"github.com/nuts-foundation/nuts-node/auth/services"
"github.com/nuts-foundation/nuts-node/core"
nutsCrypto "github.com/nuts-foundation/nuts-node/crypto"
"github.com/nuts-foundation/nuts-node/didman"
strictHttp "github.com/nuts-foundation/nuts-node/http/client"
"github.com/nuts-foundation/nuts-node/vcr/credential"
"github.com/nuts-foundation/nuts-node/vcr/holder"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
Expand Down Expand Up @@ -110,12 +110,7 @@ func (s *relyingParty) RequestRFC003AccessToken(ctx context.Context, jwtGrantTok
if s.strictMode && strings.ToLower(authorizationServerEndpoint.Scheme) != "https" {
return nil, fmt.Errorf("authorization server endpoint must be HTTPS when in strict mode: %s", authorizationServerEndpoint.String())
}
httpClient := &http.Client{}
if s.httpClientTLS != nil {
httpClient.Transport = &http.Transport{
TLSClientConfig: s.httpClientTLS,
}
}
httpClient := strictHttp.NewWithTLSConfig(s.httpClientTimeout, s.httpClientTLS)
authClient, err := client.NewHTTPClient("", s.httpClientTimeout, client.WithHTTPClient(httpClient), client.WithRequestEditorFn(core.UserAgentRequestEditor))
if err != nil {
return nil, fmt.Errorf("unable to create HTTP client: %w", err)
Expand Down
29 changes: 8 additions & 21 deletions core/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,6 @@ import (
// If the response body is longer than this, it will be truncated.
const HttpResponseBodyLogClipAt = 200

// DefaultMaxHttpResponseSize is a default maximum size of an HTTP response body that will be read.
// Very large or unbounded HTTP responses can cause denial-of-service, so it's good to limit how much data is read.
// This of course heavily depends on the use case, but 1MB is a reasonable default.
const DefaultMaxHttpResponseSize = 1024 * 1024

// LimitedReadAll reads the given reader until the DefaultMaxHttpResponseSize is reached.
// It returns an error if more data is available than DefaultMaxHttpResponseSize.
func LimitedReadAll(reader io.Reader) ([]byte, error) {
result, err := io.ReadAll(io.LimitReader(reader, DefaultMaxHttpResponseSize+1))
if len(result) > DefaultMaxHttpResponseSize {
return nil, fmt.Errorf("data to read exceeds max. safety limit of %d bytes", DefaultMaxHttpResponseSize)
}
return result, err
}

// HttpError describes an error returned when invoking a remote server.
type HttpError struct {
error
Expand All @@ -63,7 +48,7 @@ func TestResponseCode(expectedStatusCode int, response *http.Response) error {
// It logs using the given logger, unless nil is passed.
func TestResponseCodeWithLog(expectedStatusCode int, response *http.Response, log *logrus.Entry) error {
if response.StatusCode != expectedStatusCode {
responseData, _ := LimitedReadAll(response.Body)
responseData, _ := io.ReadAll(response.Body)
if log != nil {
// Cut off the response body to 100 characters max to prevent logging of large responses
responseBodyString := string(responseData)
Expand Down Expand Up @@ -104,16 +89,18 @@ func (w httpRequestDoerAdapter) Do(req *http.Request) (*http.Response, error) {
return w.fn(req)
}

// CreateHTTPClient creates a new HTTP client with the given client configuration.
// CreateHTTPInternalClient creates a new HTTP client with the given client configuration.
// This client is to be used for internal API calls (CMDs and such)
// The result HTTPRequestDoer can be supplied to OpenAPI generated clients for executing requests.
// This does not use the generated client options for e.g. authentication,
// because each generated OpenAPI client reimplements the client options using structs,
// which makes them incompatible with each other, making it impossible to use write generic client code for common traits like authorization.
// If the given authorization token builder is non-nil, it calls it and passes the resulting token as bearer token with requests.
func CreateHTTPClient(cfg ClientConfig, generator AuthorizationTokenGenerator) (HTTPRequestDoer, error) {
func CreateHTTPInternalClient(cfg ClientConfig, generator AuthorizationTokenGenerator) (HTTPRequestDoer, error) {
var result *httpRequestDoerAdapter
client := &http.Client{}
client.Timeout = cfg.Timeout

result = &httpRequestDoerAdapter{
fn: client.Do,
}
Expand Down Expand Up @@ -149,9 +136,9 @@ func CreateHTTPClient(cfg ClientConfig, generator AuthorizationTokenGenerator) (
return result, nil
}

// MustCreateHTTPClient is like CreateHTTPClient but panics if it returns an error.
func MustCreateHTTPClient(cfg ClientConfig, generator AuthorizationTokenGenerator) HTTPRequestDoer {
client, err := CreateHTTPClient(cfg, generator)
// MustCreateInternalHTTPClient is like CreateHTTPInternalClient but panics if it returns an error.
func MustCreateInternalHTTPClient(cfg ClientConfig, generator AuthorizationTokenGenerator) HTTPRequestDoer {
client, err := CreateHTTPInternalClient(cfg, generator)
if err != nil {
panic(err)
}
Expand Down
25 changes: 4 additions & 21 deletions core/http_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestHTTPClient(t *testing.T) {

t.Run("no auth token", func(t *testing.T) {
authToken = ""
client, err := CreateHTTPClient(ClientConfig{}, nil)
client, err := CreateHTTPInternalClient(ClientConfig{}, nil)
require.NoError(t, err)

req, _ := stdHttp.NewRequest(stdHttp.MethodGet, server.URL, nil)
Expand All @@ -56,7 +56,7 @@ func TestHTTPClient(t *testing.T) {
})
t.Run("with auth token", func(t *testing.T) {
authToken = ""
client, err := CreateHTTPClient(ClientConfig{
client, err := CreateHTTPInternalClient(ClientConfig{
Token: "test",
}, nil)
require.NoError(t, err)
Expand All @@ -69,7 +69,7 @@ func TestHTTPClient(t *testing.T) {
assert.Equal(t, "Bearer test", authToken)
})
t.Run("with custom token builder", func(t *testing.T) {
client, err := CreateHTTPClient(ClientConfig{}, newLegacyTokenGenerator("test"))
client, err := CreateHTTPInternalClient(ClientConfig{}, newLegacyTokenGenerator("test"))
require.NoError(t, err)

req, _ := stdHttp.NewRequest(stdHttp.MethodGet, server.URL, nil)
Expand All @@ -80,7 +80,7 @@ func TestHTTPClient(t *testing.T) {
assert.Equal(t, "Bearer test", authToken)
})
t.Run("with errored token builder", func(t *testing.T) {
client, err := CreateHTTPClient(ClientConfig{}, newErrorTokenBuilder())
client, err := CreateHTTPInternalClient(ClientConfig{}, newErrorTokenBuilder())
require.NoError(t, err)

req, _ := stdHttp.NewRequest(stdHttp.MethodGet, server.URL, nil)
Expand Down Expand Up @@ -162,20 +162,3 @@ func newErrorTokenBuilder() AuthorizationTokenGenerator {
return "", errors.New("error")
}
}

func TestLimitedReadAll(t *testing.T) {
t.Run("less than limit", func(t *testing.T) {
data := strings.Repeat("a", 10)
result, err := LimitedReadAll(strings.NewReader(data))

assert.NoError(t, err)
assert.Equal(t, []byte(data), result)
})
t.Run("more than limit", func(t *testing.T) {
data := strings.Repeat("a", DefaultMaxHttpResponseSize+1)
result, err := LimitedReadAll(strings.NewReader(data))

assert.EqualError(t, err, "data to read exceeds max. safety limit of 1048576 bytes")
assert.Nil(t, result)
})
}
3 changes: 2 additions & 1 deletion crypto/storage/external/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/nuts-foundation/nuts-node/core"
"github.com/nuts-foundation/nuts-node/crypto/storage/spi"
"github.com/nuts-foundation/nuts-node/crypto/util"
safeHttp "github.com/nuts-foundation/nuts-node/http/client"
)

// StorageType is the name of this storage type, used in health check reports and configuration.
Expand Down Expand Up @@ -82,7 +83,7 @@ func NewAPIClient(config Config) (spi.Storage, error) {
if _, err := url.ParseRequestURI(config.Address); err != nil {
return nil, err
}
client, _ := NewClientWithResponses(config.Address, WithHTTPClient(&http.Client{Timeout: config.Timeout}))
client, _ := NewClientWithResponses(config.Address, WithHTTPClient(safeHttp.New(config.Timeout)))
return &APIClient{httpClient: client}, nil
}

Expand Down
2 changes: 1 addition & 1 deletion didman/api/v1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type HTTPClient struct {
}

func (hb HTTPClient) client() ClientInterface {
response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateHTTPClient(hb.ClientConfig, hb.TokenGenerator)))
response, err := NewClientWithResponses(hb.GetAddress(), WithHTTPClient(core.MustCreateInternalHTTPClient(hb.ClientConfig, hb.TokenGenerator)))
if err != nil {
panic(err)
}
Expand Down
4 changes: 2 additions & 2 deletions http/client/caching.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ import (

// DefaultCachingTransport is a http.RoundTripper that can be used as a default transport for HTTP clients.
// If caching is enabled, it will cache responses according to RFC 7234.
// If caching is disabled, it will behave like http.DefaultTransport.
var DefaultCachingTransport = http.DefaultTransport
// If caching is disabled, it will behave like our safe http.DefaultTransport.
var DefaultCachingTransport http.RoundTripper

// maxCacheTime is the maximum time responses are cached.
// Even if the server responds with a longer cache time, responses are never cached longer than maxCacheTime.
Expand Down
60 changes: 54 additions & 6 deletions http/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,59 @@
package client

import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"github.com/nuts-foundation/nuts-node/core"
"io"
"net/http"
"time"
)

// SafeHttpTransport is a http.Transport that can be used as a default transport for HTTP clients.
var SafeHttpTransport *http.Transport

func init() {
httpTransport := http.DefaultTransport.(*http.Transport)
if httpTransport.TLSClientConfig == nil {
httpTransport.TLSClientConfig = &tls.Config{}
SafeHttpTransport = http.DefaultTransport.(*http.Transport).Clone()
if SafeHttpTransport.TLSClientConfig == nil {
SafeHttpTransport.TLSClientConfig = &tls.Config{}
}
httpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12
SafeHttpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12
// to prevent slow responses from public clients to have significant impact (default was unlimited)
SafeHttpTransport.MaxConnsPerHost = 5

DefaultCachingTransport = SafeHttpTransport
}

// StrictMode is a flag that can be set to true to enable strict mode for the HTTP client.
var StrictMode bool

// DefaultMaxHttpResponseSize is a default maximum size of an HTTP response body that will be read.
// Very large or unbounded HTTP responses can cause denial-of-service, so it's good to limit how much data is read.
// This of course heavily depends on the use case, but 1MB is a reasonable default.
const DefaultMaxHttpResponseSize = 1024 * 1024

// limitedReadAll reads the given reader until the DefaultMaxHttpResponseSize is reached.
// It returns an error if more data is available than DefaultMaxHttpResponseSize.
func limitedReadAll(reader io.Reader) ([]byte, error) {
result, err := io.ReadAll(io.LimitReader(reader, DefaultMaxHttpResponseSize+1))
if len(result) > DefaultMaxHttpResponseSize {
return nil, fmt.Errorf("data to read exceeds max. safety limit of %d bytes", DefaultMaxHttpResponseSize)
}
return result, err
}

// New creates a new HTTP client with the given timeout.
func New(timeout time.Duration) *StrictHTTPClient {
return &StrictHTTPClient{
client: &http.Client{
Transport: SafeHttpTransport,
Timeout: timeout,
},
}
}

// NewWithCache creates a new HTTP client with the given timeout.
// It uses the DefaultCachingTransport as the underlying transport.
func NewWithCache(timeout time.Duration) *StrictHTTPClient {
Expand All @@ -51,7 +87,7 @@ func NewWithCache(timeout time.Duration) *StrictHTTPClient {
// It copies the http.DefaultTransport and sets the TLSClientConfig to the given tls.Config.
// As such, it can't be used in conjunction with the CachingRoundTripper.
func NewWithTLSConfig(timeout time.Duration, tlsConfig *tls.Config) *StrictHTTPClient {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport := SafeHttpTransport.Clone()
transport.TLSClientConfig = tlsConfig
return &StrictHTTPClient{
client: &http.Client{
Expand All @@ -69,5 +105,17 @@ func (s *StrictHTTPClient) Do(req *http.Request) (*http.Response, error) {
if StrictMode && req.URL.Scheme != "https" {
return nil, errors.New("strictmode is enabled, but request is not over HTTPS")
}
return s.client.Do(req)
req.Header.Set("User-Agent", core.UserAgent())
result, err := s.client.Do(req)
if err != nil {
return nil, err
}
if result.Body != nil {
body, err := limitedReadAll(result.Body)
if err != nil {
return nil, err
}
result.Body = io.NopCloser(bytes.NewReader(body))
}
return result, nil
}
Loading

0 comments on commit 901fea3

Please sign in to comment.