diff --git a/config/config_testdata/config_multipart_err.json b/config/config_testdata/config_multipart_err.json new file mode 100644 index 00000000..d9a605c1 --- /dev/null +++ b/config/config_testdata/config_multipart_err.json @@ -0,0 +1,18 @@ +{ + "steps": [ + { + "id": 1, + "url": "https://app.servdown.com/accounts/login/?next=/", + "method": "GET", + "payload_multipart": [ + + { + "name": "example-name-5", + "value": "https://uplo333ad.wikimedia.org/wikipedia/commons/b/bd/Test.svg", + "type": "file", + "src": "remote" + } + ] + } + ] +} \ No newline at end of file diff --git a/config/json.go b/config/json.go index 8caf9936..d53d7278 100644 --- a/config/json.go +++ b/config/json.go @@ -368,32 +368,37 @@ func prepareMultipartPayload(parts []multipartFormData) (body string, contentTyp writer := multipart.NewWriter(byteBody) for _, part := range parts { - var err error - + var multipartError RemoteMultipartError if strings.EqualFold(part.Type, "file") { if strings.EqualFold(part.Src, "remote") { response, err := http.Get(part.Value) if err != nil { - return "", "", err + multipartError.wrappedErr = err + multipartError.msg = "Error while getting remote file" + return "", "", multipartError } defer response.Body.Close() u, _ := url.Parse(part.Value) formPart, err := writer.CreateFormFile(part.Name, path.Base(u.Path)) if err != nil { + multipartError.wrappedErr = err + multipartError.msg = "Error while creating form file" return "", "", err } _, err = io.Copy(formPart, response.Body) if err != nil { - return "", "", err + multipartError.wrappedErr = err + multipartError.msg = "Error while copying response body" + return "", "", multipartError } } else { file, err := os.Open(part.Value) - defer file.Close() if err != nil { return "", "", err } + defer file.Close() formPart, err := writer.CreateFormFile(part.Name, filepath.Base(file.Name())) if err != nil { @@ -418,3 +423,19 @@ func prepareMultipartPayload(parts []multipartFormData) (body string, contentTyp writer.Close() return byteBody.String(), writer.FormDataContentType(), err } + +type RemoteMultipartError struct { // UnWrappable + msg string + wrappedErr error +} + +func (nf RemoteMultipartError) Error() string { + if nf.wrappedErr != nil { + return fmt.Sprintf("%s, %s", nf.msg, nf.wrappedErr.Error()) + } + return nf.msg +} + +func (nf RemoteMultipartError) Unwrap() error { + return nf.wrappedErr +} diff --git a/config/json_test.go b/config/json_test.go index 30718b59..53287891 100644 --- a/config/json_test.go +++ b/config/json_test.go @@ -21,6 +21,7 @@ package config import ( + "errors" "fmt" "io" "net/http" @@ -362,6 +363,21 @@ func TestCreateHammerMultipartPayload(t *testing.T) { } } +func TestCreateHammerMultipartPayload_RemoteErr(t *testing.T) { + t.Parallel() + jsonReader, _ := NewConfigReader(readConfigFile("config_testdata/config_multipart_err.json"), ConfigTypeJson) + + _, err := jsonReader.CreateHammer() + if err == nil { + t.Error("TestCreateHammerMultipartPayload_RemoteErr should return error") + } + + var multipartErr RemoteMultipartError + if !errors.As(err, &multipartErr) { + t.Errorf("Expected: %v, Found: %v", multipartErr, err) + } +} + func TestCreateHammerAuth(t *testing.T) { t.Parallel() jsonReader, _ := NewConfigReader(readConfigFile("config_testdata/config_auth.json"), ConfigTypeJson) diff --git a/core/engine.go b/core/engine.go index c17c5bed..fab90fc0 100644 --- a/core/engine.go +++ b/core/engine.go @@ -30,7 +30,7 @@ import ( "go.ddosify.com/ddosify/core/proxy" "go.ddosify.com/ddosify/core/report" "go.ddosify.com/ddosify/core/scenario" - "go.ddosify.com/ddosify/core/scenario/testdata" + "go.ddosify.com/ddosify/core/scenario/data" "go.ddosify.com/ddosify/core/types" ) @@ -378,7 +378,7 @@ var readTestData = func(testDataConf map[string]types.CsvConf) (map[string]types for k, conf := range testDataConf { var rows []map[string]interface{} var err error - rows, err = testdata.ReadCsv(conf) + rows, err = data.ReadCsv(conf) if err != nil { return nil, err } diff --git a/core/scenario/testdata/csv.go b/core/scenario/data/csv.go similarity index 77% rename from core/scenario/testdata/csv.go rename to core/scenario/data/csv.go index 9b9d4105..89dcb425 100644 --- a/core/scenario/testdata/csv.go +++ b/core/scenario/data/csv.go @@ -1,4 +1,4 @@ -package testdata +package data import ( "encoding/csv" @@ -20,6 +20,22 @@ func validateConf(conf types.CsvConf) error { return nil } +type RemoteCsvError struct { // UnWrappable + msg string + wrappedErr error +} + +func (nf RemoteCsvError) Error() string { + if nf.wrappedErr != nil { + return fmt.Sprintf("%s,%s", nf.msg, nf.wrappedErr.Error()) + } + return nf.msg +} + +func (nf RemoteCsvError) Unwrap() error { + return nf.wrappedErr +} + func ReadCsv(conf types.CsvConf) ([]map[string]interface{}, error) { err := validateConf(conf) if err != nil { @@ -32,15 +48,15 @@ func ReadCsv(conf types.CsvConf) ([]map[string]interface{}, error) { if pUrl, err = url.ParseRequestURI(conf.Path); err == nil && pUrl.IsAbs() { // url req, err := http.NewRequest(http.MethodGet, conf.Path, nil) if err != nil { - return nil, err + return nil, wrapAsCsvError("can not create request", err) } resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, err + return nil, wrapAsCsvError("can not get response", err) } if !(resp.StatusCode >= 200 && resp.StatusCode <= 299) { - return nil, fmt.Errorf("request to remote url failed: %d", resp.StatusCode) + return nil, wrapAsCsvError(fmt.Sprintf("request to remote url failed: %d", resp.StatusCode), nil) } reader = resp.Body defer resp.Body.Close() @@ -52,7 +68,7 @@ func ReadCsv(conf types.CsvConf) ([]map[string]interface{}, error) { reader = f defer f.Close() } else { - return nil, err + return nil, wrapAsCsvError(fmt.Sprintf("can not parse path: %s", conf.Path), err) } // read csv values using csv.Reader @@ -132,3 +148,10 @@ func emptyLine(row []string) bool { } return true } + +func wrapAsCsvError(msg string, err error) RemoteCsvError { + var csvReqError RemoteCsvError + csvReqError.msg = msg + csvReqError.wrappedErr = err + return csvReqError +} diff --git a/core/scenario/testdata/csv_test.go b/core/scenario/data/csv_test.go similarity index 59% rename from core/scenario/testdata/csv_test.go rename to core/scenario/data/csv_test.go index 87026074..25ee1bc1 100644 --- a/core/scenario/testdata/csv_test.go +++ b/core/scenario/data/csv_test.go @@ -1,7 +1,10 @@ -package testdata +package data import ( + "errors" "fmt" + "net/http" + "net/http/httptest" "reflect" "strings" "testing" @@ -29,6 +32,95 @@ func TestValidateCsvConf(t *testing.T) { } } +func TestReadCsv_RemoteErr(t *testing.T) { + t.Parallel() + conf := types.CsvConf{ + Path: "https://invalidurl.com/csv", + Delimiter: ";", + SkipFirstLine: true, + Vars: map[string]types.Tag{ + "0": {Tag: "name", Type: "string"}, + "3": {Tag: "payload", Type: "json"}, + "4": {Tag: "age", Type: "int"}, + "5": {Tag: "percent", Type: "float"}, + "6": {Tag: "boolField", Type: "bool"}, + }, + SkipEmptyLine: true, + AllowQuota: true, + Order: "sequential", + } + + _, err := ReadCsv(conf) + + if err == nil { + t.Errorf("TestReadCsv_RemoteErr %v", err) + } + + var remoteCsvErr RemoteCsvError + if !errors.As(err, &remoteCsvErr) { + t.Errorf("Expected: %v, Found: %v", remoteCsvErr, err) + } + if remoteCsvErr.Unwrap() == nil { + t.Errorf("Expected: %v, Found: %v", "not nil", remoteCsvErr.Unwrap()) + } +} + +func TestWrapAsRemoteCsvError(t *testing.T) { + msg := "xxyy" + csvErr := wrapAsCsvError(msg, fmt.Errorf("error")) + + var remoteCsvErr RemoteCsvError + if !errors.As(csvErr, &remoteCsvErr) { + t.Errorf("Expected: %v, Found: %v", remoteCsvErr, csvErr) + } + errmsg := remoteCsvErr.Error() + if errmsg != msg+",error" { + t.Errorf("Expected: %v, Found: %v", msg, remoteCsvErr.msg) + } +} + +func TestReadCsvFromRemote(t *testing.T) { + // Test server + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + } + + path := "/csv" + mux := http.NewServeMux() + mux.HandleFunc(path, handler) + + server := httptest.NewServer(mux) + defer server.Close() + + conf := types.CsvConf{ + Path: server.URL + path, + Delimiter: ";", + SkipFirstLine: true, + Vars: map[string]types.Tag{ + "0": {Tag: "name", Type: "string"}, + "3": {Tag: "payload", Type: "json"}, + "4": {Tag: "age", Type: "int"}, + "5": {Tag: "percent", Type: "float"}, + "6": {Tag: "boolField", Type: "bool"}, + }, + SkipEmptyLine: true, + AllowQuota: true, + Order: "sequential", + } + + _, err := ReadCsv(conf) + + if err == nil { + t.Errorf("TestReadCsvFromRemote %v", err) + } + + var remoteCsvErr RemoteCsvError + if !errors.As(err, &remoteCsvErr) { + t.Errorf("Expected: %v, Found: %v", remoteCsvErr, err) + } + +} + func TestReadCsv(t *testing.T) { t.Parallel() conf := types.CsvConf{