1
1
//! Utilities for extracting original client ip
2
2
//!
3
- //! See [`ClientIP`] for more details.
4
- use std:: { net:: IpAddr , str:: FromStr } ;
3
+ //! See [`ClientIp`] for more details.
4
+ use std:: {
5
+ net:: { IpAddr , Ipv4Addr , Ipv6Addr } ,
6
+ str:: FromStr ,
7
+ } ;
5
8
6
9
use http:: { HeaderMap , HeaderName } ;
7
- use ipnet:: IpNet ;
10
+ use ipnet:: { IpNet , Ipv4Net , Ipv6Net } ;
8
11
use motore:: { layer:: Layer , Service } ;
9
12
use volo:: { context:: Context , net:: Address } ;
10
13
11
- use crate :: { context:: ServerContext , request:: Request , utils :: macros :: impl_deref_and_deref_mut } ;
14
+ use crate :: { context:: ServerContext , request:: Request } ;
12
15
13
16
/// [`Layer`] for extracting client ip
14
17
///
15
- /// See [`ClientIP `] for more details.
16
- #[ derive( Clone , Default ) ]
17
- pub struct ClientIPLayer {
18
- config : ClientIPConfig ,
18
+ /// See [`ClientIp `] for more details.
19
+ #[ derive( Clone , Debug , Default ) ]
20
+ pub struct ClientIpLayer {
21
+ config : ClientIpConfig ,
19
22
}
20
23
21
- impl ClientIPLayer {
22
- /// Create a new [`ClientIPLayer `] with default config
24
+ impl ClientIpLayer {
25
+ /// Create a new [`ClientIpLayer `] with default config
23
26
pub fn new ( ) -> Self {
24
27
Default :: default ( )
25
28
}
26
29
27
- /// Create a new [`ClientIPLayer `] with the given [`ClientIPConfig `]
28
- pub fn with_config ( self , config : ClientIPConfig ) -> Self {
30
+ /// Create a new [`ClientIpLayer `] with the given [`ClientIpConfig `]
31
+ pub fn with_config ( self , config : ClientIpConfig ) -> Self {
29
32
Self { config }
30
33
}
31
34
}
32
35
33
- impl < S > Layer < S > for ClientIPLayer
36
+ impl < S > Layer < S > for ClientIpLayer
34
37
where
35
38
S : Send + Sync + ' static ,
36
39
{
37
- type Service = ClientIPService < S > ;
40
+ type Service = ClientIpService < S > ;
38
41
39
42
fn layer ( self , inner : S ) -> Self :: Service {
40
- ClientIPService {
43
+ ClientIpService {
41
44
service : inner,
42
45
config : self . config ,
43
46
}
@@ -46,25 +49,31 @@ where
46
49
47
50
/// Config for extract client ip
48
51
#[ derive( Clone , Debug ) ]
49
- pub struct ClientIPConfig {
52
+ pub struct ClientIpConfig {
50
53
remote_ip_headers : Vec < HeaderName > ,
51
54
trusted_cidrs : Vec < IpNet > ,
52
55
}
53
56
54
- impl Default for ClientIPConfig {
57
+ impl Default for ClientIpConfig {
55
58
fn default ( ) -> Self {
56
59
Self {
57
60
remote_ip_headers : vec ! [
58
61
HeaderName :: from_static( "x-real-ip" ) ,
59
62
HeaderName :: from_static( "x-forwarded-for" ) ,
60
63
] ,
61
- trusted_cidrs : vec ! [ "0.0.0.0/0" . parse( ) . unwrap( ) , "::/0" . parse( ) . unwrap( ) ] ,
64
+ trusted_cidrs : vec ! [
65
+ IpNet :: V4 ( Ipv4Net :: new_assert( Ipv4Addr :: new( 0 , 0 , 0 , 0 ) , 0 ) ) ,
66
+ IpNet :: V6 ( Ipv6Net :: new_assert(
67
+ Ipv6Addr :: new( 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ) ,
68
+ 0 ,
69
+ ) ) ,
70
+ ] ,
62
71
}
63
72
}
64
73
}
65
74
66
- impl ClientIPConfig {
67
- /// Create a new [`ClientIPConfig `] with default values
75
+ impl ClientIpConfig {
76
+ /// Create a new [`ClientIpConfig `] with default values
68
77
///
69
78
/// default remote ip headers: `["X-Real-IP", "X-Forwarded-For"]`
70
79
///
@@ -75,15 +84,15 @@ impl ClientIPConfig {
75
84
76
85
/// Get Real Client IP by parsing the given headers.
77
86
///
78
- /// See [`ClientIP `] for more details.
87
+ /// See [`ClientIp `] for more details.
79
88
///
80
89
/// # Example
81
90
///
82
91
/// ```rust
83
- /// use volo_http::server::utils::client_ip::ClientIPConfig ;
92
+ /// use volo_http::server::utils::client_ip::ClientIpConfig ;
84
93
///
85
94
/// let client_ip_config =
86
- /// ClientIPConfig ::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]);
95
+ /// ClientIpConfig ::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]);
87
96
/// ```
88
97
pub fn with_remote_ip_headers < I > (
89
98
self ,
@@ -108,14 +117,14 @@ impl ClientIPConfig {
108
117
109
118
/// Get Real Client IP if it is trusted, otherwise it will just return caller ip.
110
119
///
111
- /// See [`ClientIP `] for more details.
120
+ /// See [`ClientIp `] for more details.
112
121
///
113
122
/// # Example
114
123
///
115
124
/// ```rust
116
- /// use volo_http::server::utils::client_ip::ClientIPConfig ;
125
+ /// use volo_http::server::utils::client_ip::ClientIpConfig ;
117
126
///
118
- /// let client_ip_config = ClientIPConfig ::new()
127
+ /// let client_ip_config = ClientIpConfig ::new()
119
128
/// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]);
120
129
/// ```
121
130
pub fn with_trusted_cidrs < H > ( self , cidrs : H ) -> Self
@@ -132,11 +141,11 @@ impl ClientIPConfig {
132
141
/// Return original client IP Address
133
142
///
134
143
/// If you want to get client IP by retrieving specific headers, you can use
135
- /// [`with_remote_ip_headers`](ClientIPConfig ::with_remote_ip_headers) to set the
144
+ /// [`with_remote_ip_headers`](ClientIpConfig ::with_remote_ip_headers) to set the
136
145
/// headers.
137
146
///
138
147
/// If you want to get client IP that is trusted with specific cidrs, you can use
139
- /// [`with_trusted_cidrs`](ClientIPConfig ::with_trusted_cidrs) to set the cidrs.
148
+ /// [`with_trusted_cidrs`](ClientIpConfig ::with_trusted_cidrs) to set the cidrs.
140
149
///
141
150
/// # Example
142
151
///
@@ -148,20 +157,20 @@ impl ClientIPConfig {
148
157
///
149
158
/// ```rust
150
159
/// ///
151
- /// use volo_http::server::utils::client_ip::ClientIP ;
160
+ /// use volo_http::server::utils::client_ip::ClientIp ;
152
161
/// use volo_http::server::{
153
162
/// route::{get, Router},
154
- /// utils::client_ip::{ClientIPConfig, ClientIPLayer },
163
+ /// utils::client_ip::{ClientIpConfig, ClientIpLayer },
155
164
/// Server,
156
165
/// };
157
166
///
158
- /// async fn handler(client_ip: ClientIP ) -> String {
167
+ /// async fn handler(client_ip: ClientIp ) -> String {
159
168
/// client_ip.unwrap().to_string()
160
169
/// }
161
170
///
162
171
/// let router: Router = Router::new()
163
172
/// .route("/", get(handler))
164
- /// .layer(ClientIPLayer ::new());
173
+ /// .layer(ClientIpLayer ::new());
165
174
/// ```
166
175
///
167
176
/// ## With custom config
@@ -172,85 +181,80 @@ impl ClientIPConfig {
172
181
/// context::ServerContext,
173
182
/// server::{
174
183
/// route::{get, Router},
175
- /// utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer },
184
+ /// utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer },
176
185
/// Server,
177
186
/// },
178
187
/// };
179
188
///
180
- /// async fn handler(client_ip: ClientIP ) -> String {
189
+ /// async fn handler(client_ip: ClientIp ) -> String {
181
190
/// client_ip.unwrap().to_string()
182
191
/// }
183
192
///
184
193
/// let router: Router = Router::new().route("/", get(handler)).layer(
185
- /// ClientIPLayer ::new().with_config(
186
- /// ClientIPConfig ::new()
194
+ /// ClientIpLayer ::new().with_config(
195
+ /// ClientIpConfig ::new()
187
196
/// .with_remote_ip_headers(vec!["x-real-ip", "x-forwarded-for"])
188
197
/// .unwrap()
189
198
/// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]),
190
199
/// ),
191
200
/// );
192
201
/// ```
193
- pub struct ClientIP ( pub Option < IpAddr > ) ;
194
-
195
- impl_deref_and_deref_mut ! ( ClientIP , Option <IpAddr >, 0 ) ;
202
+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
203
+ pub struct ClientIp ( pub Option < IpAddr > ) ;
196
204
197
- /// [`ClientIPLayer `] generated [`Service`]
205
+ /// [`ClientIpLayer `] generated [`Service`]
198
206
///
199
- /// See [`ClientIP `] for more details.
200
- #[ derive( Clone ) ]
201
- pub struct ClientIPService < S > {
207
+ /// See [`ClientIp `] for more details.
208
+ #[ derive( Clone , Debug ) ]
209
+ pub struct ClientIpService < S > {
202
210
service : S ,
203
- config : ClientIPConfig ,
211
+ config : ClientIpConfig ,
204
212
}
205
213
206
- impl < S > ClientIPService < S > {
207
- fn get_client_ip ( & self , cx : & ServerContext , headers : & HeaderMap ) -> ClientIP {
214
+ impl < S > ClientIpService < S > {
215
+ fn get_client_ip ( & self , cx : & ServerContext , headers : & HeaderMap ) -> ClientIp {
208
216
let remote_ip = match & cx. rpc_info ( ) . caller ( ) . address {
209
217
Some ( Address :: Ip ( socket_addr) ) => Some ( socket_addr. ip ( ) ) ,
210
218
#[ cfg( target_family = "unix" ) ]
211
219
Some ( Address :: Unix ( _) ) => None ,
212
- None => return ClientIP ( None ) ,
220
+ None => return ClientIp ( None ) ,
213
221
} ;
214
222
215
- if let Some ( remote_ip) = remote_ip {
223
+ if let Some ( remote_ip) = & remote_ip {
216
224
if !self
217
225
. config
218
226
. trusted_cidrs
219
227
. iter ( )
220
- . any ( |cidr| cidr. contains ( & IpNet :: from ( remote_ip) ) )
228
+ . any ( |cidr| cidr. contains ( remote_ip) )
221
229
{
222
- return ClientIP ( None ) ;
230
+ return ClientIp ( None ) ;
223
231
}
224
232
}
225
233
226
234
for remote_ip_header in self . config . remote_ip_headers . iter ( ) {
227
- let remote_ips = match headers
228
- . get ( remote_ip_header)
229
- . and_then ( |v| v. to_str ( ) . ok ( ) )
230
- . map ( |v| v. split ( ',' ) . map ( |s| s. trim ( ) ) . collect :: < Vec < _ > > ( ) )
231
- {
232
- Some ( remote_ips) => remote_ips,
233
- None => continue ,
235
+ let Some ( remote_ips) = headers. get ( remote_ip_header) . and_then ( |v| v. to_str ( ) . ok ( ) )
236
+ else {
237
+ continue ;
234
238
} ;
235
- for remote_ip in remote_ips. iter ( ) {
239
+ for remote_ip in remote_ips. split ( ',' ) . map ( str :: trim ) {
236
240
if let Ok ( remote_ip_addr) = IpAddr :: from_str ( remote_ip) {
237
241
if self
238
242
. config
239
243
. trusted_cidrs
240
244
. iter ( )
241
245
. any ( |cidr| cidr. contains ( & remote_ip_addr) )
242
246
{
243
- return ClientIP ( Some ( remote_ip_addr) ) ;
247
+ return ClientIp ( Some ( remote_ip_addr) ) ;
244
248
}
245
249
}
246
250
}
247
251
}
248
252
249
- ClientIP ( remote_ip)
253
+ ClientIp ( remote_ip)
250
254
}
251
255
}
252
256
253
- impl < S , B > Service < ServerContext , Request < B > > for ClientIPService < S >
257
+ impl < S , B > Service < ServerContext , Request < B > > for ClientIpService < S >
254
258
where
255
259
S : Service < ServerContext , Request < B > > + Send + Sync + ' static ,
256
260
B : Send ,
@@ -264,7 +268,7 @@ where
264
268
req : Request < B > ,
265
269
) -> Result < Self :: Response , Self :: Error > {
266
270
let client_ip = self . get_client_ip ( cx, req. headers ( ) ) ;
267
- cx. rpc_info_mut ( ) . caller_mut ( ) . tags . insert ( client_ip) ;
271
+ cx. rpc_info_mut ( ) . caller_mut ( ) . insert ( client_ip) ;
268
272
269
273
self . service . call ( cx, req) . await
270
274
}
@@ -283,21 +287,21 @@ mod client_ip_tests {
283
287
context:: ServerContext ,
284
288
server:: {
285
289
route:: { get, Route } ,
286
- utils:: client_ip:: { ClientIP , ClientIPConfig , ClientIPLayer } ,
290
+ utils:: client_ip:: { ClientIp , ClientIpConfig , ClientIpLayer } ,
287
291
} ,
288
292
utils:: test_helpers:: simple_req,
289
293
} ;
290
294
291
295
#[ tokio:: test]
292
296
async fn test_client_ip ( ) {
293
- async fn handler ( client_ip : ClientIP ) -> String {
294
- client_ip. unwrap ( ) . to_string ( )
297
+ async fn handler ( client_ip : ClientIp ) -> String {
298
+ client_ip. 0 . unwrap ( ) . to_string ( )
295
299
}
296
300
297
301
let route: Route < & str > = Route :: new ( get ( handler) ) ;
298
- let service = ClientIPLayer :: new ( )
302
+ let service = ClientIpLayer :: new ( )
299
303
. with_config (
300
- ClientIPConfig :: default ( ) . with_trusted_cidrs ( vec ! [ "10.0.0.0/8" . parse( ) . unwrap( ) ] ) ,
304
+ ClientIpConfig :: default ( ) . with_trusted_cidrs ( vec ! [ "10.0.0.0/8" . parse( ) . unwrap( ) ] ) ,
301
305
)
302
306
. layer ( route) ;
303
307
0 commit comments