diff --git a/headers.go b/headers.go index 8d4359f..a86f4e2 100644 --- a/headers.go +++ b/headers.go @@ -7,6 +7,9 @@ const ( HeaderProxyAuthorization = "Proxy-Authorization" HeaderWWWAuthenticate = "WWW-Authenticate" + BearerAuthHeader = "Bearer " + BasicAuthHeader = "Basic " + // Caching HeaderAge = "Age" HeaderCacheControl = "Cache-Control" diff --git a/mocks.go b/mocks.go index 873d8b9..efd648d 100644 --- a/mocks.go +++ b/mocks.go @@ -59,7 +59,7 @@ func MockHandler(statusCode int, options ...Option) http.Handler { r := MustNew(options...) return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - req, err := r.RequestContext(request.Context()) + req, err := r.RequestWithContext(request.Context()) if err != nil { panic(err) } diff --git a/options.go b/options.go index a1bb938..f44461c 100644 --- a/options.go +++ b/options.go @@ -150,7 +150,7 @@ func BasicAuth(username, password string) Option { return DeleteHeader(HeaderAuthorization) } - return Header(HeaderAuthorization, "Basic "+basicAuth(username, password)) + return Header(HeaderAuthorization, BearerAuthHeader+basicAuth(username, password)) } // basicAuth returns the base64 encoded username:password for basic auth copied from net/http @@ -166,7 +166,7 @@ func BearerAuth(token string) Option { return DeleteHeader(HeaderAuthorization) } - return Header(HeaderAuthorization, "Bearer "+token) + return Header(HeaderAuthorization, BearerAuthHeader+token) } // URL sets the request URL diff --git a/options_test.go b/options_test.go index a4a7d76..87c379e 100644 --- a/options_test.go +++ b/options_test.go @@ -323,7 +323,7 @@ func TestBasicAuth(t *testing.T) { reqs, err := New(c.options...) require.NoError(t, err) - req, err := reqs.RequestContext(context.Background()) + req, err := reqs.RequestWithContext(context.Background()) require.NoError(t, err) username, password, ok := req.BasicAuth() diff --git a/packagefunctions.go b/packagefunctions.go index 4b23b60..394840e 100644 --- a/packagefunctions.go +++ b/packagefunctions.go @@ -12,9 +12,9 @@ func Request(opts ...Option) (*http.Request, error) { return DefaultRequester.Request(opts...) } -// RequestContext does the same as Request(), but attaches a Context to the request -func RequestContext(ctx context.Context, opts ...Option) (*http.Request, error) { - return DefaultRequester.RequestContext(ctx, opts...) +// RequestWithContext does the same as Request(), but attaches a Context to the request +func RequestWithContext(ctx context.Context, opts ...Option) (*http.Request, error) { + return DefaultRequester.RequestWithContext(ctx, opts...) } // Send uses the DefaultRequester to create a request and execute it @@ -22,17 +22,17 @@ func Send(opts ...Option) (*http.Response, error) { return DefaultRequester.Send(opts...) } -// SendContext does the same as Send(), but attaches a Context to the request -func SendContext(ctx context.Context, opts ...Option) (*http.Response, error) { - return DefaultRequester.SendContext(ctx, opts...) -} - -// ReceiveContext does the same as Receive(), but attaches a Context to the request -func ReceiveContext(ctx context.Context, into interface{}, opts ...Option) (*http.Response, error) { - return DefaultRequester.ReceiveContext(ctx, into, opts...) +// SendWithContext does the same as Send(), but attaches a Context to the request +func SendWithContext(ctx context.Context, opts ...Option) (*http.Response, error) { + return DefaultRequester.SendWithContext(ctx, opts...) } // Receive uses the DefaultRequester to create a request, execute it, and read the response func Receive(into interface{}, opts ...Option) (*http.Response, error) { return DefaultRequester.Receive(into, opts...) } + +// ReceiveWithContext does the same as Receive(), but attaches a Context to the request +func ReceiveWithContext(ctx context.Context, into interface{}, opts ...Option) (*http.Response, error) { + return DefaultRequester.ReceiveWithContext(ctx, into, opts...) +} diff --git a/packagefunctions_test.go b/packagefunctions_test.go index 5b41d4e..3d55926 100644 --- a/packagefunctions_test.go +++ b/packagefunctions_test.go @@ -21,7 +21,7 @@ type testContextKey string const colorContextKey = testContextKey("color") func TestRequestContext(t *testing.T) { - req, err := RequestContext( + req, err := RequestWithContext( context.WithValue(context.Background(), colorContextKey, "green"), Get("http://blue.com/red"), ) @@ -47,7 +47,7 @@ func TestSend(t *testing.T) { func TestSendContext(t *testing.T) { i := Inspector{} - resp, err := SendContext( + resp, err := SendWithContext( context.WithValue(context.Background(), colorContextKey, "blue"), Get("/profile"), WithDoer(MockDoer(204)), @@ -82,7 +82,7 @@ func TestReceive(t *testing.T) { i := Inspector{} - resp, err := ReceiveContext( + resp, err := ReceiveWithContext( context.WithValue(context.Background(), colorContextKey, "yellow"), &m, Get("/red"), diff --git a/requester.go b/requester.go index da0541b..6ec64dc 100644 --- a/requester.go +++ b/requester.go @@ -103,11 +103,11 @@ func (r *Requester) Clone() *Requester { // Request returns a new http.Request func (r *Requester) Request(opts ...Option) (*http.Request, error) { - return r.RequestContext(context.Background(), opts...) + return r.RequestWithContext(context.Background(), opts...) } -// RequestContext does the same as Request, but requires a context -func (r *Requester) RequestContext(ctx context.Context, opts ...Option) (*http.Request, error) { +// RequestWithContext does the same as Request, but requires a context +func (r *Requester) RequestWithContext(ctx context.Context, opts ...Option) (*http.Request, error) { requester, err := r.withOpts(opts...) if err != nil { return nil, err @@ -208,7 +208,7 @@ func (r *Requester) getRequestBody() (body io.Reader, contentType string, _ erro // Send executes a request with the Doer func (r *Requester) Send(opts ...Option) (*http.Response, error) { - return r.SendContext(context.Background(), opts...) + return r.SendWithContext(context.Background(), opts...) } // withOpts is like With(), but skips the clone if there are no options to apply @@ -220,14 +220,14 @@ func (r *Requester) withOpts(opts ...Option) (*Requester, error) { return r, nil } -// SendContext does the same as Send, but requires a context -func (r *Requester) SendContext(ctx context.Context, opts ...Option) (*http.Response, error) { +// SendWithContext does the same as Send, but requires a context +func (r *Requester) SendWithContext(ctx context.Context, opts ...Option) (*http.Response, error) { reqs, err := r.withOpts(opts...) if err != nil { return nil, err } - req, err := reqs.RequestContext(ctx) + req, err := reqs.RequestWithContext(ctx) if err != nil { return nil, err } @@ -249,11 +249,12 @@ func (r *Requester) Do(req *http.Request) (*http.Response, error) { // Receive creates a new HTTP request and returns the response func (r *Requester) Receive(into interface{}, opts ...Option) (resp *http.Response, err error) { - return r.ReceiveContext(context.Background(), into, opts...) + return r.ReceiveWithContext(context.Background(), into, opts...) } -// ReceiveContext does the same as Receive, but requires a context -func (r *Requester) ReceiveContext(ctx context.Context, into interface{}, opts ...Option) (resp *http.Response, err error) { +// ReceiveWithContext does the same as Receive, but requires a context +func (r *Requester) ReceiveWithContext(ctx context.Context, into interface{}, opts ...Option) (resp *http.Response, err error) { + // if the first option is an Option, we need to copy those over and set into to nil if opt, ok := into.(Option); ok { opts = append(opts, nil) copy(opts[1:], opts) @@ -261,23 +262,25 @@ func (r *Requester) ReceiveContext(ctx context.Context, into interface{}, opts . into = nil } + // apply the options to the requester r, err = r.withOpts(opts...) if err != nil { return nil, err } - resp, err = r.SendContext(ctx) - - body, bodyReadError := readBody(resp) - + // send the request + resp, err = r.SendWithContext(ctx) if err != nil { return resp, err } + // read the body + body, bodyReadError := readBody(resp) if bodyReadError != nil { return resp, bodyReadError } + // if the into is not nil, unmarshal the body into it if into != nil { unmarshaler := r.Unmarshaler if unmarshaler == nil { @@ -290,24 +293,26 @@ func (r *Requester) ReceiveContext(ctx context.Context, into interface{}, opts . return resp, err } +// readBody reads the body of an HTTP response func readBody(resp *http.Response) ([]byte, error) { + // check for a nil response if resp == nil || resp.Body == nil || resp.Body == http.NoBody { return nil, nil } defer resp.Body.Close() - cls := resp.Header.Get(HeaderContentLength) + contentLengthHeader := resp.Header.Get(HeaderContentLength) - var cl int64 + var contentLength int64 - if cls != "" { - cl, _ = strconv.ParseInt(cls, 10, 0) + if contentLengthHeader != "" { + contentLength, _ = strconv.ParseInt(contentLengthHeader, 10, 0) } buf := bytes.Buffer{} - if cl > 0 { - buf.Grow(int(cl)) + if contentLength > 0 { + buf.Grow(int(contentLength)) } if _, err := buf.ReadFrom(resp.Body); err != nil { @@ -353,3 +358,13 @@ func (r *Requester) HTTPClient() *http.Client { return client } + +// CookieJar returns the CookieJar used by the Requester, if it exists +func (r *Requester) CookieJar() http.CookieJar { + client := r.HTTPClient() + if client == nil { + return nil + } + + return client.Jar +} diff --git a/requester_test.go b/requester_test.go index f08c6e9..fb576d6 100644 --- a/requester_test.go +++ b/requester_test.go @@ -136,7 +136,7 @@ func TestRequesterRequestURLAndMethod(t *testing.T) { t.Run("", func(t *testing.T) { reqs, err := New(c.options...) require.NoError(t, err) - req, err := reqs.RequestContext(context.Background()) + req, err := reqs.RequestWithContext(context.Background()) require.NoError(t, err) assert.Equal(t, c.expectedURL, req.URL.String()) assert.Equal(t, c.expectedMethod, req.Method) @@ -146,7 +146,7 @@ func TestRequesterRequestURLAndMethod(t *testing.T) { t.Run("invalidmethod", func(t *testing.T) { b, err := New(Method("@")) require.NoError(t, err) - req, err := b.RequestContext(context.Background()) + req, err := b.RequestWithContext(context.Background()) require.Error(t, err) require.Nil(t, req) }) @@ -167,7 +167,7 @@ func TestRequesterRequestQueryParams(t *testing.T) { reqs, err := New(c.options...) require.NoError(t, err) - req, _ := reqs.RequestContext(context.Background()) + req, _ := reqs.RequestWithContext(context.Background()) require.Equal(t, c.expectedURL, req.URL.String()) }) } @@ -199,7 +199,7 @@ func TestRequesterRequestBody(t *testing.T) { t.Run("", func(t *testing.T) { reqs, err := New(c.options...) require.NoError(t, err) - req, err := reqs.RequestContext(context.Background()) + req, err := reqs.RequestWithContext(context.Background()) require.NoError(t, err) if reqs.Body != nil { @@ -230,7 +230,7 @@ func TestRequesterRequestMarshaler(t *testing.T) { }), } - req, err := requester.RequestContext(context.Background()) + req, err := requester.RequestWithContext(context.Background()) require.NoError(t, err) require.Equal(t, []string{"blue"}, capturedV) @@ -246,7 +246,7 @@ func TestRequesterRequestMarshaler(t *testing.T) { return nil, "", errors.New("boom") // nolint: err113 }) - _, err := requester.RequestContext(context.Background()) + _, err := requester.RequestWithContext(context.Background()) require.Error(t, err, "boom") }) } @@ -255,7 +255,7 @@ func TestRequesterRequestContentLength(t *testing.T) { reqs, err := New(Body("1234")) require.NoError(t, err) - req, err := reqs.RequestContext(context.Background()) + req, err := reqs.RequestWithContext(context.Background()) require.NoError(t, err) // content length should be set automatically @@ -264,7 +264,7 @@ func TestRequesterRequestContentLength(t *testing.T) { // I should be able to override it reqs.ContentLength = 10 - req, err = reqs.RequestContext(context.Background()) + req, err = reqs.RequestWithContext(context.Background()) require.NoError(t, err) require.EqualValues(t, 10, req.ContentLength) @@ -274,7 +274,7 @@ func TestRequesterRequestGetBody(t *testing.T) { reqs, err := New(Body("1234")) require.NoError(t, err) - req, err := reqs.RequestContext(context.Background()) + req, err := reqs.RequestWithContext(context.Background()) require.NoError(t, err) // GetBody should be populated automatically @@ -291,7 +291,7 @@ func TestRequesterRequestGetBody(t *testing.T) { return io.NopCloser(strings.NewReader("5678")), nil } - req, err = reqs.RequestContext(context.Background()) + req, err = reqs.RequestWithContext(context.Background()) require.NoError(t, err) rdr, err = req.GetBody() @@ -307,7 +307,7 @@ func TestRequesterRequestHost(t *testing.T) { reqs, err := New(URL("http://test.com/red")) require.NoError(t, err) - req, err := reqs.RequestContext(context.Background()) + req, err := reqs.RequestWithContext(context.Background()) require.NoError(t, err) // Host should be set automatically @@ -316,7 +316,7 @@ func TestRequesterRequestHost(t *testing.T) { // but I can override it reqs.Host = "test2.com" - req, err = reqs.RequestContext(context.Background()) + req, err = reqs.RequestWithContext(context.Background()) require.NoError(t, err) require.Equal(t, "test2.com", req.Host) @@ -325,7 +325,7 @@ func TestRequesterRequestHost(t *testing.T) { func TestRequesterRequestTransferEncoding(t *testing.T) { reqs := Requester{} - req, err := reqs.RequestContext(context.Background()) + req, err := reqs.RequestWithContext(context.Background()) require.NoError(t, err) // should be empty by default @@ -334,7 +334,7 @@ func TestRequesterRequestTransferEncoding(t *testing.T) { // but I can set it reqs.TransferEncoding = []string{"red"} - req, err = reqs.RequestContext(context.Background()) + req, err = reqs.RequestWithContext(context.Background()) require.NoError(t, err) require.Equal(t, reqs.TransferEncoding, req.TransferEncoding) @@ -343,7 +343,7 @@ func TestRequesterRequestTransferEncoding(t *testing.T) { func TestRequesterRequestClose(t *testing.T) { reqs := Requester{} - req, err := reqs.RequestContext(context.Background()) + req, err := reqs.RequestWithContext(context.Background()) require.NoError(t, err) // should be false by default @@ -352,7 +352,7 @@ func TestRequesterRequestClose(t *testing.T) { // but I can set it reqs.Close = true - req, err = reqs.RequestContext(context.Background()) + req, err = reqs.RequestWithContext(context.Background()) require.NoError(t, err) require.True(t, req.Close) @@ -361,7 +361,7 @@ func TestRequesterRequestClose(t *testing.T) { func TestRequesterRequestTrailer(t *testing.T) { reqs := Requester{} - req, err := reqs.RequestContext(context.Background()) + req, err := reqs.RequestWithContext(context.Background()) require.NoError(t, err) // should be empty by default @@ -370,7 +370,7 @@ func TestRequesterRequestTrailer(t *testing.T) { // but I can set it reqs.Trailer = http.Header{"color": []string{"red"}} - req, err = reqs.RequestContext(context.Background()) + req, err = reqs.RequestWithContext(context.Background()) require.NoError(t, err) require.Equal(t, reqs.Trailer, req.Trailer) @@ -379,7 +379,7 @@ func TestRequesterRequestTrailer(t *testing.T) { func TestRequesterRequestHeader(t *testing.T) { reqs := Requester{} - req, err := reqs.RequestContext(context.Background()) + req, err := reqs.RequestWithContext(context.Background()) require.NoError(t, err) // should be empty by default @@ -388,7 +388,7 @@ func TestRequesterRequestHeader(t *testing.T) { // but I can set it reqs.Header = http.Header{"color": []string{"red"}} - req, err = reqs.RequestContext(context.Background()) + req, err = reqs.RequestWithContext(context.Background()) require.NoError(t, err) require.Equal(t, reqs.Header, req.Header) @@ -397,7 +397,7 @@ func TestRequesterRequestHeader(t *testing.T) { func TestRequesterRequestContext(t *testing.T) { reqs := Requester{} - req, err := reqs.RequestContext(context.WithValue(context.Background(), colorContextKey, "red")) + req, err := reqs.RequestWithContext(context.WithValue(context.Background(), colorContextKey, "red")) require.NoError(t, err) require.Equal(t, "red", req.Context().Value(colorContextKey)) @@ -430,7 +430,7 @@ func TestRequesterSendContext(t *testing.T) { i := Inspector{} r := MustNew(Get(ts.URL), &i) - resp, err := r.SendContext( + resp, err := r.SendWithContext( context.WithValue(context.Background(), colorContextKey, "purple"), Post("/server"), ) @@ -504,7 +504,7 @@ func TestRequesterReceiveContext(t *testing.T) { t.Run(fmt.Sprintf("into=%v", c.into), func(t *testing.T) { i := Inspector{} - resp, err := ReceiveContext( + resp, err := ReceiveWithContext( context.WithValue(context.Background(), colorContextKey, "purple"), c.into, Get(ts.URL, "/model.json"), @@ -530,7 +530,7 @@ func TestRequesterReceiveContext(t *testing.T) { ) urlBefore := r.URL.String() - resp, err := r.ReceiveContext( + resp, err := r.ReceiveWithContext( context.Background(), Get("/err"), ) @@ -580,7 +580,7 @@ func TestRequesterReceiveContext(t *testing.T) { // variants ctx := context.Background() - resp, err = r.ReceiveContext(ctx, Get("/blue")) + resp, err = r.ReceiveWithContext(ctx, Get("/blue")) require.NoError(t, err) defer resp.Body.Close() diff --git a/response.go b/response.go new file mode 100644 index 0000000..65390e3 --- /dev/null +++ b/response.go @@ -0,0 +1,10 @@ +package httpsling + +import "net/http" + +// IsSuccess checks if the response status code indicates success +func IsSuccess(resp *http.Response) bool { + code := resp.StatusCode + + return code >= http.StatusOK && code <= http.StatusIMUsed +} diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..4017844 --- /dev/null +++ b/response_test.go @@ -0,0 +1,56 @@ +package httpsling_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/theopenlane/httpsling" +) + +func TestIsSuccess(t *testing.T) { + testCases := []struct { + name string + code int + expected bool + }{ + { + name: "OK", + code: http.StatusOK, + expected: true, + }, + { + name: "Unauthorized", + code: http.StatusUnauthorized, + expected: false, + }, + { + name: "Created", + code: http.StatusCreated, + expected: true, + }, + { + name: "InternalServerError", + code: http.StatusInternalServerError, + expected: false, + }, + { + name: "BadRequest", + code: http.StatusBadRequest, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resp := &http.Response{ + StatusCode: tc.code, + } + + result := httpsling.IsSuccess(resp) + + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/retry_test.go b/retry_test.go index a842528..f51d67d 100644 --- a/retry_test.go +++ b/retry_test.go @@ -454,7 +454,7 @@ func TestRetryCancelContext(t *testing.T) { done := make(chan bool) go func() { - _, err = r.ReceiveContext(ctx, nil) // nolint: bodyclose + _, err = r.ReceiveWithContext(ctx, nil) // nolint: bodyclose done <- true }()