Skip to content

Commit

Permalink
refactor: add headers handler to the pool package
Browse files Browse the repository at this point in the history
  • Loading branch information
unkn0wn-root committed Dec 30, 2024
1 parent e0d9fa6 commit 10779ba
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 60 deletions.
2 changes: 1 addition & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ type Service struct {
Middleware []Middleware `yaml:"middleware"` // Middleware configurations specific to the service.
Locations []Location `yaml:"locations"` // Routing paths and backend configurations for the service.
LogName string `yaml:"log_name,omitempty"` // Name of the logger to use for this service.
Headers HeaderConfig `yaml:"headers,omitempty"` // Custom headers configuration for request and response objects
Headers *HeaderConfig `yaml:"headers,omitempty"` // Custom headers configuration for request and response objects
}

// HeaderConfig is custom response and request headers modifier
Expand Down
67 changes: 67 additions & 0 deletions internal/pool/header_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package pool

import (
"net/http"
"strings"

"github.com/unkn0wn-root/terraster/internal/config"
)

// HeaderHandler manages request and response header modifications
type HeaderHandler struct {
headerConfig config.HeaderConfig
placeholders map[string]func(*http.Request) string
}

// NewHeaderHandler creates a new HeaderHandler
func NewHeaderHandler(cfg config.HeaderConfig) *HeaderHandler {
return &HeaderHandler{
headerConfig: cfg,
placeholders: map[string]func(*http.Request) string{
"${remote_addr}": func(r *http.Request) string { return r.RemoteAddr },
"${host}": func(r *http.Request) string { return r.Host },
"${uri}": func(r *http.Request) string { return r.RequestURI },
"${method}": func(r *http.Request) string { return r.Method },
},
}
}

// ProcessRequestHeaders modifies the request headers
func (h *HeaderHandler) ProcessRequestHeaders(req *http.Request) {
for _, header := range h.headerConfig.RemoveRequestHeaders {
req.Header.Del(header)
}

for key, value := range h.headerConfig.RequestHeaders {
processedValue := h.processPlaceholders(value, req)
req.Header.Set(key, processedValue)
}
}

// ProcessResponseHeaders modifies the response headers
func (h *HeaderHandler) ProcessResponseHeaders(resp *http.Response) {
for _, header := range h.headerConfig.RemoveResponseHeaders {
resp.Header.Del(header)
}

for key, value := range h.headerConfig.ResponseHeaders {
processedValue := h.processPlaceholders(value, resp.Request)
resp.Header.Set(key, processedValue)
}
}

// processPlaceholders replaces placeholder values with actual request values
func (h *HeaderHandler) processPlaceholders(value string, req *http.Request) string {
if req == nil {
return value
}

result := value
for placeholder, getter := range h.placeholders {
if strings.Contains(value, placeholder) {
result = strings.ReplaceAll(result, placeholder, getter(req))
}
}

return result
}
8 changes: 6 additions & 2 deletions internal/pool/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ func WithLogger(logger *zap.Logger) ProxyOption {
}

// WithHeaderConfig sets custom req/res headers
func WithHeaderConfig(config *config.HeaderConfig) ProxyOption {
func WithHeaderConfig(cfg *config.HeaderConfig) ProxyOption {
return func(p *URLRewriteProxy) {
p.headerConfig = config
if cfg == nil {
return
}

p.headerHandler = NewHeaderHandler(*cfg)
}
}
61 changes: 12 additions & 49 deletions internal/pool/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ import (
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"

"github.com/unkn0wn-root/terraster/internal/config"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -69,14 +67,14 @@ func NewTransport(transport http.RoundTripper, skipTLSVerify bool) *Transport {

// URLRewriteProxy is a custom reverse proxy that handles URL rewriting and redirection based on RouteConfig.
type URLRewriteProxy struct {
proxy *httputil.ReverseProxy // proxy is the underlying reverse proxy handling the HTTP requests.
target *url.URL // target is the destination URL to which the proxy forwards requests.
path string // path is the URL path prefix that this proxy handles.
rewriteURL string // rewriteURL specifies the URL to which incoming requests should be rewritten.
urlRewriter *URLRewriter // urlRewriter handles the logic for rewriting request URLs and managing redirects.
rConfig RewriteConfig // rConfig holds the rewrite and redirect configurations.
logger *zap.Logger // logger is used for logging proxy-related activities.
headerConfig *config.HeaderConfig // headerConfig is used to modify request/response headers
proxy *httputil.ReverseProxy // proxy is the underlying reverse proxy handling the HTTP requests.
target *url.URL // target is the destination URL to which the proxy forwards requests.
path string // path is the URL path prefix that this proxy handles.
rewriteURL string // rewriteURL specifies the URL to which incoming requests should be rewritten.
urlRewriter *URLRewriter // urlRewriter handles the logic for rewriting request URLs and managing redirects.
rConfig RewriteConfig // rConfig holds the rewrite and redirect configurations.
logger *zap.Logger // logger is used for logging proxy-related activities.
headerHandler *HeaderHandler // headerHandler is used to modify request/response headers
}

// ProxyOption defines a function type for applying optional configurations to URLRewriteProxy instances.
Expand Down Expand Up @@ -177,15 +175,8 @@ func (p *URLRewriteProxy) updateRequestHeaders(req *http.Request) {
req.Header.Set(HeaderXForwardedHost, originalHost)
req.Header.Set(HeaderXForwardedFor, originalHost)

if hc := p.headerConfig; hc != nil {
for _, header := range hc.RemoveRequestHeaders {
req.Header.Del(header)
}

for key, value := range hc.RequestHeaders {
processedValue := p.processHeaderValue(value, req)
req.Header.Set(key, processedValue)
}
if p.headerHandler != nil {
p.headerHandler.ProcessRequestHeaders(req)
}
}

Expand Down Expand Up @@ -229,15 +220,8 @@ func (p *URLRewriteProxy) updateResponseHeaders(resp *http.Response) {
resp.Header.Del(HeaderXPoweredBy)
resp.Header.Set(HeaderXProxyBy, DefaultProxyLabel)

if hc := p.headerConfig; hc != nil {
for _, header := range hc.RemoveResponseHeaders {
resp.Header.Del(header)
}

for key, value := range hc.ResponseHeaders {
processedValue := p.processHeaderValue(value, resp.Request)
resp.Header.Set(key, processedValue)
}
if p.headerHandler != nil {
p.headerHandler.ProcessResponseHeaders(resp)
}
}

Expand Down Expand Up @@ -267,24 +251,3 @@ func (p *URLRewriteProxy) errorHandler(w http.ResponseWriter, r *http.Request, e
p.logger.Error("Unexpected error in proxy", zap.Error(err))
http.Error(w, "Something went wrong", http.StatusInternalServerError)
}

// processHeaderValue replaces placeholder values with actual req values
// if string placeholders in not being used - it returns value back
func (p *URLRewriteProxy) processHeaderValue(value string, req *http.Request) string {
// Replace placeholders with actual values
placeholders := map[string]func(*http.Request) string{
"${remote_addr}": func(r *http.Request) string { return r.RemoteAddr },
"${host}": func(r *http.Request) string { return r.Host },
"${uri}": func(r *http.Request) string { return r.RequestURI },
"${method}": func(r *http.Request) string { return r.Method },
}

result := value
for placeholder, getter := range placeholders {
if strings.Contains(value, placeholder) {
result = strings.ReplaceAll(result, placeholder, getter(req))
}
}

return result
}
14 changes: 7 additions & 7 deletions internal/pool/server_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ type BackendSnapshot struct {

// ServerPool manages a pool of backend servers, handling load balancing and connection management.
type ServerPool struct {
backends atomic.Value // Atomic value storing the current BackendSnapshot.
current uint64 // Atomic counter used for round-robin load balancing.
algorithm atomic.Value // Atomic value storing the current load balancing algorithm.
maxConnections atomic.Int32 // Atomic integer representing the maximum allowed connections per backend.
log *zap.Logger // Logger instance for logging pool activities.
serviceHeaders config.HeaderConfig // Service request and response custom headers
backends atomic.Value // Atomic value storing the current BackendSnapshot.
current uint64 // Atomic counter used for round-robin load balancing.
algorithm atomic.Value // Atomic value storing the current load balancing algorithm.
maxConnections atomic.Int32 // Atomic integer representing the maximum allowed connections per backend.
log *zap.Logger // Logger instance for logging pool activities.
serviceHeaders *config.HeaderConfig // Service request and response custom headers
}

func NewServerPool(svc *config.Service, logger *zap.Logger) *ServerPool {
Expand Down Expand Up @@ -73,7 +73,7 @@ func (s *ServerPool) AddBackend(
createProxy,
s.log,
WithURLRewriter(rc, url),
WithHeaderConfig(&s.serviceHeaders),
WithHeaderConfig(s.serviceHeaders),
)

maxConnections := cfg.MaxConnections
Expand Down
2 changes: 1 addition & 1 deletion internal/service/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type ServiceInfo struct {
Middleware []config.Middleware // Middleware configurations for the service.
LogName string // LogName will be used to get service logger from config.
Logger *zap.Logger // Logger instance for logging service activities.
Headers config.HeaderConfig // Request/Response custom headers
Headers *config.HeaderConfig // Request/Response custom headers
}

// ServiceType determines the protocol type of the service based on its TLS configuration.
Expand Down

0 comments on commit 10779ba

Please sign in to comment.