Skip to content

Commit

Permalink
feat(proxy): refactor proxy error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
unkn0wn-root committed Jan 3, 2025
1 parent 7f9f03a commit 1dd3042
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 15 deletions.
206 changes: 206 additions & 0 deletions internal/pool/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package pool

import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"syscall"
)

// Common proxy error types
var (
ErrBackendUnavailable = errors.New("server unavailable")
ErrBackendTimeout = errors.New("server timeout")
ErrInvalidRedirect = errors.New("invalid redirect received from server")
)

// ProxyErrorCode represents specific error conditions in the proxy
type ProxyErrorCode int

const (
ErrCodeUnknown ProxyErrorCode = iota
ErrCodeBackendConnFailed
ErrCodeBackendTimeout
ErrCodeInvalidResponse
ErrCodeTLSError
ErrCodeClientDisconnect
)

// ProxyError represents a detailed error that occurs during proxy operations
type ProxyError struct {
Op string
Code ProxyErrorCode
Message string
Err error
Retryable bool
StatusCode int
}

func (e *ProxyError) Error() string {
if e.Err != nil {
return fmt.Sprintf("%s: %s: %v", e.Op, e.Message, e.Err)
}
return fmt.Sprintf("%s: %s", e.Op, e.Message)
}

func (e *ProxyError) Unwrap() error {
return e.Err
}

// IsTemporaryError determines if an error is temporary and the request can be retried
func IsTemporaryError(err error) bool {
// Check our custom error first
var proxyErr *ProxyError
if errors.As(err, &proxyErr) {
return proxyErr.Retryable
}

// Check for network operation timeouts
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}

// Check for specific network errors
var opErr *net.OpError
if errors.As(err, &opErr) {
// Check for specific syscall errors
var syscallErr syscall.Errno
if errors.As(opErr.Err, &syscallErr) {
switch syscallErr {
case
syscall.ECONNREFUSED,
syscall.ECONNRESET,
syscall.ETIMEDOUT,
syscall.EPIPE,
syscall.ECONNABORTED,
syscall.EHOSTDOWN,
syscall.ENETUNREACH,
syscall.EHOSTUNREACH:
return true
}
}

// Check for DNS temporary errors
var dnsErr *net.DNSError
if errors.As(opErr.Err, &dnsErr) {
return dnsErr.IsTemporary
}
}

return false
}

// NewProxyError creates a new ProxyError with appropriate defaults based on the error type
func NewProxyError(op string, err error) *ProxyError {
pe := &ProxyError{
Op: op,
Err: err,
Code: ErrCodeUnknown,
StatusCode: http.StatusBadGateway,
Retryable: false,
}

switch {
case errors.Is(err, context.Canceled):
pe.Code = ErrCodeClientDisconnect
pe.Message = "Request canceled by client"
pe.StatusCode = 499 // Client closed request
pe.Retryable = false

case errors.Is(err, ErrBackendUnavailable):
pe.Code = ErrCodeBackendConnFailed
pe.Message = "Backend server unavailable"
pe.StatusCode = http.StatusBadGateway
pe.Retryable = true

case errors.Is(err, ErrBackendTimeout):
pe.Code = ErrCodeBackendTimeout
pe.Message = "Backend server timeout"
pe.StatusCode = http.StatusGatewayTimeout
pe.Retryable = true

default:
// Check for network errors
var opErr *net.OpError
if errors.As(err, &opErr) {
pe.Retryable = IsTemporaryError(err)

// Handle DNS errors specifically
var dnsErr *net.DNSError
if errors.As(opErr.Err, &dnsErr) {
pe.Code = ErrCodeBackendConnFailed
pe.Message = fmt.Sprintf("DNS error: %s", dnsErr.Error())
pe.StatusCode = http.StatusBadGateway
pe.Retryable = dnsErr.IsTemporary
return pe
}

// Handle syscall errors
var syscallErr syscall.Errno
if errors.As(opErr.Err, &syscallErr) {
switch syscallErr {
case syscall.ECONNREFUSED:
pe.Message = "Connection refused by backend"
case syscall.ECONNRESET:
pe.Message = "Connection reset by backend"
case syscall.ETIMEDOUT:
pe.Code = ErrCodeBackendTimeout
pe.Message = "Connection timed out"
pe.StatusCode = http.StatusGatewayTimeout
default:
pe.Message = fmt.Sprintf("Network error: %s", syscallErr.Error())
}
pe.Code = ErrCodeBackendConnFailed
return pe
}
}

// Handle standard net.Error timeouts
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
pe.Code = ErrCodeBackendTimeout
pe.Message = "Network timeout"
pe.StatusCode = http.StatusGatewayTimeout
pe.Retryable = true
} else {
pe.Code = ErrCodeBackendConnFailed
pe.Message = "Network error"
pe.Retryable = IsTemporaryError(err)
}
return pe
}

// Generic error handling
pe.Message = fmt.Sprintf("Unexpected error: %v", err)
}

return pe
}

// ErrorResponse represents the structure of error responses sent to clients
type ErrorResponse struct {
Status string `json:"status"`
Message string `json:"message"`
}

// WriteErrorResponse writes a structured error response to the client
func WriteErrorResponse(w http.ResponseWriter, err error) {
var pe *ProxyError
if !errors.As(err, &pe) {
pe = NewProxyError("unknown", err)
}

response := ErrorResponse{
Status: "error",
Message: pe.Message,
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(pe.StatusCode)
json.NewEncoder(w).Encode(response)
}
29 changes: 14 additions & 15 deletions internal/pool/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,6 @@ const (
DefaultProxyLabel = "terraster"
)

// ProxyError represents an error that occurs during proxy operations.
type ProxyError struct {
Op string // Op describes the operation being performed when the error occurred.
Err error // Err is the underlying error that was encountered.
}

// Error implements the error interface for ProxyError.
func (e *ProxyError) Error() string {
return fmt.Sprintf("proxy error during %s: %v", e.Op, e.Err)
}

// RouteConfig holds configuration settings for routing requests through the proxy.
type RouteConfig struct {
Path string // Path is the proxy path (upstream) used to match incoming requests (optional).
Expand Down Expand Up @@ -165,7 +154,12 @@ func (p *URLRewriteProxy) director(req *http.Request) {

// RoundTrip implements the RoundTripper interface for the Transport type.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.transport.RoundTrip(req)
r, err := t.transport.RoundTrip(req)
if err != nil {
return nil, NewProxyError("round_trip", err)
}

return r, nil
}

// updateRequestHeaders modifies the HTTP request headers before forwarding the request to the backend.
Expand All @@ -188,7 +182,7 @@ func (p *URLRewriteProxy) handleRedirect(resp *http.Response) error {
location := resp.Header.Get(HeaderLocation)
locURL, err := url.Parse(location)
if err != nil {
return &ProxyError{Op: "parse_redirect_url", Err: err} // Return a ProxyError if parsing fails.
return NewProxyError("handle_redirect", fmt.Errorf("invalid redirect URL: %w", err))
}

// Ensure that redirects to external hosts are not rewritten.
Expand Down Expand Up @@ -252,6 +246,11 @@ func (p *URLRewriteProxy) errorHandler(w http.ResponseWriter, r *http.Request, e
return
}

p.logger.Error("Unexpected error in proxy", zap.Error(err))
http.Error(w, "Something went wrong", http.StatusInternalServerError)
p.logger.Error("Proxy error",
zap.Error(err),
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
)

WriteErrorResponse(w, err)
}

0 comments on commit 1dd3042

Please sign in to comment.