From 1cdf83da23253e1d35180b7f269f5d092c95090f Mon Sep 17 00:00:00 2001 From: EwenQuim Date: Thu, 21 Mar 2024 21:39:17 +0100 Subject: [PATCH] Fixes error content-type according to situations --- serialization.go | 12 +++++++----- tests_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 tests_test.go diff --git a/serialization.go b/serialization.go index 6c53eb53..350d4d85 100644 --- a/serialization.go +++ b/serialization.go @@ -92,16 +92,18 @@ func SendJSON(w http.ResponseWriter, ans any) { // If the error implements ErrorWithStatus, the status code will be set. func SendJSONError(w http.ResponseWriter, err error) { status := http.StatusInternalServerError - errorStatus := HTTPError{ - Err: err, - } + var errorStatus ErrorWithStatus if errors.As(err, &errorStatus) { status = errorStatus.StatusCode() } w.WriteHeader(status) - w.Header().Set("Content-Type", "application/problem+json") - SendJSON(w, errorStatus) + SendJSON(w, err) + + var httpError HTTPError + if errors.As(err, &httpError) { + w.Header().Set("Content-Type", "application/problem+json") + } } // SendXML sends a XML response. diff --git a/tests_test.go b/tests_test.go new file mode 100644 index 00000000..4dac4c34 --- /dev/null +++ b/tests_test.go @@ -0,0 +1,46 @@ +package fuego + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +// Contains random tests reported on the issues. + +func TestContentType(t *testing.T) { + server := NewServer() + + t.Run("Sends application/problem+json when return type is HTTPError", func(t *testing.T) { + GetStd(server, "/json-problems", func(w http.ResponseWriter, r *http.Request) { + SendJSONError(w, UnauthorizedError{ + Title: "Unauthorized", + }) + }) + + req := httptest.NewRequest("GET", "/json-problems", nil) + w := httptest.NewRecorder() + server.Mux.ServeHTTP(w, req) + + require.Equal(t, "application/problem+json", w.Header().Get("Content-Type")) + require.Equal(t, 401, w.Code) + require.Contains(t, w.Body.String(), "Unauthorized") + }) + + t.Run("Sends application/json when return type is not HTTPError", func(t *testing.T) { + GetStd(server, "/json", func(w http.ResponseWriter, r *http.Request) { + SendJSONError(w, errors.New("error")) + }) + + req := httptest.NewRequest("GET", "/json", nil) + w := httptest.NewRecorder() + server.Mux.ServeHTTP(w, req) + + require.Equal(t, "application/json", w.Header().Get("Content-Type")) + require.Equal(t, 500, w.Code) + require.Equal(t, "{}\n", w.Body.String()) + }) +}