Skip to content

Commit

Permalink
Merge pull request #9515 from vegaprotocol/8979-XFF
Browse files Browse the repository at this point in the history
feat: add trusted proxy support to XFF check
  • Loading branch information
EVODelavega authored Sep 25, 2023
2 parents 048aa1e + 190168e commit 70d52d7
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
- [9489](https://github.com/vegaprotocol/vega/issues/9489) - A referrer cannot join another team.
- [9074](https://github.com/vegaprotocol/vega/issues/9074) - Fix error response for `CSV` exports.
- [9512](https://github.com/vegaprotocol/vega/issues/9512) - Allow hysteresis period to be set to 0.
- [8979](https://github.com/vegaprotocol/vega/issues/8979) - Add trusted proxy config and verification for `XFF` header.

## 0.72.1

Expand Down
2 changes: 2 additions & 0 deletions datanode/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (

// API Errors and descriptions.
var (
// ErrNoTrustedProxy indactes a forwarded request that did not pass through a trusted proxy.
ErrNoTrustedProxy = errors.New("forwarded requests need to pass through a trusted proxy")
// ErrChannelClosed signals that the channel streaming data is closed.
ErrChannelClosed = errors.New("channel closed")
// ErrNotAValidVegaID signals an invalid id.
Expand Down
65 changes: 46 additions & 19 deletions datanode/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ import (
"github.com/fullstorydev/grpcui/standalone"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
)

// EventService ...
Expand Down Expand Up @@ -136,6 +138,8 @@ type GRPCServer struct {
// used in order to gracefully close streams
ctx context.Context
cfunc context.CancelFunc

trustedProxies map[string]struct{}
}

// NewGRPCServer create a new instance of the GPRC api for the vega node.
Expand Down Expand Up @@ -192,6 +196,10 @@ func NewGRPCServer(
log = log.Named(namedLogger)
log.SetLevel(config.Level.Get())
ctx, cfunc := context.WithCancel(context.Background())
tps := make(map[string]struct{}, len(config.RateLimit.TrustedProxies))
for _, ip := range config.RateLimit.TrustedProxies {
tps[ip] = struct{}{}
}

return &GRPCServer{
log: log,
Expand Down Expand Up @@ -247,8 +255,9 @@ func NewGRPCServer(
eventService: eventService,
Config: config,
},
ctx: ctx,
cfunc: cfunc,
ctx: ctx,
cfunc: cfunc,
trustedProxies: tps,
}
}

Expand All @@ -262,24 +271,39 @@ func (g *GRPCServer) ReloadConf(cfg Config) {
)
g.log.SetLevel(cfg.Level.Get())
}
tps := make(map[string]struct{}, len(cfg.RateLimit.TrustedProxies))
for _, ip := range cfg.RateLimit.TrustedProxies {
tps[ip] = struct{}{}
}

// TODO(): not updating the actual server for now, may need to look at this later
// e.g restart the http server on another port or whatever
g.Config = cfg
g.trustedProxies = tps
}

func ipFromContext(ctx context.Context, method string, log *logging.Logger) string {
func (g *GRPCServer) ipFromContext(ctx context.Context, method string, log *logging.Logger) (string, error) {
// first check if the request is forwarded from our restproxy
// get the metadata
md, ok := metadata.FromIncomingContext(ctx)
if ok {
forwardedFor, ok := md["x-forwarded-for"]
if ok && len(forwardedFor) > 0 {
log.Debug("grpc request x-forwarded-for",
logging.String("method", method),
logging.String("remote-ip-addr", forwardedFor[0]),
)
return forwardedFor[0]
tps := g.trustedProxies
if len(tps) > 0 {
// get the metadata
if md, ok := metadata.FromIncomingContext(ctx); ok {
// if trusted proxies are specified, the XFF header will be used to rate-limit the IP
// for which the request is forwarded. If no proxies are specified, or no trusted proxies
// are found, the peer is rate limited.
if forwardedFor, ok := md["x-forwarded-for"]; ok && len(forwardedFor) >= 2 {
// check the proxies for trusted
for _, pip := range forwardedFor[1:] {
// trusted proxy found, return
if _, ok := tps[pip]; ok {
log.Debug("grpc request x-forwarded-for",
logging.String("method", method),
logging.String("remote-ip-addr", forwardedFor[0]),
)
return forwardedFor[0], nil
}
}
}
}
}

Expand All @@ -289,20 +313,23 @@ func ipFromContext(ctx context.Context, method string, log *logging.Logger) stri
log.Debug("grpc peer client request",
logging.String("method", method),
logging.String("remote-ip-addr", p.Addr.String()))
return p.Addr.String()
return p.Addr.String(), nil
}

return ""
return "", nil
}

func remoteAddrInterceptor(log *logging.Logger) grpc.UnaryServerInterceptor {
func (g *GRPCServer) remoteAddrInterceptor(log *logging.Logger) grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (resp interface{}, err error) {
ip := ipFromContext(ctx, info.FullMethod, log)
ip, err := g.ipFromContext(ctx, info.FullMethod, log)
if err != nil {
return nil, status.Error(codes.PermissionDenied, err.Error())
}

ctx = contextutil.WithRemoteIPAddr(ctx, ip)

Expand Down Expand Up @@ -385,12 +412,12 @@ func (g *GRPCServer) Start(ctx context.Context, lis net.Listener) error {

rateLimit := ratelimit.NewFromConfig(&g.RateLimit, g.log)
intercept := grpc.ChainUnaryInterceptor(
remoteAddrInterceptor(g.log),
g.remoteAddrInterceptor(g.log),
headersInterceptor(g.blockService.GetLastBlock, g.log),
rateLimit.GRPCInterceptor,
)

streamIntercept := grpc.StreamInterceptor(subscriptionRateLimiter.WithGrpcInterceptor(ipFromContext))
streamIntercept := grpc.StreamInterceptor(subscriptionRateLimiter.WithGrpcInterceptor(g.ipFromContext))

g.srv = grpc.NewServer(intercept, streamIntercept)

Expand Down
7 changes: 5 additions & 2 deletions datanode/gateway/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,14 @@ func (s *SubscriptionRateLimiter) WithSubscriptionRateLimiter(next http.Handler)
})
}

type ipGetter func(ctx context.Context, method string, log *logging.Logger) string
type ipGetter func(ctx context.Context, method string, log *logging.Logger) (string, error)

func (s *SubscriptionRateLimiter) WithGrpcInterceptor(ipGetterFunc ipGetter) grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
addr := ipGetterFunc(ss.Context(), info.FullMethod, s.log)
addr, err := ipGetterFunc(ss.Context(), info.FullMethod, s.log)
if err != nil {
return status.Error(codes.PermissionDenied, err.Error())
}
if addr == "" {
// If we don't have an IP we can't rate limit
return handler(srv, ss)
Expand Down
1 change: 1 addition & 0 deletions datanode/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ func newTestConfig(postgresRuntimePath string) (*config.Config, error) {
}

cfg := config.NewDefaultConfig()
// cfg.API.RateLimit.TrustedProxies = []string{}
cfg.Broker.UseEventFile = true
cfg.Broker.PanicOnError = true
cfg.Broker.FileEventSourceConfig.Directory = filepath.Join(cwd, eventsDir)
Expand Down
22 changes: 12 additions & 10 deletions datanode/ratelimit/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@ import (
)

type Config struct {
Enabled bool `description:"Enable rate limit of API requests per IP address. Based on a 'token bucket' algorithm" long:"enabled"`
Rate float64 `description:"Refill rate of token bucket; maximum average request rate" long:"rate"`
Burst int `description:"Size of token bucket; maximum number of requests in short time window" long:"burst"`
TTL encoding.Duration `description:"Time after which inactive token buckets are reset" long:"ttl"`
BanFor encoding.Duration `description:"If IP continues to make requests after passing rate limit threshold, ban for this duration. Setting to 0 seconds disables banning." long:"banfor"`
Enabled bool `description:"Enable rate limit of API requests per IP address. Based on a 'token bucket' algorithm" long:"enabled"`
TrustedProxies []string `description:"specify a trusted proxy for forwarded requests" long:"trusted-proxy"`
Rate float64 `description:"Refill rate of token bucket; maximum average request rate" long:"rate"`
Burst int `description:"Size of token bucket; maximum number of requests in short time window" long:"burst"`
TTL encoding.Duration `description:"Time after which inactive token buckets are reset" long:"ttl"`
BanFor encoding.Duration `description:"If IP continues to make requests after passing rate limit threshold, ban for this duration. Setting to 0 seconds disables banning." long:"banfor"`
}

func NewDefaultConfig() Config {
return Config{
Enabled: true,
Rate: 20,
Burst: 100,
TTL: encoding.Duration{Duration: time.Hour},
BanFor: encoding.Duration{Duration: 10 * time.Minute},
Enabled: true,
TrustedProxies: []string{"127.0.0.1"},
Rate: 20,
Burst: 100,
TTL: encoding.Duration{Duration: time.Hour},
BanFor: encoding.Duration{Duration: 10 * time.Minute},
}
}

0 comments on commit 70d52d7

Please sign in to comment.