Skip to content

Commit 9827bd7

Browse files
committed
chore(volo-http): format codes of client ip
Signed-off-by: Yu Li <liyu.yukiteru@bytedance.com>
1 parent 661346b commit 9827bd7

File tree

2 files changed

+76
-75
lines changed

2 files changed

+76
-75
lines changed

volo-http/src/server/extract.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use crate::{
2323
context::ServerContext,
2424
error::server::{body_collection_error, ExtractBodyError},
2525
request::{Request, RequestPartsExt},
26-
server::utils::client_ip::ClientIP,
26+
server::utils::client_ip::ClientIp,
2727
utils::macros::impl_deref_and_deref_mut,
2828
};
2929

@@ -291,18 +291,15 @@ impl FromContext for Method {
291291
}
292292
}
293293

294-
impl FromContext for ClientIP {
294+
impl FromContext for ClientIp {
295295
type Rejection = Infallible;
296296

297-
async fn from_context(
298-
cx: &mut ServerContext,
299-
_: &mut Parts,
300-
) -> Result<ClientIP, Self::Rejection> {
301-
Ok(ClientIP(
297+
async fn from_context(cx: &mut ServerContext, _: &mut Parts) -> Result<Self, Self::Rejection> {
298+
Ok(ClientIp(
302299
cx.rpc_info
303300
.caller()
304301
.tags
305-
.get::<ClientIP>()
302+
.get::<ClientIp>()
306303
.and_then(|v| v.0),
307304
))
308305
}

volo-http/src/server/utils/client_ip.rs

Lines changed: 71 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,46 @@
11
//! Utilities for extracting original client ip
22
//!
3-
//! See [`ClientIP`] for more details.
4-
use std::{net::IpAddr, str::FromStr};
3+
//! See [`ClientIp`] for more details.
4+
use std::{
5+
net::{IpAddr, Ipv4Addr, Ipv6Addr},
6+
str::FromStr,
7+
};
58

69
use http::{HeaderMap, HeaderName};
7-
use ipnet::IpNet;
10+
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
811
use motore::{layer::Layer, Service};
912
use volo::{context::Context, net::Address};
1013

11-
use crate::{context::ServerContext, request::Request, utils::macros::impl_deref_and_deref_mut};
14+
use crate::{context::ServerContext, request::Request};
1215

1316
/// [`Layer`] for extracting client ip
1417
///
15-
/// See [`ClientIP`] for more details.
16-
#[derive(Clone, Default)]
17-
pub struct ClientIPLayer {
18-
config: ClientIPConfig,
18+
/// See [`ClientIp`] for more details.
19+
#[derive(Clone, Debug, Default)]
20+
pub struct ClientIpLayer {
21+
config: ClientIpConfig,
1922
}
2023

21-
impl ClientIPLayer {
22-
/// Create a new [`ClientIPLayer`] with default config
24+
impl ClientIpLayer {
25+
/// Create a new [`ClientIpLayer`] with default config
2326
pub fn new() -> Self {
2427
Default::default()
2528
}
2629

27-
/// Create a new [`ClientIPLayer`] with the given [`ClientIPConfig`]
28-
pub fn with_config(self, config: ClientIPConfig) -> Self {
30+
/// Create a new [`ClientIpLayer`] with the given [`ClientIpConfig`]
31+
pub fn with_config(self, config: ClientIpConfig) -> Self {
2932
Self { config }
3033
}
3134
}
3235

33-
impl<S> Layer<S> for ClientIPLayer
36+
impl<S> Layer<S> for ClientIpLayer
3437
where
3538
S: Send + Sync + 'static,
3639
{
37-
type Service = ClientIPService<S>;
40+
type Service = ClientIpService<S>;
3841

3942
fn layer(self, inner: S) -> Self::Service {
40-
ClientIPService {
43+
ClientIpService {
4144
service: inner,
4245
config: self.config,
4346
}
@@ -46,25 +49,31 @@ where
4649

4750
/// Config for extract client ip
4851
#[derive(Clone, Debug)]
49-
pub struct ClientIPConfig {
52+
pub struct ClientIpConfig {
5053
remote_ip_headers: Vec<HeaderName>,
5154
trusted_cidrs: Vec<IpNet>,
5255
}
5356

54-
impl Default for ClientIPConfig {
57+
impl Default for ClientIpConfig {
5558
fn default() -> Self {
5659
Self {
5760
remote_ip_headers: vec![
5861
HeaderName::from_static("x-real-ip"),
5962
HeaderName::from_static("x-forwarded-for"),
6063
],
61-
trusted_cidrs: vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()],
64+
trusted_cidrs: vec![
65+
IpNet::V4(Ipv4Net::new_assert(Ipv4Addr::new(0, 0, 0, 0), 0)),
66+
IpNet::V6(Ipv6Net::new_assert(
67+
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
68+
0,
69+
)),
70+
],
6271
}
6372
}
6473
}
6574

66-
impl ClientIPConfig {
67-
/// Create a new [`ClientIPConfig`] with default values
75+
impl ClientIpConfig {
76+
/// Create a new [`ClientIpConfig`] with default values
6877
///
6978
/// default remote ip headers: `["X-Real-IP", "X-Forwarded-For"]`
7079
///
@@ -75,15 +84,15 @@ impl ClientIPConfig {
7584

7685
/// Get Real Client IP by parsing the given headers.
7786
///
78-
/// See [`ClientIP`] for more details.
87+
/// See [`ClientIp`] for more details.
7988
///
8089
/// # Example
8190
///
8291
/// ```rust
83-
/// use volo_http::server::utils::client_ip::ClientIPConfig;
92+
/// use volo_http::server::utils::client_ip::ClientIpConfig;
8493
///
8594
/// let client_ip_config =
86-
/// ClientIPConfig::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]);
95+
/// ClientIpConfig::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]);
8796
/// ```
8897
pub fn with_remote_ip_headers<I>(
8998
self,
@@ -108,14 +117,14 @@ impl ClientIPConfig {
108117

109118
/// Get Real Client IP if it is trusted, otherwise it will just return caller ip.
110119
///
111-
/// See [`ClientIP`] for more details.
120+
/// See [`ClientIp`] for more details.
112121
///
113122
/// # Example
114123
///
115124
/// ```rust
116-
/// use volo_http::server::utils::client_ip::ClientIPConfig;
125+
/// use volo_http::server::utils::client_ip::ClientIpConfig;
117126
///
118-
/// let client_ip_config = ClientIPConfig::new()
127+
/// let client_ip_config = ClientIpConfig::new()
119128
/// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]);
120129
/// ```
121130
pub fn with_trusted_cidrs<H>(self, cidrs: H) -> Self
@@ -132,11 +141,11 @@ impl ClientIPConfig {
132141
/// Return original client IP Address
133142
///
134143
/// If you want to get client IP by retrieving specific headers, you can use
135-
/// [`with_remote_ip_headers`](ClientIPConfig::with_remote_ip_headers) to set the
144+
/// [`with_remote_ip_headers`](ClientIpConfig::with_remote_ip_headers) to set the
136145
/// headers.
137146
///
138147
/// If you want to get client IP that is trusted with specific cidrs, you can use
139-
/// [`with_trusted_cidrs`](ClientIPConfig::with_trusted_cidrs) to set the cidrs.
148+
/// [`with_trusted_cidrs`](ClientIpConfig::with_trusted_cidrs) to set the cidrs.
140149
///
141150
/// # Example
142151
///
@@ -148,20 +157,20 @@ impl ClientIPConfig {
148157
///
149158
/// ```rust
150159
/// ///
151-
/// use volo_http::server::utils::client_ip::ClientIP;
160+
/// use volo_http::server::utils::client_ip::ClientIp;
152161
/// use volo_http::server::{
153162
/// route::{get, Router},
154-
/// utils::client_ip::{ClientIPConfig, ClientIPLayer},
163+
/// utils::client_ip::{ClientIpConfig, ClientIpLayer},
155164
/// Server,
156165
/// };
157166
///
158-
/// async fn handler(client_ip: ClientIP) -> String {
167+
/// async fn handler(client_ip: ClientIp) -> String {
159168
/// client_ip.unwrap().to_string()
160169
/// }
161170
///
162171
/// let router: Router = Router::new()
163172
/// .route("/", get(handler))
164-
/// .layer(ClientIPLayer::new());
173+
/// .layer(ClientIpLayer::new());
165174
/// ```
166175
///
167176
/// ## With custom config
@@ -172,85 +181,80 @@ impl ClientIPConfig {
172181
/// context::ServerContext,
173182
/// server::{
174183
/// route::{get, Router},
175-
/// utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer},
184+
/// utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer},
176185
/// Server,
177186
/// },
178187
/// };
179188
///
180-
/// async fn handler(client_ip: ClientIP) -> String {
189+
/// async fn handler(client_ip: ClientIp) -> String {
181190
/// client_ip.unwrap().to_string()
182191
/// }
183192
///
184193
/// let router: Router = Router::new().route("/", get(handler)).layer(
185-
/// ClientIPLayer::new().with_config(
186-
/// ClientIPConfig::new()
194+
/// ClientIpLayer::new().with_config(
195+
/// ClientIpConfig::new()
187196
/// .with_remote_ip_headers(vec!["x-real-ip", "x-forwarded-for"])
188197
/// .unwrap()
189198
/// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]),
190199
/// ),
191200
/// );
192201
/// ```
193-
pub struct ClientIP(pub Option<IpAddr>);
194-
195-
impl_deref_and_deref_mut!(ClientIP, Option<IpAddr>, 0);
202+
#[derive(Clone, Debug, PartialEq, Eq)]
203+
pub struct ClientIp(pub Option<IpAddr>);
196204

197-
/// [`ClientIPLayer`] generated [`Service`]
205+
/// [`ClientIpLayer`] generated [`Service`]
198206
///
199-
/// See [`ClientIP`] for more details.
200-
#[derive(Clone)]
201-
pub struct ClientIPService<S> {
207+
/// See [`ClientIp`] for more details.
208+
#[derive(Clone, Debug)]
209+
pub struct ClientIpService<S> {
202210
service: S,
203-
config: ClientIPConfig,
211+
config: ClientIpConfig,
204212
}
205213

206-
impl<S> ClientIPService<S> {
207-
fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIP {
214+
impl<S> ClientIpService<S> {
215+
fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIp {
208216
let remote_ip = match &cx.rpc_info().caller().address {
209217
Some(Address::Ip(socket_addr)) => Some(socket_addr.ip()),
210218
#[cfg(target_family = "unix")]
211219
Some(Address::Unix(_)) => None,
212-
None => return ClientIP(None),
220+
None => return ClientIp(None),
213221
};
214222

215-
if let Some(remote_ip) = remote_ip {
223+
if let Some(remote_ip) = &remote_ip {
216224
if !self
217225
.config
218226
.trusted_cidrs
219227
.iter()
220-
.any(|cidr| cidr.contains(&IpNet::from(remote_ip)))
228+
.any(|cidr| cidr.contains(remote_ip))
221229
{
222-
return ClientIP(None);
230+
return ClientIp(None);
223231
}
224232
}
225233

226234
for remote_ip_header in self.config.remote_ip_headers.iter() {
227-
let remote_ips = match headers
228-
.get(remote_ip_header)
229-
.and_then(|v| v.to_str().ok())
230-
.map(|v| v.split(',').map(|s| s.trim()).collect::<Vec<_>>())
231-
{
232-
Some(remote_ips) => remote_ips,
233-
None => continue,
235+
let Some(remote_ips) = headers.get(remote_ip_header).and_then(|v| v.to_str().ok())
236+
else {
237+
continue;
234238
};
235-
for remote_ip in remote_ips.iter() {
239+
for remote_ip in remote_ips.split(',').map(str::trim) {
236240
if let Ok(remote_ip_addr) = IpAddr::from_str(remote_ip) {
237241
if self
238242
.config
239243
.trusted_cidrs
240244
.iter()
241245
.any(|cidr| cidr.contains(&remote_ip_addr))
242246
{
243-
return ClientIP(Some(remote_ip_addr));
247+
return ClientIp(Some(remote_ip_addr));
244248
}
245249
}
246250
}
247251
}
248252

249-
ClientIP(remote_ip)
253+
ClientIp(remote_ip)
250254
}
251255
}
252256

253-
impl<S, B> Service<ServerContext, Request<B>> for ClientIPService<S>
257+
impl<S, B> Service<ServerContext, Request<B>> for ClientIpService<S>
254258
where
255259
S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
256260
B: Send,
@@ -264,7 +268,7 @@ where
264268
req: Request<B>,
265269
) -> Result<Self::Response, Self::Error> {
266270
let client_ip = self.get_client_ip(cx, req.headers());
267-
cx.rpc_info_mut().caller_mut().tags.insert(client_ip);
271+
cx.rpc_info_mut().caller_mut().insert(client_ip);
268272

269273
self.service.call(cx, req).await
270274
}
@@ -283,21 +287,21 @@ mod client_ip_tests {
283287
context::ServerContext,
284288
server::{
285289
route::{get, Route},
286-
utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer},
290+
utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer},
287291
},
288292
utils::test_helpers::simple_req,
289293
};
290294

291295
#[tokio::test]
292296
async fn test_client_ip() {
293-
async fn handler(client_ip: ClientIP) -> String {
294-
client_ip.unwrap().to_string()
297+
async fn handler(client_ip: ClientIp) -> String {
298+
client_ip.0.unwrap().to_string()
295299
}
296300

297301
let route: Route<&str> = Route::new(get(handler));
298-
let service = ClientIPLayer::new()
302+
let service = ClientIpLayer::new()
299303
.with_config(
300-
ClientIPConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]),
304+
ClientIpConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]),
301305
)
302306
.layer(route);
303307

0 commit comments

Comments
 (0)