From 338bb7dd2ac206b2e3c63e7ac001c2d2b23b30ca Mon Sep 17 00:00:00 2001 From: Philip Su Date: Tue, 2 Apr 2024 06:49:59 -0700 Subject: [PATCH] Add request length check to json rpcs (#221) * Add request length check to avoid ddos * Add tests --- rpc/jsonrpc/server/http_json_handler.go | 7 +++ rpc/jsonrpc/server/http_json_handler_test.go | 62 ++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/rpc/jsonrpc/server/http_json_handler.go b/rpc/jsonrpc/server/http_json_handler.go index 8ab88bcea..bbaa0727d 100644 --- a/rpc/jsonrpc/server/http_json_handler.go +++ b/rpc/jsonrpc/server/http_json_handler.go @@ -15,6 +15,8 @@ import ( // HTTP + JSON handler +const REQUEST_BATCH_SIZE_LIMIT = 10 + // jsonrpc calls grab the given method's function info and runs reflect.Call func makeJSONRPCHandler(funcMap map[string]*RPCFunc, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, hreq *http.Request) { @@ -41,6 +43,11 @@ func makeJSONRPCHandler(funcMap map[string]*RPCFunc, logger log.Logger) http.Han } requests, err := parseRequests(b) + if len(requests) > REQUEST_BATCH_SIZE_LIMIT { + writeRPCResponse(w, logger, rpctypes.RPCRequest{}.MakeErrorf( + rpctypes.CodeParseError, "Batch size limit exceeded.")) + return + } if err != nil { writeRPCResponse(w, logger, rpctypes.RPCRequest{}.MakeErrorf( rpctypes.CodeParseError, "decoding request: %v", err)) diff --git a/rpc/jsonrpc/server/http_json_handler_test.go b/rpc/jsonrpc/server/http_json_handler_test.go index dd4a9d8e2..a62b61aab 100644 --- a/rpc/jsonrpc/server/http_json_handler_test.go +++ b/rpc/jsonrpc/server/http_json_handler_test.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -221,6 +222,67 @@ func TestRPCNotificationInBatch(t *testing.T) { } } +func TestRPCBatchLimit(t *testing.T) { + mux := testMux() + tests := []struct { + payload string + success bool + }{ + { + `[ + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]} + ]`, + true, + }, + { + `[ + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}, + {"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]} + ]`, + false, + }, + } + for i, tt := range tests { + req, _ := http.NewRequest("POST", "http://localhost/", strings.NewReader(tt.payload)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + // Always expecting back a JSONRPCResponse + assert.True(t, statusOK(res.StatusCode), "#%d: should always return 2XX", i) + blob, err := io.ReadAll(res.Body) + + fmt.Printf("responses: %s\n", blob) + if err != nil { + t.Errorf("#%d: err reading body: %v", i, err) + continue + } + res.Body.Close() + + var responses []rpctypes.RPCResponse + err = json.Unmarshal(blob, &responses) + if err != nil { + if tt.success { + t.Errorf("#%d: expected successful parsing of an RPCResponse\nblob: %s", i, blob) + continue + } else { + fmt.Printf("blob: %s, %d\n", responses, len(responses)) + assert.Contains(t, string(blob), "Batch size limit exceeded.") + } + } + + } +} + func TestUnknownRPCPath(t *testing.T) { mux := testMux() req, _ := http.NewRequest("GET", "http://localhost/unknownrpcpath", strings.NewReader(""))