From e4d7531fe9643c601e1d22195a537ed5e7ffdfdd Mon Sep 17 00:00:00 2001 From: Dave Freilich Date: Fri, 7 Feb 2025 11:22:50 +0200 Subject: [PATCH] fix: add functions, and query test --- router-tests/cache_warmup_test.go | 63 +++++++++++++++++++++++++++++++ router/core/graph_server.go | 3 +- router/core/graphql_prehandler.go | 22 +++++++++-- 3 files changed, 84 insertions(+), 4 deletions(-) diff --git a/router-tests/cache_warmup_test.go b/router-tests/cache_warmup_test.go index 4fde9cf4db..9637dc91cf 100644 --- a/router-tests/cache_warmup_test.go +++ b/router-tests/cache_warmup_test.go @@ -392,6 +392,69 @@ func TestCacheWarmup(t *testing.T) { }) }) + t.Run("cache warmup persisted operation with multiple operations works with safelist enabled", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ + Enabled: true, + Source: config.CacheWarmupSource{ + Filesystem: &config.CacheWarmupFileSystemSource{ + Path: "testenv/testdata/cache_warmup/json_po_multi_operations", + }, + }, + }), + }, + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.Safelist = config.EnableOperationConfiguration{Enabled: true} + }, + AssertCacheMetrics: &testenv.CacheMetricsAssertions{ + BaseGraphAssertions: testenv.CacheMetricsAssertion{ + QueryNormalizationMisses: 1, // 1x miss during first safelist call + QueryNormalizationHits: 1, // 1x hit during second safelist call + PersistedQueryNormalizationHits: 2, // 1x hit after warmup, when called with operation name. No hit from second request because of missing operation name, it recomputes it + PersistedQueryNormalizationMisses: 5, // 1x miss during warmup, 1 miss for first operation trying without operation name, 1 miss for second operation trying without operation name, 2x miss during safelist because went to normal query normalization cache + ValidationHits: 4, + ValidationMisses: 1, + PlanHits: 4, + PlanMisses: 1, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + header := make(http.Header) + header.Add("graphql-client-name", "my-client") + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + OperationName: []byte(`"A"`), + Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "724399f210ef3f16e6e5427a70bb9609ecea7297e99c3e9241d5912d04eabe60"}}`), + Header: header, + }) + require.NoError(t, err) + require.Equal(t, `{"data":{"a":{"id":1,"details":{"pets":null}}}}`, res.Body) + + res2, err2 := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "724399f210ef3f16e6e5427a70bb9609ecea7297e99c3e9241d5912d04eabe60"}}`), + Header: header, + }) + require.NoError(t, err2) + require.Equal(t, `{"data":{"a":{"id":1,"details":{"pets":null}}}}`, res2.Body) + + res3, err3 := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Header: header, + Query: "query A {\n a: employee(id: 1) {\n id\n details {\n pets {\n name\n }\n }\n }\n}\n\nquery B ($id: Int!) {\n b: employee(id: $id) {\n id\n details {\n pets {\n name\n }\n }\n }\n}", + }) + require.NoError(t, err3) + require.Equal(t, `{"data":{"a":{"id":1,"details":{"pets":null}}}}`, res3.Body) + + res4, err4 := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + OperationName: []byte(`"A"`), + Header: header, + Query: "query A {\n a: employee(id: 1) {\n id\n details {\n pets {\n name\n }\n }\n }\n}\n\nquery B ($id: Int!) {\n b: employee(id: $id) {\n id\n details {\n pets {\n name\n }\n }\n }\n}", + }) + require.NoError(t, err4) + require.Equal(t, `{"data":{"a":{"id":1,"details":{"pets":null}}}}`, res4.Body) + }) + }) + t.Run("cache warmup workers throttle", func(t *testing.T) { t.Parallel() logger, err := zap.NewDevelopment() diff --git a/router/core/graph_server.go b/router/core/graph_server.go index ace67e0266..4d92739db6 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -454,7 +454,8 @@ func (s *graphMux) buildOperationCaches(srv *graphServer) (computeSha256 bool, e } } } else if srv.securityConfiguration.Safelist.Enabled || srv.securityConfiguration.LogUnknownOperations.Enabled { - // In this case, we'll want to compute the sha256 for every operation + // In these case, we'll want to compute the sha256 for every operation, in order to check that the operation + // is present in the Persisted Operation cache computeSha256 = true } diff --git a/router/core/graphql_prehandler.go b/router/core/graphql_prehandler.go index 193d0f2364..7a10755a8f 100644 --- a/router/core/graphql_prehandler.go +++ b/router/core/graphql_prehandler.go @@ -389,6 +389,23 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { }) } +func (h *PreHandler) shouldComputeHash(operationKit *OperationKit) bool { + if h.computeOperationSha256 { + return true + } + hasPersistedHash := operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery != nil && operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash != "" + // If it already has a persisted hash attached to the request, then there is no need for us to compute it anew + // Otherwise, we only want to compute the hash (an expensive operation) if we're safelisting or logging unknown persisted operations + return !hasPersistedHash && (h.operationBlocker.SafelistEnabled || h.operationBlocker.LogUnknownOperationsEnabled) +} + +// shouldFetchPersistedOperation determines if we should fetch a persisted operation. The most intuitive case is if the +// operation is a persisted operation. However, we also want to fetch persisted operations if we're enabling safelisting +// and if we're logging unknown operations. This is because we want to check if the operation is already persisted in the cache +func (h *PreHandler) shouldFetchPersistedOperation(operationKit *OperationKit) bool { + return operationKit.parsedOperation.IsPersistedOperation || h.operationBlocker.SafelistEnabled || h.operationBlocker.LogUnknownOperationsEnabled +} + func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson.Parser, httpOperation *httpOperation) error { operationKit, err := h.operationProcessor.NewKit() if err != nil { @@ -431,8 +448,7 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson } // Compute the operation sha256 hash as soon as possible for observability reasons - hasPersistedHash := operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery != nil && operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash != "" - if h.computeOperationSha256 || !hasPersistedHash && (h.operationBlocker.SafelistEnabled || h.operationBlocker.LogUnknownOperationsEnabled) { + if h.shouldComputeHash(operationKit) { if err := operationKit.ComputeOperationSha256(); err != nil { return &httpGraphqlError{ message: fmt.Sprintf("error hashing operation: %s", err), @@ -463,7 +479,7 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson isApq bool ) - if operationKit.parsedOperation.IsPersistedOperation || h.operationBlocker.SafelistEnabled || h.operationBlocker.LogUnknownOperationsEnabled { + if h.shouldFetchPersistedOperation(operationKit) { ctx, span := h.tracer.Start(req.Context(), "Load Persisted Operation", trace.WithSpanKind(trace.SpanKindClient), trace.WithAttributes(requestContext.telemetry.traceAttrs...),