diff --git a/check/base.go b/check/base.go index 548dacdd..b6244c11 100644 --- a/check/base.go +++ b/check/base.go @@ -47,7 +47,7 @@ func (c *FTWCheck) SetExpectResponse(response string) { // SetExpectError sets the boolean if we are expecting an error from the server func (c *FTWCheck) SetExpectError(expect bool) { - c.expected.ExpectError = expect + c.expected.ExpectError = &expect } // SetLogContains sets the string to look for in logs diff --git a/check/base_test.go b/check/base_test.go index f29d28fe..6a2e62b6 100644 --- a/check/base_test.go +++ b/check/base_test.go @@ -62,11 +62,11 @@ func (s *checkBaseTestSuite) TestNewCheck() { ResponseContains: "", LogContains: "nothing", NoLogContains: "", - ExpectError: true, + ExpectError: func() *bool { b := true; return &b }(), } c.SetExpectTestOutput(&to) - s.True(c.expected.ExpectError, "Problem setting expected output") + s.True(*c.expected.ExpectError, "Problem setting expected output") c.SetNoLogContains("nologcontains") diff --git a/check/error.go b/check/error.go index 04fce446..118d3f46 100644 --- a/check/error.go +++ b/check/error.go @@ -5,11 +5,11 @@ import "github.com/rs/zerolog/log" // AssertExpectError helper to check if this error was expected or not func (c *FTWCheck) AssertExpectError(err error) bool { if err != nil { - log.Debug().Msgf("ftw/check: expected error? -> %t, and error is %s", c.expected.ExpectError, err.Error()) + log.Debug().Msgf("ftw/check: expected error? -> %t, and error is %s", *c.expected.ExpectError, err.Error()) } else { - log.Debug().Msgf("ftw/check: expected error? -> %t, and error is nil", c.expected.ExpectError) + log.Debug().Msgf("ftw/check: expected error? -> %t, and error is nil", *c.expected.ExpectError) } - if c.expected.ExpectError && err != nil { + if *c.expected.ExpectError && err != nil { return true } return false diff --git a/cmd/root.go b/cmd/root.go index 4f832048..29caed95 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -68,4 +68,5 @@ func initConfig() { if cloud { cfg.RunMode = config.CloudRunMode } + } diff --git a/runner/run.go b/runner/run.go index cd83d3a1..1d266607 100644 --- a/runner/run.go +++ b/runner/run.go @@ -19,7 +19,7 @@ import ( "github.com/coreruleset/go-ftw/waflog" ) -var errBadTestRequest = errors.New("ftw/run: bad test: choose between data, encoded_request, or raw_request") +var errBadTestInput = errors.New("ftw/run: bad test input: choose between data, encoded_request, or raw_request") // Run runs your tests with the specified Config. func Run(cfg *config.FTWConfiguration, tests []test.FTWTest, c RunnerConfig, out *output.Output) (*TestRunContext, error) { @@ -75,9 +75,9 @@ func RunTest(runContext *TestRunContext, ftwTest test.FTWTest) error { for _, testCase := range ftwTest.Tests { // if we received a particular testid, skip until we find it - if needToSkipTest(runContext.Include, runContext.Exclude, testCase.TestTitle, ftwTest.Meta.Enabled) { + if needToSkipTest(runContext.Include, runContext.Exclude, testCase.TestTitle, *ftwTest.Meta.Enabled) { runContext.Stats.addResultToStats(Skipped, testCase.TestTitle, 0) - if !ftwTest.Meta.Enabled && !runContext.ShowOnlyFailed { + if !*ftwTest.Meta.Enabled && !runContext.ShowOnlyFailed { runContext.Output.Println("\tskipping %s - (enabled: false) in file.", testCase.TestTitle) } continue @@ -115,16 +115,13 @@ func RunStage(runContext *TestRunContext, ftwCheck *check.FTWCheck, testCase tes stageStartTime := time.Now() stageID := uuid.NewString() // Apply global overrides initially - testRequest := stage.Input - err := applyInputOverride(runContext.Config.TestOverride, &testRequest) - if err != nil { - log.Debug().Msgf("ftw/run: problem overriding input: %s", err.Error()) - } + testInput := stage.Input + test.ApplyInputOverrides(&runContext.Config.TestOverride.Overrides, &testInput) expectedOutput := stage.Output // Check sanity first - if checkTestSanity(testRequest) { - return errBadTestRequest + if checkTestSanity(testInput) { + return errBadTestInput } // Do not even run test if result is overridden. Just use the override and display the overridden result. @@ -138,24 +135,24 @@ func RunStage(runContext *TestRunContext, ftwCheck *check.FTWCheck, testCase tes // Destination is needed for a request dest := &ftwhttp.Destination{ - DestAddr: testRequest.GetDestAddr(), - Port: testRequest.GetPort(), - Protocol: testRequest.GetProtocol(), + DestAddr: testInput.GetDestAddr(), + Port: testInput.GetPort(), + Protocol: testInput.GetProtocol(), } if notRunningInCloudMode(ftwCheck) { startMarker, err := markAndFlush(runContext, dest, stageID) - if err != nil && !expectedOutput.ExpectError { + if err != nil && !*expectedOutput.ExpectError { return fmt.Errorf("failed to find start marker: %w", err) } ftwCheck.SetStartMarker(startMarker) } - req = getRequestFromTest(testRequest) + req = getRequestFromTest(testInput) - err = runContext.Client.NewConnection(*dest) + err := runContext.Client.NewConnection(*dest) - if err != nil && !expectedOutput.ExpectError { + if err != nil && !*expectedOutput.ExpectError { return fmt.Errorf("can't connect to destination %+v: %w", dest, err) } runContext.Client.StartTrackingTime() @@ -163,13 +160,13 @@ func RunStage(runContext *TestRunContext, ftwCheck *check.FTWCheck, testCase tes response, responseErr := runContext.Client.Do(*req) runContext.Client.StopTrackingTime() - if responseErr != nil && !expectedOutput.ExpectError { + if responseErr != nil && !*expectedOutput.ExpectError { return fmt.Errorf("failed sending request to destination %+v: %w", dest, responseErr) } if notRunningInCloudMode(ftwCheck) { endMarker, err := markAndFlush(runContext, dest, stageID) - if err != nil && !expectedOutput.ExpectError { + if err != nil && !*expectedOutput.ExpectError { return fmt.Errorf("failed to find end marker: %w", err) } @@ -270,10 +267,10 @@ func needToSkipTest(include *regexp.Regexp, exclude *regexp.Regexp, title string return result } -func checkTestSanity(testRequest test.Input) bool { - return (utils.IsNotEmpty(testRequest.Data) && testRequest.EncodedRequest != "") || - (utils.IsNotEmpty(testRequest.Data) && testRequest.RAWRequest != "") || - (testRequest.EncodedRequest != "" && testRequest.RAWRequest != "") +func checkTestSanity(testInput test.Input) bool { + return (utils.IsNotEmpty(testInput.Data) && testInput.EncodedRequest != "") || + (utils.IsNotEmpty(testInput.Data) && testInput.RAWRequest != "") || + (testInput.EncodedRequest != "" && testInput.RAWRequest != "") } func displayResult(rc *TestRunContext, result TestResult, roundTripTime time.Duration, stageTime time.Duration) { @@ -353,104 +350,33 @@ func checkResult(c *check.FTWCheck, response *ftwhttp.Response, responseError er return Success } -func getRequestFromTest(testRequest test.Input) *ftwhttp.Request { +func getRequestFromTest(testInput test.Input) *ftwhttp.Request { var req *ftwhttp.Request // get raw request, if anything - raw, err := testRequest.GetRawRequest() + raw, err := testInput.GetRawRequest() if err != nil { log.Error().Msgf("ftw/run: error getting raw data: %s\n", err.Error()) } // If we use raw or encoded request, then we don't use other fields if raw != nil { - req = ftwhttp.NewRawRequest(raw, !testRequest.NoAutocompleteHeaders) + req = ftwhttp.NewRawRequest(raw, !*testInput.NoAutocompleteHeaders) } else { rline := &ftwhttp.RequestLine{ - Method: testRequest.GetMethod(), - URI: testRequest.GetURI(), - Version: testRequest.GetVersion(), + Method: testInput.GetMethod(), + URI: testInput.GetURI(), + Version: testInput.GetVersion(), } - data := testRequest.ParseData() + data := testInput.ParseData() // create a new request - req = ftwhttp.NewRequest(rline, testRequest.Headers, - data, !testRequest.NoAutocompleteHeaders) + req = ftwhttp.NewRequest(rline, testInput.Headers, + data, !*testInput.NoAutocompleteHeaders) } return req } -// applyInputOverride will check if config had global overrides and write that into the test. -func applyInputOverride(o config.FTWTestOverride, testRequest *test.Input) error { - overrides := o.Overrides - - if overrides.DestAddr != nil { - testRequest.DestAddr = overrides.DestAddr - if testRequest.Headers == nil { - testRequest.Headers = ftwhttp.Header{} - } - if overrides.OverrideEmptyHostHeader && testRequest.Headers.Get("Host") == "" { - testRequest.Headers.Set("Host", *overrides.DestAddr) - } - } - - if overrides.Port != nil { - testRequest.Port = overrides.Port - } - - if overrides.Protocol != nil { - testRequest.Protocol = overrides.Protocol - } - - if overrides.URI != nil { - testRequest.URI = overrides.URI - } - - if overrides.Version != nil { - testRequest.Version = overrides.Version - } - - if overrides.Headers != nil { - if testRequest.Headers == nil { - testRequest.Headers = ftwhttp.Header{} - } - for k, v := range overrides.Headers { - testRequest.Headers.Set(k, v) - } - } - - if overrides.Method != nil { - testRequest.Method = overrides.Method - } - - if overrides.Data != nil { - testRequest.Data = overrides.Data - } - - // TODO: postprocess - if overrides.SaveCookie != nil { - testRequest.SaveCookie = overrides.SaveCookie - } - - if overrides.StopMagic != nil { - testRequest.StopMagic = overrides.StopMagic - } - - if overrides.NoAutocompleteHeaders != nil { - testRequest.NoAutocompleteHeaders = overrides.NoAutocompleteHeaders - } - - if overrides.EncodedRequest != nil { - testRequest.EncodedRequest = *overrides.EncodedRequest - } - - if overrides.RAWRequest != nil { - testRequest.RAWRequest = *overrides.RAWRequest - } - - return nil -} - func notRunningInCloudMode(c *check.FTWCheck) bool { return !c.CloudMode() } diff --git a/runner/run_input_override_test.go b/runner/run_input_override_test.go index 53637259..12238386 100644 --- a/runner/run_input_override_test.go +++ b/runner/run_input_override_test.go @@ -3,6 +3,7 @@ package runner import ( "bytes" "errors" + "fmt" "runtime" "strconv" "strings" @@ -27,6 +28,7 @@ var configTemplate = ` testoverride: input: {{ with .StopMagic }}stop_magic: {{ . }}{{ end }} + {{ with .NoAutocompleteHeaders }}no_autocomplete_headers: {{ . }}{{ end }} {{ with .BrokenConfig }}this_does_not_exist: "test"{{ end }} {{ with .Port }}port: {{ . }}{{ end }} {{ with .DestAddr }}dest_addr: {{ . }}{{ end }} @@ -93,6 +95,9 @@ var overrideConfigMap = map[string]interface{}{ "TestApplyInputOverrideStopMagic": map[string]interface{}{ "StopMagic": "true", }, + "TestApplyInputOverrideNoAutocompleteHeaders": map[string]interface{}{ + "NoAutocompleteHeaders": "true", + }, } // getOverrideConfigValue is useful to not repeat the text in the test itself @@ -109,7 +114,12 @@ func getOverrideConfigValue(key string) (string, error) { keyParts := strings.Split(key, ".") return overrideConfigMap[name].(map[string]interface{})[keyParts[0]].(map[string]string)[keyParts[1]], nil } - return overrideConfigMap[name].(map[string]interface{})[key].(string), nil + value, ok := overrideConfigMap[name].(map[string]interface{})[key] + if !ok { + return "", fmt.Errorf("Key '%s' not found four test '%s'", key, name) + } + + return value.(string), nil } return "", errors.New("failed to determine calling function") } @@ -150,13 +160,12 @@ func (s *inputOverrideTestSuite) TestSetHostFromDestAddr() { TestOverride: config.FTWTestOverride{ Overrides: test.Overrides{ DestAddr: &overrideHost, - OverrideEmptyHostHeader: true, + OverrideEmptyHostHeader: func() *bool { b := true; return &b }(), }, }, } - err = applyInputOverride(cfg.TestOverride, &testInput) - s.NoError(err, "Failed to apply input overrides") + test.ApplyInputOverrides(&cfg.TestOverride.Overrides, &testInput) s.Equal(overrideHost, *testInput.DestAddr, "`dest_addr` should have been overridden") @@ -176,8 +185,7 @@ func (s *inputOverrideTestSuite) TestSetHostFromHostHeaderOverride() { DestAddr: &originalDestAddr, } - err = applyInputOverride(s.cfg.TestOverride, &testInput) - s.NoError(err, "Failed to apply input overrides") + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) hostHeader := testInput.Headers.Get("Host") s.NotEqual("", hostHeader, "Host header must be set after overriding the `Host` header") @@ -199,7 +207,7 @@ func (s *inputOverrideTestSuite) TestSetHeaderOverridingExistingOne() { s.NotNil(testInput.Headers, "Header map must exist before overriding any header") - err = applyInputOverride(s.cfg.TestOverride, &testInput) + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) s.NoError(err, "Failed to apply input overrides") overriddenHeader := testInput.Headers.Get("unique_id") @@ -218,7 +226,7 @@ func (s *inputOverrideTestSuite) TestApplyInputOverrides() { s.NotNil(testInput.Headers, "Header map must exist before overriding any header") - err = applyInputOverride(s.cfg.TestOverride, &testInput) + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) s.NoError(err, "Failed to apply input overrides") overriddenHeader := testInput.Headers.Get("unique_id") @@ -235,7 +243,7 @@ func (s *inputOverrideTestSuite) TestApplyInputOverrideURI() { URI: &originalURI, } - err = applyInputOverride(s.cfg.TestOverride, &testInput) + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) s.NoError(err, "Failed to apply input overrides") s.Equal(overrideURI, *testInput.URI, "`URI` should have been overridden") } @@ -248,7 +256,7 @@ func (s *inputOverrideTestSuite) TestApplyInputOverrideVersion() { testInput := test.Input{ Version: &originalVersion, } - err = applyInputOverride(s.cfg.TestOverride, &testInput) + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) s.NoError(err, "Failed to apply input overrides") s.Equal(overrideVersion, *testInput.Version, "`Version` should have been overridden") } @@ -261,7 +269,7 @@ func (s *inputOverrideTestSuite) TestApplyInputOverrideMethod() { testInput := test.Input{ Method: &originalMethod, } - err = applyInputOverride(s.cfg.TestOverride, &testInput) + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) s.NoError(err, "Failed to apply input overrides") s.Equal(overrideMethod, *testInput.Method, "`Method` should have been overridden") } @@ -274,7 +282,7 @@ func (s *inputOverrideTestSuite) TestApplyInputOverrideData() { testInput := test.Input{ Data: &originalData, } - err = applyInputOverride(s.cfg.TestOverride, &testInput) + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) s.NoError(err, "Failed to apply input overrides") s.Equal(overrideData, *testInput.Data, "`Data` should have been overridden") } @@ -285,11 +293,26 @@ func (s *inputOverrideTestSuite) TestApplyInputOverrideStopMagic() { overrideStopMagic, err := strconv.ParseBool(stopMagicBool) s.NoError(err, "Failed to parse `StopMagic` override value") testInput := test.Input{ - StopMagic: false, + StopMagic: func() *bool { b := false; return &b }(), + } + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) + s.NoError(err, "Failed to apply input overrides") + // nolint + s.Equal(overrideStopMagic, *testInput.StopMagic, "`StopMagic` should have been overridden") +} + +func (s *inputOverrideTestSuite) TestApplyInputOverrideNoAutocompleteHeaders() { + noAutocompleteHeadersBool, err := getOverrideConfigValue("NoAutocompleteHeaders") + s.NoError(err, "cannot get override value") + overrideNoAutocompleteHeaders, err := strconv.ParseBool(noAutocompleteHeadersBool) + s.NoError(err, "Failed to parse `NoAutocompleteHeaders` override value") + testInput := test.Input{ + NoAutocompleteHeaders: func() *bool { b := false; return &b }(), } - err = applyInputOverride(s.cfg.TestOverride, &testInput) + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) s.NoError(err, "Failed to apply input overrides") - s.Equal(overrideStopMagic, testInput.StopMagic, "`StopMagic` should have been overridden") + // nolint + s.Equal(overrideNoAutocompleteHeaders, *testInput.NoAutocompleteHeaders, "`NoAutocompleteHeaders` should have been overridden") } func (s *inputOverrideTestSuite) TestApplyInputOverrideEncodedRequest() { @@ -299,7 +322,7 @@ func (s *inputOverrideTestSuite) TestApplyInputOverrideEncodedRequest() { testInput := test.Input{ EncodedRequest: originalEncodedRequest, } - err = applyInputOverride(s.cfg.TestOverride, &testInput) + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) s.NoError(err, "Failed to apply input overrides") s.Equal(overrideEncodedRequest, testInput.EncodedRequest, "`EncodedRequest` should have been overridden") } @@ -313,7 +336,7 @@ func (s *inputOverrideTestSuite) TestApplyInputOverrideRAWRequest() { RAWRequest: originalRAWRequest, } - err = applyInputOverride(s.cfg.TestOverride, &testInput) + test.ApplyInputOverrides(&s.cfg.TestOverride.Overrides, &testInput) s.NoError(err, "Failed to apply input overrides") s.Equal(overrideRAWRequest, testInput.RAWRequest, "`RAWRequest` should have been overridden") } diff --git a/runner/run_test.go b/runner/run_test.go index 98afdb0b..6d7421d5 100644 --- a/runner/run_test.go +++ b/runner/run_test.go @@ -105,56 +105,6 @@ type runTestSuite struct { tempFileName string } -var yamlNoAutocompleteHeadersTest = `--- -meta: - author: "tester" - enabled: true - name: "gotest-ftw.yaml" - description: "Example Test" -tests: - - test_title: "001" - description: "autocomplete headers by default" - stages: - - stage: - input: - dest_addr: "localhost" - headers: - User-Agent: "ModSecurity CRS 3 Tests" - Accept: "*/*" - Host: "localhost" - output: - expect_error: False - status: [200] - - test_title: "002" - description: "autocomplete headers explicitly" - stages: - - stage: - input: - no_autocomplete_headers: false - dest_addr: "localhost" - headers: - User-Agent: "ModSecurity CRS 3 Tests" - Accept: "*/*" - Host: "localhost" - output: - expect_error: False - status: [200] - - test_title: "003" - description: "do not autocomplete" - stages: - - stage: - input: - no_autocomplete_headers: true - dest_addr: "localhost" - headers: - User-Agent: "ModSecurity CRS 3 Tests" - Accept: "*/*" - Host: "localhost" - output: - expect_error: False - status: [200] -` - // Error checking omitted for brevity func (s *runTestSuite) newTestServer(logLines string) { var err error @@ -394,27 +344,3 @@ func (s *runTestSuite) TestIgnoredTestsRun() { s.NoError(err) s.Equal(res.Stats.TotalFailed(), 1, "Oops, test run failed!") } - -func TestNoAutocompleteHeadersDefault(t *testing.T) { - ftwTest, err := test.GetTestFromYaml([]byte(yamlNoAutocompleteHeadersTest)) - assert.NoError(t, err) - - request := getRequestFromTest(ftwTest.Tests[0].Stages[0].Stage.Input) - assert.True(t, request.WithAutoCompleteHeaders()) -} - -func TestNoAutocompleteHeadersFalse(t *testing.T) { - ftwTest, err := test.GetTestFromYaml([]byte(yamlNoAutocompleteHeadersTest)) - assert.NoError(t, err) - - request := getRequestFromTest(ftwTest.Tests[1].Stages[0].Stage.Input) - assert.True(t, request.WithAutoCompleteHeaders()) -} - -func TestNoAutocompleteHeadersTrue(t *testing.T) { - ftwTest, err := test.GetTestFromYaml([]byte(yamlNoAutocompleteHeadersTest)) - assert.NoError(t, err) - - request := getRequestFromTest(ftwTest.Tests[2].Stages[0].Stage.Input) - assert.False(t, request.WithAutoCompleteHeaders()) -} diff --git a/test/data_test.go b/test/data_test.go index 411adb1a..45e1699d 100644 --- a/test/data_test.go +++ b/test/data_test.go @@ -35,7 +35,7 @@ uri: "/" input := Input{} err := yaml.Unmarshal([]byte(yamlString), &input) s.NoError(err) - s.True(input.NoAutocompleteHeaders) + s.True(*input.NoAutocompleteHeaders) } func (s *dataTestSuite) TestGetPartialDataFromYAML() { @@ -57,7 +57,7 @@ uri: "/" err := yaml.Unmarshal([]byte(yamlString), &input) s.NoError(err) s.Empty(*input.Version) - s.False(input.NoAutocompleteHeaders) + s.False(*input.NoAutocompleteHeaders) } func (s *dataTestSuite) TestDataTemplateFromYAML() { @@ -83,5 +83,5 @@ uri: "/" data = input.ParseData() s.Equal([]byte(repeatTestSprig), data) - s.True(input.NoAutocompleteHeaders) + s.True(*input.NoAutocompleteHeaders) } diff --git a/test/defaults_test.go b/test/defaults_test.go index 3b0a8243..583cb135 100644 --- a/test/defaults_test.go +++ b/test/defaults_test.go @@ -23,8 +23,8 @@ func getTestInputDefaults() *Input { inputDefaults := Input{ Headers: make(ftwhttp.Header), Data: &data, - SaveCookie: false, - NoAutocompleteHeaders: false, + SaveCookie: func() *bool { b := false; return &b }(), + NoAutocompleteHeaders: func() *bool { b := false; return &b }(), } return &inputDefaults } @@ -47,8 +47,8 @@ func getTestExampleInput() *Input { Method: &method, Data: nil, EncodedRequest: "TXkgRGF0YQo=", - SaveCookie: false, - NoAutocompleteHeaders: false, + SaveCookie: func() *bool { b := false; return &b }(), + NoAutocompleteHeaders: func() *bool { b := false; return &b }(), } return &inputTest @@ -75,8 +75,8 @@ Keep-Alive: 300 Proxy-Connection: keep-alive User-Agent: Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; SV1; .NET CLR 2.0.50727) `, - SaveCookie: false, - NoAutocompleteHeaders: true, + SaveCookie: func() *bool { b := false; return &b }(), + NoAutocompleteHeaders: func() *bool { b := true; return &b }(), } return &inputTest @@ -128,7 +128,7 @@ func (s *defaultsTestSuite) TestDefaultGetters() { func (s *defaultsTestSuite) TestRaw() { raw := getRawInput() - s.True(raw.NoAutocompleteHeaders) + s.True(*raw.NoAutocompleteHeaders) request, _ := raw.GetRawRequest() s.NotEqual(2, bytes.Index(request, []byte("Acunetix"))) diff --git a/test/files.go b/test/files.go index 7890e692..40321a41 100644 --- a/test/files.go +++ b/test/files.go @@ -2,13 +2,9 @@ package test import ( "errors" - "fmt" "os" - "regexp" - "strings" "github.com/goccy/go-yaml" - "github.com/goccy/go-yaml/ast" "github.com/rs/zerolog/log" "github.com/yargevad/filepathx" ) @@ -38,12 +34,12 @@ func GetTestsFromFiles(globPattern string) ([]FTWTest, error) { if err != nil { log.Error().Msgf("Problem detected in file %s:\n%s\n%s", fileName, yaml.FormatError(err, true, true), - describeYamlError(err)) + DescribeYamlError(err)) return tests, err } ftwTest.FileName = fileName - tests = append(tests, ftwTest) + tests = append(tests, *ftwTest) } if len(tests) == 0 { @@ -52,23 +48,6 @@ func GetTestsFromFiles(globPattern string) ([]FTWTest, error) { return tests, nil } -// GetTestFromYaml will get the tests to be processed from a YAML string. -func GetTestFromYaml(testYaml []byte) (ftwTest FTWTest, err error) { - ftwTest, err = readTestYaml(testYaml) - if err != nil { - return FTWTest{}, err - } - - postProcess(testYaml, &ftwTest) - - return ftwTest, nil -} - -func readTestYaml(testYaml []byte) (t FTWTest, err error) { - err = yaml.Unmarshal(testYaml, &t) - return t, err -} - func readFileContents(fileName string) (contents []byte, err error) { contents, err = os.ReadFile(fileName) if err != nil { @@ -76,117 +55,3 @@ func readFileContents(fileName string) (contents []byte, err error) { } return contents, err } - -func describeYamlError(yamlError error) string { - matched, err := regexp.MatchString(`.*int was used where sequence is expected.*`, yamlError.Error()) - if err != nil { - return err.Error() - } - if matched { - return "\nTip: This might refer to a \"status\" line being '200', where it should be '[200]'.\n" + - "The default \"status\" is a list now.\n" + - "A simple example would be like this:\n\n" + - "status: 403\n" + - "needs to be changed to:\n\n" + - "status: [403]\n\n" - } - matched, err = regexp.MatchString(`.*cannot unmarshal \[]interface {} into Go struct field FTWTest.Tests of type string.*`, yamlError.Error()) - if err != nil { - return err.Error() - } - if matched { - return "\nTip: This might refer to \"data\" on the test being a list of strings instead of a proper YAML multiline.\n" + - "To fix this, convert this \"data\" string list to a multiline YAML and this will be fixed.\n" + - "A simple example would be like this:\n\n" + - "data:\n" + - " - 'Hello'\n" + - " - 'World'\n" + - "can be expressed as:\n\n" + - "data: |\n" + - " Hello\n" + - " World\n\n" + - "You can also remove single/double quotes from beggining and end of text, they are not needed. See https://yaml-multiline.info/ for additional help.\n" - } - - return "We do not have an extended explanation of this error." -} - -// TODO: make more general (env, string) -// TODO: also post process overrides -func postProcess(testYaml []byte, ftwTest *FTWTest) error { - yamlString := string(testYaml) - for index, test := range ftwTest.Tests { - path, err := yaml.PathString(fmt.Sprintf("$.tests[%d]", index)) - if err != nil { - return err - } - node, err := path.ReadNode(strings.NewReader(yamlString)) - if err != nil { - return err - } - err = postProcessTest(node.(*ast.MappingNode), &test) - if err != nil { - return err - } - } - return nil -} - -func postProcessTest(node *ast.MappingNode, test *Test) error { - nodeString := node.String() - for index, stage := range test.Stages { - path, err := yaml.PathString(fmt.Sprintf("$.stages[%d]", index)) - if err != nil { - return err - } - stageNode, err := path.ReadNode(strings.NewReader(nodeString)) - if err != nil { - return err - } - err = postProcessStage(stageNode.(*ast.MappingValueNode), &stage.Stage) - if err != nil { - return err - } - } - return nil -} - -func postProcessStage(node *ast.MappingValueNode, stage *Stage) error { - return postProcessInput(node.Value.(*ast.MappingNode), &stage.Input) -} - -func postProcessInput(node *ast.MappingNode, input *Input) error { - return postProcessNoAutocompleteHeaders(node, input) -} - -func postProcessNoAutocompleteHeaders(node ast.Node, input *Input) error { - noAutocompleteHeadersMissing := false - stopMagicMissing := false - err := readField(node, "no_autocomplete_headers", &input.NoAutocompleteHeaders) - if err != nil { - noAutocompleteHeadersMissing = true - } - err = readField(node, "stop_magic", &input.StopMagic) - if err != nil { - stopMagicMissing = true - } - - if noAutocompleteHeadersMissing && stopMagicMissing { - return nil - } - if noAutocompleteHeadersMissing && !stopMagicMissing { - input.NoAutocompleteHeaders = input.StopMagic - return nil - } - input.StopMagic = input.NoAutocompleteHeaders - return nil -} - -func readField(node ast.Node, fieldName string, out interface{}) error { - path, err := yaml.PathString("$." + fieldName) - if err != nil { - return err - } - err = path.Read(strings.NewReader(node.String()), out) - return err -} diff --git a/test/types.go b/test/types.go index d2a16a46..051fe59b 100644 --- a/test/types.go +++ b/test/types.go @@ -1,25 +1,27 @@ package test -import "github.com/coreruleset/go-ftw/ftwhttp" +import ( + "github.com/coreruleset/go-ftw/ftwhttp" +) // Input represents the input request in a stage // The fields `Version`, `Method` and `URI` we want to explicitly know when they are set to "" type Input struct { - DestAddr *string `yaml:"dest_addr,omitempty" koanf:"dest_addr,omitempty"` - Port *int `yaml:"port,omitempty" koanf:"port,omitempty"` - Protocol *string `yaml:"protocol,omitempty" koanf:"protocol,omitempty"` - URI *string `yaml:"uri,omitempty" koanf:"uri,omitempty"` - Version *string `yaml:"version,omitempty" koanf:"version,omitempty"` - Headers ftwhttp.Header `yaml:"headers,omitempty" koanf:"headers,omitempty"` - Method *string `yaml:"method,omitempty" koanf:"method,omitempty"` - Data *string `yaml:"data,omitempty" koanf:"data,omitempty"` - SaveCookie bool `yaml:"save_cookie,omitempty" koanf:"save_cookie,omitempty"` - // Deprecated, replaced with NoAutocompleteHeaders - StopMagic bool `yaml:"stop_magic" koanf:"stop_magic,omitempty"` - NoAutocompleteHeaders bool `yaml:"no_autocomplete_headers" koanf:"no_autocomplete_headers,omitempty"` - EncodedRequest string `yaml:"encoded_request,omitempty" koanf:"encoded_request,omitempty"` - RAWRequest string `yaml:"raw_request,omitempty" koanf:"raw_request,omitempty"` + DestAddr *string `yaml:"dest_addr,omitempty"` + Port *int `yaml:"port,omitempty"` + Protocol *string `yaml:"protocol,omitempty"` + URI *string `yaml:"uri,omitempty"` + Version *string `yaml:"version,omitempty"` + Headers ftwhttp.Header `yaml:"headers,omitempty"` + Method *string `yaml:"method,omitempty"` + Data *string `yaml:"data,omitempty"` + SaveCookie *bool `yaml:"save_cookie,omitempty"` + // Deprecated: replaced with NoAutocompleteHeaders + StopMagic *bool `yaml:"stop_magic"` + NoAutocompleteHeaders *bool `yaml:"no_autocomplete_headers"` + EncodedRequest string `yaml:"encoded_request,omitempty"` + RAWRequest string `yaml:"raw_request,omitempty"` } // Overrides represents the overridden inputs that have to be applied to tests @@ -32,13 +34,13 @@ type Overrides struct { Headers ftwhttp.Header `yaml:"headers,omitempty" koanf:"headers,omitempty"` Method *string `yaml:"method,omitempty" koanf:"method,omitempty"` Data *string `yaml:"data,omitempty" koanf:"data,omitempty"` - SaveCookie bool `yaml:"save_cookie,omitempty" koanf:"save_cookie,omitempty"` - // Deprecated, replaced with NoAutocompleteHeaders - StopMagic bool `yaml:"stop_magic" koanf:"stop_magic,omitempty"` - NoAutocompleteHeaders bool `yaml:"no_autocomplete_headers" koanf:"no_autocomplete_headers,omitempty"` + SaveCookie *bool `yaml:"save_cookie,omitempty" koanf:"save_cookie,omitempty"` + // Deprecated: replaced with NoAutocompleteHeaders + StopMagic *bool `yaml:"stop_magic" koanf:"stop_magic,omitempty"` + NoAutocompleteHeaders *bool `yaml:"no_autocomplete_headers" koanf:"no_autocomplete_headers,omitempty"` EncodedRequest *string `yaml:"encoded_request,omitempty" koanf:"encoded_request,omitempty"` RAWRequest *string `yaml:"raw_request,omitempty" koanf:"raw_request,omitempty"` - OverrideEmptyHostHeader bool `yaml:"override_empty_host_header,omitempty" koanf:"override_empty_host_header,omitempty"` + OverrideEmptyHostHeader *bool `yaml:"override_empty_host_header,omitempty" koanf:"override_empty_host_header,omitempty"` } // Output is the response expected from the test @@ -47,7 +49,7 @@ type Output struct { ResponseContains string `yaml:"response_contains,omitempty"` LogContains string `yaml:"log_contains,omitempty"` NoLogContains string `yaml:"no_log_contains,omitempty"` - ExpectError bool `yaml:"expect_error,omitempty"` + ExpectError *bool `yaml:"expect_error,omitempty"` } // Stage is an individual test stage @@ -70,9 +72,112 @@ type FTWTest struct { FileName string Meta struct { Author string `yaml:"author,omitempty"` - Enabled bool `yaml:"enabled,omitempty"` + Enabled *bool `yaml:"enabled,omitempty"` Name string `yaml:"name,omitempty"` Description string `yaml:"description,omitempty"` } `yaml:"meta"` Tests []Test `yaml:"tests"` } + +// ApplyInputOverride will check if config had global overrides and write that into the test. +func ApplyInputOverrides(overrides *Overrides, input *Input) { + applySimpleOverrides(overrides, input) + applyDestAddrOverride(overrides, input) + applyHeadersOverride(overrides, input) + postProcessNoAutocompleteHeaders(overrides.NoAutocompleteHeaders, overrides.StopMagic, input) +} + +func applyDestAddrOverride(overrides *Overrides, input *Input) { + if overrides.DestAddr != nil { + input.DestAddr = overrides.DestAddr + if input.Headers == nil { + input.Headers = ftwhttp.Header{} + } + if overrides.OverrideEmptyHostHeader != nil && *overrides.OverrideEmptyHostHeader && input.Headers.Get("Host") == "" { + input.Headers.Set("Host", *overrides.DestAddr) + } + } +} + +func applySimpleOverrides(overrides *Overrides, input *Input) { + if overrides.Port != nil { + input.Port = overrides.Port + } + + if overrides.Protocol != nil { + input.Protocol = overrides.Protocol + } + + if overrides.URI != nil { + input.URI = overrides.URI + } + + if overrides.Version != nil { + input.Version = overrides.Version + } + + if overrides.Method != nil { + input.Method = overrides.Method + } + + if overrides.Data != nil { + input.Data = overrides.Data + } + + if overrides.SaveCookie != nil { + input.SaveCookie = overrides.SaveCookie + } + + if overrides.EncodedRequest != nil { + input.EncodedRequest = *overrides.EncodedRequest + } + + if overrides.RAWRequest != nil { + input.RAWRequest = *overrides.RAWRequest + } +} + +func applyHeadersOverride(overrides *Overrides, input *Input) { + if overrides.Headers != nil { + if input.Headers == nil { + input.Headers = ftwhttp.Header{} + } + for k, v := range overrides.Headers { + input.Headers.Set(k, v) + } + } +} + +func postLoadTestFTWTest(ftwTest *FTWTest) { + for _, test := range ftwTest.Tests { + postLoadTest(&test) + } +} + +func postLoadTest(test *Test) { + for index := range test.Stages { + postLoadStage(&test.Stages[index].Stage) + } +} + +func postLoadStage(stage *Stage) { + postLoadInput(&stage.Input) +} + +func postLoadInput(input *Input) { + postProcessNoAutocompleteHeaders(input.NoAutocompleteHeaders, input.StopMagic, input) +} + +func postProcessNoAutocompleteHeaders(noAutocompleteHeaders *bool, stopMagic *bool, input *Input) { + noAutocompleteHeadersMissing := noAutocompleteHeaders == nil + stopMagicMissing := stopMagic == nil + finalValue := false + + if noAutocompleteHeadersMissing && !stopMagicMissing { + finalValue = *stopMagic + } else if !noAutocompleteHeadersMissing { + finalValue = *noAutocompleteHeaders + } + input.NoAutocompleteHeaders = &finalValue + input.StopMagic = &finalValue +} diff --git a/test/types_test.go b/test/types_test.go new file mode 100644 index 00000000..0d4d614b --- /dev/null +++ b/test/types_test.go @@ -0,0 +1,251 @@ +package test + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type typesTestSuite struct { + suite.Suite +} + +func TestTypesTestSuite(t *testing.T) { + suite.Run(t, new(typesTestSuite)) +} + +var noAutocompleteHeadersDefaultYaml = `--- +meta: + author: "tester" + enabled: true + name: "gotest-ftw.yaml" + description: "Example Test" +tests: + - test_title: "001" + description: "autocomplete headers by default" + stages: + - stage: + input: + dest_addr: "localhost" + headers: + User-Agent: "ModSecurity CRS 3 Tests" + Accept: "*/*" + Host: "localhost" + output: + expect_error: False + status: [200] + - test_title: "002" + description: "autocomplete headers by default" + stages: + - stage: + input: + stop_magic: true + dest_addr: "localhost" + headers: + User-Agent: "ModSecurity CRS 3 Tests" + Accept: "*/*" + Host: "localhost" + output: + expect_error: False + status: [200] + - test_title: "003" + description: "autocomplete headers by default" + stages: + - stage: + input: + stop_magic: false + dest_addr: "localhost" + headers: + User-Agent: "ModSecurity CRS 3 Tests" + Accept: "*/*" + Host: "localhost" + output: + expect_error: False + status: [200] +` + +var noAutocompleteHeadersFalseYaml = `--- +meta: + author: "tester" + enabled: true + name: "gotest-ftw.yaml" + description: "Example Test" +tests: + - test_title: "001" + description: "autocomplete headers explicitly" + stages: + - stage: + input: + no_autocomplete_headers: false + dest_addr: "localhost" + headers: + User-Agent: "ModSecurity CRS 3 Tests" + Accept: "*/*" + Host: "localhost" + output: + expect_error: False + status: [200] + - test_title: "002" + description: "autocomplete headers explicitly" + stages: + - stage: + input: + no_autocomplete_headers: false + stop_magic: true + dest_addr: "localhost" + headers: + User-Agent: "ModSecurity CRS 3 Tests" + Accept: "*/*" + Host: "localhost" + output: + expect_error: False + status: [200] + - test_title: "003" + description: "autocomplete headers explicitly" + stages: + - stage: + input: + no_autocomplete_headers: false + stop_magic: false + dest_addr: "localhost" + headers: + User-Agent: "ModSecurity CRS 3 Tests" + Accept: "*/*" + Host: "localhost" + output: + expect_error: False + status: [200] +` + +var noAutocompleteHeadersTrueYaml = `--- +meta: + author: "tester" + enabled: true + name: "gotest-ftw.yaml" + description: "Example Test" +tests: + - test_title: "001" + description: "do not autocomplete headers explicitly" + stages: + - stage: + input: + no_autocomplete_headers: true + dest_addr: "localhost" + headers: + User-Agent: "ModSecurity CRS 3 Tests" + Accept: "*/*" + Host: "localhost" + output: + expect_error: False + status: [200] + - test_title: "002" + description: "do not autocomplete headers explicitly" + stages: + - stage: + input: + no_autocomplete_headers: true + stop_magic: true + dest_addr: "localhost" + headers: + User-Agent: "ModSecurity CRS 3 Tests" + Accept: "*/*" + Host: "localhost" + output: + expect_error: False + status: [200] + - test_title: "003" + description: "do not autocomplete headers explicitly" + stages: + - stage: + input: + no_autocomplete_headers: true + stop_magic: false + dest_addr: "localhost" + headers: + User-Agent: "ModSecurity CRS 3 Tests" + Accept: "*/*" + Host: "localhost" + output: + expect_error: False + status: [200] +` + +func (s *typesTestSuite) TestNoAutocompleteHeadersDefault_StopMagicDefault() { + test, err := GetTestFromYaml([]byte(noAutocompleteHeadersDefaultYaml)) + s.NoError(err, "Parsing YAML shouldn't fail") + + input := test.Tests[0].Stages[0].Stage.Input + s.False(*input.NoAutocompleteHeaders) + s.False(*input.StopMagic) +} + +func (s *typesTestSuite) TestNoAutocompleteHeadersDefault_StopMagicTrue() { + test, err := GetTestFromYaml([]byte(noAutocompleteHeadersDefaultYaml)) + s.NoError(err, "Parsing YAML shouldn't fail") + + input := test.Tests[1].Stages[0].Stage.Input + s.True(*input.NoAutocompleteHeaders) + s.True(*input.StopMagic) +} +func (s *typesTestSuite) TestNoAutocompleteHeadersDefault_StopMagicFalse() { + test, err := GetTestFromYaml([]byte(noAutocompleteHeadersDefaultYaml)) + s.NoError(err, "Parsing YAML shouldn't fail") + + input := test.Tests[2].Stages[0].Stage.Input + s.False(*input.NoAutocompleteHeaders) + s.False(*input.StopMagic) +} + +func (s *typesTestSuite) TestNoAutocompleteHeadersFalse_StopMagicDefault() { + test, err := GetTestFromYaml([]byte(noAutocompleteHeadersFalseYaml)) + s.NoError(err, "Parsing YAML shouldn't fail") + + input := test.Tests[0].Stages[0].Stage.Input + s.False(*input.NoAutocompleteHeaders) + s.False(*input.StopMagic) +} + +func (s *typesTestSuite) TestNoAutocompleteHeadersFalse_StopMagicTrue() { + test, err := GetTestFromYaml([]byte(noAutocompleteHeadersFalseYaml)) + s.NoError(err, "Parsing YAML shouldn't fail") + + input := test.Tests[1].Stages[0].Stage.Input + s.False(*input.NoAutocompleteHeaders) + s.False(*input.StopMagic) +} + +func (s *typesTestSuite) TestNoAutocompleteHeadersFalse_StopMagicFalse() { + test, err := GetTestFromYaml([]byte(noAutocompleteHeadersFalseYaml)) + s.NoError(err, "Parsing YAML shouldn't fail") + + input := test.Tests[2].Stages[0].Stage.Input + s.False(*input.NoAutocompleteHeaders) + s.False(*input.StopMagic) +} + +func (s *typesTestSuite) TestNoAutocompleteHeadersTrue_StopMagicDefault() { + test, err := GetTestFromYaml([]byte(noAutocompleteHeadersTrueYaml)) + s.NoError(err, "Parsing YAML shouldn't fail") + + input := test.Tests[0].Stages[0].Stage.Input + s.True(*input.NoAutocompleteHeaders) + s.True(*input.StopMagic) +} + +func (s *typesTestSuite) TestNoAutocompleteHeadersTrue_StopMagicTrue() { + test, err := GetTestFromYaml([]byte(noAutocompleteHeadersTrueYaml)) + s.NoError(err, "Parsing YAML shouldn't fail") + + input := test.Tests[1].Stages[0].Stage.Input + s.True(*input.NoAutocompleteHeaders) + s.True(*input.StopMagic) +} + +func (s *typesTestSuite) TestNoAutocompleteHeadersTrue_StopMagicFalse() { + test, err := GetTestFromYaml([]byte(noAutocompleteHeadersTrueYaml)) + s.NoError(err, "Parsing YAML shouldn't fail") + + input := test.Tests[2].Stages[0].Stage.Input + s.True(*input.NoAutocompleteHeaders) + s.True(*input.StopMagic) +} diff --git a/test/yaml.go b/test/yaml.go new file mode 100644 index 00000000..693596b4 --- /dev/null +++ b/test/yaml.go @@ -0,0 +1,54 @@ +package test + +import ( + "regexp" + + "github.com/goccy/go-yaml" +) + +// GetTestFromYaml will get the tests to be processed from a YAML string. +func GetTestFromYaml(testYaml []byte) (ftwTest *FTWTest, err error) { + ftwTest = &FTWTest{} + err = yaml.Unmarshal(testYaml, ftwTest) + if err != nil { + return &FTWTest{}, err + } + + postLoadTestFTWTest(ftwTest) + + return ftwTest, nil +} + +func DescribeYamlError(yamlError error) string { + matched, err := regexp.MatchString(`.*int was used where sequence is expected.*`, yamlError.Error()) + if err != nil { + return err.Error() + } + if matched { + return "\nTip: This might refer to a \"status\" line being '200', where it should be '[200]'.\n" + + "The default \"status\" is a list now.\n" + + "A simple example would be like this:\n\n" + + "status: 403\n" + + "needs to be changed to:\n\n" + + "status: [403]\n\n" + } + matched, err = regexp.MatchString(`.*cannot unmarshal \[]interface {} into Go struct field FTWTest.Tests of type string.*`, yamlError.Error()) + if err != nil { + return err.Error() + } + if matched { + return "\nTip: This might refer to \"data\" on the test being a list of strings instead of a proper YAML multiline.\n" + + "To fix this, convert this \"data\" string list to a multiline YAML and this will be fixed.\n" + + "A simple example would be like this:\n\n" + + "data:\n" + + " - 'Hello'\n" + + " - 'World'\n" + + "can be expressed as:\n\n" + + "data: |\n" + + " Hello\n" + + " World\n\n" + + "You can also remove single/double quotes from beggining and end of text, they are not needed. See https://yaml-multiline.info/ for additional help.\n" + } + + return "We do not have an extended explanation of this error." +}