diff --git a/README.md b/README.md index fd1daf9..b754e35 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ A high-performance, feature-rich Layer 7 (L7) load balancer with a robust and us - ✅ Dynamic Middleware Plug-in - ✅ Configurable Request Logging - ✅ Restrict access to API via IPs whitelist +- ✅ Custom Request Headers ### Core Features - ✅ Health Checking @@ -43,7 +44,6 @@ A high-performance, feature-rich Layer 7 (L7) load balancer with a robust and us ## WIP - ⏳ WebSocket Support (WIP) - ⏳ Automatic Certificate Management (WIP) -- ❌ Custom Request Headers ## Quick Start @@ -156,18 +156,29 @@ services: host: internal-api1.local.com port: 8455 log_name: backend-api # Maps to logger configuration - + headers: # Custom headers + request_headers: + X-Custom-Header: "custom-value" + response_headers: + Cache-Control: "no-cache" + remove_request_headers: + - User-Agent + - Accept-Encoding + remove_response_headers: + - Server + - X-Powered-By + # Service-specific TLS tls: cert_file: "/path/to/api-cert.pem" key_file: "/path/to/api-key.pem" - + # Service-specific middleware (overrides global) middleware: - rate_limit: requests_per_second: 2500 burst: 500 - + # Service-specific health check health_check: type: "http" @@ -177,7 +188,7 @@ services: thresholds: healthy: 2 unhealthy: 3 - + # Path-based routing locations: - path: "/api/" diff --git a/internal/config/config.go b/internal/config/config.go index 404e004..4a6989e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -177,6 +177,15 @@ type Service struct { Middleware []Middleware `yaml:"middleware"` // Middleware configurations specific to the service. Locations []Location `yaml:"locations"` // Routing paths and backend configurations for the service. LogName string `yaml:"log_name,omitempty"` // Name of the logger to use for this service. + Headers HeaderConfig `yaml:"headers,omitempty"` // Custom headers configuration for request and response objects +} + +// HeaderConfig is custom response and request headers modifier +type HeaderConfig struct { + RequestHeaders map[string]string `yaml:"request_headers,omitempty"` // Request headers to be added/modified when forwarding to backend + ResponseHeaders map[string]string `yaml:"response_headers,omitempty"` // Response headers to be added/modified before sending back to client + RemoveRequestHeaders []string `yaml:"remove_request_headers,omitempty"` // Headers to be removed from the request before forwarding + RemoveResponseHeaders []string `yaml:"remove_response_headers,omitempty"` // Headers to be removed from the response before sending back } // Middleware defines the configuration for various middleware components. diff --git a/internal/pool/options.go b/internal/pool/options.go index b601d3f..57962a4 100644 --- a/internal/pool/options.go +++ b/internal/pool/options.go @@ -3,10 +3,11 @@ package pool import ( "net/url" + "github.com/unkn0wn-root/terraster/internal/config" "go.uber.org/zap" ) -// WithURLRewriter is a functional option for configuring the URLRewriteProxy. +// WithURLRewriter is configuring the URLRewriteProxy. // It sets up a URL rewriter based on the provided RouteConfig and backend URL. // This allows the proxy to modify incoming request URLs according to the specified rewrite rules, // ensuring that requests are correctly routed to the intended backend services. @@ -16,9 +17,16 @@ func WithURLRewriter(config RouteConfig, backendURL *url.URL) ProxyOption { } } -// Functional option for configuring the URLRewriteProxy with a custom logger. +// WithLogger is configuring the URLRewriteProxy with a custom logger. func WithLogger(logger *zap.Logger) ProxyOption { return func(p *URLRewriteProxy) { p.logger = logger } } + +// WithHeaderConfig sets custom req/res headers +func WithHeaderConfig(config *config.HeaderConfig) ProxyOption { + return func(p *URLRewriteProxy) { + p.headerConfig = config + } +} diff --git a/internal/pool/proxy.go b/internal/pool/proxy.go index 3372825..160e538 100644 --- a/internal/pool/proxy.go +++ b/internal/pool/proxy.go @@ -8,8 +8,10 @@ import ( "net/http" "net/http/httputil" "net/url" + "strings" "time" + "github.com/unkn0wn-root/terraster/internal/config" "go.uber.org/zap" ) @@ -67,13 +69,14 @@ func NewTransport(transport http.RoundTripper, skipTLSVerify bool) *Transport { // URLRewriteProxy is a custom reverse proxy that handles URL rewriting and redirection based on RouteConfig. type URLRewriteProxy struct { - proxy *httputil.ReverseProxy // proxy is the underlying reverse proxy handling the HTTP requests. - target *url.URL // target is the destination URL to which the proxy forwards requests. - path string // path is the URL path prefix that this proxy handles. - rewriteURL string // rewriteURL specifies the URL to which incoming requests should be rewritten. - urlRewriter *URLRewriter // urlRewriter handles the logic for rewriting request URLs and managing redirects. - rConfig RewriteConfig // rConfig holds the rewrite and redirect configurations. - logger *zap.Logger // logger is used for logging proxy-related activities. + proxy *httputil.ReverseProxy // proxy is the underlying reverse proxy handling the HTTP requests. + target *url.URL // target is the destination URL to which the proxy forwards requests. + path string // path is the URL path prefix that this proxy handles. + rewriteURL string // rewriteURL specifies the URL to which incoming requests should be rewritten. + urlRewriter *URLRewriter // urlRewriter handles the logic for rewriting request URLs and managing redirects. + rConfig RewriteConfig // rConfig holds the rewrite and redirect configurations. + logger *zap.Logger // logger is used for logging proxy-related activities. + headerConfig *config.HeaderConfig // headerConfig is used to modify request/response headers } // ProxyOption defines a function type for applying optional configurations to URLRewriteProxy instances. @@ -173,6 +176,19 @@ func (p *URLRewriteProxy) updateRequestHeaders(req *http.Request) { originalHost := req.Host req.Header.Set(HeaderXForwardedHost, originalHost) req.Header.Set(HeaderXForwardedFor, originalHost) + + hc := p.headerConfig + if hc != nil { + for _, header := range hc.RemoveRequestHeaders { + req.Header.Del(header) + } + + // Add/modify configured request headers + for key, value := range hc.RequestHeaders { + processedValue := p.processHeaderValue(value, req) + req.Header.Set(key, processedValue) + } + } } // handleRedirect processes HTTP redirect responses from the backend server. @@ -214,6 +230,20 @@ func (p *URLRewriteProxy) updateResponseHeaders(resp *http.Response) { resp.Header.Del(HeaderServer) resp.Header.Del(HeaderXPoweredBy) resp.Header.Set(HeaderXProxyBy, DefaultProxyLabel) + + hc := p.headerConfig + if hc != nil { + // Remove specified response headers + for _, header := range hc.RemoveResponseHeaders { + resp.Header.Del(header) + } + + // Add/modify configured response headers + for key, value := range hc.ResponseHeaders { + processedValue := p.processHeaderValue(value, resp.Request) + resp.Header.Set(key, processedValue) + } + } } // isRedirect checks if the provided HTTP status code is one that indicates a redirection. @@ -242,3 +272,23 @@ func (p *URLRewriteProxy) errorHandler(w http.ResponseWriter, r *http.Request, e p.logger.Error("Unexpected error in proxy", zap.Error(err)) http.Error(w, "Something went wrong", http.StatusInternalServerError) } + +// processHeaderValue replaces placeholder values with actual req values +func (p *URLRewriteProxy) processHeaderValue(value string, req *http.Request) string { + // Replace placeholders with actual values + placeholders := map[string]func(*http.Request) string{ + "${remote_addr}": func(r *http.Request) string { return r.RemoteAddr }, + "${host}": func(r *http.Request) string { return r.Host }, + "${uri}": func(r *http.Request) string { return r.RequestURI }, + "${method}": func(r *http.Request) string { return r.Method }, + } + + result := value + for placeholder, getter := range placeholders { + if strings.Contains(value, placeholder) { + result = strings.ReplaceAll(result, placeholder, getter(req)) + } + } + + return result +} diff --git a/internal/pool/server_pool.go b/internal/pool/server_pool.go index 7e98b15..987d4d3 100644 --- a/internal/pool/server_pool.go +++ b/internal/pool/server_pool.go @@ -32,15 +32,16 @@ type BackendSnapshot struct { // ServerPool manages a pool of backend servers, handling load balancing and connection management. type ServerPool struct { - backends atomic.Value // Atomic value storing the current BackendSnapshot. - current uint64 // Atomic counter used for round-robin load balancing. - algorithm atomic.Value // Atomic value storing the current load balancing algorithm. - maxConnections atomic.Int32 // Atomic integer representing the maximum allowed connections per backend. - log *zap.Logger // Logger instance for logging pool activities. + backends atomic.Value // Atomic value storing the current BackendSnapshot. + current uint64 // Atomic counter used for round-robin load balancing. + algorithm atomic.Value // Atomic value storing the current load balancing algorithm. + maxConnections atomic.Int32 // Atomic integer representing the maximum allowed connections per backend. + log *zap.Logger // Logger instance for logging pool activities. + serviceHeaders config.HeaderConfig // Service request and response custom headers } -func NewServerPool(logger *zap.Logger) *ServerPool { - pool := &ServerPool{log: logger} +func NewServerPool(svc *config.Service, logger *zap.Logger) *ServerPool { + pool := &ServerPool{serviceHeaders: svc.Headers, log: logger} initialSnapshot := &BackendSnapshot{ Backends: []*Backend{}, BackendCache: make(map[string]*Backend), @@ -72,6 +73,7 @@ func (s *ServerPool) AddBackend( createProxy, s.log, WithURLRewriter(rc, url), + WithHeaderConfig(&s.serviceHeaders), ) maxConnections := cfg.MaxConnections diff --git a/internal/service/manager.go b/internal/service/manager.go index 8168001..e688859 100644 --- a/internal/service/manager.go +++ b/internal/service/manager.go @@ -46,6 +46,7 @@ type ServiceInfo struct { Middleware []config.Middleware // Middleware configurations for the service. LogName string // LogName will be used to get service logger from config. Logger *zap.Logger // Logger instance for logging service activities. + Headers config.HeaderConfig // Request/Response custom headers } // ServiceType determines the protocol type of the service based on its TLS configuration. @@ -137,7 +138,7 @@ func (m *Manager) AddService(service config.Service, globalHealthCheck *config.H locationPaths[location.Path] = true - serverPool, err := m.createServerPool(location, globalHealthCheck) + serverPool, err := m.createServerPool(service, location, globalHealthCheck) if err != nil { return err } @@ -183,6 +184,7 @@ func (m *Manager) AddService(service config.Service, globalHealthCheck *config.H Locations: locations, // Associated locations with their backends. Middleware: service.Middleware, LogName: service.LogName, + Headers: service.Headers, } m.mu.Unlock() @@ -268,17 +270,21 @@ func (m *Manager) AssignLogger(serviceName string, logger *zap.Logger) { // createServerPool initializes and configures a ServerPool for a given service location. // It sets up the load balancing algorithm and adds all backends associated with the location to the pool. -func (m *Manager) createServerPool(srvc config.Location, serviceHealthCheck *config.HealthCheckConfig) (*pool.ServerPool, error) { - serverPool := pool.NewServerPool(m.logger) +func (m *Manager) createServerPool( + svc config.Service, + lc config.Location, + serviceHealthCheck *config.HealthCheckConfig, +) (*pool.ServerPool, error) { + serverPool := pool.NewServerPool(&svc, m.logger) serverPool.UpdateConfig(pool.PoolConfig{ - Algorithm: srvc.LoadBalancer, + Algorithm: lc.LoadBalancer, }) - for _, backend := range srvc.Backends { + for _, backend := range lc.Backends { rc := pool.RouteConfig{ - Path: srvc.Path, // The path associated with the backend. - RewriteURL: srvc.Rewrite, // URL rewrite rules for the backend. - Redirect: srvc.Redirect, // Redirect settings if applicable. + Path: lc.Path, // The path associated with the backend. + RewriteURL: lc.Rewrite, // URL rewrite rules for the backend. + Redirect: lc.Redirect, // Redirect settings if applicable. SkipTLSVerify: backend.SkipTLSVerify, // TLS verification settings for the backend. }