Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding grpc_cluster client and server middleware #640

Merged
merged 8 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@
* [FEATURE] Add methods `Increment`, `FlushAll`, `CompareAndSwap`, `Touch` to `cache.MemcachedClient` #477
* [FEATURE] Add `concurrency.ForEachJobMergeResults()` utility function. #486
* [FEATURE] Add `ring.DoMultiUntilQuorumWithoutSuccessfulContextCancellation()`. #495
* [FEATURE] Add `middleware.ClusterUnaryClientInterceptor`, a `grpc.UnaryClientInterceptor` that propagates a cluster info to the outgoing gRPC metadata. #640
* [FEATURE] Add `middleware.ClusterUnaryServerInterceptor`, a `grpc.UnaryServerInterceptor` that checks if the incoming gRPC metadata contains a correct cluster info, and returns an error if it is not the case. #640
* [FEATURE] Add `ring.GetWithOptions()` method to support additional features at a per-call level. #632
* [ENHANCEMENT] Add option to hide token information in ring status page #633
* [ENHANCEMENT] Display token information in partition ring status page #631
Expand Down
53 changes: 53 additions & 0 deletions middleware/grpc_cluster.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package middleware

import (
"context"
"fmt"

"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/gogo/status"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
)

const (
MetadataClusterKey = "x-cluster"
)

// ClusterUnaryClientInterceptor propagates the given cluster info to gRPC metadata.
func ClusterUnaryClientInterceptor(cluster string) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if cluster != "" {
ctx = metadata.AppendToOutgoingContext(ctx, MetadataClusterKey, cluster)
}

return invoker(ctx, method, req, reply, cc, opts...)
}
}

// ClusterUnaryServerInterceptor checks if the incoming gRPC metadata contains any cluster information and if so,
// checks if the latter corresponds to the given cluster. If it is the case, the request is further propagated.
// Otherwise, an error is returned.
func ClusterUnaryServerInterceptor(cluster string, logger log.Logger) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
reqCluster := getClusterFromIncomingContext(ctx, logger)
if cluster != reqCluster {
msg := fmt.Sprintf("request intended for cluster %q - this is cluster %q", reqCluster, cluster)
level.Warn(logger).Log("msg", msg)
return nil, status.Error(codes.FailedPrecondition, msg)
}
return handler(ctx, req)
}
}

func getClusterFromIncomingContext(ctx context.Context, logger log.Logger) string {
clusterIDs := metadata.ValueFromIncomingContext(ctx, MetadataClusterKey)
if len(clusterIDs) != 1 {
msg := fmt.Sprintf("gRPC metadata should contain exactly 1 value for key \"%s\", but the current set of values is %v. Returning an empty string.", MetadataClusterKey, clusterIDs)
level.Warn(logger).Log("msg", msg)
return ""
}
return clusterIDs[0]
}
139 changes: 139 additions & 0 deletions middleware/grpc_cluster_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package middleware

import (
"context"
"net/http"
"os"
"strings"
"testing"

"github.com/go-kit/log"
"github.com/gogo/status"

"github.com/grafana/dskit/httpgrpc"

"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
)

func TestClusterUnaryClientInterceptor(t *testing.T) {
testCases := map[string]struct {
cluster string
expectedClusterFromContext string
}{
"no cluster info sets no cluster info in context": {
cluster: "",
expectedClusterFromContext: "",
},
"if cluster info is set, it should be propagated to invoker": {
cluster: "cluster",
expectedClusterFromContext: "cluster",
},
}
verify := func(ctx context.Context, expectedCluster string) {
md, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok)
clusterIDs, ok := md[MetadataClusterKey]
require.True(t, ok)
require.Len(t, clusterIDs, 1)
require.Equal(t, expectedCluster, clusterIDs[0])
}
for testName, testCase := range testCases {
t.Run(testName, func(t *testing.T) {
interceptor := ClusterUnaryClientInterceptor(testCase.cluster)
invoker := func(ctx context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error {
if testCase.expectedClusterFromContext != "" {
verify(ctx, testCase.expectedClusterFromContext)
}
duricanikolic marked this conversation as resolved.
Show resolved Hide resolved
return nil
}

err := interceptor(context.Background(), "GET", createRequest(t), nil, nil, invoker)
require.NoError(t, err)
})
}
}

func TestClusterUnaryServerInterceptor(t *testing.T) {
testCases := map[string]struct {
incomingContext context.Context
requestCluster string
serverCluster string
expectedError error
}{
"equal request and server clusters give no error": {
incomingContext: createIncomingContext(true, "cluster"),
requestCluster: "cluster",
serverCluster: "cluster",
expectedError: nil,
},
"different request and server clusters give rise to an error": {
incomingContext: createIncomingContext(true, "wrong-cluster"),
requestCluster: "wrong-cluster",
serverCluster: "cluster",
expectedError: status.Error(codes.FailedPrecondition, "request intended for cluster \"wrong-cluster\" - this is cluster \"cluster\""),
},
"empty request cluster and non-empty server cluster give rise to an error": {
incomingContext: createIncomingContext(true, ""),
requestCluster: "",
serverCluster: "cluster",
expectedError: status.Error(codes.FailedPrecondition, "request intended for cluster \"\" - this is cluster \"cluster\""),
},
"no request cluster and non-empty server cluster give rise to an error": {
incomingContext: createIncomingContext(false, ""),
requestCluster: "",
serverCluster: "cluster",
expectedError: status.Error(codes.FailedPrecondition, "request intended for cluster \"\" - this is cluster \"cluster\""),
},
"empty request cluster and empty server cluster give no error": {
incomingContext: createIncomingContext(true, ""),
requestCluster: "",
serverCluster: "",
expectedError: nil,
},
"no request cluster and empty server cluster give no error": {
incomingContext: createIncomingContext(false, ""),
requestCluster: "",
serverCluster: "",
expectedError: nil,
},
duricanikolic marked this conversation as resolved.
Show resolved Hide resolved
}
for testName, testCase := range testCases {
t.Run(testName, func(t *testing.T) {
logger := log.NewLogfmtLogger(os.Stdin)
interceptor := ClusterUnaryServerInterceptor(testCase.serverCluster, logger)
handler := func(context.Context, interface{}) (interface{}, error) {
return nil, nil
}
info := &grpc.UnaryServerInfo{FullMethod: "/Test/Me"}
req := createRequest(t)
_, err := interceptor(testCase.incomingContext, req, info, handler)
if testCase.expectedError == nil {
require.NoError(t, err)
} else {
require.Equal(t, testCase.expectedError, err)
}
})
}
}

func createIncomingContext(containsRequestCluster bool, requestCluster string) context.Context {
ctx := context.Background()
if !containsRequestCluster {
return ctx
}
md := map[string][]string{
MetadataClusterKey: {requestCluster},
}
return metadata.NewIncomingContext(ctx, md)
}

func createRequest(t *testing.T) *httpgrpc.HTTPRequest {
r, err := http.NewRequest("POST", "/i/am/calling/you", strings.NewReader("some body"))
require.NoError(t, err)
req, err := httpgrpc.FromHTTPRequest(r)
require.NoError(t, err)
return req
}