From f2b8305c06a3ef91e4c572b4d1d49f5d0594ebdb Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Sun, 14 May 2023 15:11:00 -0500 Subject: [PATCH] Add Generic gateway test and bugfixes fix broken tests (random seed) --- gateway/generic.go | 12 ++++++++++-- main.go | 4 ++++ main_test.go | 28 +++++++++++++++++++++++++++- version_test.go | 6 ++++++ 4 files changed, 47 insertions(+), 3 deletions(-) diff --git a/gateway/generic.go b/gateway/generic.go index 8574cd7..68b3c9b 100644 --- a/gateway/generic.go +++ b/gateway/generic.go @@ -35,7 +35,10 @@ func (m Generic) Get() []byte { func (m Generic) Req(body []byte, req http.Request) ([]*http.Request, error) { myurl := req.URL.EscapedPath() - encodedEndpoint := strings.SplitN(myurl, "/", 4)[2] + encodedEndpoint := "" + if encodedEndpoints := strings.SplitN(myurl, "/", 4); len(encodedEndpoints) >= 3 { + encodedEndpoint = encodedEndpoints[2] + } endpointBytes, err := base64.RawURLEncoding.DecodeString(encodedEndpoint) if err != nil { return nil, fmt.Errorf("Encoded endpoint not valid base64: %w", err) @@ -63,7 +66,12 @@ func (m Generic) Req(body []byte, req http.Request) ([]*http.Request, error) { } func (Generic) Resp(r []*http.Response, w http.ResponseWriter) { - w.WriteHeader(r[0].StatusCode) + if r[0] != nil { + w.WriteHeader(r[0].StatusCode) + } else { + w.WriteHeader(500) + } + w.Header().Add("TTL", "0") } func (m *Generic) Defaults() (failed bool) { diff --git a/main.go b/main.go index f526ca4..b870e04 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,10 @@ var configFile = flag.String("c", "config.toml", "path to toml file for config") // various translaters var handlers = []Handler{} +func init() { + log.SetFlags(log.LstdFlags | log.Lmicroseconds) +} + func main() { flag.Parse() err := ParseConf(*configFile) diff --git a/main_test.go b/main_test.go index f3432d4..3823c80 100644 --- a/main_test.go +++ b/main_test.go @@ -2,8 +2,10 @@ package main import ( "bytes" + "encoding/base64" "io" "io/ioutil" + "math/rand" "net/http" "net/http/httptest" "net/url" @@ -43,6 +45,7 @@ func (s *RewriteTests) SetupTest() { u, _ := url.Parse(s.ts.URL) config.Config.Gateway.AllowedHosts = []string{u.Host} + ForceBypassValidProviderCheck(s.ts.URL, "http://temp.test") s.resetTest() } @@ -62,6 +65,7 @@ const goodFCMResponse = `{"Results": [{"Error":""}]}` func (s *RewriteTests) TestFCM() { fcm := rewrite.FCM{Key: "testkey", APIURL: s.ts.URL} + rand.Seed(0) myFancyContent, myFancyContent64 := myFancyContentGenerate() cases := [][]string{ @@ -69,7 +73,7 @@ func (s *RewriteTests) TestFCM() { {"FCMD", "/?token=a&app=a", `{"to":"a","data":{"app":"a","body":"content"}}`, `content`}, {"FCMv2", "/?token=a&instance=myinst&v2", `{"to":"a","data":{"b":"Y29udGVudA==","i":"myinst"}}`, `content`}, {"FCMv2-2", "/?v2&token=a&instance=myinst", `{"to":"a","data":{"b":"Y29udGVudA==","i":"myinst"}}`, `content`}, - {"FCMv2-3", "/?v2&token=a&instance=myinst", `{"to":"a","data":{"b":"` + myFancyContent64[3000:] + `","i":"myinst","m":"5577006791947779411","s":"2"}}`, myFancyContent}, // this test only tests the second value because that's much easier than testing for the first one due to the architecture of this file. Someday I'll fix that TODO. + {"FCMv2-3", "/?v2&token=a&instance=myinst", `{"to":"a","data":{"b":"` + myFancyContent64[3000:] + `","i":"myinst","m":"8717895732742165506","s":"2"}}`, myFancyContent}, // this test only tests the second value because that's much easier than testing for the first one due to the architecture of this file. Someday I'll fix that TODO. } for _, i := range cases { @@ -181,6 +185,28 @@ func (s *RewriteTests) TestMatrixResp() { //TODO } +func (s *RewriteTests) TestGenericGateway() { + gw := gateway.Generic{} + + content := `this is + +my msg` + request := httptest.NewRequest("POST", "/generic/"+base64.RawURLEncoding.EncodeToString([]byte(s.ts.URL)), bytes.NewBufferString(content)) + request.Header.Add("cOntent-Encoding", "aesgcm") + request.Header.Add("cryPTo-KEY", `dh="BNoRDbb84JGm8g5Z5CFxurSqsXWJ11ItfXEWYVLE85Y7CYkDjXsIEc4aqxYaQ1G8BqkXCJ6DPpDrWtdWj_mugHU"`) + request.Header.Add("EncRYPTION", `Encryption: salt="lngarbyKfMoi9Z75xYXmkg"`) + handle(&gw)(s.Resp, request) + + s.Equal(200, s.Resp.Result().StatusCode, "request should be valid") + s.Equal(`this is + +my msg +dh="BNoRDbb84JGm8g5Z5CFxurSqsXWJ11ItfXEWYVLE85Y7CYkDjXsIEc4aqxYaQ1G8BqkXCJ6DPpDrWtdWj_mugHU" +Encryption: salt="lngarbyKfMoi9Z75xYXmkg" +aesgcm`, string(s.CallBody), "body should match") + +} + func (s *RewriteTests) TestHealth() { resp, err := http.Get(s.ts.URL + "/health") s.Require().Nil(err) diff --git a/version_test.go b/version_test.go index 26ae0e9..ae0ea79 100644 --- a/version_test.go +++ b/version_test.go @@ -13,6 +13,12 @@ import ( "github.com/stretchr/testify/assert" ) +func ForceBypassValidProviderCheck(host ...string) { + for _, i := range host { + allowedProxies.Set(i, true, 1000*time.Minute) + } +} + func TestCheckRewriteProxy(t *testing.T) { c := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error {