diff --git a/client/client_execute.go b/client/client_execute.go index 2646d14..0e0b256 100644 --- a/client/client_execute.go +++ b/client/client_execute.go @@ -28,6 +28,10 @@ func (r *Request) Execute(ctx context.Context, req *Req) ([]byte, error) { return nil, err } + if req.rawResponseContainer != nil && res != nil { + *req.rawResponseContainer = *res + } + err = r.HttpErrorHandler(res, request.URL.String()) if err != nil { return nil, err diff --git a/client/request.go b/client/request.go index eb782bc..67fccba 100644 --- a/client/request.go +++ b/client/request.go @@ -1,6 +1,7 @@ package client import ( + "net/http" "net/url" ) @@ -9,12 +10,13 @@ import ( // // To build this struct, use NewReqBuilder. type Req struct { - headers map[string]string - resultContainer any - method string - path Path - query url.Values - body any + headers map[string]string + resultContainer any + method string + path Path + query url.Values + body any + rawResponseContainer *http.Response metricName string pathMetricEnabled bool @@ -46,6 +48,11 @@ func (builder *ReqBuilder) WriteTo(resultContainer any) *ReqBuilder { return builder } +func (builder *ReqBuilder) WriteRawResponseTo(resp *http.Response) *ReqBuilder { + builder.req.rawResponseContainer = resp + return builder +} + func (builder *ReqBuilder) Method(method string) *ReqBuilder { builder.req.method = method return builder diff --git a/client/request_test.go b/client/request_test.go new file mode 100644 index 0000000..e137131 --- /dev/null +++ b/client/request_test.go @@ -0,0 +1,74 @@ +package client + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRequest_WriteRawResponseTo(t *testing.T) { + const ( + pathOk = "/ok" + path5xx = "/5xx" + ) + + tests := []struct { + name string + path string + statusCode int + headers http.Header + }{ + { + name: "Test write raw response with statusOK", + path: pathOk, + statusCode: http.StatusOK, + headers: http.Header{ + "Content-Type": []string{"application/json"}, + "x-aptos-block-height": []string{"73287085"}, + }, + }, + { + name: "Test write raw response with status5xx", + path: path5xx, + statusCode: http.StatusInternalServerError, + headers: http.Header{ + "Content-Type": []string{"application/json"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := InitClient("http://www.example.com", nil, + WithHttpClient(&http.Client{ + Transport: RoundTripperFunc(func(request *http.Request) (*http.Response, error) { + switch request.URL.Path { + case pathOk: + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"Data": "ok"}`)), + Header: tt.headers, + }, nil + case path5xx: + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Request: request, + Body: io.NopCloser(strings.NewReader(`{"Data": "5xx"}`)), + Header: tt.headers, + }, nil + default: + return nil, nil + } + }), + }), + ) + var resp http.Response + _, _ = client.Execute(context.Background(), NewReqBuilder().Method(http.MethodGet).PathStatic(tt.path).WriteRawResponseTo(&resp).Build()) + require.Equal(t, tt.headers, resp.Header) + require.Equal(t, tt.statusCode, resp.StatusCode) + }) + } +}