diff --git a/Cargo.lock b/Cargo.lock index 18817b20..865a9ba4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4071,7 +4071,7 @@ dependencies = [ [[package]] name = "volo-http" -version = "0.3.0-rc.1" +version = "0.3.0-rc.2" dependencies = [ "ahash", "async-broadcast", diff --git a/examples/src/http/example-http-client.rs b/examples/src/http/example-http-client.rs index 87a7af9f..148c5cfd 100644 --- a/examples/src/http/example-http-client.rs +++ b/examples/src/http/example-http-client.rs @@ -40,12 +40,12 @@ async fn main() -> Result<(), BoxError> { let client = { let mut builder = ClientBuilder::new(); builder - .caller_name("example.http.client") - .callee_name("example.http.server") + .user_agent("example.http.client") + .default_host("example.http.server") // set default target address .address("127.0.0.1:8080".parse::().unwrap()) - .header("Test", "Test")?; - builder.build() + .header("Test", "Test"); + builder.build()? }; // set host and override the default one @@ -98,7 +98,7 @@ async fn main() -> Result<(), BoxError> { ); // an empty client - let client = ClientBuilder::new().build(); + let client = ClientBuilder::new().build()?; println!( "{}", client diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index 59291eac..da36335b 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "volo-http" -version = "0.3.0-rc.1" +version = "0.3.0-rc.2" edition.workspace = true homepage.workspace = true repository.workspace = true diff --git a/volo-http/src/client/callopt.rs b/volo-http/src/client/callopt.rs index d30e85e7..9a5166e7 100644 --- a/volo-http/src/client/callopt.rs +++ b/volo-http/src/client/callopt.rs @@ -2,28 +2,28 @@ //! //! See [`CallOpt`] for more details. +use std::time::Duration; + use faststr::FastStr; use metainfo::{FastStrMap, TypeMap}; +use volo::{client::Apply, context::Context}; + +use crate::{context::ClientContext, error::ClientError}; /// Call options for requests -/// -/// It can be set to a [`Client`][Client] or a [`RequestBuilder`][RequestBuilder]. The -/// [`TargetParser`][TargetParser] will handle [`Target`][Target] and the [`CallOpt`] for -/// applying information to the [`Endpoint`][Endpoint]. -/// -/// [Client]: crate::client::Client -/// [RequestBuilder]: crate::client::RequestBuilder -/// [TargetParser]: crate::client::target::TargetParser -/// [Target]: crate::client::target::Target -/// [Endpoint]: volo::context::Endpoint #[derive(Debug, Default)] pub struct CallOpt { - /// `tags` is used to store additional information of the endpoint. + /// Timeout of the whole request + /// + /// This timeout includes connect, sending request headers, receiving response headers, but + /// without receiving streaming data. + pub timeout: Option, + /// Additional information of the endpoint. /// /// Users can use `tags` to store custom data, such as the datacenter name or the region name, /// which can be used by the service discoverer. pub tags: TypeMap, - /// `faststr_tags` is a optimized typemap to store additional information of the endpoint. + /// A optimized typemap for storing additional information of the endpoint. /// /// Use [`FastStrMap`] instead of [`TypeMap`] can reduce the Box allocation. /// @@ -32,13 +32,24 @@ pub struct CallOpt { } impl CallOpt { - /// Create a new [`CallOpt`] + /// Create a new [`CallOpt`]. #[inline] pub fn new() -> Self { Self::default() } - /// Check if [`CallOpt`] tags contain entry + /// Set a timeout for the [`CallOpt`]. + pub fn set_timeout(&mut self, timeout: Duration) { + self.timeout = Some(timeout); + } + + /// Consume current [`CallOpt`] and return a new one with the given timeout. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + /// Check if [`CallOpt`] tags contain entry. #[inline] pub fn contains(&self) -> bool { self.tags.contains::() @@ -50,7 +61,7 @@ impl CallOpt { self.tags.insert(val); } - /// Insert a tag into this [`CallOpt`] and return self. + /// Consume current [`CallOpt`] and return a new one with the tag. #[inline] pub fn with(mut self, val: T) -> Self { self.tags.insert(val); @@ -63,7 +74,7 @@ impl CallOpt { self.tags.get::() } - /// Check if [`CallOpt`] tags contain entry + /// Check if [`CallOpt`] tags contain entry. #[inline] pub fn contains_faststr(&self) -> bool { self.faststr_tags.contains::() @@ -75,7 +86,7 @@ impl CallOpt { self.faststr_tags.insert::(val); } - /// Insert a tag into this [`CallOpt`] and return self. + /// Consume current [`CallOpt`] and return a new one with the tag. #[inline] pub fn with_faststr(mut self, val: FastStr) -> Self { self.faststr_tags.insert::(val); @@ -88,3 +99,26 @@ impl CallOpt { self.faststr_tags.get::() } } + +impl Apply for CallOpt { + type Error = ClientError; + + fn apply(self, cx: &mut ClientContext) -> Result<(), Self::Error> { + { + let callee = cx.rpc_info_mut().callee_mut(); + if !self.tags.is_empty() { + callee.tags.extend(self.tags); + } + if !self.faststr_tags.is_empty() { + callee.faststr_tags.extend(self.faststr_tags); + } + } + { + let config = cx.rpc_info_mut().config_mut(); + if self.timeout.is_some() { + config.set_timeout(self.timeout); + } + } + Ok(()) + } +} diff --git a/volo-http/src/client/dns.rs b/volo-http/src/client/dns.rs index b7c36d37..7ca16bc6 100644 --- a/volo-http/src/client/dns.rs +++ b/volo-http/src/client/dns.rs @@ -17,14 +17,7 @@ use volo::{ net::Address, }; -use super::{target::RemoteTargetAddress, Target}; -#[cfg(feature = "__tls")] -use crate::client::transport::TlsTransport; -use crate::{ - client::callopt::CallOpt, - error::client::{bad_host_name, no_address}, - utils::consts, -}; +use crate::error::client::{bad_host_name, no_address}; /// The port for `DnsResolver`, and only used for `DnsResolver`. /// @@ -33,6 +26,7 @@ use crate::{ /// /// For setting port to `DnsResolver`, you can insert it into `Endpoint` of `callee` in /// `ClientContext`, the resolver will apply it. +#[derive(Clone, Copy, Debug, Default)] pub struct Port(pub u16); impl Deref for Port { @@ -99,14 +93,7 @@ impl Discover for DnsResolver { let port = match endpoint.get::() { Some(port) => port.0, None => { - #[cfg(feature = "__tls")] - if endpoint.contains::() { - consts::HTTPS_DEFAULT_PORT - } else { - consts::HTTP_DEFAULT_PORT - } - #[cfg(not(feature = "__tls"))] - consts::HTTP_DEFAULT_PORT + unreachable!(); } }; @@ -130,40 +117,3 @@ impl Discover for DnsResolver { None } } - -/// [`TargetParser`][TargetParser] for parsing [`Target`] and [`CallOpt`] to [`Endpoint`] -/// -/// Because [`LoadBalance`][LoadBalance] accepts only [`Endpoint`], but we should create an HTTP -/// target through [`Target`], the [`parse_target`] can parse them and apply to [`Endpoint`] for -/// [LoadBalance] using. -/// -/// [TargetParser]: crate::client::target::TargetParser -/// [LoadBalance]: volo::loadbalance::LoadBalance -pub fn parse_target(target: Target, _: Option<&CallOpt>, endpoint: &mut Endpoint) { - match target { - Target::None => (), - Target::Remote(rt) => { - let port = rt.port(); - - #[cfg(feature = "__tls")] - if rt.is_https() { - endpoint.insert(TlsTransport); - } - - match rt.addr { - RemoteTargetAddress::Ip(ip) => { - let sa = SocketAddr::new(ip, port); - endpoint.set_address(Address::Ip(sa)); - } - RemoteTargetAddress::Name(host) => { - endpoint.insert(Port(port)); - endpoint.set_service_name(host); - } - } - } - #[cfg(target_family = "unix")] - Target::Local(unix_socket) => { - endpoint.set_address(Address::Unix(unix_socket.clone())); - } - } -} diff --git a/volo-http/src/client/layer.rs b/volo-http/src/client/layer/fail_on_status.rs similarity index 71% rename from volo-http/src/client/layer.rs rename to volo-http/src/client/layer/fail_on_status.rs index 74c317c2..f5a5d26f 100644 --- a/volo-http/src/client/layer.rs +++ b/volo-http/src/client/layer/fail_on_status.rs @@ -1,83 +1,26 @@ -//! Collections of some useful `Layer`s. - -use std::{error::Error, fmt, time::Duration}; +use std::{error::Error, fmt}; use http::status::StatusCode; use motore::{layer::Layer, service::Service}; +use url::Url; +use volo::context::Context; use crate::{ error::{client::request_error, ClientError}, + request::RequestPartsExt, response::ClientResponse, }; -/// [`Layer`] for setting timeout to the request. -/// -/// See [`TimeoutLayer::new`] for more details. -pub struct TimeoutLayer { - duration: Duration, -} - -impl TimeoutLayer { - /// Create a new [`TimeoutLayer`] with given [`Duration`]. - /// - /// If the request times out, an error [`Timeout`] is returned. - /// - /// [`Timeout`]: crate::error::client::Timeout - pub fn new(duration: Duration) -> Self { - Self { duration } - } -} - -impl Layer for TimeoutLayer { - type Service = TimeoutService; - - fn layer(self, inner: S) -> Self::Service { - TimeoutService { - inner, - duration: self.duration, - } - } -} - -/// The [`Service`] generated by [`TimeoutLayer`]. -/// -/// See [`TimeoutLayer`] and [`TimeoutLayer::new`] for more details. -pub struct TimeoutService { - inner: S, - duration: Duration, -} - -impl Service for TimeoutService -where - Cx: Send, - Req: Send, - S: Service + Send + Sync, -{ - type Response = S::Response; - type Error = S::Error; - - async fn call(&self, cx: &mut Cx, req: Req) -> Result { - let fut = self.inner.call(cx, req); - let sleep = tokio::time::sleep(self.duration); - - tokio::select! { - res = fut => res, - _ = sleep => { - tracing::error!("[Volo-HTTP] request timeout"); - Err(crate::error::client::timeout()) - } - } - } -} - /// [`Layer`] for throwing service error with the response's error status code. /// /// Users can use [`FailOnStatus::all`], [`FailOnStatus::client_error`] or /// [`FailOnStatus::server_error`] for creating the [`FailOnStatus`] layer that convert all (4XX and /// 5XX), client error (4XX) or server error (5XX) to a error of service. +#[derive(Clone, Debug, Default)] pub struct FailOnStatus { client_error: bool, server_error: bool, + detailed: bool, } impl FailOnStatus { @@ -87,6 +30,7 @@ impl FailOnStatus { Self { client_error: true, server_error: true, + detailed: false, } } @@ -96,6 +40,7 @@ impl FailOnStatus { Self { client_error: true, server_error: false, + detailed: false, } } @@ -105,8 +50,18 @@ impl FailOnStatus { Self { client_error: false, server_error: true, + detailed: false, } } + + /// Collect more details in [`StatusCodeError`]. + /// + /// When error occurs, the request has been consumed and the original response will be dropped. + /// With this flag enabled, the layer will save more details in [`StatusCodeError`]. + pub fn detailed(mut self) -> Self { + self.detailed = true; + self + } } impl Layer for FailOnStatus { @@ -130,20 +85,26 @@ pub struct FailOnStatusService { impl Service for FailOnStatusService where - Cx: Send, - Req: Send, + Cx: Context + Send, + Req: RequestPartsExt + Send, S: Service, Error = ClientError> + Send + Sync, { type Response = S::Response; type Error = S::Error; async fn call(&self, cx: &mut Cx, req: Req) -> Result { + let url = if self.fail_on.detailed { + req.url() + } else { + None + }; let resp = self.inner.call(cx, req).await?; let status = resp.status(); if (self.fail_on.client_error && status.is_client_error()) || (self.fail_on.server_error && status.is_server_error()) { - Err(request_error(StatusCodeError::new(status))) + Err(request_error(StatusCodeError { status, url }) + .with_endpoint(cx.rpc_info().callee())) } else { Ok(resp) } @@ -152,13 +113,21 @@ where /// Client received a response with an error status code. pub struct StatusCodeError { - /// The original status code - pub status: StatusCode, + status: StatusCode, + url: Option, } impl StatusCodeError { - fn new(status: StatusCode) -> Self { - Self { status } + /// The original status code. + pub fn status(&self) -> StatusCode { + self.status + } + + /// The target [`Url`] + /// + /// It will only be saved when [`FailOnStatus::detailed`] enabled. + pub fn url(&self) -> Option<&Url> { + self.url.as_ref() } } @@ -172,14 +141,18 @@ impl fmt::Debug for StatusCodeError { impl fmt::Display for StatusCodeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "client received an error status code: {}", self.status) + write!(f, "client received an error status `{}`", self.status)?; + if let Some(url) = &self.url { + write!(f, " for `{url}`")?; + } + Ok(()) } } impl Error for StatusCodeError {} #[cfg(test)] -mod client_layers_tests { +mod fail_on_status_tests { use http::status::StatusCode; use motore::service::Service; @@ -216,7 +189,8 @@ mod client_layers_tests { // Reject all error status codes let client = ClientBuilder::new() .layer_outer_front(FailOnStatus::all()) - .mock(MockTransport::service(ReturnStatus)); + .mock(MockTransport::service(ReturnStatus)) + .unwrap(); client.get("/400").send().await.unwrap_err(); client.get("/500").send().await.unwrap_err(); } @@ -224,7 +198,8 @@ mod client_layers_tests { // Reject client error status codes let client = ClientBuilder::new() .layer_outer_front(FailOnStatus::client_error()) - .mock(MockTransport::service(ReturnStatus)); + .mock(MockTransport::service(ReturnStatus)) + .unwrap(); client.get("/400").send().await.unwrap_err(); // 5XX is server error, it should not be handled client.get("/500").send().await.unwrap(); @@ -233,7 +208,8 @@ mod client_layers_tests { // Reject all error status codes let client = ClientBuilder::new() .layer_outer_front(FailOnStatus::server_error()) - .mock(MockTransport::service(ReturnStatus)); + .mock(MockTransport::service(ReturnStatus)) + .unwrap(); // 4XX is client error, it should not be handled client.get("/400").send().await.unwrap(); client.get("/500").send().await.unwrap_err(); diff --git a/volo-http/src/client/layer/header.rs b/volo-http/src/client/layer/header.rs new file mode 100644 index 00000000..811e92a3 --- /dev/null +++ b/volo-http/src/client/layer/header.rs @@ -0,0 +1,471 @@ +//! [`Layer`]s for inserting header to requests. +//! +//! - [`Header`] inserts any [`HeaderName`] and [`HeaderValue`] +//! - [`Host`] inserts the given `Host` or a `Host` generated by the target hostname or target +//! address with its scheme and port. +//! - [`UserAgent`] inserts the given `User-Agent` or a `User-Agent` generated by the current +//! package information. + +use std::{error::Error, future::Future, ops::Deref}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + uri::Scheme, +}; +use motore::{layer::Layer, service::Service}; +use volo::{ + context::{Context, Endpoint}, + net::Address, +}; + +use crate::{ + client::{dns::Port, target::is_default_port}, + error::client::{builder_error, Result}, + request::ClientRequest, +}; + +/// [`Layer`] for inserting a header to requests. +#[derive(Clone, Debug)] +pub struct Header { + key: HeaderName, + val: HeaderValue, +} + +impl Header { + /// Create a new [`Header`] layer for inserting a header to requests. + /// + /// This function takes [`HeaderName`] and [`HeaderValue`], users should create it by + /// themselves. + /// + /// For using string types directly, see [`Header::try_new`]. + pub fn new(key: HeaderName, val: HeaderValue) -> Self { + Self { key, val } + } + + /// Create a new [`Header`] layer for inserting a header to requests. + /// + /// This function takes any types that can be converted into [`HeaderName`] or [`HeaderValue`]. + /// If the values are invalid [`HeaderName`] or [`HeaderValue`], an [`ClientError`] with + /// [`ErrorKind::Builder`] will be returned. + /// + /// [`ClientError`]: crate::error::client::ClientError + /// [`ErrorKind::Builder`]: crate::error::client::ErrorKind::Builder + pub fn try_new(key: K, val: V) -> Result + where + K: TryInto, + K::Error: Error + Send + Sync + 'static, + V: TryInto, + V::Error: Error + Send + Sync + 'static, + { + let key = key.try_into().map_err(builder_error)?; + let val = val.try_into().map_err(builder_error)?; + + Ok(Self::new(key, val)) + } +} + +impl Layer for Header { + type Service = HeaderService; + + fn layer(self, inner: S) -> Self::Service { + HeaderService { + inner, + key: self.key, + val: self.val, + } + } +} + +/// [`Service`] generated by [`Header`]. +/// +/// See [`Header`], [`Header::new`] and [`Header::try_new`] for more details. +pub struct HeaderService { + inner: S, + key: HeaderName, + val: HeaderValue, +} + +impl Service> for HeaderService +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + + fn call( + &self, + cx: &mut Cx, + mut req: ClientRequest, + ) -> impl Future> + Send { + req.headers_mut().insert(self.key.clone(), self.val.clone()); + self.inner.call(cx, req) + } +} + +/// [`Layer`] for inserting `Host` into the request header. +/// +/// See [`Host::new`] and [`Host::auto`] for more details. +pub struct Host { + val: Option, +} + +impl Host { + /// Create a new [`Host`] layer that inserts `Host` into the request header. + /// + /// Note that the layer only inserts it if there is no `Host` + pub fn new(val: HeaderValue) -> Self { + Self { val: Some(val) } + } + + /// Create a new [`Host`] layer that inserts `Host` by the current target host name, port or + /// address. + /// + /// Note that the layer only inserts it if there is no `Host`. + /// + /// This layer also does nothing if there is no target hostname and the target address is not + /// an address (such as a unix domain socket). + pub fn auto() -> Self { + Self { val: None } + } +} + +impl Layer for Host { + type Service = HostService; + + fn layer(self, inner: S) -> Self::Service { + HostService { + inner, + val: self.val, + } + } +} + +/// [`Service`] generated by [`Host`]. +/// +/// See [`Host`] and [`Host::new`] for more details. +pub struct HostService { + inner: S, + val: Option, +} + +// keep it as a separate function to facilitate unit testing +fn gen_host( + scheme: &Scheme, + name: &str, + addr: Option<&Address>, + port: Option, +) -> Option { + if name.is_empty() { + match addr? { + Address::Ip(sa) => { + if is_default_port(scheme, sa.port()) { + HeaderValue::try_from(format!("{}", sa.ip())).ok() + } else { + HeaderValue::try_from(format!("{}", sa)).ok() + } + } + #[cfg(target_family = "unix")] + Address::Unix(_) => None, + } + } else { + if let Some(port) = port { + if !is_default_port(scheme, port) { + return HeaderValue::try_from(format!("{name}:{port}")).ok(); + } + } + HeaderValue::from_str(name).ok() + } +} + +fn gen_host_by_ep(ep: &Endpoint) -> Option { + let scheme = ep.get::()?; + let name = ep.service_name_ref(); + let addr = ep.address.as_ref(); + let port = ep.get::().map(Deref::deref).cloned(); + gen_host(scheme, name, addr, port) +} + +impl Service> for HostService +where + Cx: Context, + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + + fn call( + &self, + cx: &mut Cx, + mut req: ClientRequest, + ) -> impl Future> + Send { + if !req.headers().contains_key(header::HOST) { + if let Some(val) = &self.val { + req.headers_mut().insert(header::HOST, val.clone()); + } else if let Some(val) = gen_host_by_ep(cx.rpc_info().callee()) { + req.headers_mut().insert(header::HOST, val); + } + } + self.inner.call(cx, req) + } +} + +const PKG_NAME_WITH_VER: &str = concat!(env!("CARGO_PKG_NAME"), '/', env!("CARGO_PKG_VERSION")); + +/// [`Layer`] for inserting `User-Agent` into the request header. +/// +/// See [`UserAgent::new`] for more details. +pub struct UserAgent { + val: HeaderValue, +} + +impl UserAgent { + /// Create a new [`UserAgent`] layer that inserts `User-Agent` into the request header. + /// + /// Note that the layer only inserts it if there is no `User-Agent` + pub fn new(val: HeaderValue) -> Self { + Self { val } + } + + /// Create a new [`UserAgent`] layer with the package name and package version as its default + /// value. + /// + /// Note that the layer only inserts it if there is no `User-Agent` + pub fn auto() -> Self { + Self { + val: HeaderValue::from_static(PKG_NAME_WITH_VER), + } + } +} + +impl Layer for UserAgent { + type Service = UserAgentService; + + fn layer(self, inner: S) -> Self::Service { + UserAgentService { + inner, + val: self.val, + } + } +} + +/// [`Service`] generated by [`UserAgent`]. +/// +/// See [`UserAgent`] and [`UserAgent::new`] for more details. +pub struct UserAgentService { + inner: S, + val: HeaderValue, +} + +impl Service> for UserAgentService +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + + fn call( + &self, + cx: &mut Cx, + mut req: ClientRequest, + ) -> impl Future> + Send { + if !req.headers().contains_key(header::USER_AGENT) { + req.headers_mut() + .insert(header::USER_AGENT, self.val.clone()); + } + self.inner.call(cx, req) + } +} + +#[cfg(test)] +mod layer_header_tests { + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + + use http::uri::Scheme; + use volo::net::Address; + + use crate::client::layer::header::gen_host; + + fn gen_ipv4addr(port: u16) -> Address { + // 127.0.0.1:port + Address::Ip(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + port, + )) + } + + fn gen_ipv6addr(port: u16) -> Address { + // [::1]:port + Address::Ip(SocketAddr::new( + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), + port, + )) + } + + #[test] + fn gen_host_test() { + // no host, no addr + assert_eq!(gen_host(&Scheme::HTTP, "", None, Some(80)), None); + + // host without port + assert_eq!( + gen_host(&Scheme::HTTP, "github.com", None, None).unwrap(), + "github.com" + ); + // host with default port + assert_eq!( + gen_host(&Scheme::HTTP, "github.com", None, Some(80)).unwrap(), + "github.com" + ); + // host with non-default port + assert_eq!( + gen_host(&Scheme::HTTP, "github.com", None, Some(8000)).unwrap(), + "github.com:8000" + ); + assert_eq!( + gen_host(&Scheme::HTTP, "github.com", None, Some(443)).unwrap(), + "github.com:443" + ); + + // same test case as above, but with a resolved address + // host without port + assert_eq!( + gen_host(&Scheme::HTTP, "github.com", Some(&gen_ipv4addr(80)), None).unwrap(), + "github.com" + ); + // host with default port + assert_eq!( + gen_host( + &Scheme::HTTP, + "github.com", + Some(&gen_ipv4addr(80)), + Some(80) + ) + .unwrap(), + "github.com" + ); + // host with non-default port + assert_eq!( + gen_host( + &Scheme::HTTP, + "github.com", + Some(&gen_ipv4addr(8000)), + Some(8000) + ) + .unwrap(), + "github.com:8000" + ); + assert_eq!( + gen_host( + &Scheme::HTTP, + "github.com", + Some(&gen_ipv4addr(8000)), + Some(443) + ) + .unwrap(), + "github.com:443" + ); + + // ipv4 addr with default port + assert_eq!( + gen_host(&Scheme::HTTP, "", Some(&gen_ipv4addr(80)), None).unwrap(), + "127.0.0.1" + ); + // ipv4 addr with non-default port + assert_eq!( + gen_host(&Scheme::HTTP, "", Some(&gen_ipv4addr(8000)), None).unwrap(), + "127.0.0.1:8000" + ); + assert_eq!( + gen_host(&Scheme::HTTP, "", Some(&gen_ipv4addr(443)), None).unwrap(), + "127.0.0.1:443" + ); + + // althrough these cases are impossible to happen, we also test it + // ipv4 addr with default port + assert_eq!( + gen_host(&Scheme::HTTP, "", Some(&gen_ipv4addr(80)), Some(8888)).unwrap(), + "127.0.0.1" + ); + // ipv4 addr with non-default port + assert_eq!( + gen_host(&Scheme::HTTP, "", Some(&gen_ipv4addr(8000)), Some(8888)).unwrap(), + "127.0.0.1:8000" + ); + assert_eq!( + gen_host(&Scheme::HTTP, "", Some(&gen_ipv4addr(443)), Some(8888)).unwrap(), + "127.0.0.1:443" + ); + + // ipv6 addr with default port + assert_eq!( + gen_host(&Scheme::HTTP, "", Some(&gen_ipv6addr(80)), None).unwrap(), + "::1" + ); + // ipv6 addr with non-default port + assert_eq!( + gen_host(&Scheme::HTTP, "", Some(&gen_ipv6addr(8000)), None).unwrap(), + "[::1]:8000" + ); + assert_eq!( + gen_host(&Scheme::HTTP, "", Some(&gen_ipv6addr(443)), None).unwrap(), + "[::1]:443" + ); + } + + #[cfg(feature = "__tls")] + #[test] + fn gen_host_with_tls_test() { + // no host, no addr + assert_eq!(gen_host(&Scheme::HTTPS, "", None, Some(443)), None); + + // host without port + assert_eq!( + gen_host(&Scheme::HTTPS, "github.com", None, None).unwrap(), + "github.com" + ); + // host with default port + assert_eq!( + gen_host(&Scheme::HTTPS, "github.com", None, Some(443)).unwrap(), + "github.com" + ); + // host with non-default port + assert_eq!( + gen_host(&Scheme::HTTPS, "github.com", None, Some(4430)).unwrap(), + "github.com:4430" + ); + assert_eq!( + gen_host(&Scheme::HTTPS, "github.com", None, Some(80)).unwrap(), + "github.com:80" + ); + + // ipv4 addr with default port + assert_eq!( + gen_host(&Scheme::HTTPS, "", Some(&gen_ipv4addr(443)), None).unwrap(), + "127.0.0.1" + ); + // ipv4 addr with non-default port + assert_eq!( + gen_host(&Scheme::HTTPS, "", Some(&gen_ipv4addr(4430)), None).unwrap(), + "127.0.0.1:4430" + ); + assert_eq!( + gen_host(&Scheme::HTTPS, "", Some(&gen_ipv4addr(80)), None).unwrap(), + "127.0.0.1:80" + ); + + // ipv6 addr with default port + assert_eq!( + gen_host(&Scheme::HTTPS, "", Some(&gen_ipv6addr(443)), None).unwrap(), + "::1" + ); + // ipv6 addr with non-default port + assert_eq!( + gen_host(&Scheme::HTTPS, "", Some(&gen_ipv6addr(4430)), None).unwrap(), + "[::1]:4430" + ); + assert_eq!( + gen_host(&Scheme::HTTPS, "", Some(&gen_ipv6addr(80)), None).unwrap(), + "[::1]:80" + ); + } +} diff --git a/volo-http/src/client/layer/mod.rs b/volo-http/src/client/layer/mod.rs new file mode 100644 index 00000000..033d03e1 --- /dev/null +++ b/volo-http/src/client/layer/mod.rs @@ -0,0 +1,12 @@ +//! Collections of some useful [`Layer`]s. +//! +//! [`Layer`]: motore::layer::Layer + +mod fail_on_status; +pub mod header; +mod timeout; + +pub use self::{ + fail_on_status::{FailOnStatus, StatusCodeError}, + timeout::Timeout, +}; diff --git a/volo-http/src/client/layer/timeout.rs b/volo-http/src/client/layer/timeout.rs new file mode 100644 index 00000000..4f7a5a3b --- /dev/null +++ b/volo-http/src/client/layer/timeout.rs @@ -0,0 +1,59 @@ +use motore::{layer::Layer, service::Service}; +use volo::context::Context; + +use crate::{context::client::Config, error::ClientError}; + +/// [`Layer`] for applying timeout from [`Config`]. +/// +/// This layer will be applied by default when using [`ClientBuilder::build`], without this layer, +/// timeout from [`Client`] or [`CallOpt`] will not work. +/// +/// [`Client`]: crate::client::Client +/// [`ClientBuilder::build`]: crate::client::ClientBuilder::build +/// [`CallOpt`]: crate::client::CallOpt +#[derive(Clone, Debug, Default)] +pub struct Timeout; + +impl Layer for Timeout { + type Service = TimeoutService; + + fn layer(self, inner: S) -> Self::Service { + TimeoutService { inner } + } +} + +/// The [`Service`] generated by [`Timeout`]. +/// +/// See [`Timeout`] for more details. +pub struct TimeoutService { + inner: S, +} + +impl Service for TimeoutService +where + Cx: Context + Send, + Req: Send, + S: Service + Send + Sync, +{ + type Response = S::Response; + type Error = S::Error; + + async fn call(&self, cx: &mut Cx, req: Req) -> Result { + let timeout = cx.rpc_info().config().timeout().cloned(); + let fut = self.inner.call(cx, req); + + if let Some(duration) = timeout { + let sleep = tokio::time::sleep(duration); + + tokio::select! { + res = fut => res, + _ = sleep => { + tracing::error!("[Volo-HTTP] request timeout"); + Err(crate::error::client::timeout().with_endpoint(cx.rpc_info().callee())) + } + } + } else { + fut.await + } + } +} diff --git a/volo-http/src/client/mod.rs b/volo-http/src/client/mod.rs index 7ea0e140..bfdad7d5 100644 --- a/volo-http/src/client/mod.rs +++ b/volo-http/src/client/mod.rs @@ -2,11 +2,18 @@ //! //! See [`Client`] for more details. -use std::{cell::RefCell, error::Error, sync::Arc, time::Duration}; +use std::{ + borrow::Cow, + cell::RefCell, + error::Error, + future::Future, + sync::{Arc, LazyLock}, + time::Duration, +}; use faststr::FastStr; use http::{ - header::{self, HeaderMap, HeaderName, HeaderValue}, + header::{HeaderMap, HeaderName, HeaderValue}, uri::{Scheme, Uri}, Method, }; @@ -17,7 +24,7 @@ use motore::{ }; use paste::paste; use volo::{ - client::MkClient, + client::{Apply, MkClient, OneShotService}, context::Context, loadbalance::MkLbLayer, net::{ @@ -26,27 +33,25 @@ use volo::{ }, }; -#[cfg(feature = "__tls")] -#[cfg_attr(docsrs, doc(cfg(any(feature = "rustls", feature = "native-tls"))))] -pub use self::transport::TlsTransport; use self::{ - callopt::CallOpt, - dns::parse_target, - loadbalance::{DefaultLB, DefaultLBService, LbConfig}, - target::TargetParser, + layer::{ + header::{Host, HostService, UserAgent, UserAgentService}, + Timeout, + }, + loadbalance::{DefaultLB, LbConfig}, transport::{ClientConfig, ClientTransport, ClientTransportConfig}, }; use crate::{ context::ClientContext, error::{ - client::{builder_error, no_address, ClientError, Result}, - BoxError, + client::{builder_error, Result}, + BoxError, ClientError, }, request::ClientRequest, response::ClientResponse, }; -pub mod callopt; +mod callopt; #[cfg(feature = "cookie")] pub mod cookie; pub mod dns; @@ -58,36 +63,44 @@ pub mod target; pub mod test_helpers; mod transport; -pub use self::{request_builder::RequestBuilder, target::Target}; +pub use self::{callopt::CallOpt, request_builder::RequestBuilder, target::Target}; #[doc(hidden)] pub mod prelude { pub use super::{Client, ClientBuilder}; } -const PKG_NAME_WITH_VER: &str = concat!(env!("CARGO_PKG_NAME"), '/', env!("CARGO_PKG_VERSION")); - /// Default inner service of [`Client`] pub type ClientMetaService = ClientTransport; -/// Default [`Client`] without any extra [`Layer`]s +/// [`Client`] generated service with given `IL`, `OL` and `LB` +pub type ClientService =
    ::Layer as Layer<>::Service>>::Service, +>>::Service; +/// Default [`Client`] without default [`Layer`]s +pub type SimpleClient = Client>; +/// Default [`Layer`]s that [`ClientBuilder::build`] append to outer layers +pub type DefaultClientOuterService = + >>>::Service; +/// Default [`Client`] with default [`Layer`]s pub type DefaultClient = - Client<
      >::Service>>>::Service>; + Client>>; /// A builder for configuring an HTTP [`Client`]. pub struct ClientBuilder { http_config: ClientConfig, builder_config: BuilderConfig, connector: DefaultMakeTransport, - callee_name: FastStr, - caller_name: FastStr, target: Target, - call_opt: Option, - target_parser: TargetParser, + timeout: Option, + user_agent: Option, + host: Option, + callee_name: FastStr, headers: HeaderMap, inner_layer: IL, outer_layer: OL, mk_client: C, mk_lb: LB, + status: Result<()>, #[cfg(feature = "__tls")] tls_config: Option, } @@ -119,16 +132,17 @@ impl ClientBuilder { http_config: Default::default(), builder_config: Default::default(), connector: Default::default(), - callee_name: FastStr::empty(), - caller_name: FastStr::empty(), target: Default::default(), - call_opt: Default::default(), - target_parser: parse_target, + timeout: None, + user_agent: None, + host: None, + callee_name: FastStr::empty(), headers: Default::default(), inner_layer: Identity::new(), outer_layer: Identity::new(), mk_client: DefaultMkClient, mk_lb: Default::default(), + status: Ok(()), #[cfg(feature = "__tls")] tls_config: None, } @@ -151,16 +165,17 @@ impl ClientBuilder> { http_config: self.http_config, builder_config: self.builder_config, connector: self.connector, - callee_name: self.callee_name, - caller_name: self.caller_name, target: self.target, - call_opt: self.call_opt, - target_parser: self.target_parser, + timeout: self.timeout, + user_agent: self.user_agent, + host: self.host, + callee_name: self.callee_name, headers: self.headers, inner_layer: self.inner_layer, outer_layer: self.outer_layer, mk_client: self.mk_client, mk_lb: self.mk_lb.load_balance(load_balance), + status: self.status, #[cfg(feature = "__tls")] tls_config: self.tls_config, } @@ -172,16 +187,17 @@ impl ClientBuilder> { http_config: self.http_config, builder_config: self.builder_config, connector: self.connector, - callee_name: self.callee_name, - caller_name: self.caller_name, target: self.target, - call_opt: self.call_opt, - target_parser: self.target_parser, + timeout: self.timeout, + user_agent: self.user_agent, + host: self.host, + callee_name: self.callee_name, headers: self.headers, inner_layer: self.inner_layer, outer_layer: self.outer_layer, mk_client: self.mk_client, mk_lb: self.mk_lb.discover(discover), + status: self.status, #[cfg(feature = "__tls")] tls_config: self.tls_config, } @@ -196,16 +212,17 @@ impl ClientBuilder { http_config: self.http_config, builder_config: self.builder_config, connector: self.connector, - callee_name: self.callee_name, - caller_name: self.caller_name, target: self.target, - call_opt: self.call_opt, - target_parser: self.target_parser, + timeout: self.timeout, + user_agent: self.user_agent, + host: self.host, + callee_name: self.callee_name, headers: self.headers, inner_layer: self.inner_layer, outer_layer: self.outer_layer, mk_client: new_mk_client, mk_lb: self.mk_lb, + status: self.status, #[cfg(feature = "__tls")] tls_config: self.tls_config, } @@ -229,16 +246,17 @@ impl ClientBuilder { http_config: self.http_config, builder_config: self.builder_config, connector: self.connector, - callee_name: self.callee_name, - caller_name: self.caller_name, target: self.target, - call_opt: self.call_opt, - target_parser: self.target_parser, + timeout: self.timeout, + user_agent: self.user_agent, + host: self.host, + callee_name: self.callee_name, headers: self.headers, inner_layer: Stack::new(layer, self.inner_layer), outer_layer: self.outer_layer, mk_client: self.mk_client, mk_lb: self.mk_lb, + status: self.status, #[cfg(feature = "__tls")] tls_config: self.tls_config, } @@ -265,16 +283,17 @@ impl ClientBuilder { http_config: self.http_config, builder_config: self.builder_config, connector: self.connector, - callee_name: self.callee_name, - caller_name: self.caller_name, target: self.target, - call_opt: self.call_opt, - target_parser: self.target_parser, + timeout: self.timeout, + user_agent: self.user_agent, + host: self.host, + callee_name: self.callee_name, headers: self.headers, inner_layer: Stack::new(self.inner_layer, layer), outer_layer: self.outer_layer, mk_client: self.mk_client, mk_lb: self.mk_lb, + status: self.status, #[cfg(feature = "__tls")] tls_config: self.tls_config, } @@ -298,16 +317,17 @@ impl ClientBuilder { http_config: self.http_config, builder_config: self.builder_config, connector: self.connector, - callee_name: self.callee_name, - caller_name: self.caller_name, target: self.target, - call_opt: self.call_opt, - target_parser: self.target_parser, + timeout: self.timeout, + user_agent: self.user_agent, + host: self.host, + callee_name: self.callee_name, headers: self.headers, inner_layer: self.inner_layer, outer_layer: Stack::new(layer, self.outer_layer), mk_client: self.mk_client, mk_lb: self.mk_lb, + status: self.status, #[cfg(feature = "__tls")] tls_config: self.tls_config, } @@ -334,16 +354,17 @@ impl ClientBuilder { http_config: self.http_config, builder_config: self.builder_config, connector: self.connector, - callee_name: self.callee_name, - caller_name: self.caller_name, target: self.target, - call_opt: self.call_opt, - target_parser: self.target_parser, + timeout: self.timeout, + user_agent: self.user_agent, + host: self.host, + callee_name: self.callee_name, headers: self.headers, inner_layer: self.inner_layer, outer_layer: Stack::new(self.outer_layer, layer), mk_client: self.mk_client, mk_lb: self.mk_lb, + status: self.status, #[cfg(feature = "__tls")] tls_config: self.tls_config, } @@ -355,123 +376,98 @@ impl ClientBuilder { http_config: self.http_config, builder_config: self.builder_config, connector: self.connector, - callee_name: self.callee_name, - caller_name: self.caller_name, target: self.target, - call_opt: self.call_opt, - target_parser: self.target_parser, + timeout: self.timeout, + user_agent: self.user_agent, + host: self.host, + callee_name: self.callee_name, headers: self.headers, inner_layer: self.inner_layer, outer_layer: self.outer_layer, mk_client: self.mk_client, mk_lb: mk_load_balance, + status: self.status, #[cfg(feature = "__tls")] tls_config: self.tls_config, } } - /// Set the target server's name. + /// Set default target address of the client. /// - /// When sending a request, the `Host` in headers will be the host name or host address by - /// default. If the callee name is not empty, it can override the default `Host`. - /// - /// Default is empty, the default `Host` will be used. - pub fn callee_name(&mut self, callee: S) -> &mut Self - where - S: AsRef, - { - self.callee_name = FastStr::from_string(callee.as_ref().to_owned()); - self - } - - /// Set the client's name sent to the server. - /// - /// When sending a request, the `User-Agent` in headers will be the crate name with its - /// version. If the caller name is not empty, it can override the default `User-Agent`. - /// - /// Default is empty, the default `User-Agnet` will be used. - pub fn caller_name(&mut self, caller: S) -> &mut Self - where - S: AsRef, - { - self.caller_name = FastStr::from_string(caller.as_ref().to_owned()); - self - } - - /// Set the target address of the client. + /// For using given `Host` in header or using HTTPS, use [`ClientBuilder::default_host`] for + /// setting it. /// /// If there is no target specified when building a request, client will use this address. pub fn address(&mut self, address: A) -> &mut Self where A: Into
      , { - self.target = Target::from_address(address); + self.target = Target::from(address.into()); self } - /// Set the target host of the client. + /// Set default target host of the client. /// /// If there is no target specified when building a request, client will use this address. /// /// It uses http with port 80 by default. /// - /// For setting scheme and port, use [`Self::with_port`] and [`Self::with_https`] after - /// specifying host. - pub fn host(&mut self, host: H) -> &mut Self + /// For setting scheme and port, use [`ClientBuilder::with_scheme`] and + /// [`ClientBuilder::with_port`] after specifying host. + pub fn host(&mut self, host: S) -> &mut Self where - H: AsRef, + S: Into>, { - self.target = Target::from_host(host); + let host = host.into(); + self.target = Target::from_host(host.clone()); + self.default_host(host); self } - /// Set the port of the default target. + /// Set port of the default target. /// /// If there is no target specified, the function will do nothing. pub fn with_port(&mut self, port: u16) -> &mut Self { - self.target.set_port(port); - self - } + if self.status.is_err() { + return self; + } - /// Set if the default target uses https for transporting. - #[cfg(feature = "__tls")] - pub fn with_https(&mut self, https: bool) -> &mut Self { - self.target.set_https(https); - self - } + if let Err(err) = self.target.set_port(port) { + self.status = Err(err); + } - /// Set a [`CallOpt`] to the client as default options for the default target. - /// - /// The [`CallOpt`] is used for service discover, default is an empty one. - /// - /// See [`CallOpt`] for more details. - pub fn with_callopt(&mut self, call_opt: CallOpt) -> &mut Self { - self.call_opt = Some(call_opt); self } - /// Set a target parser for parsing `Target` and updating `Endpoint`. - /// - /// The `TargetParser` usually used for service discover, it can update `Endpoint` from - /// `Target` and the service discover will resolve the `Endpoint` to `Address`. - pub fn target_parser(&mut self, target_parser: TargetParser) -> &mut Self { - self.target_parser = target_parser; + /// Set scheme of default target. + pub fn with_scheme(&mut self, scheme: Scheme) -> &mut Self { + if self.status.is_err() { + return self; + } + + if let Err(err) = self.target.set_scheme(scheme) { + self.status = Err(err); + } + self } /// Insert a header to the request. - pub fn header(&mut self, key: K, value: V) -> Result<&mut Self> + pub fn header(&mut self, key: K, value: V) -> &mut Self where K: TryInto, K::Error: Error + Send + Sync + 'static, V: TryInto, V::Error: Error + Send + Sync + 'static, { - self.headers.insert( - key.try_into().map_err(builder_error)?, - value.try_into().map_err(builder_error)?, - ); - Ok(self) + if self.status.is_err() { + return self; + } + + if let Err(err) = insert_header(&mut self.headers, key, value) { + self.status = Err(err); + } + self } /// Get a reference of [`Target`]. @@ -484,16 +480,6 @@ impl ClientBuilder { &mut self.target } - /// Get a reference of [`CallOpt`]. - pub fn callopt_ref(&self) -> &Option { - &self.call_opt - } - - /// Get a mutable reference of [`CallOpt`]. - pub fn callopt_mut(&mut self) -> &mut Option { - &mut self.call_opt - } - /// Set tls config for the client. #[cfg(feature = "__tls")] #[cfg_attr(docsrs, doc(cfg(any(feature = "rustls", feature = "native-tls"))))] @@ -613,8 +599,74 @@ impl ClientBuilder { self } - /// Build the HTTP client. - pub fn build(mut self) -> C::Target + /// Set the maximum idle time for the whole request. + pub fn set_request_timeout(&mut self, timeout: Duration) -> &mut Self { + self.timeout = Some(timeout); + self + } + + /// Set default `User-Agent` in request header. + /// + /// If there is `User-Agent` given, a default `User-Agent` will be generated by crate name and + /// version. + pub fn user_agent(&mut self, val: V) -> &mut Self + where + V: TryInto, + V::Error: Error + Send + Sync + 'static, + { + if self.status.is_err() { + return self; + } + match val.try_into() { + Ok(val) => self.user_agent = Some(val), + Err(err) => self.status = Err(builder_error(err)), + } + self + } + + /// Set default `Host` for service name and request header. + /// + /// If there is no default `Host`, it will be generated by target hostname and address for each + /// request. + pub fn default_host(&mut self, host: S) -> &mut Self + where + S: Into>, + { + if self.status.is_err() { + return self; + } + let host = FastStr::from(host.into()); + match HeaderValue::from_str(&host) { + Ok(val) => self.host = Some(val), + Err(err) => self.status = Err(builder_error(err)), + } + self.callee_name = host; + self + } + + /// Build the HTTP client with default configurations. + /// + /// This method will insert some default layers: [`Timeout`], [`UserAgent`] and [`Host`], and + /// the final calling sequence will be as follows: + /// + /// - Outer: + /// - [`Timeout`]: Apply timeout from [`ClientBuilder::set_request_timeout`] or + /// [`CallOpt::with_timeout`]. Note that without this layer, timeout from [`Client`] or + /// [`CallOpt`] will not work. + /// - [`Host`]: Insert `Host` to request headers, it takes the given value from + /// [`ClientBuilder::default_host`] or generating by request everytime. If there is already + /// a `Host`, the layer does nothing. + /// - [`UserAgent`]: Insert `User-Agent` into the request header, it takes the given value + /// from [`ClientBuilder::user_agent`] or generates a value based on the current package + /// name and version. If `User-Agent` already exists, this layer does nothing. + /// - Other outer layers + /// - LoadBalance ([`LbConfig`] with [`DnsResolver`] by default) + /// - Inner layers + /// - Other inner layers + /// - Transport through network or unix domain socket. + /// + /// [`DnsResolver`]: crate::client::dns::DnsResolver + pub fn build(mut self) -> Result where IL: Layer, IL::Service: Send + Sync + 'static, @@ -623,8 +675,44 @@ impl ClientBuilder { >::Service: Send + Sync, OL: Layer<>::Service>, OL::Service: Send + Sync + 'static, + C: MkClient< + Client<>>>::Service>, + >, + { + let timeout_layer = Timeout; + let host_layer = match self.host.take() { + Some(host) => Host::new(host), + None => Host::auto(), + }; + let ua_layer = match self.user_agent.take() { + Some(ua) => UserAgent::new(ua), + None => UserAgent::auto(), + }; + self.layer_outer_front(ua_layer) + .layer_outer_front(host_layer) + .layer_outer_front(timeout_layer) + .build_without_extra_layers() + } + + /// Build the HTTP client without inserting any extra layers. + /// + /// This method is provided for advanced users, some features may not work properly without the + /// default layers, + /// + /// See [`ClientBuilder::build`] for more details. + pub fn build_without_extra_layers(self) -> Result + where + IL: Layer, + IL::Service: Send + Sync + 'static, + LB: MkLbLayer, + LB::Layer: Layer, + >::Service: Send + Sync, + OL: Layer<>::Service>, + OL::Service: Send + Sync + 'static, C: MkClient>, { + self.status?; + let transport_config = ClientTransportConfig { stat_enable: self.builder_config.stat_enable, #[cfg(feature = "__tls")] @@ -637,48 +725,43 @@ impl ClientBuilder { #[cfg(feature = "__tls")] self.tls_config.unwrap_or_default(), ); - let meta_service = transport; - let service = self.outer_layer.layer( - self.mk_lb - .make() - .layer(self.inner_layer.layer(meta_service)), - ); - - let caller_name = if self.caller_name.is_empty() { - FastStr::from_static_str(PKG_NAME_WITH_VER) - } else { - self.caller_name - }; - if !caller_name.is_empty() && self.headers.get(header::USER_AGENT).is_none() { - self.headers.insert( - header::USER_AGENT, - HeaderValue::from_str(caller_name.as_str()).expect("Invalid caller name"), - ); - } + let service = self + .outer_layer + .layer(self.mk_lb.make().layer(self.inner_layer.layer(transport))); let client_inner = ClientInner { service, - caller_name, - callee_name: self.callee_name, - default_target: self.target, - default_call_opt: self.call_opt, - target_parser: self.target_parser, + target: self.target, + timeout: self.timeout, + default_callee_name: self.callee_name, headers: self.headers, }; let client = Client { inner: Arc::new(client_inner), }; - self.mk_client.mk_client(client) + Ok(self.mk_client.mk_client(client)) } } +fn insert_header(headers: &mut HeaderMap, key: K, value: V) -> Result<()> +where + K: TryInto, + K::Error: Error + Send + Sync + 'static, + V: TryInto, + V::Error: Error + Send + Sync + 'static, +{ + headers.insert( + key.try_into().map_err(builder_error)?, + value.try_into().map_err(builder_error)?, + ); + Ok(()) +} + struct ClientInner { service: S, - caller_name: FastStr, - callee_name: FastStr, - default_target: Target, - default_call_opt: Option, - target_parser: TargetParser, + target: Target, + timeout: Option, + default_callee_name: FastStr, headers: HeaderMap, } @@ -690,7 +773,7 @@ struct ClientInner { /// use volo_http::{body::BodyConversion, client::Client}; /// /// # tokio_test::block_on(async { -/// let client = Client::builder().build(); +/// let client = Client::builder().build().unwrap(); /// let resp = client /// .get("http://httpbin.org/get") /// .send() @@ -706,6 +789,12 @@ pub struct Client { inner: Arc>, } +impl Default for DefaultClient { + fn default() -> Self { + ClientBuilder::default().build().unwrap() + } +} + impl Clone for Client { fn clone(&self) -> Self { Self { @@ -718,7 +807,7 @@ macro_rules! method_requests { ($method:ident) => { paste! { #[doc = concat!("Create a request with `", stringify!([<$method:upper>]) ,"` method and the given `uri`.")] - pub fn [<$method:lower>](&self, uri: U) -> RequestBuilder + pub fn [<$method:lower>](&self, uri: U) -> RequestBuilder where U: TryInto, U::Error: Into, @@ -738,12 +827,12 @@ impl Client<()> { impl Client { /// Create a builder for building a request. - pub fn request_builder(&self) -> RequestBuilder { + pub fn request_builder(&self) -> RequestBuilder { RequestBuilder::new(self.clone()) } /// Create a builder for building a request with the specified method and URI. - pub fn request(&self, method: Method, uri: U) -> RequestBuilder + pub fn request(&self, method: Method, uri: U) -> RequestBuilder where U: TryInto, U::Error: Into, @@ -763,128 +852,54 @@ impl Client { /// Get the default target address of the client. pub fn default_target(&self) -> &Target { - &self.inner.default_target - } - - /// Send a request to the target address. - /// - /// This is a low-level method and you should build the `uri` and `request`, and get the - /// address by yourself. - /// - /// For simple usage, you can use the `get`, `post` and other methods directly. - /// - /// # Example - /// - /// ```no_run - /// use std::net::SocketAddr; - /// - /// use http::{Method, Uri}; - /// use volo::net::Address; - /// use volo_http::{ - /// body::{Body, BodyConversion}, - /// client::{Client, Target}, - /// request::ClientRequest, - /// }; - /// - /// # tokio_test::block_on(async { - /// let client = Client::builder().build(); - /// let addr: SocketAddr = "[::]:8080".parse().unwrap(); - /// let resp = client - /// .send_request( - /// Target::from_address(addr), - /// Default::default(), - /// ClientRequest::builder() - /// .method(Method::GET) - /// .uri("/") - /// .body(Body::empty()) - /// .expect("build request failed"), - /// ) - /// .await - /// .expect("request failed") - /// .into_string() - /// .await - /// .expect("response failed to convert to string"); - /// println!("{resp:?}"); - /// # }) - /// ``` - pub async fn send_request( - &self, - target: Target, - call_opt: Option, - mut request: ClientRequest, - ) -> Result - where - S: Service, Response = ClientResponse, Error = ClientError> - + Send - + Sync - + 'static, - B: Send + 'static, - { - let caller_name = self.inner.caller_name.clone(); - let callee_name = self.inner.callee_name.clone(); - - let (target, call_opt) = match (target.is_none(), self.inner.default_target.is_none()) { - // The target specified by request exists and we can use it directly. - // - // Note that the default callopt only applies to the default target and should not be - // used here. - (false, _) => (target, call_opt.as_ref()), - // Target is not specified by request, we can use the default target. - // - // Although the request does not set a target, its callopt should be valid for the - // default target. - (true, false) => ( - self.inner.default_target.clone(), - call_opt.as_ref().or(self.inner.default_call_opt.as_ref()), - ), - // Both target are none, return an error. - (true, true) => { - return Err(no_address()); - } - }; - - let host = if callee_name.is_empty() { - target.gen_host() - } else { - HeaderValue::from_str(callee_name.as_str()).ok() - }; - if let Some(host) = host { - request.headers_mut().insert(header::HOST, host); - } - - let scheme = match target.is_https() { - true => Scheme::HTTPS, - false => Scheme::HTTP, - }; - request.extensions_mut().insert(scheme); - - let mut cx = ClientContext::new(); - cx.rpc_info_mut().caller_mut().set_service_name(caller_name); - cx.rpc_info_mut().callee_mut().set_service_name(callee_name); - (self.inner.target_parser)(target, call_opt, cx.rpc_info_mut().callee_mut()); - - self.call(&mut cx, request).await + &self.inner.target } } -impl Service> for Client +impl OneShotService> for Client where - S: Service, Response = ClientResponse, Error = ClientError> - + Send - + Sync - + 'static, - B: Send + 'static, + S: Service, Error = ClientError> + Send + Sync, + B: Send, { type Response = S::Response; type Error = S::Error; async fn call( - &self, + self, cx: &mut ClientContext, mut req: ClientRequest, ) -> Result { + // set target + self.inner.target.clone().apply(cx)?; + + // also save a scheme in request + { + if let Some(scheme) = cx.rpc_info().callee().get::() { + req.extensions_mut().insert(scheme.to_owned()); + } + } + + // set default callee name + { + let callee = cx.rpc_info_mut().callee_mut(); + if callee.service_name_ref().is_empty() { + callee.set_service_name(self.inner.default_callee_name.clone()); + } + } + + // set timeout + { + let config = cx.rpc_info_mut().config_mut(); + // We should check it here because CallOptService must be outer of the client service + if config.timeout().is_none() { + config.set_timeout(self.inner.timeout); + } + } + + // extend headermap req.headers_mut().extend(self.inner.headers.clone()); + // apply metainfo if it does not exist let has_metainfo = METAINFO.try_with(|_| {}).is_ok(); let fut = self.inner.service.call(cx, req); @@ -897,6 +912,23 @@ where } } +impl Service> for Client +where + S: Service, Error = ClientError> + Send + Sync, + B: Send, +{ + type Response = S::Response; + type Error = S::Error; + + fn call( + &self, + cx: &mut ClientContext, + req: ClientRequest, + ) -> impl Future> + Send { + OneShotService::call(self.clone(), cx, req) + } +} + /// A dummy [`MkClient`] that does not have any functionality pub struct DefaultMkClient; @@ -908,13 +940,15 @@ impl MkClient> for DefaultMkClient { } } +static CLIENT: LazyLock = LazyLock::new(Default::default); + /// Create a GET request to the specified URI. pub async fn get(uri: U) -> Result where U: TryInto, U::Error: Into, { - ClientBuilder::new().build().get(uri).send().await + CLIENT.clone().get(uri).send().await } // The `httpbin.org` always responses a json data. @@ -925,22 +959,19 @@ mod client_tests { #[cfg(feature = "cookie")] use cookie::Cookie; - use http::{header, StatusCode}; + use http::{header, status::StatusCode}; use motore::{ - layer::{Layer, Stack}, + layer::{Identity, Layer, Stack}, service::Service, }; use serde::Deserialize; - use volo::{context::Endpoint, layer::Identity}; - use super::{ - callopt::CallOpt, - dns::{parse_target, DnsResolver}, - get, Client, DefaultClient, Target, - }; + use super::{dns::DnsResolver, get, Client, DefaultClient}; #[cfg(feature = "cookie")] use crate::client::cookie::CookieLayer; - use crate::{body::BodyConversion, utils::consts::HTTP_DEFAULT_PORT, ClientBuilder}; + use crate::{ + body::BodyConversion, client::SimpleClient, utils::consts::HTTP_DEFAULT_PORT, ClientBuilder, + }; #[derive(Deserialize)] struct HttpBinResponse { @@ -988,21 +1019,44 @@ mod client_tests { } } - let _: DefaultClient = ClientBuilder::new().build(); - let _: DefaultClient = ClientBuilder::new().layer_inner(TestLayer).build(); - let _: DefaultClient = ClientBuilder::new().layer_inner_front(TestLayer).build(); - let _: DefaultClient = - ClientBuilder::new().layer_outer(TestLayer).build(); + let _: SimpleClient = ClientBuilder::new().build_without_extra_layers().unwrap(); + let _: SimpleClient = ClientBuilder::new() + .layer_inner(TestLayer) + .build_without_extra_layers() + .unwrap(); + let _: SimpleClient = ClientBuilder::new() + .layer_outer(TestLayer) + .build_without_extra_layers() + .unwrap(); + let _: SimpleClient = ClientBuilder::new() + .layer_inner(TestLayer) + .layer_outer(TestLayer) + .build_without_extra_layers() + .unwrap(); + + let _: DefaultClient = ClientBuilder::new().build().unwrap(); + let _: DefaultClient = + ClientBuilder::new().layer_inner(TestLayer).build().unwrap(); + let _: DefaultClient = ClientBuilder::new() + .layer_inner_front(TestLayer) + .build() + .unwrap(); let _: DefaultClient = - ClientBuilder::new().layer_outer_front(TestLayer).build(); + ClientBuilder::new().layer_outer(TestLayer).build().unwrap(); + let _: DefaultClient = ClientBuilder::new() + .layer_outer_front(TestLayer) + .build() + .unwrap(); let _: DefaultClient = ClientBuilder::new() .layer_inner(TestLayer) .layer_outer(TestLayer) - .build(); + .build() + .unwrap(); let _: DefaultClient> = ClientBuilder::new() .layer_inner(TestLayer) .layer_inner(TestLayer) - .build(); + .build() + .unwrap(); } #[tokio::test] @@ -1020,8 +1074,8 @@ mod client_tests { #[tokio::test] async fn client_builder_with_header() { let mut builder = Client::builder(); - builder.header(header::USER_AGENT, USER_AGENT_VAL).unwrap(); - let client = builder.build(); + builder.header(header::USER_AGENT, USER_AGENT_VAL); + let client = builder.build().unwrap(); let resp = client .get(HTTPBIN_GET) @@ -1040,7 +1094,7 @@ mod client_tests { async fn client_builder_with_host() { let mut builder = Client::builder(); builder.host("httpbin.org"); - let client = builder.build(); + let client = builder.build().unwrap(); let resp = client .get("/get") @@ -1061,8 +1115,8 @@ mod client_tests { .await .unwrap(); let mut builder = Client::builder(); - builder.address(addr).callee_name("httpbin.org"); - let client = builder.build(); + builder.default_host("httpbin.org").address(addr); + let client = builder.build().unwrap(); let resp = client .get("/get") @@ -1080,8 +1134,10 @@ mod client_tests { #[tokio::test] async fn client_builder_with_https() { let mut builder = Client::builder(); - builder.host("httpbin.org").with_https(true); - let client = builder.build(); + builder + .host("httpbin.org") + .with_scheme(http::uri::Scheme::HTTPS); + let client = builder.build().unwrap(); let resp = client .get("/get") @@ -1104,10 +1160,10 @@ mod client_tests { .unwrap(); let mut builder = Client::builder(); builder + .default_host("httpbin.org") .address(addr) - .with_https(true) - .callee_name("httpbin.org"); - let client = builder.build(); + .with_scheme(http::uri::Scheme::HTTPS); + let client = builder.build().unwrap(); let resp = client .get("/get") @@ -1125,7 +1181,7 @@ mod client_tests { async fn client_builder_with_port() { let mut builder = Client::builder(); builder.host("httpbin.org").with_port(443); - let client = builder.build(); + let client = builder.build().unwrap(); let resp = client.get("/get").send().await.unwrap(); // Send HTTP request to the HTTPS port (443), `httpbin.org` will response `400 Bad @@ -1140,7 +1196,7 @@ mod client_tests { let mut builder = Client::builder(); builder.disable_tls(true); - let client = builder.build(); + let client = builder.build().unwrap(); assert_eq!( format!( "{}", @@ -1154,93 +1210,6 @@ mod client_tests { ); } - struct CallOptInserted; - - // Wrapper for [`parse_target`] with checking [`CallOptInserted`] - fn callopt_should_inserted( - target: Target, - call_opt: Option<&CallOpt>, - endpoint: &mut Endpoint, - ) { - assert!(call_opt.is_some()); - assert!(call_opt.unwrap().contains::()); - parse_target(target, call_opt, endpoint); - } - - fn callopt_should_not_inserted( - target: Target, - call_opt: Option<&CallOpt>, - endpoint: &mut Endpoint, - ) { - if let Some(call_opt) = call_opt { - assert!(!call_opt.contains::()); - } - parse_target(target, call_opt, endpoint); - } - - #[tokio::test] - async fn no_callopt() { - let mut builder = Client::builder(); - builder.target_parser(callopt_should_not_inserted); - let client = builder.build(); - - let resp = client.get(HTTPBIN_GET).send().await; - assert!(resp.is_ok()); - } - - #[tokio::test] - async fn default_callopt() { - let mut builder = Client::builder(); - builder.with_callopt(CallOpt::new().with(CallOptInserted)); - builder.target_parser(callopt_should_not_inserted); - let client = builder.build(); - - let resp = client.get(HTTPBIN_GET).send().await; - assert!(resp.is_ok()); - } - - #[tokio::test] - async fn request_callopt() { - let mut builder = Client::builder(); - builder.target_parser(callopt_should_inserted); - let client = builder.build(); - - let resp = client - .get(HTTPBIN_GET) - .with_callopt(CallOpt::new().with(CallOptInserted)) - .send() - .await; - assert!(resp.is_ok()); - } - - #[tokio::test] - async fn override_callopt() { - let mut builder = Client::builder(); - builder.with_callopt(CallOpt::new().with(CallOptInserted)); - builder.target_parser(callopt_should_not_inserted); - let client = builder.build(); - - let resp = client - .get(HTTPBIN_GET) - // insert an empty callopt - .with_callopt(CallOpt::new()) - .send() - .await; - assert!(resp.is_ok()); - } - - #[tokio::test] - async fn default_target_and_callopt_with_new_target() { - let mut builder = Client::builder(); - builder.host("httpbin.org"); - builder.with_callopt(CallOpt::new().with(CallOptInserted)); - builder.target_parser(callopt_should_not_inserted); - let client = builder.build(); - - let resp = client.get(HTTPBIN_GET).send().await; - assert!(resp.is_ok()); - } - #[cfg(feature = "cookie")] #[tokio::test] async fn cookie_store() { @@ -1248,7 +1217,7 @@ mod client_tests { builder.host("httpbin.org"); - let client = builder.build(); + let client = builder.build().unwrap(); // test server add cookie let resp = client diff --git a/volo-http/src/client/request_builder.rs b/volo-http/src/client/request_builder.rs index 12bfb015..288e9991 100644 --- a/volo-http/src/client/request_builder.rs +++ b/volo-http/src/client/request_builder.rs @@ -2,17 +2,22 @@ //! //! See [`RequestBuilder`] for more details. -use std::error::Error; +use std::{borrow::Cow, error::Error}; +use faststr::FastStr; use http::{ header::{HeaderMap, HeaderName, HeaderValue}, - uri::PathAndQuery, - Method, Request, Uri, Version, + method::Method, + request::Request, + uri::{PathAndQuery, Scheme, Uri}, + version::Version, +}; +use volo::{ + client::{Apply, OneShotService, WithOptService}, + net::Address, }; -use motore::service::Service; -use volo::net::Address; -use super::{callopt::CallOpt, target::Target, Client}; +use super::{insert_header, target::Target, CallOpt}; use crate::{ body::Body, context::ClientContext, @@ -22,23 +27,24 @@ use crate::{ }, request::ClientRequest, response::ClientResponse, + utils::consts, }; /// The builder for building a request. pub struct RequestBuilder { - client: Client, + inner: S, target: Target, - call_opt: Option, - request: Result>, + request: ClientRequest, + status: Result<()>, } -impl RequestBuilder { - pub(crate) fn new(client: Client) -> Self { +impl RequestBuilder { + pub(super) fn new(inner: S) -> Self { Self { - client, + inner, target: Default::default(), - call_opt: Default::default(), - request: Ok(ClientRequest::default()), + request: ClientRequest::default(), + status: Ok(()), } } @@ -48,23 +54,20 @@ impl RequestBuilder { D: TryInto, D::Error: Error + Send + Sync + 'static, { - if self.request.is_err() { + if self.status.is_err() { return self; } - let Ok(req) = self.request else { - unreachable!(); - }; let body = match data.try_into() { Ok(body) => body, Err(err) => { - self.request = Err(builder_error(err)); + self.status = Err(builder_error(err)); return self; } }; - let (parts, _) = req.into_parts(); - self.request = Ok(Request::from_parts(parts, body)); + let (parts, _) = self.request.into_parts(); + self.request = Request::from_parts(parts, body); self } @@ -75,27 +78,24 @@ impl RequestBuilder { where T: serde::Serialize, { - if self.request.is_err() { + if self.status.is_err() { return self; } - let Ok(req) = self.request else { - unreachable!(); - }; let json = match crate::utils::json::serialize(json) { Ok(json) => json, Err(err) => { - self.request = Err(builder_error(err)); + self.status = Err(builder_error(err)); return self; } }; - let (mut parts, _) = req.into_parts(); + let (mut parts, _) = self.request.into_parts(); parts.headers.insert( http::header::CONTENT_TYPE, crate::utils::consts::APPLICATION_JSON, ); - self.request = Ok(Request::from_parts(parts, Body::from(json))); + self.request = Request::from_parts(parts, Body::from(json)); self } @@ -106,27 +106,24 @@ impl RequestBuilder { where T: serde::Serialize, { - if self.request.is_err() { + if self.status.is_err() { return self; } - let Ok(req) = self.request else { - unreachable!(); - }; let form = match serde_urlencoded::to_string(form) { Ok(form) => form, Err(err) => { - self.request = Err(builder_error(err)); + self.status = Err(builder_error(err)); return self; } }; - let (mut parts, _) = req.into_parts(); + let (mut parts, _) = self.request.into_parts(); parts.headers.insert( http::header::CONTENT_TYPE, crate::utils::consts::APPLICATION_WWW_FORM_URLENCODED, ); - self.request = Ok(Request::from_parts(parts, Body::from(form))); + self.request = Request::from_parts(parts, Body::from(form)); self } @@ -135,15 +132,13 @@ impl RequestBuilder { impl RequestBuilder { /// Set method for the request. pub fn method(mut self, method: Method) -> Self { - if let Ok(req) = self.request.as_mut() { - *req.method_mut() = method; - } + *self.request.method_mut() = method; self } /// Get a reference to method in the request. - pub fn method_ref(&self) -> Option<&Method> { - self.request.as_ref().ok().map(Request::method) + pub fn method_ref(&self) -> &Method { + self.request.method() } /// Set uri for building request. @@ -159,21 +154,21 @@ impl RequestBuilder { U: TryInto, U::Error: Into, { - if self.request.is_err() { + if self.status.is_err() { return self; } let uri = match uri.try_into() { Ok(uri) => uri, Err(err) => { - self.request = Err(builder_error(err)); + self.status = Err(builder_error(err)); return self; } }; - if let Some(target) = Target::from_uri(&uri) { - match target { + if uri.host().is_some() { + match Target::from_uri(&uri) { Ok(target) => self.target = target, Err(err) => { - self.request = Err(err); + self.status = Err(err); return self; } } @@ -183,59 +178,31 @@ impl RequestBuilder { .map(PathAndQuery::to_owned) .unwrap_or_else(|| PathAndQuery::from_static("/")) .into(); - let Ok(req) = self.request.as_mut() else { - unreachable!(); - }; - *req.uri_mut() = rela_uri; + *self.request.uri_mut() = rela_uri; self } /// Set full uri for building request. /// - /// In this function, scheme and host will be resolved as the target address, and the full uri - /// will be set as the request uri. - /// /// This function is only used for using http(s) proxy. pub fn full_uri(mut self, uri: U) -> Self where U: TryInto, U::Error: Into, { - if self.request.is_err() { + if self.status.is_err() { return self; } let uri = match uri.try_into() { Ok(uri) => uri, Err(err) => { - self.request = Err(builder_error(err)); + self.status = Err(builder_error(err)); return self; } }; - if let Some(target) = Target::from_uri(&uri) { - match target { - Ok(target) => self.target = target, - Err(err) => { - self.request = Err(err); - return self; - } - } - } - let Ok(req) = self.request.as_mut() else { - unreachable!(); - }; - *req.uri_mut() = uri; - - self - } + *self.request.uri_mut() = uri; - /// Set a [`CallOpt`] to the request. - /// - /// The [`CallOpt`] is used for service discover, default is an empty one. - /// - /// See [`CallOpt`] for more details. - pub fn with_callopt(mut self, call_opt: CallOpt) -> Self { - self.call_opt = Some(call_opt); self } @@ -245,22 +212,19 @@ impl RequestBuilder { where T: serde::Serialize, { - if self.request.is_err() { + if self.status.is_err() { return self; } let query_str = match serde_urlencoded::to_string(query) { Ok(query) => query, Err(err) => { - self.request = Err(builder_error(err)); + self.status = Err(builder_error(err)); return self; } }; - let Ok(req) = self.request.as_mut() else { - unreachable!(); - }; // We should keep path only without query - let path_str = req.uri().path(); + let path_str = self.request.uri().path(); let mut path = String::with_capacity(path_str.len() + 1 + query_str.len()); path.push_str(path_str); path.push('?'); @@ -270,73 +234,54 @@ impl RequestBuilder { unreachable!(); }; - *req.uri_mut() = uri; + *self.request.uri_mut() = uri; self } /// Get a reference to uri in the request. - pub fn uri_ref(&self) -> Option<&Uri> { - self.request.as_ref().ok().map(Request::uri) + pub fn uri_ref(&self) -> &Uri { + self.request.uri() } /// Set version of the HTTP request. pub fn version(mut self, version: Version) -> Self { - if let Ok(req) = self.request.as_mut() { - *req.version_mut() = version; - } + *self.request.version_mut() = version; self } /// Get a reference to version in the request. - pub fn version_ref(&self) -> Option { - self.request.as_ref().ok().map(Request::version) + pub fn version_ref(&self) -> Version { + self.request.version() } /// Insert a header into the request header map. pub fn header(mut self, key: K, value: V) -> Self where K: TryInto, - K::Error: Into, + K::Error: Error + Send + Sync + 'static, V: TryInto, - V::Error: Into, + V::Error: Error + Send + Sync + 'static, { - if self.request.is_err() { + if self.status.is_err() { return self; } - let key = match key.try_into() { - Ok(key) => key, - Err(err) => { - self.request = Err(builder_error(err.into())); - return self; - } - }; - let value = match value.try_into() { - Ok(value) => value, - Err(err) => { - self.request = Err(builder_error(err.into())); - return self; - } - }; - - let Ok(req) = self.request.as_mut() else { - unreachable!(); - }; - - req.headers_mut().insert(key, value); + if let Err(err) = insert_header(self.request.headers_mut(), key, value) { + self.status = Err(err); + } self } /// Get a reference to headers in the request. - pub fn headers(&self) -> Option<&HeaderMap> { - self.request.as_ref().ok().map(Request::headers) + pub fn headers(&self) -> &HeaderMap { + self.request.headers() } /// Get a mutable reference to headers in the request. - pub fn headers_mut(&mut self) -> Option<&mut HeaderMap> { - self.request.as_mut().ok().map(Request::headers_mut) + pub fn headers_mut(&mut self) -> &mut HeaderMap { + self.request.headers_mut() } /// Set target address for the request. @@ -344,7 +289,7 @@ impl RequestBuilder { where A: Into
      , { - self.target = Target::from_address(address); + self.target = Target::from(address.into()); self } @@ -352,26 +297,42 @@ impl RequestBuilder { /// /// It uses http with port 80 by default. /// - /// For setting scheme and port, use [`Self::with_port`] and [`Self::with_https`] after + /// For setting scheme and port, use [`Self::with_scheme`] and [`Self::with_port`] after /// specifying host. pub fn host(mut self, host: H) -> Self where - H: AsRef, + H: Into>, { - self.target = Target::from_host(host); + // SAFETY: using HTTP is safe + self.target = unsafe { + Target::new_host_unchecked( + Scheme::HTTP, + FastStr::from(host.into()), + consts::HTTP_DEFAULT_PORT, + ) + }; self } - /// Set port for the target address of this request. - pub fn with_port(mut self, port: u16) -> Self { - self.target.set_port(port); + /// Set scheme for target of the request. + pub fn with_scheme(mut self, scheme: Scheme) -> Self { + if self.status.is_err() { + return self; + } + if let Err(err) = self.target.set_scheme(scheme) { + self.status = Err(err); + } self } - /// Set if the request uses https. - #[cfg(feature = "__tls")] - pub fn with_https(mut self, https: bool) -> Self { - self.target.set_https(https); + /// Set port for target address of this request. + pub fn with_port(mut self, port: u16) -> Self { + if self.status.is_err() { + return self; + } + if let Err(err) = self.target.set_port(port) { + self.status = Err(err); + } self } @@ -385,53 +346,51 @@ impl RequestBuilder { &mut self.target } - /// Get a reference to [`CallOpt`]. - pub fn callopt_ref(&self) -> &Option { - &self.call_opt - } - - /// Get a mutable reference to [`CallOpt`]. - pub fn callopt_mut(&mut self) -> &mut Option { - &mut self.call_opt - } - /// Set a request body. pub fn body(self, body: B2) -> RequestBuilder { - let request = match self.request { - Ok(req) => { - let (parts, _) = req.into_parts(); - Ok(Request::from_parts(parts, body)) - } - Err(err) => Err(err), - }; + let (parts, _) = self.request.into_parts(); + let request = Request::from_parts(parts, body); RequestBuilder { - client: self.client, + inner: self.inner, target: self.target, - call_opt: self.call_opt, request, + status: self.status, } } /// Get a reference to body in the request. - pub fn body_ref(&self) -> Option<&B> { - self.request.as_ref().ok().map(Request::body) + pub fn body_ref(&self) -> &B { + self.request.body() + } + + /// Apply a [`CallOpt`] to the request. + pub fn with_callopt(self, callopt: CallOpt) -> RequestBuilder, B> { + RequestBuilder { + inner: WithOptService::new(self.inner, callopt), + target: self.target, + request: self.request, + status: self.status, + } } -} -impl RequestBuilder -where - S: Service, Response = ClientResponse, Error = ClientError> - + Send - + Sync - + 'static, - B: Send + 'static, -{ /// Send the request and get the response. - pub async fn send(self) -> Result { - self.client - .send_request(self.target, self.call_opt, self.request?) - .await + pub async fn send(self) -> Result + where + S: OneShotService< + ClientContext, + ClientRequest, + Response = ClientResponse, + Error = ClientError, + > + Send + + Sync + + 'static, + B: Send + 'static, + { + self.status?; + let mut cx = ClientContext::new(); + self.target.apply(&mut cx)?; + self.inner.call(&mut cx, self.request).await } } @@ -443,8 +402,7 @@ mod request_tests { use serde::Deserialize; - use super::Client; - use crate::body::BodyConversion; + use crate::{body::BodyConversion, client::Client}; #[allow(dead_code)] #[derive(Deserialize)] @@ -471,7 +429,7 @@ mod request_tests { async fn set_query() { let data = test_data(); - let client = Client::builder().build(); + let client = Client::builder().build().unwrap(); let resp = client .get("http://httpbin.org/get") .set_query(&data) @@ -489,7 +447,7 @@ mod request_tests { async fn set_form() { let data = test_data(); - let client = Client::builder().build(); + let client = Client::builder().build().unwrap(); let resp = client .post("http://httpbin.org/post") .form(&data) @@ -506,7 +464,7 @@ mod request_tests { async fn set_json() { let data = test_data(); - let client = Client::builder().build(); + let client = Client::builder().build().unwrap(); let resp = client .post("http://httpbin.org/post") .json(&data) @@ -519,3 +477,85 @@ mod request_tests { assert_eq!(resp.json, Some(data)); } } + +#[cfg(test)] +mod with_callopt_tests { + use std::{future::Future, time::Duration}; + + use http::status::StatusCode; + use motore::service::Service; + use volo::context::Context; + + use crate::{ + body::{Body, BodyConversion}, + client::{layer::FailOnStatus, test_helpers::MockTransport, CallOpt, Client}, + context::client::Config, + error::ClientError, + response::ClientResponse, + }; + + struct GetTimeoutAsSeconds; + + impl Service for GetTimeoutAsSeconds + where + Cx: Context, + { + type Response = ClientResponse; + type Error = ClientError; + + fn call( + &self, + cx: &mut Cx, + _: Req, + ) -> impl Future> + Send { + let timeout = cx.rpc_info().config().timeout(); + let resp = match timeout { + Some(timeout) => { + let secs = timeout.as_secs(); + ClientResponse::new(Body::from(format!("{secs}"))) + } + None => { + let mut resp = ClientResponse::new(Body::empty()); + *resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + resp + } + }; + async { Ok(resp) } + } + } + + #[tokio::test] + async fn callopt_test() { + let mut builder = Client::builder(); + builder.set_request_timeout(Duration::from_secs(1)); + let client = builder + .layer_outer_front(FailOnStatus::server_error()) + .mock(MockTransport::service(GetTimeoutAsSeconds)) + .unwrap(); + // default timeout is 1 seconds + assert_eq!( + client + .get("/") + .send() + .await + .unwrap() + .into_string() + .await + .unwrap(), + "1" + ); + // callopt set timeout to 5 seconds + assert_eq!( + client + .get("/") + .with_callopt(CallOpt::new().with_timeout(Duration::from_secs(5))) + .send() + .await + .unwrap() + .into_string() + .await + .unwrap(), + "5" + ); + } +} diff --git a/volo-http/src/client/target.rs b/volo-http/src/client/target.rs index 7e5b67f5..8765d5bd 100644 --- a/volo-http/src/client/target.rs +++ b/volo-http/src/client/target.rs @@ -2,25 +2,25 @@ //! //! See [`Target`], [`RemoteTarget`] for more details. -use std::net::{IpAddr, SocketAddr}; +use std::{ + borrow::Cow, + net::{IpAddr, SocketAddr}, +}; use faststr::FastStr; -use http::{uri::Scheme, HeaderValue, Uri}; -use volo::{context::Endpoint, net::Address}; +use http::uri::{Scheme, Uri}; +use volo::{client::Apply, context::Context, net::Address}; +use super::dns::Port; use crate::{ - client::callopt::CallOpt, - error::{client::bad_scheme, ClientError}, + context::ClientContext, + error::{ + client::{bad_address, bad_scheme, no_address, Result}, + ClientError, + }, utils::consts, }; -/// Function for parsing [`Target`] and [`CallOpt`] to update [`Endpoint`]. -/// -/// The `TargetParser` usually used for service discover. It can update [`Endpoint` ]from -/// [`Target`] and [`CallOpt`], and the service discover will resolve the [`Endpoint`] to -/// [`Address`]\(es\) and access them. -pub type TargetParser = fn(Target, Option<&CallOpt>, &mut Endpoint); - /// HTTP target server descriptor #[derive(Clone, Debug, Default)] pub enum Target { @@ -37,13 +37,12 @@ pub enum Target { /// Remote part of [`Target`] #[derive(Clone, Debug)] pub struct RemoteTarget { + /// Target scheme + pub scheme: Scheme, /// The target address pub addr: RemoteTargetAddress, - /// Target port, its default value depends on scheme - pub port: Option, - /// Use https for transporting - #[cfg(feature = "__tls")] - pub https: bool, + /// Target port + pub port: u16, } /// Remote address of [`RemoteTarget`] @@ -55,68 +54,170 @@ pub enum RemoteTargetAddress { Name(FastStr), } -impl Target { - /// Build a `Target` from `Uri`. - /// - /// If there is no host, `None` will be returned. If there is a host, but the uri has something - /// invalid (e.g., unsupported scheme), an error will be returned. - pub fn from_uri(uri: &Uri) -> Option> { - let host = uri.host()?; - let Some(https) = is_https(uri) else { - tracing::error!("[Volo-HTTP] unsupported scheme: {:?}.", uri.scheme()); - return Some(Err(bad_scheme())); - }; +fn check_scheme(scheme: &Scheme) -> Result<()> { + if scheme == &Scheme::HTTPS { #[cfg(not(feature = "__tls"))] - if https { - tracing::error!("[Volo-HTTP] https is not allowed when feature `tls` is not enabled."); - return Some(Err(bad_scheme())); + { + tracing::error!("[Volo-HTTP] https is not allowed when feature `tls` is not enabled"); + return Err(bad_scheme()); } + #[cfg(feature = "__tls")] + return Ok(()); + } + if scheme == &Scheme::HTTP { + return Ok(()); + } + tracing::error!("[Volo-HTTP] scheme '{scheme}' is unsupported"); + Err(bad_scheme()) +} - let addr = match host - .trim_start_matches('[') - .trim_end_matches(']') - .parse::() - { - Ok(ip) => RemoteTargetAddress::Ip(ip), - Err(_) => RemoteTargetAddress::Name(FastStr::from_string(host.to_owned())), - }; - let port = uri.port_u16(); - Some(Ok(Self::Remote(RemoteTarget { - addr, - port, - #[cfg(feature = "__tls")] - https, - }))) +fn get_default_port(scheme: &Scheme) -> u16 { + #[cfg(feature = "__tls")] + if scheme == &Scheme::HTTPS { + return consts::HTTPS_DEFAULT_PORT; + } + if scheme == &Scheme::HTTP { + return consts::HTTP_DEFAULT_PORT; } + unreachable!("[Volo-HTTP] https is not allowed when feature `tls` is not enabled") +} - /// Build a `Target` from an address. - pub fn from_address(addr: A) -> Self - where - A: Into
      , - { - Self::from(addr.into()) +pub(super) fn is_default_port(scheme: &Scheme, port: u16) -> bool { + get_default_port(scheme) == port +} + +impl Target { + /// Create a [`Target`] by a scheme, host and port without checking scheme + /// + /// # Safety + /// + /// Users must ensure that the scheme is valid. + /// + /// - HTTP is always valid + /// - HTTPS is valid if any feature of tls is enabled + /// - Other schemes are always invalid + pub const unsafe fn new_host_unchecked(scheme: Scheme, host: FastStr, port: u16) -> Self { + Self::Remote(RemoteTarget { + scheme, + addr: RemoteTargetAddress::Name(host), + port, + }) } - /// Build a `Target` from a host name. + /// Create a [`Target`] by a scheme, ip address and port without checking scheme /// - /// Note that the `host` must be a host name, it will be used for service discover. + /// # Safety /// - /// It should NOT be an address or something with port. + /// Users must ensure that the scheme is valid. /// - /// If you have a uri and you are not sure if the host is a host, try `from_uri`. + /// - HTTP is always valid + /// - HTTPS is valid if any feature of tls is enabled + /// - Other schemes are always invalid + pub const unsafe fn new_addr_unchecked(scheme: Scheme, ip: IpAddr, port: u16) -> Self { + Self::Remote(RemoteTarget { + scheme, + addr: RemoteTargetAddress::Ip(ip), + port, + }) + } + + /// Create a [`Target`] through a scheme, host name and a port + pub fn new_host(scheme: Option, host: S, port: Option) -> Result + where + S: Into>, + { + let scheme = scheme.unwrap_or(Scheme::HTTP); + check_scheme(&scheme)?; + let host = FastStr::from(host.into()); + let port = match port { + Some(p) => p, + None => get_default_port(&scheme), + }; + // SAFETY: we've checked scheme + Ok(unsafe { Self::new_host_unchecked(scheme, host, port) }) + } + + /// Create a [`Target`] through a scheme, ip address and a port + pub fn new_addr(scheme: Option, ip: IpAddr, port: Option) -> Result { + let scheme = scheme.unwrap_or(Scheme::HTTP); + check_scheme(&scheme)?; + let port = match port { + Some(p) => p, + None => get_default_port(&scheme), + }; + // SAFETY: we've checked scheme + Ok(unsafe { Self::new_addr_unchecked(scheme, ip, port) }) + } + + /// Create a [`Target`] through a host name pub fn from_host(host: S) -> Self where - S: AsRef, + S: Into>, { - Self::Remote(RemoteTarget { - addr: RemoteTargetAddress::Name(FastStr::from_string(host.as_ref().to_owned())), - port: None, - #[cfg(feature = "__tls")] - https: false, + let host = FastStr::from(host.into()); + // SAFETY: HTTP is always valid + unsafe { Self::new_host_unchecked(Scheme::HTTP, host, consts::HTTP_DEFAULT_PORT) } + } + + /// Create a [`Target`] from [`Uri`] + pub fn from_uri(uri: &Uri) -> Result { + let scheme = uri.scheme().cloned().unwrap_or(Scheme::HTTP); + check_scheme(&scheme)?; + let host = uri.host().ok_or_else(no_address)?; + let port = match uri.port_u16() { + Some(p) => p, + None => get_default_port(&scheme), + }; + + // SAFETY: we've checked scheme + Ok(unsafe { + match host + .trim_start_matches('[') + .trim_end_matches(']') + .parse::() + { + Ok(ip) => Self::new_addr_unchecked(scheme, ip, port), + Err(_) => { + Self::new_host_unchecked(scheme, FastStr::from_string(host.to_owned()), port) + } + } }) } - /// Return if the `Target` is `None`. + /// Set a new scheme to the [`Target`] + /// + /// Note that if the previous is default port of the previous scheme, the port will be also + /// updated to default port of the new scheme. + pub fn set_scheme(&mut self, scheme: Scheme) -> Result<()> { + let rt = match self.remote_mut() { + Some(rt) => rt, + None => { + tracing::warn!("[Volo-HTTP] set scheme to an empty target or uds is invalid"); + return Err(bad_address()); + } + }; + check_scheme(&scheme)?; + if is_default_port(&rt.scheme, rt.port) { + rt.port = get_default_port(&scheme); + } + rt.scheme = scheme; + Ok(()) + } + + /// Set a new port to the [`Target`] + pub fn set_port(&mut self, port: u16) -> Result<()> { + let rt = match self.remote_mut() { + Some(rt) => rt, + None => { + tracing::warn!("[Volo-HTTP] set port to an empty target or uds is invalid"); + return Err(bad_address()); + } + }; + rt.port = port; + Ok(()) + } + + /// Return if the [`Target`] is [`Target::None`] pub fn is_none(&self) -> bool { matches!(self, Target::None) } @@ -137,39 +238,6 @@ impl Target { } } - /// Set remote port and return a new target. - pub fn set_port(&mut self, port: u16) { - if let Some(rt) = self.remote_mut() { - rt.port = Some(port); - } - } - - /// Set if use https for the target. - /// - /// If the [`Target`] cannot use https ([`Target::None`] or [`Target::Local`]), this function - /// will do nothing. - #[cfg(feature = "__tls")] - pub fn set_https(&mut self, https: bool) { - if let Some(rt) = self.remote_mut() { - rt.set_https(https); - } - } - - /// Check if the target uses https. - /// - /// If the [`Target`] cannot use https ([`Target::None`] or [`Target::Local`]), this function - /// will return `false`. - pub fn is_https(&self) -> bool { - #[cfg(feature = "__tls")] - if let Some(rt) = self.remote_ref() { - rt.is_https() - } else { - false - } - #[cfg(not(feature = "__tls"))] - false - } - /// Return the remote [`IpAddr`] if the [`Target`] is an IP address. pub fn remote_ip(&self) -> Option<&IpAddr> { if let Self::Remote(rt) = &self { @@ -200,60 +268,21 @@ impl Target { } } - /// Return the port if the [`Target`] is a remote address and the port is given. - pub fn port(&self) -> Option { - if let Some(remote) = self.remote_ref() { - remote.port + /// Return target scheme if the [`Target`] is a remote address + pub fn scheme(&self) -> Option<&Scheme> { + if let Self::Remote(rt) = self { + Some(&rt.scheme) } else { None } } - /// Generate a `HeaderValue` for `Host` in HTTP headers. - pub fn gen_host(&self) -> Option { - HeaderValue::try_from(self.try_to_string()?).ok() - } - - fn is_default_port(&self) -> bool { - let Some(rt) = self.remote_ref() else { - // Local address does not have port. - return false; - }; - let Some(port) = rt.port else { - // `None` means using default port. - return true; - }; - #[cfg(feature = "__tls")] - if rt.https { - return port == consts::HTTPS_DEFAULT_PORT; - } - port == consts::HTTP_DEFAULT_PORT - } - - fn try_to_string(&self) -> Option { - let rt = self.remote_ref()?; - let without_port = self.is_default_port(); - match rt.addr { - RemoteTargetAddress::Ip(ref ip) => { - if without_port { - return Some(ip.to_string()); - } - // SAFETY: the port must exist if the port is non-default one - let port = rt.port.unwrap(); - if ip.is_ipv6() { - Some(format!("[{ip}]:{port}")) - } else { - Some(format!("{ip}:{port}")) - } - } - RemoteTargetAddress::Name(ref name) => { - if without_port { - return Some(name.to_string()); - } - // SAFETY: the port must exist if the port is non-default one - let port = rt.port.unwrap(); - Some(format!("{name}:{port}")) - } + /// Return target port if the [`Target`] is a remote address + pub fn port(&self) -> Option { + if let Self::Remote(rt) = self { + Some(rt.port) + } else { + None } } } @@ -261,444 +290,56 @@ impl Target { impl From
      for Target { fn from(value: Address) -> Self { match value { - Address::Ip(sa) => Target::Remote(RemoteTarget { - addr: RemoteTargetAddress::Ip(sa.ip()), - port: Some(sa.port()), - #[cfg(feature = "__tls")] - https: false, - }), + Address::Ip(sa) => { + // SAFETY: HTTP is always valid + unsafe { Target::new_addr_unchecked(Scheme::HTTP, sa.ip(), sa.port()) } + } #[cfg(target_family = "unix")] Address::Unix(uds) => Target::Local(uds), } } } -impl TryFrom for Address { - type Error = Target; +impl Apply for Target { + type Error = ClientError; - fn try_from(value: Target) -> Result { - match value { - Target::None => Err(value), - #[cfg(target_family = "unix")] - Target::Local(sa) => Ok(Address::Unix(sa)), - Target::Remote(rt) => { - let port = rt.port(); - if let RemoteTargetAddress::Ip(ip) = rt.addr { - Ok(Address::Ip(SocketAddr::new(ip, port))) - } else { - Err(Target::Remote(rt)) - } - } + fn apply(self, cx: &mut ClientContext) -> Result<(), Self::Error> { + if self.is_none() { + return Ok(()); } - } -} -impl RemoteTarget { - /// Get the target port for the [`RemoteTarget`]. - /// - /// If the port has not been set, it will return a default port based on if https is enabled. - pub fn port(&self) -> u16 { - if let Some(port) = self.port { - return port; + let callee = cx.rpc_info_mut().callee_mut(); + if !(callee.service_name_ref().is_empty() && callee.address.is_none()) { + // Target exists in context + return Ok(()); } - #[cfg(feature = "__tls")] - if self.https { - return consts::HTTPS_DEFAULT_PORT; - } - consts::HTTP_DEFAULT_PORT - } - - /// Set if use https for the target. - #[cfg(feature = "__tls")] - pub fn set_https(&mut self, https: bool) { - self.https = https; - } - - /// Check if the target uses https. - #[cfg(feature = "__tls")] - pub fn is_https(&self) -> bool { - self.https - } -} - -fn is_https(uri: &Uri) -> Option { - let Some(scheme) = uri.scheme() else { - return Some(false); - }; - if scheme == &Scheme::HTTPS { - return Some(true); - } - if scheme == &Scheme::HTTP { - return Some(false); - } - None -} - -#[cfg(test)] -mod target_tests { - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; - - use http::uri::Uri; - use volo::net::Address; - - use super::Target; - - #[test] - fn test_from_uri() { - // no domain name - let target = Target::from_uri(&Uri::from_static("/api/v1/config")); - assert!(target.is_none()); - - // invalid scheme - let target = Target::from_uri(&Uri::from_static("ftp://github.com")); - assert!(matches!(target, Some(Err(_)))); - - // ipv4 only - let target = Target::from_uri(&Uri::from_static("10.0.0.1")); - assert!(matches!(target, Some(Ok(_)))); - let target = target.unwrap().unwrap(); - assert_eq!( - target.remote_ip().unwrap().to_string(), - IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)).to_string(), - ); - assert_eq!(target.port(), None); - assert!(!target.is_https()); - - // ipv4 with port - let target = Target::from_uri(&Uri::from_static("10.0.0.1:8000")); - assert!(matches!(target, Some(Ok(_)))); - let target = target.unwrap().unwrap(); - assert_eq!( - target.remote_ip().unwrap().to_string(), - IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)).to_string(), - ); - assert_eq!(target.port(), Some(8000)); - assert!(!target.is_https()); - - // ipv6 with port - let target = Target::from_uri(&Uri::from_static("[ff::1]:8000")); - assert!(matches!(target, Some(Ok(_)))); - let target = target.unwrap().unwrap(); - assert_eq!( - target.remote_ip().unwrap().to_string(), - IpAddr::V6(Ipv6Addr::new(0xff, 0, 0, 0, 0, 0, 0, 1)).to_string(), - ); - assert_eq!(target.port(), Some(8000)); - assert!(!target.is_https()); - - // domain name only - let target = Target::from_uri(&Uri::from_static("github.com")); - assert!(matches!(target, Some(Ok(_)))); - let target = target.unwrap().unwrap(); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert_eq!(target.port(), None); - assert!(!target.is_https()); - - // domain with scheme (http) - let target = Target::from_uri(&Uri::from_static("http://github.com/")); - assert!(matches!(target, Some(Ok(_)))); - let target = target.unwrap().unwrap(); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert_eq!(target.port(), None); - assert!(!target.is_https()); - - // domain with port - let target = Target::from_uri(&Uri::from_static("github.com:8000")); - assert!(matches!(target, Some(Ok(_)))); - let target = target.unwrap().unwrap(); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert_eq!(target.port(), Some(8000)); - assert!(!target.is_https()); - - // domain with scheme (http) and port - let target = Target::from_uri(&Uri::from_static("http://github.com:8000/")); - assert!(matches!(target, Some(Ok(_)))); - let target = target.unwrap().unwrap(); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert_eq!(target.port(), Some(8000)); - assert!(!target.is_https()); - } - - #[cfg(not(feature = "__tls"))] - #[test] - fn test_from_uri_without_tls() { - // domain with scheme (https) - - use crate::error::client::bad_scheme; - let target = Target::from_uri(&Uri::from_static("https://github.com")); - assert!(matches!(target, Some(Err(_)))); - assert_eq!( - format!("{}", target.unwrap().unwrap_err()), - format!("{}", bad_scheme()), - ); - - // domain with scheme (https) and port - let target = Target::from_uri(&Uri::from_static("https://github.com:8000/")); - assert!(matches!(target, Some(Err(_)))); - assert_eq!( - format!("{}", target.unwrap().unwrap_err()), - format!("{}", bad_scheme()), - ); - } - #[cfg(feature = "__tls")] - #[test] - fn test_from_uri_with_tls() { - // domain with scheme (https) - let target = Target::from_uri(&Uri::from_static("https://github.com")); - assert!(matches!(target, Some(Ok(_)))); - let target = target.unwrap().unwrap(); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert_eq!(target.port(), None); - assert!(target.is_https()); - - // domain with scheme (https) and port - let target = Target::from_uri(&Uri::from_static("https://github.com:8000/")); - assert!(matches!(target, Some(Ok(_)))); - let target = target.unwrap().unwrap(); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert_eq!(target.port(), Some(8000)); - assert!(target.is_https()); - } - - #[test] - fn test_from_ip_address() { - // IPv4 - let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); - let port = 8000; - let addr = Address::Ip(SocketAddr::new(ip, port)); - let target = Target::from_address(addr); - assert_eq!(target.remote_ip(), Some(&ip)); - assert_eq!(target.port(), Some(port)); - assert!(!target.is_https()); - - // IPv6 - let ip = IpAddr::V6(Ipv6Addr::new(0xff, 0, 0, 0, 0, 0, 0, 0)); - let port = 8000; - let addr = Address::Ip(SocketAddr::new(ip, port)); - let target = Target::from_address(addr); - assert_eq!(target.remote_ip(), Some(&ip)); - assert_eq!(target.port(), Some(port)); - assert!(!target.is_https()); - } - - #[cfg(target_family = "unix")] - #[test] - fn test_from_uds_address() { - #[derive(Debug, PartialEq, Eq)] - struct SocketAddr { - addr: libc::sockaddr_un, - len: libc::socklen_t, - } - - let uds = std::os::unix::net::SocketAddr::from_pathname("/tmp/test.sock").unwrap(); - let addr = Address::Unix(uds.clone()); - let target = Target::from_address(addr); - - // Use a same struct with `PartialEq` and `Eq` and transmute them for comparing. - let uds: SocketAddr = unsafe { std::mem::transmute(uds) }; - let target_uds: SocketAddr = - unsafe { std::mem::transmute(target.unix_socket_addr().unwrap().to_owned()) }; - assert_eq!(target_uds, uds); - assert!(target.port().is_none()); - assert!(!target.is_https()); - } - - #[test] - fn test_from_host() { - let target = Target::from_host("github.com"); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert!(target.port().is_none()); - assert!(!target.is_https()); - } - - #[test] - fn test_uri_with_port() { - // domain name only - let target = Target::from_uri(&Uri::from_static("github.com")); - assert!(matches!(target, Some(Ok(_)))); - let mut target = target.unwrap().unwrap(); - target.set_port(8000); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert_eq!(target.port(), Some(8000)); - - // domain name with port and override it - let target = Target::from_uri(&Uri::from_static("github.com:80")); - assert!(matches!(target, Some(Ok(_)))); - let mut target = target.unwrap().unwrap(); - target.set_port(8000); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert_eq!(target.port(), Some(8000)); - } - - #[test] - fn test_ip_with_port() { - // IPv4 - let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); - let port = 8000; - let addr = Address::Ip(SocketAddr::new(ip, port)); - let mut target = Target::from_address(addr); - target.set_port(80); - assert_eq!(target.remote_ip(), Some(&ip)); - assert_eq!(target.port(), Some(80)); - - // IPv6 - let ip = IpAddr::V6(Ipv6Addr::new(0xff, 0, 0, 0, 0, 0, 0, 1)); - let port = 8000; - let addr = Address::Ip(SocketAddr::new(ip, port)); - let mut target = Target::from_address(addr); - target.set_port(80); - assert_eq!(target.remote_ip(), Some(&ip)); - assert_eq!(target.port(), Some(80)); - } - - #[cfg(target_family = "unix")] - #[test] - fn test_uds_with_port() { - let uds = std::os::unix::net::SocketAddr::from_pathname("/tmp/test.sock").unwrap(); - let addr = Address::Unix(uds.clone()); - let mut target = Target::from_address(addr); - assert!(target.port().is_none()); - target.set_port(80); - // uds does not have port - assert!(target.port().is_none()); - } - - #[test] - fn test_host_with_port() { - let mut target = Target::from_host("github.com"); - let port = 8000; - target.set_port(port); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert_eq!(target.port(), Some(port)); - } - - #[cfg(feature = "__tls")] - #[test] - fn test_uri_with_https() { - // domain name only - let target = Target::from_uri(&Uri::from_static("github.com")); - assert!(matches!(target, Some(Ok(_)))); - let mut target = target.unwrap().unwrap(); - target.set_https(true); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert!(target.is_https()); - - // domain name with http and override it - let target = Target::from_uri(&Uri::from_static("http://github.com")); - assert!(matches!(target, Some(Ok(_)))); - let mut target = target.unwrap().unwrap(); - target.set_https(true); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert!(target.is_https()); - } - - #[cfg(feature = "__tls")] - #[test] - fn test_ip_with_https() { - // IPv4 - let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); - let port = 8000; - let addr = Address::Ip(SocketAddr::new(ip, port)); - let mut target = Target::from_address(addr); - target.set_https(true); - assert_eq!(target.remote_ip(), Some(&ip)); - assert!(target.is_https()); - - // IPv6 - let ip = IpAddr::V6(Ipv6Addr::new(0xff, 0, 0, 0, 0, 0, 0, 0)); - let port = 8000; - let addr = Address::Ip(SocketAddr::new(ip, port)); - let mut target = Target::from_address(addr); - target.set_https(true); - assert_eq!(target.remote_ip(), Some(&ip)); - assert!(target.is_https()); - } - - #[cfg(all(feature = "__tls", target_family = "unix"))] - #[test] - fn test_uds_with_https() { - let uds = std::os::unix::net::SocketAddr::from_pathname("/tmp/test.sock").unwrap(); - let addr = Address::Unix(uds.clone()); - let mut target = Target::from_address(addr); - assert!(target.port().is_none()); - target.set_https(true); - // uds does not have port - assert!(!target.is_https()); - } - - #[cfg(feature = "__tls")] - #[test] - fn test_host_with_https() { - let mut target = Target::from_host("github.com"); - target.set_https(true); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert!(target.is_https()); - } - - #[cfg(feature = "__tls")] - #[test] - fn test_gen_host() { - fn gen_host_to_string(target: &Target) -> Option { - let host = target.gen_host()?; - Some(host.to_str().map(ToOwned::to_owned).unwrap_or_default()) + match self { + Self::Remote(rt) => { + callee.insert(rt.scheme); + match rt.addr { + RemoteTargetAddress::Ip(ip) => { + let sa = SocketAddr::new(ip, rt.port); + tracing::trace!("[Volo-HTTP] Target::apply: set target to {sa}"); + callee.set_address(Address::Ip(sa)); + } + RemoteTargetAddress::Name(host) => { + let port = rt.port; + tracing::trace!("[Volo-HTTP] Target::apply: set target to {host}:{port}"); + callee.set_service_name(host); + callee.insert(Port(port)); + } + } + } + #[cfg(target_family = "unix")] + Self::Local(uds) => { + callee.set_address(Address::Unix(uds)); + } + Self::None => { + unreachable!() + } } - // ipv4 with default http port - let target = Target::from_address(Address::Ip(SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - 80, - ))); - assert_eq!(gen_host_to_string(&target).as_deref(), Some("127.0.0.1")); - // ipv4 with non-default http port - let target = Target::from_address(Address::Ip(SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - 443, - ))); - assert_eq!( - gen_host_to_string(&target).as_deref(), - Some("127.0.0.1:443") - ); - // ipv4 with default https port - let mut target = Target::from_address(Address::Ip(SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - 443, - ))); - target.set_https(true); - assert_eq!(gen_host_to_string(&target).as_deref(), Some("127.0.0.1")); - // ipv4 with non-default https port - let mut target = Target::from_address(Address::Ip(SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - 80, - ))); - target.set_https(true); - assert_eq!(gen_host_to_string(&target).as_deref(), Some("127.0.0.1:80")); - - // ipv6 with default http port - let target = Target::from_address(Address::Ip(SocketAddr::new( - IpAddr::V6(Ipv6Addr::new(0xff, 0, 0, 0, 0, 0, 0, 1)), - 80, - ))); - assert_eq!(gen_host_to_string(&target).as_deref(), Some("ff::1")); - // ipv6 with non-default http port - let target = Target::from_address(Address::Ip(SocketAddr::new( - IpAddr::V6(Ipv6Addr::new(0xff, 0, 0, 0, 0, 0, 0, 1)), - 443, - ))); - assert_eq!(gen_host_to_string(&target).as_deref(), Some("[ff::1]:443")); - // ipv6 with default https port - let mut target = Target::from_address(Address::Ip(SocketAddr::new( - IpAddr::V6(Ipv6Addr::new(0xff, 0, 0, 0, 0, 0, 0, 1)), - 443, - ))); - target.set_https(true); - assert_eq!(gen_host_to_string(&target).as_deref(), Some("ff::1")); - // ipv6 with non-default https port - let mut target = Target::from_address(Address::Ip(SocketAddr::new( - IpAddr::V6(Ipv6Addr::new(0xff, 0, 0, 0, 0, 0, 0, 1)), - 80, - ))); - target.set_https(true); - assert_eq!(gen_host_to_string(&target).as_deref(), Some("[ff::1]:80")); + Ok(()) } } diff --git a/volo-http/src/client/test_helpers.rs b/volo-http/src/client/test_helpers.rs index 74e19856..b4ec7710 100644 --- a/volo-http/src/client/test_helpers.rs +++ b/volo-http/src/client/test_helpers.rs @@ -3,17 +3,20 @@ use std::sync::Arc; use faststr::FastStr; -use http::{header, header::HeaderValue, status::StatusCode}; +use http::status::StatusCode; use motore::{ layer::{Identity, Layer}, service::{BoxService, Service}, }; -use volo::{client::MkClient, context::Endpoint}; +use volo::client::MkClient; -use super::{callopt::CallOpt, Client, ClientBuilder, ClientInner, Target, PKG_NAME_WITH_VER}; +use super::{Client, ClientBuilder, ClientInner, Target}; use crate::{ - context::client::ClientContext, error::ClientError, request::ClientRequest, - response::ClientResponse, utils::test_helpers::mock_address, + context::client::ClientContext, + error::client::{ClientError, Result}, + request::ClientRequest, + response::ClientResponse, + utils::test_helpers::mock_address, }; /// Default mock service of [`Client`] @@ -110,7 +113,7 @@ impl Service for MockTransport { impl ClientBuilder { /// Build a mock HTTP client with a [`MockTransport`] service. - pub fn mock(mut self, transport: MockTransport) -> C::Target + pub fn mock(self, transport: MockTransport) -> Result where IL: Layer, IL::Service: Send + Sync + 'static, @@ -119,42 +122,26 @@ impl ClientBuilder { OL::Service: Send + Sync + 'static, C: MkClient>, { + self.status?; + let meta_service = transport; let service = self.outer_layer.layer(self.inner_layer.layer(meta_service)); - let caller_name = if self.caller_name.is_empty() { - FastStr::from_static_str(PKG_NAME_WITH_VER) - } else { - self.caller_name - }; - if !caller_name.is_empty() && self.headers.get(header::USER_AGENT).is_none() { - self.headers.insert( - header::USER_AGENT, - HeaderValue::from_str(caller_name.as_str()).expect("Invalid caller name"), - ); - } - let client_inner = ClientInner { service, - caller_name, - callee_name: self.callee_name, // set a default target so that we can create a request without authority - default_target: Target::from_address(mock_address()), - default_call_opt: self.call_opt, - // do nothing - target_parser: parse_target, + target: Target::from(mock_address()), + timeout: self.timeout, + default_callee_name: FastStr::empty(), headers: self.headers, }; let client = Client { inner: Arc::new(client_inner), }; - self.mk_client.mk_client(client) + Ok(self.mk_client.mk_client(client)) } } -// do nothing -fn parse_target(_: Target, _: Option<&CallOpt>, _: &mut Endpoint) {} - #[allow(unused)] fn client_types_check() { struct TestLayer; @@ -186,23 +173,28 @@ fn client_types_check() { } } - let _: DefaultMockClient = ClientBuilder::new().mock(Default::default()); + let _: DefaultMockClient = ClientBuilder::new().mock(Default::default()).unwrap(); let _: DefaultMockClient = ClientBuilder::new() .layer_inner(TestLayer) - .mock(Default::default()); + .mock(Default::default()) + .unwrap(); let _: DefaultMockClient = ClientBuilder::new() .layer_inner_front(TestLayer) - .mock(Default::default()); + .mock(Default::default()) + .unwrap(); let _: DefaultMockClient = ClientBuilder::new() .layer_outer(TestLayer) - .mock(Default::default()); + .mock(Default::default()) + .unwrap(); let _: DefaultMockClient = ClientBuilder::new() .layer_outer_front(TestLayer) - .mock(Default::default()); + .mock(Default::default()) + .unwrap(); let _: DefaultMockClient = ClientBuilder::new() .layer_inner(TestLayer) .layer_outer(TestLayer) - .mock(Default::default()); + .mock(Default::default()) + .unwrap(); } mod mock_transport_tests { @@ -213,7 +205,7 @@ mod mock_transport_tests { #[tokio::test] async fn empty_response_test() { - let client = ClientBuilder::new().mock(MockTransport::default()); + let client = ClientBuilder::new().mock(MockTransport::default()).unwrap(); let resp = client.get("/").send().await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert!(resp.headers().is_empty()); diff --git a/volo-http/src/client/transport.rs b/volo-http/src/client/transport.rs index a111ba58..ab36e172 100644 --- a/volo-http/src/client/transport.rs +++ b/volo-http/src/client/transport.rs @@ -17,15 +17,6 @@ use crate::{ response::ClientResponse, }; -/// TLS transport tag with no content -/// -/// When implementing a service discover and TLS should be enable, just inserting it to the callee. -/// -/// The struct is used for advanced users. -#[cfg(feature = "__tls")] -#[cfg_attr(docsrs, doc(cfg(any(feature = "rustls", feature = "native-tls"))))] -pub struct TlsTransport; - #[derive(Clone)] pub struct ClientTransport { client: http1::Builder, @@ -60,18 +51,23 @@ impl ClientTransport { } async fn connect_to(&self, address: Address) -> Result { - self.mk_conn.make_connection(address).await.map_err(|err| { - tracing::error!("[Volo-HTTP] failed to make connection, error: {err}"); - request_error(err) - }) + self.mk_conn + .make_connection(address.clone()) + .await + .map_err(|err| { + tracing::error!("[Volo-HTTP] failed to make connection, error: {err}"); + request_error(err).with_address(address) + }) } #[cfg(feature = "__tls")] async fn make_connection(&self, cx: &ClientContext) -> Result { + use http::uri::Scheme; + use crate::error::client::bad_scheme; let callee = cx.rpc_info().callee(); - let https = callee.contains::(); + let https = callee.get::() == Some(&Scheme::HTTPS); if self.config.disable_tls && https { // TLS is disabled but the request still use TLS @@ -124,12 +120,12 @@ impl ClientTransport { let io = TokioIo::new(conn); let (mut sender, conn) = self.client.handshake(io).await.map_err(|err| { tracing::error!("[Volo-HTTP] failed to handshake, error: {err}"); - request_error(err) + request_error(err).with_endpoint(cx.rpc_info().callee()) })?; tokio::spawn(conn); let resp = sender.send_request(req).await.map_err(|err| { tracing::error!("[Volo-HTTP] failed to send request, error: {err}"); - request_error(err) + request_error(err).with_endpoint(cx.rpc_info().callee()) })?; Ok(resp.map(crate::body::Body::from_incoming)) } diff --git a/volo-http/src/context/client.rs b/volo-http/src/context/client.rs index 4a095b58..043b8526 100644 --- a/volo-http/src/context/client.rs +++ b/volo-http/src/context/client.rs @@ -1,5 +1,7 @@ //! Context and its utilities of client +use std::time::Duration; + use chrono::{DateTime, Local}; use volo::{ context::{Reusable, Role, RpcCx, RpcInfo}, @@ -59,8 +61,33 @@ impl ClientStats { /// Configuration of the request #[derive(Clone, Debug, Default)] -pub struct Config; +pub struct Config { + /// Timeout of the current request + pub timeout: Option, +} + +impl Config { + /// Create a default [`Config`] + #[inline] + pub fn new() -> Self { + Default::default() + } + + /// Get current timeout of the request + #[inline] + pub fn timeout(&self) -> Option<&Duration> { + self.timeout.as_ref() + } + + /// Set timeout to the request + #[inline] + pub fn set_timeout(&mut self, timeout: Option) { + self.timeout = timeout; + } +} impl Reusable for Config { - fn clear(&mut self) {} + fn clear(&mut self) { + self.timeout = None; + } } diff --git a/volo-http/src/error/client.rs b/volo-http/src/error/client.rs index f8cda040..969dfac1 100644 --- a/volo-http/src/error/client.rs +++ b/volo-http/src/error/client.rs @@ -1,9 +1,10 @@ //! Generic error types for client -use std::{error::Error, fmt}; +use std::{error::Error, fmt, net::SocketAddr}; use http::uri::Uri; use paste::paste; +use volo::{context::Endpoint, net::Address}; use super::BoxError; use crate::body::BodyConvertError; @@ -16,7 +17,8 @@ pub type Result = std::result::Result; pub struct ClientError { kind: ErrorKind, source: Option, - url: Option, + uri: Option, + addr: Option, } impl ClientError { @@ -28,51 +30,95 @@ impl ClientError { Self { kind, source: error.map(Into::into), - url: None, + uri: None, + addr: None, } } - /// Set a [`Uri`] to the [`ClientError`], it can be displayed when printing - pub fn with_url(self, url: Uri) -> Self { - Self { - kind: self.kind, - source: self.source, - url: Some(url), + /// Set a [`Uri`] to the [`ClientError`]. + #[inline] + pub fn set_url(&mut self, uri: Uri) { + self.uri = Some(uri); + } + + /// Set a [`SocketAddr`] to the [`ClientError`]. + #[inline] + pub fn set_addr(&mut self, addr: SocketAddr) { + self.addr = Some(addr); + } + + /// Consume current [`ClientError`] and return a new one with given [`Uri`]. + #[inline] + pub fn with_url(mut self, uri: Uri) -> Self { + self.uri = Some(uri); + self + } + + /// Remove [`Uri`] from the [`ClientError`]. + #[inline] + pub fn without_url(mut self) -> Self { + self.uri = None; + self + } + + /// Consume current [`ClientError`] and return a new one with given [`SocketAddr`]. + #[inline] + pub fn with_addr(mut self, addr: SocketAddr) -> Self { + self.addr = Some(addr); + self + } + + /// Consume current [`ClientError`] and return a new one with [`SocketAddr`] from the + /// [`Address`] if exists. + #[inline] + pub fn with_address(mut self, address: Address) -> Self { + match address { + Address::Ip(addr) => self.addr = Some(addr), + #[cfg(target_family = "unix")] + Address::Unix(_) => {} } + self } - /// Remote the [`Uri`] from the [`ClientError`] - pub fn without_url(self) -> Self { - Self { - kind: self.kind, - source: self.source, - url: None, + /// Consume current [`ClientError`] and return a new one with [`SocketAddr`] from the + /// [`Address`] if exists. + #[inline] + pub fn with_endpoint(mut self, ep: &Endpoint) -> Self { + if let Some(Address::Ip(addr)) = &ep.address { + self.addr = Some(*addr); } + self } /// Get a reference to the [`ErrorKind`] + #[inline] pub fn kind(&self) -> &ErrorKind { &self.kind } /// Get a reference to the [`Uri`] if it exists - pub fn url(&self) -> Option<&Uri> { - self.url.as_ref() + #[inline] + pub fn uri(&self) -> Option<&Uri> { + self.uri.as_ref() } - /// Get a mutable reference to the [`Uri`] if it exists - pub fn url_mut(&mut self) -> Option<&mut Uri> { - self.url.as_mut() + /// Get a reference to the [`SocketAddr`] if it exists + #[inline] + pub fn addr(&self) -> Option<&SocketAddr> { + self.addr.as_ref() } } impl fmt::Display for ClientError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.kind)?; - if let Some(ref url) = self.url { - write!(f, "for url `{url}`")?; + if let Some(addr) = &self.addr { + write!(f, " to addr {addr}")?; + } + if let Some(uri) = &self.uri { + write!(f, " for uri `{uri}`")?; } - if let Some(ref source) = self.source { + if let Some(source) = &self.source { write!(f, ": {source}")?; } Ok(()) @@ -182,6 +228,7 @@ macro_rules! simple_error { simple_error!(Builder => NoAddress => "missing target address"); simple_error!(Builder => BadScheme => "bad scheme"); simple_error!(Builder => BadHostName => "bad host name"); +simple_error!(Builder => BadAddress => "bad address"); simple_error!(Request => Timeout => "request timeout"); simple_error!(LoadBalance => NoAvailableEndpoint => "no available endpoint"); diff --git a/volo-http/src/request.rs b/volo-http/src/request.rs index 375f9b8b..63ced196 100644 --- a/volo-http/src/request.rs +++ b/volo-http/src/request.rs @@ -1,11 +1,13 @@ //! Request types and utils. +use std::str::FromStr; + use http::{ header::{self, HeaderMap, HeaderName}, request::{Parts, Request}, - uri::Scheme, - Uri, + uri::{Scheme, Uri}, }; +use url::Url; /// [`Request`] with [`Body`] as default body. /// @@ -38,7 +40,7 @@ pub trait RequestPartsExt: sealed::SealedRequestPartsExt { mod sealed { pub trait SealedRequestPartsExt { fn headers(&self) -> &http::header::HeaderMap; - fn uri(&self) -> &http::Uri; + fn uri(&self) -> &http::uri::Uri; fn extensions(&self) -> &http::Extensions; } } @@ -54,6 +56,7 @@ impl sealed::SealedRequestPartsExt for Parts { &self.extensions } } + impl sealed::SealedRequestPartsExt for Request { fn headers(&self) -> &HeaderMap { self.headers() @@ -74,22 +77,11 @@ where simdutf8::basic::from_utf8(self.headers().get(header::HOST)?.as_bytes()).ok() } - fn url(&self) -> Option { + fn url(&self) -> Option { + let scheme = self.extensions().get::().unwrap_or(&Scheme::HTTP); let host = self.host()?; - let uri = self.uri(); - - let mut url_str = String::new(); - - if let Some(scheme) = self.extensions().get::() { - url_str.push_str(scheme.as_str()); - url_str.push_str("://"); - } else { - url_str.push_str("http://"); - } - - url_str.push_str(host); - url_str.push_str(uri.path()); + let path = self.uri().path(); - url::Url::parse(url_str.as_str()).ok() + Url::from_str(&format!("{scheme}://{host}{path}")).ok() } } diff --git a/volo-http/src/utils/test_helpers.rs b/volo-http/src/utils/test_helpers.rs index 81ce6d64..3063376b 100644 --- a/volo-http/src/utils/test_helpers.rs +++ b/volo-http/src/utils/test_helpers.rs @@ -95,10 +95,8 @@ mod convert_service { #[cfg(feature = "__tls")] fn new_server_config(client_cx: &ClientContext) -> crate::context::server::Config { let mut config = crate::context::server::Config::default(); - if client_cx - .rpc_info() - .callee() - .contains::() + if client_cx.rpc_info().callee().get::() + == Some(&http::uri::Scheme::HTTPS) { config.set_tls(true); } @@ -138,7 +136,9 @@ mod helper_tests { #[tokio::test] async fn client_call_router() { let router: Router = Router::new().route("/get", get(|| async { HELLO_WORLD })); - let client = ClientBuilder::new().mock(MockTransport::server_service(router)); + let client = ClientBuilder::new() + .mock(MockTransport::server_service(router)) + .unwrap(); { let ret = client .get("/get") diff --git a/volo/src/net/mod.rs b/volo/src/net/mod.rs index 93b85f6f..495443cf 100644 --- a/volo/src/net/mod.rs +++ b/volo/src/net/mod.rs @@ -29,6 +29,24 @@ pub enum Address { Unix(StdUnixSocketAddr), } +impl Address { + pub fn ip_addr(&self) -> Option<&SocketAddr> { + match self { + Self::Ip(ip) => Some(ip), + #[cfg(target_family = "unix")] + Self::Unix(_) => None, + } + } + + #[cfg(target_family = "unix")] + pub fn unix_addr(&self) -> Option<&StdUnixSocketAddr> { + match self { + Self::Ip(_) => None, + Self::Unix(unix) => Some(unix), + } + } +} + impl PartialEq for Address { fn eq(&self, other: &Self) -> bool { match (self, other) {