From 9423458b476545e417d7606f6371ce621b725674 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 8 Jan 2025 15:10:37 +0100 Subject: [PATCH] feat: add extensions.code to rate limiting error (#1027) --- v2/pkg/engine/resolve/context.go | 7 +++++++ v2/pkg/engine/resolve/loader.go | 27 +++++++++++++++++-------- v2/pkg/engine/resolve/ratelimit_test.go | 16 +++++++++++++++ 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 7a191e1d93..68871cb876 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -79,6 +79,13 @@ type RateLimitOptions struct { Period time.Duration RateLimitKey string RejectExceedingRequests bool + + ErrorExtensionCode RateLimitErrorExtensionCode +} + +type RateLimitErrorExtensionCode struct { + Enabled bool + Code string } type RateLimitDeny struct { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 90a866eb9f..38bc8c7b1c 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1034,35 +1034,46 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result) error { l.ctx.appendSubgraphError(goerrors.Join(res.err, NewRateLimitError(res.ds.Name, fetchItem.ResponsePath, res.rateLimitRejectedReason))) pathPart := l.renderAtPathErrorPart(fetchItem.ResponsePath) + var ( + err error + errorObject *astjson.Value + ) if res.ds.Name == "" { if res.rateLimitRejectedReason == "" { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) + errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) if err != nil { return err } - astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) + errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } - astjson.AppendToArray(l.resolvable.errors, errorObject) } } else { if res.rateLimitRejectedReason == "" { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) + errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) if err != nil { return err } - astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) + errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } - astjson.AppendToArray(l.resolvable.errors, errorObject) } } + if l.ctx.RateLimitOptions.ErrorExtensionCode.Enabled { + extension, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) + if err != nil { + return err + } + errorObject, _, err = astjson.MergeValuesWithPath(errorObject, extension, "extensions") + if err != nil { + return err + } + } + astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } diff --git a/v2/pkg/engine/resolve/ratelimit_test.go b/v2/pkg/engine/resolve/ratelimit_test.go index 47b8202dec..b7a6b00d41 100644 --- a/v2/pkg/engine/resolve/ratelimit_test.go +++ b/v2/pkg/engine/resolve/ratelimit_test.go @@ -95,6 +95,22 @@ func TestRateLimiter(t *testing.T) { assert.Equal(t, int64(1), limiter.rateLimitPreFetchCalls.Load()) } })) + t.Run("deny with code", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { + + limiter := &testRateLimiter{ + allowFn: func(ctx *Context, info *FetchInfo, input json.RawMessage) (*RateLimitDeny, error) { + return &RateLimitDeny{Reason: "rate limit exceeded"}, nil + }, + } + + res := generateTestFederationGraphQLResponse(t, ctrl) + + return res, &Context{ctx: context.Background(), Variables: nil, rateLimiter: limiter, RateLimitOptions: RateLimitOptions{Enable: true, ErrorExtensionCode: RateLimitErrorExtensionCode{Enabled: true, Code: "RATE_LIMIT_EXCEEDED"}}}, + `{"errors":[{"message":"Rate limit exceeded for Subgraph 'users' at Path 'query', Reason: rate limit exceeded.","extensions":{"code":"RATE_LIMIT_EXCEEDED"}},{"message":"Failed to fetch from Subgraph 'reviews' at Path 'query.me'.","extensions":{"errors":[{"message":"Failed to render Fetch Input","path":["me"]}]}},{"message":"Failed to fetch from Subgraph 'products' at Path 'query.me.reviews.@.product'.","extensions":{"errors":[{"message":"Failed to render Fetch Input","path":["me","reviews","@","product"]}]}}],"data":{"me":null}}`, + func(t *testing.T) { + assert.Equal(t, int64(1), limiter.rateLimitPreFetchCalls.Load()) + } + })) t.Run("err all", testFnWithError(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { limiter := &testRateLimiter{