@@ -17,7 +17,7 @@ use axum::http::{
1717 AUTHORIZATION , CACHE_CONTROL , CONTENT_SECURITY_POLICY , CONTENT_TYPE , EXPIRES ,
1818 PRAGMA , STRICT_TRANSPORT_SECURITY , X_CONTENT_TYPE_OPTIONS , X_FRAME_OPTIONS ,
1919 } ,
20- HeaderValue , Method ,
20+ HeaderName , HeaderValue , Method ,
2121} ;
2222use tower_http:: cors:: { AllowHeaders , AllowMethods , AllowOrigin , CorsLayer } ;
2323use tower_http:: limit:: RequestBodyLimitLayer ;
@@ -29,7 +29,45 @@ use tracing_subscriber;
2929use tokio:: signal;
3030use std:: net:: SocketAddr ;
3131use std:: io:: ErrorKind ;
32- use tower_governor:: key_extractor:: SmartIpKeyExtractor ;
32+ use tower_governor:: key_extractor:: { PeerIpKeyExtractor , SmartIpKeyExtractor } ;
33+
34+ const PERMISSIONS_POLICY : HeaderName = HeaderName :: from_static ( "permissions-policy" ) ;
35+ const REFERRER_POLICY : HeaderName = HeaderName :: from_static ( "referrer-policy" ) ;
36+ const X_XSS_PROTECTION : HeaderName = HeaderName :: from_static ( "x-xss-protection" ) ;
37+ const FORWARDED_HEADER : HeaderName = HeaderName :: from_static ( "forwarded" ) ;
38+ const X_FORWARDED_FOR_HEADER : HeaderName = HeaderName :: from_static ( "x-forwarded-for" ) ;
39+ const X_FORWARDED_PROTO_HEADER : HeaderName = HeaderName :: from_static ( "x-forwarded-proto" ) ;
40+ const X_FORWARDED_HOST_HEADER : HeaderName = HeaderName :: from_static ( "x-forwarded-host" ) ;
41+ const X_REAL_IP_HEADER : HeaderName = HeaderName :: from_static ( "x-real-ip" ) ;
42+
43+ fn parse_env_bool ( key : & str , default : bool ) -> bool {
44+ env:: var ( key)
45+ . ok ( )
46+ . and_then ( |value| {
47+ match value. trim ( ) . to_ascii_lowercase ( ) . as_str ( ) {
48+ "1" | "true" | "yes" | "on" => Some ( true ) ,
49+ "0" | "false" | "no" | "off" => Some ( false ) ,
50+ _ => {
51+ tracing:: warn!( key = %key, value = %value, "Invalid boolean env value; using default" ) ;
52+ None
53+ }
54+ }
55+ } )
56+ . unwrap_or ( default)
57+ }
58+
59+ async fn strip_untrusted_forwarded_headers ( mut request : Request , next : Next ) -> Response {
60+ {
61+ let headers = request. headers_mut ( ) ;
62+ headers. remove ( FORWARDED_HEADER ) ;
63+ headers. remove ( X_FORWARDED_FOR_HEADER ) ;
64+ headers. remove ( X_FORWARDED_PROTO_HEADER ) ;
65+ headers. remove ( X_FORWARDED_HOST_HEADER ) ;
66+ headers. remove ( X_REAL_IP_HEADER ) ;
67+ }
68+
69+ next. run ( request) . await
70+ }
3371
3472// Security headers middleware
3573async fn security_headers (
@@ -75,10 +113,10 @@ async fn security_headers(
75113 // Content Security Policy - Environment-dependent for dev mode support
76114 let csp = if cfg ! ( debug_assertions) {
77115 // Development mode: Allow WebSocket for Vite HMR
78- "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:; object-src 'none'; base-uri 'self'; form-action 'self';"
116+ "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:; object-src 'none'; base-uri 'self'; form-action 'self'; frame-ancestors 'none'; "
79117 } else {
80118 // Production mode: Strict CSP (no inline styles)
81- "default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; connect-src 'self'; object-src 'none'; base-uri 'self'; form-action 'self';"
119+ "default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; connect-src 'self'; object-src 'none'; base-uri 'self'; form-action 'self'; frame-ancestors 'none'; upgrade-insecure-requests; "
82120 } ;
83121
84122 headers. insert (
@@ -105,6 +143,21 @@ async fn security_headers(
105143 X_FRAME_OPTIONS ,
106144 HeaderValue :: from_static ( "DENY" ) ,
107145 ) ;
146+
147+ headers. insert (
148+ REFERRER_POLICY ,
149+ HeaderValue :: from_static ( "no-referrer" ) ,
150+ ) ;
151+
152+ headers. insert (
153+ PERMISSIONS_POLICY ,
154+ HeaderValue :: from_static ( "geolocation=(), microphone=(), camera=()" ) ,
155+ ) ;
156+
157+ headers. insert (
158+ X_XSS_PROTECTION ,
159+ HeaderValue :: from_static ( "0" ) ,
160+ ) ;
108161
109162 response
110163}
@@ -199,12 +252,25 @@ async fn main() {
199252
200253 tracing:: info!( origins = ?allowed_origins, "Configured CORS origins" ) ;
201254
255+ let trust_proxy_ip_headers = parse_env_bool ( "TRUST_PROXY_IP_HEADERS" , false ) ;
256+ if trust_proxy_ip_headers {
257+ tracing:: info!( "Trusting X-Forwarded-* headers for client IP extraction" ) ;
258+ } else {
259+ tracing:: info!( "Proxy headers will be stripped before rate limiting to prevent spoofing" ) ;
260+ }
261+
262+ let login_key_extractor = if trust_proxy_ip_headers {
263+ SmartIpKeyExtractor
264+ } else {
265+ PeerIpKeyExtractor
266+ } ;
267+
202268 // Configure rate limiting (average 1 request/sec for login, burst up to 5)
203269 let rate_limit_config = std:: sync:: Arc :: new (
204270 GovernorConfigBuilder :: default ( )
205271 . per_second ( 1 )
206272 . burst_size ( 5 )
207- . key_extractor ( SmartIpKeyExtractor )
273+ . key_extractor ( login_key_extractor )
208274 . finish ( )
209275 . expect ( "Failed to build governor config" ) ,
210276 ) ;
@@ -217,11 +283,17 @@ async fn main() {
217283 . layer ( RequestBodyLimitLayer :: new ( LOGIN_BODY_LIMIT ) )
218284 . layer ( GovernorLayer :: new ( rate_limit_config) ) ;
219285
286+ let admin_key_extractor = if trust_proxy_ip_headers {
287+ SmartIpKeyExtractor
288+ } else {
289+ PeerIpKeyExtractor
290+ } ;
291+
220292 let admin_rate_limit_config = std:: sync:: Arc :: new (
221293 GovernorConfigBuilder :: default ( )
222294 . per_second ( 1 )
223295 . burst_size ( 3 )
224- . key_extractor ( SmartIpKeyExtractor )
296+ . key_extractor ( admin_key_extractor )
225297 . finish ( )
226298 . expect ( "Failed to build governor config for write routes" ) ,
227299 ) ;
@@ -276,7 +348,7 @@ async fn main() {
276348 . layer ( RequestBodyLimitLayer :: new ( ADMIN_BODY_LIMIT ) )
277349 . layer ( GovernorLayer :: new ( admin_rate_limit_config. clone ( ) ) ) ;
278350
279- let app = Router :: new ( )
351+ let mut app = Router :: new ( )
280352 . merge ( login_router)
281353 // Auth routes
282354 . route ( "/api/auth/me" , get ( handlers:: auth:: me) )
@@ -321,6 +393,10 @@ async fn main() {
321393 . layer ( middleware:: from_fn ( security_headers) )
322394 . with_state ( pool) ;
323395
396+ if !trust_proxy_ip_headers {
397+ app = app. layer ( middleware:: from_fn ( strip_untrusted_forwarded_headers) ) ;
398+ }
399+
324400 // Get port from environment or use default
325401 let port_str = env:: var ( "PORT" ) . unwrap_or_else ( |_| "8489" . to_string ( ) ) ;
326402 let port: u16 = match port_str. parse ( ) {
0 commit comments