diff --git a/dune/dune.go b/dune/dune.go index 1b2dc86..3a57d37 100644 --- a/dune/dune.go +++ b/dune/dune.go @@ -169,6 +169,9 @@ func (c *duneClient) QueryExecute(queryID int, queryParameters map[string]any) ( if err != nil { return nil, err } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(jsonData)), nil } resp, err := httpRequest(c.env.APIKey, req) if err != nil { return nil, err @@ -197,7 +200,9 @@ func (c *duneClient) SQLExecute(sql string, performance string) (*models.Execute if err != nil { return nil, err } - + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(jsonData)), nil } resp, err := httpRequest(c.env.APIKey, req) if err != nil { return nil, err @@ -225,6 +230,9 @@ func (c *duneClient) QueryPipelineExecute(queryID string, performance string) (* if err != nil { return nil, err } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(jsonData)), nil } resp, err := httpRequest(c.env.APIKey, req) if err != nil { @@ -382,6 +390,9 @@ func (c *duneClient) getUsage(startDate, endDate *string) (*models.UsageResponse if err != nil { return nil, err } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(jsonData)), nil } resp, err := httpRequest(c.env.APIKey, req) if err != nil { diff --git a/dune/execution.go b/dune/execution.go index 552df8d..5ae48f0 100644 --- a/dune/execution.go +++ b/dune/execution.go @@ -70,16 +70,21 @@ func (e *execution) GetResultsCSV() (io.Reader, error) { } func (e *execution) WaitGetResults(pollInterval time.Duration, maxRetries int) (*models.ResultsResponse, error) { - errCount := 0 + errAttempts := 0 for { resultsResp, err := e.client.QueryResultsV2(e.ID, models.ResultOptions{}) if err != nil { - if maxRetries != 0 && errCount > maxRetries { + errAttempts++ + if maxRetries != 0 && errAttempts >= maxRetries { return nil, fmt.Errorf("%w. %s", ErrorRetriesExhausted, err.Error()) } fmt.Fprintln(os.Stderr, "failed to retrieve results. Retrying...\n", err) - errCount += 1 - } else if resultsResp.IsExecutionFinished { + sleep := defaultRetryPolicy.NextBackoff(errAttempts) + time.Sleep(sleep) + continue + } + errAttempts = 0 + if resultsResp.IsExecutionFinished { return resultsResp, nil } time.Sleep(pollInterval) diff --git a/dune/http.go b/dune/http.go index 346464c..a06db64 100644 --- a/dune/http.go +++ b/dune/http.go @@ -4,7 +4,10 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" + "strconv" + "time" ) var ErrorReqUnsuccessful = errors.New("request was not successful") @@ -13,6 +16,59 @@ type ErrorResponse struct { Error string `json:"error"` } +type RateLimit struct { + Limit int + Remaining int + Reset int64 +} + +type APIError struct { + StatusCode int + StatusText string + BodySnippet string + RateLimit *RateLimit + RetryAfter time.Duration +} + +func (e *APIError) Error() string { + if e.BodySnippet != "" { + return fmt.Sprintf("http %d %s: %s", e.StatusCode, e.StatusText, e.BodySnippet) + } + return fmt.Sprintf("http %d %s", e.StatusCode, e.StatusText) +} + + +func parseRateLimitHeaders(h http.Header) *RateLimit { + limStr := h.Get("X-RateLimit-Limit") + remStr := h.Get("X-RateLimit-Remaining") + resetStr := h.Get("X-RateLimit-Reset") + + var limit, remaining int + var reset int64 + + if limStr != "" { + if v, err := strconv.Atoi(limStr); err == nil { + limit = v + } + } + if remStr != "" { + if v, err := strconv.Atoi(remStr); err == nil { + remaining = v + } + } + if resetStr != "" { + if v, err := strconv.ParseInt(resetStr, 10, 64); err == nil { + reset = v + } + } + + if limit == 0 && remaining == 0 && reset == 0 { + return nil + } + return &RateLimit{Limit: limit, Remaining: remaining, Reset: reset} +} + + func decodeBody(resp *http.Response, dest interface{}) error { defer resp.Body.Close() err := json.NewDecoder(resp.Body).Decode(dest) @@ -24,20 +80,86 @@ func decodeBody(resp *http.Response, dest interface{}) error { func httpRequest(apiKey string, req *http.Request) (*http.Response, error) { req.Header.Add("X-DUNE-API-KEY", apiKey) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - - if resp.StatusCode != 200 { - defer resp.Body.Close() - var errorResponse ErrorResponse - err := json.NewDecoder(resp.Body).Decode(&errorResponse) + p := defaultRetryPolicy + attempt := 1 + for { + if attempt > 1 && req.Body != nil { + if req.GetBody != nil { + b, err := req.GetBody() + if err != nil { + return nil, fmt.Errorf("failed to reset request body: %w", err) + } + req.Body = b + } + } + resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to read error response body: %w", err) + if attempt >= p.MaxAttempts { + return nil, fmt.Errorf("failed to send request: %w", err) + } + time.Sleep(p.NextBackoff(attempt)) + attempt++ + continue + } + + if resp.StatusCode == 200 { + return resp, nil } - return resp, fmt.Errorf("%w [%d]: %s", ErrorReqUnsuccessful, resp.StatusCode, errorResponse.Error) - } - return resp, nil + snippetBytes, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + _ = resp.Body.Close() + var errorResp ErrorResponse + msg := string(snippetBytes) + if err := json.Unmarshal(snippetBytes, &errorResp); err == nil && errorResp.Error != "" { + msg = errorResp.Error + } + rl := parseRateLimitHeaders(resp.Header) + retryAfter := time.Duration(0) + if ra := resp.Header.Get("Retry-After"); ra != "" { + if secs, err := strconv.Atoi(ra); err == nil { + retryAfter = time.Duration(secs) * time.Second + } + } + apiErr := &APIError{StatusCode: resp.StatusCode, StatusText: resp.Status, BodySnippet: msg, RateLimit: rl, RetryAfter: retryAfter} + retryable := false + for _, code := range p.RetryableStatusCodes { + if resp.StatusCode == code { + retryable = true + break + } + } + if retryable && attempt < p.MaxAttempts { + sleep := p.NextBackoff(attempt) + if apiErr.RetryAfter > 0 && apiErr.RetryAfter > sleep { + sleep = apiErr.RetryAfter + } + time.Sleep(sleep) + attempt++ + continue + } + return nil, fmt.Errorf("%w: %v", ErrorReqUnsuccessful, apiErr) + rl := parseRateLimitHeaders(resp.Header) + retryAfter := time.Duration(0) + if ra := resp.Header.Get("Retry-After"); ra != "" { + if secs, err := strconv.Atoi(ra); err == nil { + retryAfter = time.Duration(secs) * time.Second + } + } + apiErr := &APIError{ + StatusCode: resp.StatusCode, + StatusText: resp.Status, + BodySnippet: msg, + RateLimit: rl, + RetryAfter: retryAfter, + } + retryable := false + for _, code := range p.RetryableStatusCodes { + if resp.StatusCode == code { + retryable = true + break + } + } + // unreachable due to early returns above; kept for clarity + return nil, fmt.Errorf("%w: unexpected error state", ErrorReqUnsuccessful) + } } diff --git a/dune/retries.go b/dune/retries.go new file mode 100644 index 0000000..2201777 --- /dev/null +++ b/dune/retries.go @@ -0,0 +1,34 @@ +package dune + +import "time" + +type RetryPolicy struct { + MaxAttempts int + InitialBackoff time.Duration + MaxBackoff time.Duration + Jitter time.Duration + RetryableStatusCodes []int +} + +var defaultRetryPolicy = RetryPolicy{ + MaxAttempts: 5, + InitialBackoff: 2 * time.Second, + MaxBackoff: 60 * time.Second, + Jitter: 250 * time.Millisecond, + RetryableStatusCodes: []int{429, 500, 502, 503, 504}, +} + +func (p RetryPolicy) NextBackoff(attempt int) time.Duration { + b := p.InitialBackoff + for i := 1; i < attempt; i++ { + b *= 2 + if b > p.MaxBackoff { + b = p.MaxBackoff + break + } + } + if p.Jitter > 0 { + b += p.Jitter + } + return b +}