diff --git a/errors/errors.go b/errors/errors.go index ba99374..76d1ce1 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -5,6 +5,8 @@ import ( "fmt" "io" "net/http" + "strconv" + "time" ) // HTTPError preserves HTTP response details @@ -24,7 +26,24 @@ func (e *HTTPError) Error() string { // Temporary implements net.Error interface func (e *HTTPError) Temporary() bool { - return e.Response.StatusCode >= 500 + code := e.Response.StatusCode + return code == 408 || code == 429 || code >= 500 +} + +// RetryAfter returns how long to wait before retrying based on +// rate limit headers in the response +func (e *HTTPError) RetryAfter() time.Duration { + if v := e.Response.Header.Get("Retry-After"); v != "" { + if seconds, err := strconv.Atoi(v); err == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + if t, err := http.ParseTime(v); err == nil { + if delay := time.Until(t); delay > 0 { + return delay + } + } + } + return 0 } // StatusCode returns the HTTP status code diff --git a/errors/errors_test.go b/errors/errors_test.go new file mode 100644 index 0000000..0f7fad5 --- /dev/null +++ b/errors/errors_test.go @@ -0,0 +1,114 @@ +package errors + +import ( + "net/http" + "testing" + "time" +) + +func newHTTPResponse(statusCode int, headers map[string]string) *http.Response { + h := make(http.Header) + for k, v := range headers { + h.Set(k, v) + } + return &http.Response{ + StatusCode: statusCode, + Header: h, + } +} + +func TestRetryAfter_Seconds(t *testing.T) { + e := &HTTPError{Response: newHTTPResponse(429, map[string]string{ + "Retry-After": "30", + })} + + got := e.RetryAfter() + want := 30 * time.Second + if got != want { + t.Errorf("RetryAfter() = %v, want %v", got, want) + } +} + +func TestRetryAfter_HTTPDate(t *testing.T) { + future := time.Now().Add(45 * time.Second) + dateStr := future.UTC().Format(http.TimeFormat) + + e := &HTTPError{Response: newHTTPResponse(429, map[string]string{ + "Retry-After": dateStr, + })} + + got := e.RetryAfter() + if got < 43*time.Second || got > 47*time.Second { + t.Errorf("RetryAfter() = %v, want ~45s", got) + } +} + +func TestRetryAfter_NoHeader(t *testing.T) { + e := &HTTPError{Response: newHTTPResponse(429, nil)} + + got := e.RetryAfter() + if got != 0 { + t.Errorf("RetryAfter() = %v, want 0", got) + } +} + +func TestRetryAfter_InvalidValue(t *testing.T) { + e := &HTTPError{Response: newHTTPResponse(429, map[string]string{ + "Retry-After": "not-a-number", + })} + + got := e.RetryAfter() + if got != 0 { + t.Errorf("RetryAfter() = %v, want 0 for unparseable value", got) + } +} + +func TestRetryAfter_ZeroSeconds(t *testing.T) { + e := &HTTPError{Response: newHTTPResponse(429, map[string]string{ + "Retry-After": "0", + })} + + got := e.RetryAfter() + if got != 0 { + t.Errorf("RetryAfter() = %v, want 0", got) + } +} + +func TestRetryAfter_PastHTTPDate(t *testing.T) { + past := time.Now().Add(-10 * time.Second) + dateStr := past.UTC().Format(http.TimeFormat) + + e := &HTTPError{Response: newHTTPResponse(429, map[string]string{ + "Retry-After": dateStr, + })} + + got := e.RetryAfter() + if got != 0 { + t.Errorf("RetryAfter() = %v, want 0 for past date", got) + } +} + +func TestTemporary(t *testing.T) { + tests := []struct { + code int + want bool + }{ + {200, false}, + {400, false}, + {401, false}, + {403, false}, + {404, false}, + {408, true}, + {429, true}, + {500, true}, + {502, true}, + {503, true}, + } + + for _, tc := range tests { + e := &HTTPError{Response: newHTTPResponse(tc.code, nil)} + if got := e.Temporary(); got != tc.want { + t.Errorf("Temporary() for status %d = %v, want %v", tc.code, got, tc.want) + } + } +} diff --git a/users/limit.go b/users/limit.go index 3b08246..02a27a3 100644 --- a/users/limit.go +++ b/users/limit.go @@ -4,10 +4,10 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "github.com/internxt/rclone-adapter/config" + sdkerrors "github.com/internxt/rclone-adapter/errors" ) type LimitResponse struct { @@ -30,8 +30,7 @@ func GetLimit(ctx context.Context, cfg *config.Config) (*LimitResponse, error) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("GET %s returned %d: %s", url, resp.StatusCode, string(body)) + return nil, sdkerrors.NewHTTPError(resp, "get limit") } var limit LimitResponse diff --git a/users/usage.go b/users/usage.go index 09ef4c3..7dcc824 100644 --- a/users/usage.go +++ b/users/usage.go @@ -4,10 +4,10 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "github.com/internxt/rclone-adapter/config" + sdkerrors "github.com/internxt/rclone-adapter/errors" ) type UsageResponse struct { @@ -30,8 +30,7 @@ func GetUsage(ctx context.Context, cfg *config.Config) (*UsageResponse, error) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("GET %s returned %d: %s", url, resp.StatusCode, string(body)) + return nil, sdkerrors.NewHTTPError(resp, "get usage") } var usage UsageResponse diff --git a/users/users_test.go b/users/users_test.go index a9204e0..f157fed 100644 --- a/users/users_test.go +++ b/users/users_test.go @@ -3,10 +3,14 @@ package users import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" "testing" + "time" + + sdkerrors "github.com/internxt/rclone-adapter/errors" ) func TestGetUsage(t *testing.T) { @@ -260,3 +264,92 @@ func TestGetLimitHTTPClientError(t *testing.T) { t.Errorf("expected error to contain 'failed to execute', got %q", err.Error()) } } + +// TestGetUsage429ReturnsHTTPError verifies that a 429 from GetUsage returns +// an *sdkerrors.HTTPError +func TestGetUsage429ReturnsHTTPError(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "15") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"error":"rate limited"}`)) + })) + defer mockServer.Close() + + cfg := newTestConfig(mockServer.URL) + _, err := GetUsage(context.Background(), cfg) + if err == nil { + t.Fatal("expected error, got nil") + } + + var httpErr *sdkerrors.HTTPError + if !errors.As(err, &httpErr) { + t.Fatalf("expected *sdkerrors.HTTPError, got %T: %v", err, err) + } + if httpErr.StatusCode() != 429 { + t.Errorf("expected status 429, got %d", httpErr.StatusCode()) + } + if !httpErr.Temporary() { + t.Error("expected Temporary() = true for 429") + } + if got := httpErr.RetryAfter(); got != 15*time.Second { + t.Errorf("RetryAfter() = %v, want 15s", got) + } +} + +// TestGetLimit429ReturnsHTTPError verifies that a 429 from GetLimit returns +// an *sdkerrors.HTTPError +func TestGetLimit429ReturnsHTTPError(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "60") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"error":"rate limited"}`)) + })) + defer mockServer.Close() + + cfg := newTestConfig(mockServer.URL) + _, err := GetLimit(context.Background(), cfg) + if err == nil { + t.Fatal("expected error, got nil") + } + + var httpErr *sdkerrors.HTTPError + if !errors.As(err, &httpErr) { + t.Fatalf("expected *sdkerrors.HTTPError, got %T: %v", err, err) + } + if httpErr.StatusCode() != 429 { + t.Errorf("expected status 429, got %d", httpErr.StatusCode()) + } + if !httpErr.Temporary() { + t.Error("expected Temporary() = true for 429") + } + if got := httpErr.RetryAfter(); got != 60*time.Second { + t.Errorf("RetryAfter() = %v, want 60s", got) + } +} + +// TestGetUsage408ReturnsHTTPError verifies that a 408 timeout returns +// a retryable *sdkerrors.HTTPError. +func TestGetUsage408ReturnsHTTPError(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusRequestTimeout) + w.Write([]byte(`{"message":"Request timed out"}`)) + })) + defer mockServer.Close() + + cfg := newTestConfig(mockServer.URL) + _, err := GetUsage(context.Background(), cfg) + if err == nil { + t.Fatal("expected error, got nil") + } + + var httpErr *sdkerrors.HTTPError + if !errors.As(err, &httpErr) { + t.Fatalf("expected *sdkerrors.HTTPError, got %T: %v", err, err) + } + if httpErr.StatusCode() != 408 { + t.Errorf("expected status 408, got %d", httpErr.StatusCode()) + } + if !httpErr.Temporary() { + t.Error("expected Temporary() = true for 408") + } +}