diff --git a/lib/queue_manager.go b/lib/queue_manager.go index 8564193..6e2bbde 100644 --- a/lib/queue_manager.go +++ b/lib/queue_manager.go @@ -3,15 +3,21 @@ package lib import ( "context" "errors" - lru "github.com/hashicorp/golang-lru" - "github.com/hashicorp/memberlist" - "github.com/sirupsen/logrus" "net/http" "sort" "strconv" "strings" "sync" "time" + + lru "github.com/hashicorp/golang-lru" + "github.com/hashicorp/memberlist" + "github.com/sirupsen/logrus" +) + +const ( + nirnRetryCountHeader = "X-Nirn-Retry-Count" + maxRetries = 2 ) type QueueType int64 @@ -169,14 +175,26 @@ func (m *QueueManager) routeRequest(addr string, req *http.Request) (*http.Respo nodeReq, err := http.NewRequestWithContext(req.Context(), req.Method, "http://"+addr+req.URL.Path+"?"+req.URL.RawQuery, req.Body) nodeReq.Header = req.Header.Clone() nodeReq.Header.Set("nirn-routed-to", addr) + + // Increment retry count header when routing + retryCount := 0 + if s := req.Header.Get(nirnRetryCountHeader); s != "" { + if v, err := strconv.Atoi(s); err == nil { + retryCount = v + } + } + retryCount++ + nodeReq.Header.Set(nirnRetryCountHeader, strconv.Itoa(retryCount)) + if err != nil { return nil, err } logger.WithFields(logrus.Fields{ - "to": addr, - "path": req.URL.Path, - "method": req.Method, + "to": addr, + "path": req.URL.Path, + "method": req.Method, + "retryCount": retryCount, }).Trace("Routing request to node in cluster") resp, err := client.Do(nodeReq) logger.WithFields(logrus.Fields{ @@ -261,11 +279,25 @@ func (m *QueueManager) getOrCreateBearerQueue(token string) (*RequestQueue, erro } func (m *QueueManager) DiscordRequestHandler(resp http.ResponseWriter, req *http.Request) { + // Check retry count early and fail fast to prevent infinite loops + if s := req.Header.Get(nirnRetryCountHeader); s != "" { + if v, err := strconv.Atoi(s); err == nil && v > maxRetries { + logger.WithFields(logrus.Fields{ + "retryCount": v, + "maxRetries": maxRetries, + "path": req.URL.Path, + "method": req.Method, + }).Warn("Request exceeded max retry count, rejecting") + Generate429(&resp) + return + } + } + reqStart := time.Now() metricsPath := GetMetricsPath(req.URL.Path) token := req.Header.Get("Authorization") - clientId := GetBotId(token) + clientId := GetBotId(token) ConnectionsOpen.With(map[string]string{"route": metricsPath, "method": req.Method, "clientId": clientId}).Inc() defer ConnectionsOpen.With(map[string]string{"route": metricsPath, "method": req.Method, "clientId": clientId}).Dec()