diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ed3bf2dd..c86077c15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -99,6 +99,8 @@ * [FEATURE] Add methods `Increment`, `FlushAll`, `CompareAndSwap`, `Touch` to `cache.MemcachedClient` #477 * [FEATURE] Add `concurrency.ForEachJobMergeResults()` utility function. #486 * [FEATURE] Add `ring.DoMultiUntilQuorumWithoutSuccessfulContextCancellation()`. #495 +* [FEATURE] Add `middleware.ClusterUnaryClientInterceptor`, a `grpc.UnaryClientInterceptor` that propagates a cluster info to the outgoing gRPC metadata. #640 +* [FEATURE] Add `middleware.ClusterUnaryServerInterceptor`, a `grpc.UnaryServerInterceptor` that checks if the incoming gRPC metadata contains a correct cluster info, and returns an error if it is not the case. #640 * [FEATURE] Add `ring.GetWithOptions()` method to support additional features at a per-call level. #632 * [ENHANCEMENT] Add option to hide token information in ring status page #633 * [ENHANCEMENT] Display token information in partition ring status page #631 diff --git a/middleware/grpc_cluster.go b/middleware/grpc_cluster.go new file mode 100644 index 000000000..a705e7550 --- /dev/null +++ b/middleware/grpc_cluster.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "context" + "fmt" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" + "github.com/gogo/status" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +const ( + MetadataClusterKey = "x-cluster" +) + +// ClusterUnaryClientInterceptor propagates the given cluster info to gRPC metadata. +func ClusterUnaryClientInterceptor(cluster string) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if cluster != "" { + ctx = metadata.AppendToOutgoingContext(ctx, MetadataClusterKey, cluster) + } + + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +// ClusterUnaryServerInterceptor checks if the incoming gRPC metadata contains any cluster information and if so, +// checks if the latter corresponds to the given cluster. If it is the case, the request is further propagated. +// Otherwise, an error is returned. +func ClusterUnaryServerInterceptor(cluster string, logger log.Logger) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + reqCluster := getClusterFromIncomingContext(ctx, logger) + if cluster != reqCluster { + msg := fmt.Sprintf("request intended for cluster %q - this is cluster %q", reqCluster, cluster) + level.Warn(logger).Log("msg", msg) + return nil, status.Error(codes.FailedPrecondition, msg) + } + return handler(ctx, req) + } +} + +func getClusterFromIncomingContext(ctx context.Context, logger log.Logger) string { + clusterIDs := metadata.ValueFromIncomingContext(ctx, MetadataClusterKey) + if len(clusterIDs) != 1 { + msg := fmt.Sprintf("gRPC metadata should contain exactly 1 value for key \"%s\", but the current set of values is %v. Returning an empty string.", MetadataClusterKey, clusterIDs) + level.Warn(logger).Log("msg", msg) + return "" + } + return clusterIDs[0] +} diff --git a/middleware/grpc_cluster_test.go b/middleware/grpc_cluster_test.go new file mode 100644 index 000000000..436bc2698 --- /dev/null +++ b/middleware/grpc_cluster_test.go @@ -0,0 +1,139 @@ +package middleware + +import ( + "context" + "net/http" + "os" + "strings" + "testing" + + "github.com/go-kit/log" + "github.com/gogo/status" + + "github.com/grafana/dskit/httpgrpc" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +func TestClusterUnaryClientInterceptor(t *testing.T) { + testCases := map[string]struct { + cluster string + expectedClusterFromContext string + }{ + "no cluster info sets no cluster info in context": { + cluster: "", + expectedClusterFromContext: "", + }, + "if cluster info is set, it should be propagated to invoker": { + cluster: "cluster", + expectedClusterFromContext: "cluster", + }, + } + verify := func(ctx context.Context, expectedCluster string) { + md, ok := metadata.FromOutgoingContext(ctx) + require.True(t, ok) + clusterIDs, ok := md[MetadataClusterKey] + require.True(t, ok) + require.Len(t, clusterIDs, 1) + require.Equal(t, expectedCluster, clusterIDs[0]) + } + for testName, testCase := range testCases { + t.Run(testName, func(t *testing.T) { + interceptor := ClusterUnaryClientInterceptor(testCase.cluster) + invoker := func(ctx context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error { + if testCase.expectedClusterFromContext != "" { + verify(ctx, testCase.expectedClusterFromContext) + } + return nil + } + + err := interceptor(context.Background(), "GET", createRequest(t), nil, nil, invoker) + require.NoError(t, err) + }) + } +} + +func TestClusterUnaryServerInterceptor(t *testing.T) { + testCases := map[string]struct { + incomingContext context.Context + requestCluster string + serverCluster string + expectedError error + }{ + "equal request and server clusters give no error": { + incomingContext: createIncomingContext(true, "cluster"), + requestCluster: "cluster", + serverCluster: "cluster", + expectedError: nil, + }, + "different request and server clusters give rise to an error": { + incomingContext: createIncomingContext(true, "wrong-cluster"), + requestCluster: "wrong-cluster", + serverCluster: "cluster", + expectedError: status.Error(codes.FailedPrecondition, "request intended for cluster \"wrong-cluster\" - this is cluster \"cluster\""), + }, + "empty request cluster and non-empty server cluster give rise to an error": { + incomingContext: createIncomingContext(true, ""), + requestCluster: "", + serverCluster: "cluster", + expectedError: status.Error(codes.FailedPrecondition, "request intended for cluster \"\" - this is cluster \"cluster\""), + }, + "no request cluster and non-empty server cluster give rise to an error": { + incomingContext: createIncomingContext(false, ""), + requestCluster: "", + serverCluster: "cluster", + expectedError: status.Error(codes.FailedPrecondition, "request intended for cluster \"\" - this is cluster \"cluster\""), + }, + "empty request cluster and empty server cluster give no error": { + incomingContext: createIncomingContext(true, ""), + requestCluster: "", + serverCluster: "", + expectedError: nil, + }, + "no request cluster and empty server cluster give no error": { + incomingContext: createIncomingContext(false, ""), + requestCluster: "", + serverCluster: "", + expectedError: nil, + }, + } + for testName, testCase := range testCases { + t.Run(testName, func(t *testing.T) { + logger := log.NewLogfmtLogger(os.Stdin) + interceptor := ClusterUnaryServerInterceptor(testCase.serverCluster, logger) + handler := func(context.Context, interface{}) (interface{}, error) { + return nil, nil + } + info := &grpc.UnaryServerInfo{FullMethod: "/Test/Me"} + req := createRequest(t) + _, err := interceptor(testCase.incomingContext, req, info, handler) + if testCase.expectedError == nil { + require.NoError(t, err) + } else { + require.Equal(t, testCase.expectedError, err) + } + }) + } +} + +func createIncomingContext(containsRequestCluster bool, requestCluster string) context.Context { + ctx := context.Background() + if !containsRequestCluster { + return ctx + } + md := map[string][]string{ + MetadataClusterKey: {requestCluster}, + } + return metadata.NewIncomingContext(ctx, md) +} + +func createRequest(t *testing.T) *httpgrpc.HTTPRequest { + r, err := http.NewRequest("POST", "/i/am/calling/you", strings.NewReader("some body")) + require.NoError(t, err) + req, err := httpgrpc.FromHTTPRequest(r) + require.NoError(t, err) + return req +}