From afc0cd6311f6037a56fdd9140db191410779c2be Mon Sep 17 00:00:00 2001 From: Dave Freilich Date: Tue, 11 Feb 2025 16:55:41 +0200 Subject: [PATCH] feat(router): make header matching rules case insensitive --- router-tests/headers_test.go | 23 +++++--- router/core/header_rule_engine.go | 12 ++-- router/core/header_rule_engine_test.go | 80 ++++++++++++++++++++++++-- 3 files changed, 97 insertions(+), 18 deletions(-) diff --git a/router-tests/headers_test.go b/router-tests/headers_test.go index 84b7e499cb..28aa9376b9 100644 --- a/router-tests/headers_test.go +++ b/router-tests/headers_test.go @@ -17,13 +17,15 @@ func TestForwardHeaders(t *testing.T) { const ( // Make sure you copy these to the struct tag in the subscription test - headerNameInGlobalRule = "foo" - headerNameInSubgraphRule = "barista" // This matches the regex in test1 subgraph forwarding rules - headerValue = "bar" - headerValue2 = "baz" - - subscriptionForGlobalRulePayload = `{"query": "subscription { headerValue(name:\"foo\", repeat:3) { value }}"}` - subscriptionForSubgraphRulePayload = `{"query": "subscription { headerValue(name:\"barista\", repeat:3) { value }}"}` + headerNameInGlobalRule = "foo" + headerNameInSubgraphRule = "barista" // This matches the regex in test1 subgraph forwarding rules + headerNameCaseInsensitiveRule = "bAz-CAse-Insensitive" // This matches the regex in test1 subgraph forwarding rules + headerValue = "bar" + headerValue2 = "baz" + + subscriptionForGlobalRulePayload = `{"query": "subscription { headerValue(name:\"foo\", repeat:3) { value }}"}` + subscriptionForSubgraphRulePayload = `{"query": "subscription { headerValue(name:\"barista\", repeat:3) { value }}"}` + subscriptionForSubgraphCaseRulePayload = `{"query": "subscription { headerValue(name:\"baz-case-insensitive\", repeat:3) { value }}"}` ) headerRules := config.HeaderRules{ @@ -42,6 +44,10 @@ func TestForwardHeaders(t *testing.T) { Operation: config.HeaderRuleOperationPropagate, Matching: "(?i)^bar.*", }, + { + Operation: config.HeaderRuleOperationPropagate, + Matching: "^baz-case-.*", + }, }, }, }, @@ -56,6 +62,7 @@ func TestForwardHeaders(t *testing.T) { }{ {headerNameInGlobalRule, "global rule"}, {headerNameInSubgraphRule, "subgraph rule"}, + {headerNameCaseInsensitiveRule, "subgraph rule"}, } testenv.Run(t, &testenv.Config{ ModifyEngineExecutionConfiguration: func(cfg *config.EngineExecutionConfiguration) { @@ -160,6 +167,7 @@ func TestForwardHeaders(t *testing.T) { testName string }{ {headerNameInSubgraphRule, subscriptionForSubgraphRulePayload, "subgraph rule"}, + {headerNameCaseInsensitiveRule, subscriptionForSubgraphCaseRulePayload, "subgraph case insensitive rule"}, } testenv.Run(t, &testenv.Config{ ModifyEngineExecutionConfiguration: func(cfg *config.EngineExecutionConfiguration) { @@ -227,6 +235,7 @@ func TestForwardHeaders(t *testing.T) { }{ {headerNameInGlobalRule, subscriptionForGlobalRulePayload, "global rule"}, {headerNameInSubgraphRule, subscriptionForSubgraphRulePayload, "subgraph rule"}, + {headerNameCaseInsensitiveRule, subscriptionForSubgraphCaseRulePayload, "subgraph case insensitive rule"}, } testenv.Run(t, &testenv.Config{ ModifyEngineExecutionConfiguration: func(cfg *config.EngineExecutionConfiguration) { diff --git a/router/core/header_rule_engine.go b/router/core/header_rule_engine.go index b3e7bcdc11..64d59ed07f 100644 --- a/router/core/header_rule_engine.go +++ b/router/core/header_rule_engine.go @@ -51,9 +51,10 @@ var ( "Sec-Websocket-Protocol", "Sec-Websocket-Version", } - cacheControlKey = "Cache-Control" - expiresKey = "Expires" - noCache = "no-cache" + cacheControlKey = "Cache-Control" + expiresKey = "Expires" + noCache = "no-cache" + caseInsensitiveRegexp = "(?i)" ) type responseHeaderPropagationKey struct{} @@ -204,7 +205,7 @@ func (hf *HeaderPropagation) processRule(rule config.HeaderRule, index int) erro case config.HeaderRuleOperationSet: case config.HeaderRuleOperationPropagate: if rule.GetMatching() != "" { - regex, err := regexp.Compile(rule.GetMatching()) + regex, err := regexp.Compile(caseInsensitiveRegexp + rule.GetMatching()) if err != nil { return fmt.Errorf("invalid regex '%s' for header rule %d: %w", rule.GetMatching(), index, err) } @@ -644,7 +645,8 @@ func PropagatedHeaders(rules []*config.RequestHeaderRule) (headerNames []string, headerNames = append(headerNames, rule.Name) case config.HeaderRuleOperationPropagate: if rule.Matching != "" { - re, err := regexp.Compile(rule.Matching) + // Header Names are case insensitive: https://www.w3.org/Protocols/rfc2616/rfc2616.html + re, err := regexp.Compile(caseInsensitiveRegexp + rule.Matching) if err != nil { return nil, nil, fmt.Errorf("error compiling regular expression %q in header rule %+v: %w", rule.Matching, rule, err) } diff --git a/router/core/header_rule_engine_test.go b/router/core/header_rule_engine_test.go index b024d34dd5..67c43e78d7 100644 --- a/router/core/header_rule_engine_test.go +++ b/router/core/header_rule_engine_test.go @@ -17,7 +17,6 @@ import ( func TestPropagateHeaderRule(t *testing.T) { t.Run("Should propagate with named header name / named", func(t *testing.T) { - ht, err := NewHeaderPropagation(&config.HeaderRules{ All: &config.GlobalHeaderRule{ Request: []*config.RequestHeaderRule{ @@ -25,6 +24,10 @@ func TestPropagateHeaderRule(t *testing.T) { Operation: "propagate", Named: "X-Test-1", }, + { + Operation: "propagate", + Named: "x-teST-3", + }, }, }, }) @@ -36,6 +39,7 @@ func TestPropagateHeaderRule(t *testing.T) { require.NoError(t, err) clientReq.Header.Set("X-Test-1", "test1") clientReq.Header.Set("X-Test-2", "test2") + clientReq.Header.Set("X-tesT-3", "test3") originReq, err := http.NewRequest("POST", "http://localhost", nil) assert.Nil(t, err) @@ -48,10 +52,10 @@ func TestPropagateHeaderRule(t *testing.T) { subgraphResolver: NewSubgraphResolver(nil), }) - assert.Len(t, updatedClientReq.Header, 1) + assert.Len(t, updatedClientReq.Header, 2) assert.Equal(t, "test1", updatedClientReq.Header.Get("X-Test-1")) assert.Empty(t, updatedClientReq.Header.Get("X-Test-2")) - + assert.Equal(t, "test3", updatedClientReq.Header.Get("X-Test-3")) }) t.Run("Should propagate based on matching regex / matching", func(t *testing.T) { @@ -92,6 +96,44 @@ func TestPropagateHeaderRule(t *testing.T) { assert.Empty(t, updatedClientReq.Header.Get("Y-Test")) }) + t.Run("Should propagate based on matching regex / matching in different case", func(t *testing.T) { + ht, err := NewHeaderPropagation(&config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ + { + Operation: "propagate", + Matching: "x-tEsT-.*", + }, + }, + }, + }) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + + clientReq, err := http.NewRequest("POST", "http://localhost", nil) + require.NoError(t, err) + clientReq.Header.Set("x-Test-1", "test1") + clientReq.Header.Set("X-tEsT-2", "test2") + clientReq.Header.Set("Y-Test", "test3") + + originReq, err := http.NewRequest("POST", "http://localhost", nil) + assert.Nil(t, err) + + updatedClientReq, _ := ht.OnOriginRequest(originReq, &requestContext{ + logger: zap.NewNop(), + responseWriter: rr, + request: clientReq, + operation: &operationContext{}, + subgraphResolver: NewSubgraphResolver(nil), + }) + + assert.Len(t, updatedClientReq.Header, 2) + assert.Equal(t, "test1", updatedClientReq.Header.Get("X-Test-1")) + assert.Equal(t, "test2", updatedClientReq.Header.Get("X-Test-2")) + assert.Empty(t, updatedClientReq.Header.Get("Y-Test")) + }) + t.Run("Should propagate with default value / named + default", func(t *testing.T) { ht, err := NewHeaderPropagation(&config.HeaderRules{ All: &config.GlobalHeaderRule{ @@ -204,6 +246,11 @@ func TestRenamePropagateHeaderRule(t *testing.T) { Named: "X-Test-1", Rename: "X-Test-Renamed", }, + { + Operation: "propagate", + Named: "X-teST-cASE-insensitive", + Rename: "X-Test-case-not-sensitive", + }, }, }, }) @@ -215,6 +262,7 @@ func TestRenamePropagateHeaderRule(t *testing.T) { require.NoError(t, err) clientReq.Header.Set("X-Test-1", "test1") clientReq.Header.Set("X-Test-2", "test2") + clientReq.Header.Set("X-Test-Case-Insensitive", "test3") originReq, err := http.NewRequest("POST", "http://localhost", nil) assert.Nil(t, err) @@ -227,10 +275,11 @@ func TestRenamePropagateHeaderRule(t *testing.T) { subgraphResolver: NewSubgraphResolver(nil), }) - assert.Len(t, updatedClientReq.Header, 1) + assert.Len(t, updatedClientReq.Header, 2) assert.Equal(t, "test1", updatedClientReq.Header.Get("X-Test-Renamed")) assert.Empty(t, updatedClientReq.Header.Get("X-Test-1")) assert.Empty(t, updatedClientReq.Header.Get("X-Test-2")) + assert.Equal(t, "test3", updatedClientReq.Header.Get("X-Test-Case-Not-Sensitive")) }) t.Run("Rename based on matching regex pattern / matching", func(t *testing.T) { @@ -243,6 +292,11 @@ func TestRenamePropagateHeaderRule(t *testing.T) { Matching: "(?i)X-Test-.*", Rename: "X-Test-Renamed-1", }, + { + Operation: "propagate", + Matching: "x-testcase-in.*", + Rename: "X-Test-Renamed-Case", + }, { Operation: "propagate", Matching: "(?i)X-Test-Default-.*", @@ -260,6 +314,7 @@ func TestRenamePropagateHeaderRule(t *testing.T) { require.NoError(t, err) clientReq.Header.Set("X-Test-1", "test1") clientReq.Header.Set("X-Test-Default-2", "") + clientReq.Header.Set("x-TESTCASE-INSENSITIVE", "test3") originReq, err := http.NewRequest("POST", "http://localhost", nil) assert.Nil(t, err) @@ -272,9 +327,10 @@ func TestRenamePropagateHeaderRule(t *testing.T) { subgraphResolver: NewSubgraphResolver(nil), }) - assert.Len(t, updatedClientReq.Header, 2) + assert.Len(t, updatedClientReq.Header, 3) assert.Equal(t, "test1", updatedClientReq.Header.Get("X-Test-Renamed-1")) assert.Equal(t, "default", updatedClientReq.Header.Get("X-Test-Renamed-Default-2")) + assert.Equal(t, "test3", updatedClientReq.Header.Get("X-Test-Renamed-Case")) assert.Empty(t, updatedClientReq.Header.Get("X-Test-1")) assert.Empty(t, updatedClientReq.Header.Get("X-Test-2")) }) @@ -385,6 +441,10 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { Operation: "propagate", Named: "X-Test-Subgraph", }, + { + Operation: "propagate", + Named: "X-test-suBGraph-case", + }, }, }, }, @@ -396,6 +456,7 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { clientReq, err := http.NewRequest("POST", "http://localhost", nil) require.NoError(t, err) clientReq.Header.Set("X-Test-Subgraph", "Test-Value") + clientReq.Header.Set("X-Test-Subgraph-Case", "Test-Value1") sg1Url, _ := url.Parse("http://subgraph-1.local") @@ -420,8 +481,9 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { assert.Nil(t, err) updatedClientReq1, _ := ht.OnOriginRequest(originReq1, ctx) - assert.Len(t, updatedClientReq1.Header, 1) + assert.Len(t, updatedClientReq1.Header, 2) assert.Equal(t, "Test-Value", updatedClientReq1.Header.Get("X-Test-Subgraph")) + assert.Equal(t, "Test-Value1", updatedClientReq1.Header.Get("X-Test-Subgraph-case")) assert.Empty(t, updatedClientReq1.Header.Get("Test-Value")) }) @@ -434,6 +496,10 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { Operation: "propagate", Matching: "(?i)X-Test-.*", }, + { + Operation: "propagate", + Matching: "X-TestCASE-.*", + }, }, }, }, @@ -445,6 +511,7 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { clientReq, err := http.NewRequest("POST", "http://localhost", nil) require.NoError(t, err) clientReq.Header.Set("X-Test-Subgraph", "Test-Value") + clientReq.Header.Set("X-TestCase-Subgraph", "Test-Value") sg1Url, _ := url.Parse("http://subgraph-1.local") @@ -470,6 +537,7 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { updatedClientReq1, _ := ht.OnOriginRequest(originReq1, ctx) assert.Equal(t, "Test-Value", updatedClientReq1.Header.Get("X-Test-Subgraph")) + assert.Equal(t, "Test-Value", updatedClientReq1.Header.Get("X-TestCase-Subgraph")) assert.Empty(t, updatedClientReq1.Header.Get("Test-Value")) })