diff --git a/CHANGELOG.md b/CHANGELOG.md index 94ceba6c8..79485a0b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -212,6 +212,7 @@ * [ENHANCEMENT] Middleware: determine route name in a single place, and add `middleware.ExtractRouteName()` method to allow consuming applications to retrieve the route name. #527 * [ENHANCEMENT] SpanProfiler: do less work on unsampled traces. #528 * [ENHANCEMENT] Log Middleware: if the trace is not sampled, log its ID as `trace_id_unsampled` instead of `trace_id`. #529 +* [EHNANCEMENT] httpgrpc: httpgrpc Server can now use error message from special HTTP header when converting HTTP response to an error. This is useful when HTTP response body contains binary data that doesn't form valid utf-8 string, otherwise grpc would fail to marshal returned error. #531 * [BUGFIX] spanlogger: Support multiple tenant IDs. #59 * [BUGFIX] Memberlist: fixed corrupted packets when sending compound messages with more than 255 messages or messages bigger than 64KB. #85 * [BUGFIX] Ring: `ring_member_ownership_percent` and `ring_tokens_owned` metrics are not updated on scale down. #109 diff --git a/httpgrpc/httpgrpc.go b/httpgrpc/httpgrpc.go index b755e2adc..02e6e4937 100644 --- a/httpgrpc/httpgrpc.go +++ b/httpgrpc/httpgrpc.go @@ -116,8 +116,14 @@ func Errorf(code int, tmpl string, args ...interface{}) error { }) } -// ErrorFromHTTPResponse converts an HTTP response into a grpc error +// ErrorFromHTTPResponse converts an HTTP response into a grpc error, and uses HTTP response body as an error message. +// Note that if HTTP response body contains non-utf8 string, then returned error cannot be marshalled by protobuf. func ErrorFromHTTPResponse(resp *HTTPResponse) error { + return ErrorFromHTTPResponseWithMessage(resp, string(resp.Body)) +} + +// ErrorFromHTTPResponseWithMessage converts an HTTP response into a grpc error, and uses supplied message for Error message. +func ErrorFromHTTPResponseWithMessage(resp *HTTPResponse, msg string) error { a, err := types.MarshalAny(resp) if err != nil { return err @@ -125,7 +131,7 @@ func ErrorFromHTTPResponse(resp *HTTPResponse) error { return status.ErrorProto(&spb.Status{ Code: resp.Code, - Message: string(resp.Body), + Message: msg, Details: []*types.Any{a}, }) } diff --git a/httpgrpc/server/server.go b/httpgrpc/server/server.go index b73c5a0f7..6a831dac0 100644 --- a/httpgrpc/server/server.go +++ b/httpgrpc/server/server.go @@ -26,12 +26,22 @@ import ( ) var ( - // DoNotLogErrorHeaderKey is a header key used for marking non-loggable errors. More precisely, if an HTTP response + // DoNotLogErrorHeaderKey is a header name used for marking non-loggable errors. More precisely, if an HTTP response // has a status code 5xx, and contains a header with key DoNotLogErrorHeaderKey and any values, the generated error // will be marked as non-loggable. DoNotLogErrorHeaderKey = http.CanonicalHeaderKey("X-DoNotLogError") + + // ErrorMessageHeaderKey is a header name for header that contains error message that should be used when Server.Handle + // (httpgrpc.HTTP/Handle implementation) decides to return the response as an error, using status.ErrorProto. + // Normally Server.Handle would use entire response body as a error message, but Message field of rcp.Status object + // is a string, and if body contains non-utf8 bytes, marshalling of this object will fail. + ErrorMessageHeaderKey = http.CanonicalHeaderKey("X-ErrorMessage") ) +type contextType int + +const handledByHttpgrpcServer contextType = 0 + type Option func(*Server) func WithReturn4XXErrors(s *Server) { @@ -59,6 +69,8 @@ func NewServer(handler http.Handler, opts ...Option) *Server { // Handle implements HTTPServer. func (s Server) Handle(ctx context.Context, r *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) { + ctx = context.WithValue(ctx, handledByHttpgrpcServer, true) + req, err := httpgrpc.ToHTTPRequest(ctx, r) if err != nil { return nil, err @@ -74,13 +86,24 @@ func (s Server) Handle(ctx context.Context, r *httpgrpc.HTTPRequest) (*httpgrpc. header.Del(DoNotLogErrorHeaderKey) // remove before converting to httpgrpc resp } + errorMessageFromHeader := "" + if msg, ok := header[ErrorMessageHeaderKey]; ok { + errorMessageFromHeader = msg[0] + header.Del(ErrorMessageHeaderKey) // remove before converting to httpgrpc resp + } + resp := &httpgrpc.HTTPResponse{ Code: int32(recorder.Code), Headers: httpgrpc.FromHeader(header), Body: recorder.Body.Bytes(), } if s.shouldReturnError(resp) { - err := httpgrpc.ErrorFromHTTPResponse(resp) + var err error + if errorMessageFromHeader != "" { + err = httpgrpc.ErrorFromHTTPResponseWithMessage(resp, errorMessageFromHeader) + } else { + err = httpgrpc.ErrorFromHTTPResponse(resp) + } if doNotLogError { err = middleware.DoNotLogError{Err: err} } @@ -206,3 +229,13 @@ func (c *Client) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } } + +// IsHandledByHttpgrpcServer returns true if context is associated with HTTP request that was initiated by +// Server.Handle, which is an implementation of httpgrpc.HTTP/Handle gRPC method. +func IsHandledByHttpgrpcServer(ctx context.Context) bool { + val := ctx.Value(handledByHttpgrpcServer) + if v, ok := val.(bool); ok { + return v + } + return false +} diff --git a/httpgrpc/server/server_test.go b/httpgrpc/server/server_test.go index 50b827a20..43bedc024 100644 --- a/httpgrpc/server/server_test.go +++ b/httpgrpc/server/server_test.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "net/http/httptest" + "strconv" "testing" opentracing "github.com/opentracing/opentracing-go" @@ -326,3 +327,79 @@ func TestTracePropagation(t *testing.T) { assert.Equal(t, "world", recorder.Body.String()) assert.Equal(t, 200, recorder.Code) } + +func TestGrpcErrorsHaveCorrectMessage(t *testing.T) { + testCases := map[string]struct { + responseBody string + errorMessageInHeader string + + expectedErrorMessage string + }{ + "error response with string body": { + responseBody: "hello world", + expectedErrorMessage: "rpc error: code = Code(500) desc = hello world", + }, + "error response with binary body": { + responseBody: "\x08\x08\x12\xc7\x03the request has been rejected", + expectedErrorMessage: "rpc error: code = Code(500) desc = \x08\x08\x12\xc7\x03the request has been rejected", + }, + "error response with binary body and provided message via header": { + responseBody: "\x08\x08\x12\xc7\x03the request has been rejected", + errorMessageInHeader: "hello world", + expectedErrorMessage: "rpc error: code = Code(500) desc = hello world", + }, + } + for testName, testData := range testCases { + t.Run(testName, func(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if testData.errorMessageInHeader != "" { + w.Header().Set(ErrorMessageHeaderKey, testData.errorMessageInHeader) + } + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(testData.responseBody)) + }) + + s := NewServer(h) + req := &httpgrpc.HTTPRequest{Method: "GET", Url: "/test"} + resp, err := s.Handle(context.Background(), req) + require.Error(t, err) + require.Nil(t, resp) + + require.Equal(t, testData.expectedErrorMessage, err.Error()) + + httpResp, ok := httpgrpc.HTTPResponseFromError(err) + require.True(t, ok) + // Verify that header was removed + require.Empty(t, httpResp.Headers) + }) + } +} + +func TestIsHandledByHttpgrpcServer(t *testing.T) { + t.Run("false by default", func(t *testing.T) { + require.False(t, IsHandledByHttpgrpcServer(context.Background())) + }) + + const testHeader = "X-HandledByHttpgrpcServer" + + // Handler will return value returned by IsHandledByHttpgrpcServer in test header. + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(testHeader, strconv.FormatBool(IsHandledByHttpgrpcServer(r.Context()))) + w.WriteHeader(200) + }) + + t.Run("handler runs outside of httpgrpc.Server", func(t *testing.T) { + rec := httptest.NewRecorder() + handler(rec, &http.Request{}) + require.Equal(t, "false", rec.Header().Get(testHeader)) + }) + + t.Run("handler runs from httpgrpc.Server", func(t *testing.T) { + s := NewServer(handler) + resp, err := s.Handle(context.Background(), &httpgrpc.HTTPRequest{Method: "GET", Url: "/test"}) + require.NoError(t, err) + require.NotNil(t, resp) + + require.Equal(t, []*httpgrpc.Header{{Key: http.CanonicalHeaderKey(testHeader), Values: []string{"true"}}}, resp.Headers) + }) +}