diff --git a/examples/internal/integration/integration_test.go b/examples/internal/integration/integration_test.go index 3382b9483d1..954d2f040bc 100644 --- a/examples/internal/integration/integration_test.go +++ b/examples/internal/integration/integration_test.go @@ -53,6 +53,8 @@ func TestEcho(t *testing.T) { testEchoBody(t, 8089, apiPrefix, true) testEchoBody(t, 8089, apiPrefix, false) testEchoBodyParamOverwrite(t, 8088) + testEchoWithNonASCIIHeaderValues(t, 8088, apiPrefix) + testEchoWithInvalidHeaderKey(t, 8088, apiPrefix) }) } } @@ -2504,3 +2506,79 @@ func testABETrace(t *testing.T, port int) { return } } + +func testEchoWithNonASCIIHeaderValues(t *testing.T, port int, apiPrefix string) { + apiURL := fmt.Sprintf("http://localhost:%d/%s/example/echo/myid", port, apiPrefix) + + req, err := http.NewRequest("POST", apiURL, strings.NewReader("{}")) + if err != nil { + t.Errorf("http.NewRequest() = err: %v", err) + return + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Grpc-Metadata-Location", "Gjøvik") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("http.Post(%q) failed with %v; want success", apiURL, err) + return + } + defer resp.Body.Close() + + buf, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("io.ReadAll(resp.Body) failed with %v; want success", err) + return + } + + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("resp.StatusCode = %d; want %d", got, want) + t.Logf("%s", buf) + } + + msg := new(examplepb.UnannotatedSimpleMessage) + if err := marshaler.Unmarshal(buf, msg); err != nil { + t.Errorf("marshaler.Unmarshal(%s, msg) failed with %v; want success", buf, err) + return + } + if got, want := msg.Id, "myid"; got != want { + t.Errorf("msg.Id = %q; want %q", got, want) + } +} + +func testEchoWithInvalidHeaderKey(t *testing.T, port int, apiPrefix string) { + apiURL := fmt.Sprintf("http://localhost:%d/%s/example/echo/myid", port, apiPrefix) + + req, err := http.NewRequest("POST", apiURL, strings.NewReader("{}")) + if err != nil { + t.Errorf("http.NewRequest() = err: %v", err) + return + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Grpc-Metadata-Foo+Bar", "Hello") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("http.Post(%q) failed with %v; want success", apiURL, err) + return + } + defer resp.Body.Close() + + buf, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("io.ReadAll(resp.Body) failed with %v; want success", err) + return + } + + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("resp.StatusCode = %d; want %d", got, want) + t.Logf("%s", buf) + } + + msg := new(examplepb.UnannotatedSimpleMessage) + if err := marshaler.Unmarshal(buf, msg); err != nil { + t.Errorf("marshaler.Unmarshal(%s, msg) failed with %v; want success", buf, err) + return + } + if got, want := msg.Id, "myid"; got != want { + t.Errorf("msg.Id = %q; want %q", got, want) + } +} diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index b5140a3c9d1..195df460c39 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -26,6 +26,7 @@ go_library( deps = [ "//internal/httprule", "//utilities", + "@com_github_golang_glog//:glog", "@go_googleapis//google/api:httpbody_go_proto", "@io_bazel_rules_go//proto/wkt:field_mask_go_proto", "@org_golang_google_grpc//codes", diff --git a/runtime/context.go b/runtime/context.go index 5ab5b3841da..9956e7ef0bd 100644 --- a/runtime/context.go +++ b/runtime/context.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/golang/glog" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -99,6 +100,38 @@ func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Reque return metadata.NewIncomingContext(ctx, md), nil } +func isValidGRPCMetadataKey(key string) bool { + // Must be a valid gRPC "Header-Name" as defined here: + // https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md + // This means 0-9 a-z _ - . + // Only lowercase letters are valid in the wire protocol, but the client library will normalize + // uppercase ASCII to lowercase, so uppercase ASCII is also acceptable. + bytes := []byte(key) // gRPC validates strings on the byte level, not Unicode. + for _, ch := range bytes { + validLowercaseLetter := ch >= 'a' && ch <= 'z' + validUppercaseLetter := ch >= 'A' && ch <= 'Z' + validDigit := ch >= '0' && ch <= '9' + validOther := ch == '.' || ch == '-' || ch == '_' + if !validLowercaseLetter && !validUppercaseLetter && !validDigit && !validOther { + return false + } + } + return true +} + +func isValidGRPCMetadataTextValue(textValue string) bool { + // Must be a valid gRPC "ASCII-Value" as defined here: + // https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md + // This means printable ASCII (including/plus spaces); 0x20 to 0x7E inclusive. + bytes := []byte(textValue) // gRPC validates strings on the byte level, not Unicode. + for _, ch := range bytes { + if ch < 0x20 || ch > 0x7E { + return false + } + } + return true +} + func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, metadata.MD, error) { ctx = withRPCMethod(ctx, rpcMethodName) for _, o := range options { @@ -121,6 +154,10 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM pairs = append(pairs, "authorization", val) } if h, ok := mux.incomingHeaderMatcher(key); ok { + if !isValidGRPCMetadataKey(h) { + glog.Errorf("HTTP header name %q is not valid as gRPC metadata key; skipping", h) + continue + } // Handles "-bin" metadata in grpc, since grpc will do another base64 // encode before sending to server, we need to decode it first. if strings.HasSuffix(key, metadataHeaderBinarySuffix) { @@ -130,6 +167,9 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM } val = string(b) + } else if !isValidGRPCMetadataTextValue(val) { + glog.Errorf("Value of HTTP header %q contains non-ASCII value (not valid as gRPC metadata): skipping", h) + continue } pairs = append(pairs, h, val) }