diff --git a/experimental/experimental.go b/experimental/experimental.go index 719692636505..3ba948bab316 100644 --- a/experimental/experimental.go +++ b/experimental/experimental.go @@ -62,3 +62,12 @@ func WithBufferPool(bufferPool mem.BufferPool) grpc.DialOption { func BufferPool(bufferPool mem.BufferPool) grpc.ServerOption { return internal.BufferPool.(func(mem.BufferPool) grpc.ServerOption)(bufferPool) } + +// AcceptCompressors returns a CallOption that limits the values +// advertised in the grpc-accept-encoding header for the provided RPC. The +// supplied names must correspond to compressors registered via +// encoding.RegisterCompressor. Passing no names advertises "identity" (no +// compression) only. +func AcceptCompressors(names ...string) grpc.CallOption { + return internal.AcceptCompressors.(func(...string) grpc.CallOption)(names...) +} diff --git a/internal/experimental.go b/internal/experimental.go index 7617be215895..c90cc51bdd2b 100644 --- a/internal/experimental.go +++ b/internal/experimental.go @@ -25,4 +25,8 @@ var ( // BufferPool is implemented by the grpc package and returns a server // option to configure a shared buffer pool for a grpc.Server. BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption + + // AcceptCompressors is implemented by the grpc package and returns + // a call option that restricts the grpc-accept-encoding header for a call. + AcceptCompressors any // func(...string) grpc.CallOption ) diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 65b4ab2439e2..19c9b1ebad0b 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -551,6 +551,9 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te hfLen += len(authData) + len(callAuthData) registeredCompressors := t.registeredCompressors + if callHdr.AcceptedCompressors != nil { + registeredCompressors = *callHdr.AcceptedCompressors + } if callHdr.PreviousAttempts > 0 { hfLen++ } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 5ff83a7d7d74..e1e466698e34 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -553,6 +553,12 @@ type CallHdr struct { // outbound message. SendCompress string + // AcceptedCompressors overrides the grpc-accept-encoding header for this + // call. When nil, the transport advertises the default set of registered + // compressors. A non-nil pointer overrides that value (including the empty + // string to advertise none). + AcceptedCompressors *string + // Creds specifies credentials.PerRPCCredentials for a call. Creds credentials.PerRPCCredentials diff --git a/rpc_util.go b/rpc_util.go index 6b04c9e87357..32dddc68ab42 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -33,6 +33,8 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding/proto" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" @@ -41,6 +43,10 @@ import ( "google.golang.org/grpc/status" ) +func init() { + internal.AcceptCompressors = AcceptCompressors +} + // Compressor defines the interface gRPC uses to compress a message. // // Deprecated: use package encoding. @@ -151,16 +157,32 @@ func (d *gzipDecompressor) Type() string { // callInfo contains all related configuration and information about an RPC. type callInfo struct { - compressorName string - failFast bool - maxReceiveMessageSize *int - maxSendMessageSize *int - creds credentials.PerRPCCredentials - contentSubtype string - codec baseCodec - maxRetryRPCBufferSize int - onFinish []func(err error) - authority string + compressorName string + failFast bool + maxReceiveMessageSize *int + maxSendMessageSize *int + creds credentials.PerRPCCredentials + contentSubtype string + codec baseCodec + maxRetryRPCBufferSize int + onFinish []func(err error) + authority string + acceptedResponseCompressors []string +} + +func acceptedCompressorAllows(allowed []string, name string) bool { + if allowed == nil { + return true + } + if name == "" || name == encoding.Identity { + return true + } + for _, a := range allowed { + if a == name { + return true + } + } + return false } func defaultCallInfo() *callInfo { @@ -170,6 +192,29 @@ func defaultCallInfo() *callInfo { } } +func newAcceptedCompressionConfig(names []string) ([]string, error) { + if len(names) == 0 { + return nil, nil + } + var allowed []string + seen := make(map[string]struct{}, len(names)) + for _, name := range names { + name = strings.TrimSpace(name) + if name == "" || name == encoding.Identity { + continue + } + if !grpcutil.IsCompressorNameRegistered(name) { + return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name) + } + if _, dup := seen[name]; dup { + continue + } + seen[name] = struct{}{} + allowed = append(allowed, name) + } + return allowed, nil +} + // CallOption configures a Call before it starts or extracts information from // a Call after it completes. type CallOption interface { @@ -471,6 +516,31 @@ func (o CompressorCallOption) before(c *callInfo) error { } func (o CompressorCallOption) after(*callInfo, *csAttempt) {} +// AcceptCompressors returns a CallOption that limits the compression algorithms +// advertised in the grpc-accept-encoding header for response messages. +// Compression algorithms not in the provided list will not be advertised, and +// responses compressed with non-listed algorithms will be rejected. +func AcceptCompressors(names ...string) CallOption { + cp := append([]string(nil), names...) + return AcceptCompressorsCallOption{names: cp} +} + +// AcceptCompressorsCallOption is a CallOption that limits response compression. +type AcceptCompressorsCallOption struct { + names []string +} + +func (o AcceptCompressorsCallOption) before(c *callInfo) error { + allowed, err := newAcceptedCompressionConfig(o.names) + if err != nil { + return err + } + c.acceptedResponseCompressors = allowed + return nil +} + +func (AcceptCompressorsCallOption) after(*callInfo, *csAttempt) {} + // CallContentSubtype returns a CallOption that will set the content-subtype // for a call. For example, if content-subtype is "json", the Content-Type over // the wire will be "application/grpc+json". The content-subtype is converted @@ -857,8 +927,7 @@ func (p *payloadInfo) free() { // the buffer is no longer needed. // TODO: Refactor this function to reduce the number of arguments. // See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists -func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, -) (out mem.BufferSlice, err error) { +func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) (out mem.BufferSlice, err error) { pf, compressed, err := p.recvMsg(maxReceiveMessageSize) if err != nil { return nil, err diff --git a/rpc_util_test.go b/rpc_util_test.go index a5c5cb8b17e2..79628d1be1d1 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -48,6 +48,118 @@ const ( decompressionErrorMsg = "invalid compression format" ) +type testCompressorForRegistry struct { + name string +} + +func (c *testCompressorForRegistry) Compress(w io.Writer) (io.WriteCloser, error) { + return &testWriteCloser{w}, nil +} + +func (c *testCompressorForRegistry) Decompress(r io.Reader) (io.Reader, error) { + return r, nil +} + +func (c *testCompressorForRegistry) Name() string { + return c.name +} + +type testWriteCloser struct { + io.Writer +} + +func (w *testWriteCloser) Close() error { + return nil +} + +func (s) TestNewAcceptedCompressionConfig(t *testing.T) { + // Register a test compressor for multi-compressor tests + testCompressor := &testCompressorForRegistry{name: "test-compressor"} + encoding.RegisterCompressor(testCompressor) + defer func() { + // Unregister the test compressor + encoding.RegisterCompressor(&testCompressorForRegistry{name: "test-compressor"}) + }() + + tests := []struct { + name string + input []string + wantAllowed []string + wantErr bool + }{ + { + name: "identity-only", + input: nil, + wantAllowed: nil, + }, + { + name: "single valid", + input: []string{"gzip"}, + wantAllowed: []string{"gzip"}, + }, + { + name: "dedupe and trim", + input: []string{" gzip ", "gzip"}, + wantAllowed: []string{"gzip"}, + }, + { + name: "ignores identity", + input: []string{"identity", "gzip"}, + wantAllowed: []string{"gzip"}, + }, + { + name: "explicit identity only", + input: []string{"identity"}, + wantAllowed: nil, + }, + { + name: "invalid compressor", + input: []string{"does-not-exist"}, + wantErr: true, + }, + { + name: "only whitespace", + input: []string{" ", "\t"}, + wantAllowed: nil, + }, + { + name: "multiple valid compressors", + input: []string{"gzip", "test-compressor"}, + wantAllowed: []string{"gzip", "test-compressor"}, + }, + { + name: "multiple with identity and whitespace", + input: []string{"gzip", "identity", " test-compressor ", " "}, + wantAllowed: []string{"gzip", "test-compressor"}, + }, + { + name: "empty string in list", + input: []string{"gzip", "", "test-compressor"}, + wantAllowed: []string{"gzip", "test-compressor"}, + }, + { + name: "mixed valid and invalid", + input: []string{"gzip", "invalid-comp"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + allowed, err := newAcceptedCompressionConfig(tt.input) + if (err != nil) != tt.wantErr { + t.Fatalf("newAcceptedCompressionConfig(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + if tt.wantErr { + return + } + if diff := cmp.Diff(tt.wantAllowed, allowed); diff != "" { + t.Fatalf("allowed diff (-want +got): %v", diff) + } + }) + } +} + type fullReader struct { data []byte } diff --git a/stream.go b/stream.go index ca87ff9776ef..e9e7adb54164 100644 --- a/stream.go +++ b/stream.go @@ -25,6 +25,7 @@ import ( "math" rand "math/rand/v2" "strconv" + "strings" "sync" "time" @@ -301,6 +302,10 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client DoneFunc: doneFunc, Authority: callInfo.authority, } + if allowed := callInfo.acceptedResponseCompressors; len(allowed) > 0 { + headerValue := strings.Join(allowed, ",") + callHdr.AcceptedCompressors = &headerValue + } // Set our outgoing compression according to the UseCompressor CallOption, if // set. In that case, also find the compressor from the encoding package. @@ -1134,6 +1139,10 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { a.decompressorV0 = nil a.decompressorV1 = encoding.GetCompressor(ct) } + // Validate that the compression method is acceptable for this call. + if !acceptedCompressorAllows(cs.callInfo.acceptedResponseCompressors, ct) { + return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct) + } } else { // No compression is used; disable our decompressor. a.decompressorV0 = nil @@ -1479,6 +1488,10 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { as.decompressorV0 = nil as.decompressorV1 = encoding.GetCompressor(ct) } + // Validate that the compression method is acceptable for this call. + if !acceptedCompressorAllows(as.callInfo.acceptedResponseCompressors, ct) { + return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct) + } } else { // No compression is used; disable our decompressor. as.decompressorV0 = nil diff --git a/test/compressor_test.go b/test/compressor_test.go index dbdc06222220..ebc42f2ede3b 100644 --- a/test/compressor_test.go +++ b/test/compressor_test.go @@ -30,6 +30,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/experimental" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -533,6 +535,57 @@ func (s) TestClientSupportedCompressors(t *testing.T) { } } +func (s) TestAcceptCompressorsCallOption(t *testing.T) { + tests := []struct { + name string + callOption grpc.CallOption + wantHeader string + }{ + { + name: "with AcceptCompressors", + callOption: experimental.AcceptCompressors("gzip"), + wantHeader: "gzip", + }, + { + name: "without AcceptCompressors uses default", + callOption: nil, + wantHeader: grpcutil.RegisteredCompressors(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + md, _ := metadata.FromIncomingContext(ctx) + header := md.Get("grpc-accept-encoding") + + if len(header) != 1 || header[0] != tt.wantHeader { + t.Errorf("unexpected grpc-accept-encoding header: got %v, want %v", header, tt.wantHeader) + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("failed to start server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + opts := []grpc.CallOption{} + if tt.callOption != nil { + opts = append(opts, tt.callOption) + } + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, opts...); err != nil { + t.Fatalf("EmptyCall failed: %v", err) + } + }) + } +} + func (s) TestCompressorRegister(t *testing.T) { for _, e := range listTestEnv() { testCompressorRegister(t, e)