Skip to content

Commit f455436

Browse files
committed
feat(channel): Make channel feature additive
1 parent 9e9bc59 commit f455436

File tree

20 files changed

+254
-207
lines changed

20 files changed

+254
-207
lines changed

examples/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,13 @@ hyper-warp-multiplex = ["hyper-warp"]
298298
uds = ["tokio-stream/net", "dep:tower", "dep:hyper"]
299299
streaming = ["tokio-stream", "dep:h2"]
300300
mock = ["tokio-stream", "dep:tower"]
301-
tower = ["dep:hyper", "dep:tower", "dep:http"]
301+
tower = ["dep:hyper", "tower/timeout", "dep:http"]
302302
json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"]
303303
compression = ["tonic/gzip"]
304304
tls = ["tonic/tls"]
305305
tls-rustls = ["dep:hyper", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"]
306306
dynamic-load-balance = ["dep:tower"]
307-
timeout = ["tokio/time", "dep:tower"]
307+
timeout = ["tokio/time", "tower/timeout"]
308308
tls-client-auth = ["tonic/tls"]
309309
types = ["dep:tonic-types"]
310310
h2c = ["dep:hyper", "dep:tower", "dep:http"]

tonic/Cargo.toml

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,26 @@ version = "0.11.0"
2626
codegen = ["dep:async-trait"]
2727
gzip = ["dep:flate2"]
2828
zstd = ["dep:zstd"]
29-
default = ["transport", "codegen", "prost"]
29+
default = ["channel", "codegen", "prost"]
3030
prost = ["dep:prost"]
3131
tls = ["dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
3232
tls-roots = ["tls-roots-common", "dep:rustls-native-certs"]
33-
tls-roots-common = ["tls"]
33+
tls-roots-common = ["tls", "channel"]
3434
tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"]
3535
transport = [
3636
"dep:async-stream",
3737
"dep:axum",
38-
"channel",
3938
"dep:h2",
40-
"dep:hyper",
39+
"dep:hyper", "hyper?/server",
4140
"dep:tokio", "tokio?/net", "tokio?/time",
42-
"dep:tower",
41+
"dep:tower", "tower?/util", "tower?/limit",
42+
]
43+
channel = [
44+
"transport",
45+
"dep:hyper", "hyper?/client",
46+
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make",
4347
"dep:hyper-timeout",
4448
]
45-
channel = []
4649

4750
# [[bench]]
4851
# name = "bench_main"
@@ -69,13 +72,15 @@ async-trait = {version = "0.1.13", optional = true}
6972
# transport
7073
async-stream = {version = "0.3", optional = true}
7174
h2 = {version = "0.3.24", optional = true}
72-
hyper = {version = "0.14.26", features = ["full"], optional = true}
73-
hyper-timeout = {version = "0.4", optional = true}
75+
hyper = {version = "0.14.26", features = ["http1", "http2", "runtime", "stream"], optional = true}
7476
tokio = {version = "1.0.1", optional = true}
7577
tokio-stream = "0.1"
76-
tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true}
78+
tower = {version = "0.4.7", default-features = false, optional = true}
7779
axum = {version = "0.6.9", default-features = false, optional = true}
7880

81+
# channel
82+
hyper-timeout = {version = "0.4", optional = true}
83+
7984
# rustls
8085
rustls-pemfile = { version = "2.0", optional = true }
8186
rustls-native-certs = { version = "0.7", optional = true }

tonic/src/transport/channel/endpoint.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
use super::super::service;
2-
use super::Channel;
3-
#[cfg(feature = "tls")]
4-
use super::ClientTlsConfig;
5-
#[cfg(feature = "tls")]
6-
use crate::transport::service::TlsConnector;
7-
use crate::transport::{service::SharedExec, Error, Executor};
1+
use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration};
2+
83
use bytes::Bytes;
94
use http::{uri::Uri, HeaderValue};
10-
use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration};
115
use tower::make::MakeConnection;
12-
// use crate::transport::E
6+
7+
use crate::transport::Error;
8+
9+
use super::service::{Connector, Executor, SharedExec};
10+
use super::Channel;
11+
#[cfg(feature = "tls")]
12+
use super::{service::TlsConnector, ClientTlsConfig};
1313

1414
/// Channel builder.
1515
///
@@ -318,15 +318,15 @@ impl Endpoint {
318318
self
319319
}
320320

321-
pub(crate) fn connector<C>(&self, c: C) -> service::Connector<C> {
321+
pub(crate) fn connector<C>(&self, c: C) -> Connector<C> {
322322
#[cfg(all(feature = "tls", not(feature = "tls-roots-common")))]
323-
let connector = service::Connector::new(c, self.tls.clone());
323+
let connector = Connector::new(c, self.tls.clone());
324324

325325
#[cfg(all(feature = "tls", feature = "tls-roots-common"))]
326-
let connector = service::Connector::new(c, self.tls.clone(), self.tls_assume_http2);
326+
let connector = Connector::new(c, self.tls.clone(), self.tls_assume_http2);
327327

328328
#[cfg(not(feature = "tls"))]
329-
let connector = service::Connector::new(c);
329+
let connector = Connector::new(c);
330330

331331
connector
332332
}

tonic/src/transport/channel/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Client implementation and builder.
22
33
mod endpoint;
4+
pub(crate) mod service;
45
#[cfg(feature = "tls")]
56
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
67
mod tls;
@@ -9,7 +10,7 @@ pub use endpoint::Endpoint;
910
#[cfg(feature = "tls")]
1011
pub use tls::ClientTlsConfig;
1112

12-
use super::service::{Connection, DynamicServiceStream, SharedExec};
13+
use self::service::{Connection, DynamicServiceStream, SharedExec};
1314
use crate::body::BoxBody;
1415
use crate::transport::Executor;
1516
use bytes::Bytes;

tonic/src/transport/service/connection.rs renamed to tonic/src/transport/channel/service/connection.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent};
1+
use super::{reconnect::Reconnect, AddOrigin, UserAgent};
2+
use crate::transport::service::GrpcTimeout;
23
use crate::{
34
body::BoxBody,
45
transport::{BoxFuture, Endpoint},

tonic/src/transport/service/connector.rs renamed to tonic/src/transport/channel/service/connector.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
use super::super::BoxFuture;
2-
use super::io::BoxedIo;
3-
#[cfg(feature = "tls")]
4-
use super::tls::TlsConnector;
51
use http::Uri;
62
#[cfg(feature = "tls")]
73
use std::fmt;
84
use std::task::{Context, Poll};
95
use tower::make::MakeConnection;
106
use tower_service::Service;
117

8+
use super::io::BoxedIo;
9+
#[cfg(feature = "tls")]
10+
use super::tls::TlsConnector;
11+
use crate::transport::BoxFuture;
12+
1213
pub(crate) struct Connector<C> {
1314
inner: C,
1415
#[cfg(feature = "tls")]
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use std::io::{self, IoSlice};
2+
use std::pin::Pin;
3+
use std::task::{Context, Poll};
4+
5+
use hyper::client::connect::{Connected as HyperConnected, Connection};
6+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7+
8+
pub(super) trait Io: AsyncRead + AsyncWrite + Send + 'static {}
9+
10+
impl<T> Io for T where T: AsyncRead + AsyncWrite + Send + 'static {}
11+
12+
pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);
13+
14+
impl BoxedIo {
15+
pub(super) fn new<I: Io>(io: I) -> Self {
16+
BoxedIo(Box::pin(io))
17+
}
18+
}
19+
20+
impl Connection for BoxedIo {
21+
fn connected(&self) -> HyperConnected {
22+
HyperConnected::new()
23+
}
24+
}
25+
26+
impl AsyncRead for BoxedIo {
27+
fn poll_read(
28+
mut self: Pin<&mut Self>,
29+
cx: &mut Context<'_>,
30+
buf: &mut ReadBuf<'_>,
31+
) -> Poll<io::Result<()>> {
32+
Pin::new(&mut self.0).poll_read(cx, buf)
33+
}
34+
}
35+
36+
impl AsyncWrite for BoxedIo {
37+
fn poll_write(
38+
mut self: Pin<&mut Self>,
39+
cx: &mut Context<'_>,
40+
buf: &[u8],
41+
) -> Poll<io::Result<usize>> {
42+
Pin::new(&mut self.0).poll_write(cx, buf)
43+
}
44+
45+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46+
Pin::new(&mut self.0).poll_flush(cx)
47+
}
48+
49+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
50+
Pin::new(&mut self.0).poll_shutdown(cx)
51+
}
52+
53+
fn poll_write_vectored(
54+
mut self: Pin<&mut Self>,
55+
cx: &mut Context<'_>,
56+
bufs: &[IoSlice<'_>],
57+
) -> Poll<Result<usize, io::Error>> {
58+
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
59+
}
60+
61+
fn is_write_vectored(&self) -> bool {
62+
self.0.is_write_vectored()
63+
}
64+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
mod add_origin;
2+
pub(crate) use self::add_origin::AddOrigin;
3+
4+
mod connector;
5+
pub(crate) use self::connector::Connector;
6+
7+
mod connection;
8+
pub(crate) use self::connection::Connection;
9+
10+
mod discover;
11+
pub(crate) use self::discover::DynamicServiceStream;
12+
13+
pub(crate) mod executor;
14+
pub(crate) use self::executor::{Executor, SharedExec};
15+
16+
pub(crate) mod io;
17+
18+
mod reconnect;
19+
20+
mod user_agent;
21+
pub(crate) use self::user_agent::UserAgent;
22+
23+
#[cfg(feature = "tls")]
24+
mod tls;
25+
#[cfg(feature = "tls")]
26+
pub(crate) use self::tls::TlsConnector;
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
use std::fmt;
2+
use std::io::Cursor;
3+
use std::sync::Arc;
4+
5+
use tokio::io::{AsyncRead, AsyncWrite};
6+
use tokio_rustls::{
7+
rustls::{pki_types::ServerName, ClientConfig, RootCertStore},
8+
TlsConnector as RustlsConnector,
9+
};
10+
11+
use super::io::BoxedIo;
12+
use crate::transport::service::tls::{add_certs_from_pem, load_identity, ALPN_H2};
13+
use crate::transport::tls::{Certificate, Identity};
14+
15+
#[derive(Debug)]
16+
enum TlsError {
17+
H2NotNegotiated,
18+
}
19+
20+
impl fmt::Display for TlsError {
21+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22+
match self {
23+
TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
24+
}
25+
}
26+
}
27+
28+
impl std::error::Error for TlsError {}
29+
30+
#[derive(Clone)]
31+
pub(crate) struct TlsConnector {
32+
config: Arc<ClientConfig>,
33+
domain: Arc<ServerName<'static>>,
34+
assume_http2: bool,
35+
}
36+
37+
impl TlsConnector {
38+
pub(crate) fn new(
39+
ca_cert: Option<Certificate>,
40+
identity: Option<Identity>,
41+
domain: &str,
42+
assume_http2: bool,
43+
) -> Result<Self, crate::Error> {
44+
let builder = ClientConfig::builder();
45+
let mut roots = RootCertStore::empty();
46+
47+
#[cfg(feature = "tls-roots")]
48+
roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
49+
50+
#[cfg(feature = "tls-webpki-roots")]
51+
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
52+
53+
if let Some(cert) = ca_cert {
54+
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
55+
}
56+
57+
let builder = builder.with_root_certificates(roots);
58+
let mut config = match identity {
59+
Some(identity) => {
60+
let (client_cert, client_key) = load_identity(identity)?;
61+
builder.with_client_auth_cert(client_cert, client_key)?
62+
}
63+
None => builder.with_no_client_auth(),
64+
};
65+
66+
config.alpn_protocols.push(ALPN_H2.into());
67+
Ok(Self {
68+
config: Arc::new(config),
69+
domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
70+
assume_http2,
71+
})
72+
}
73+
74+
pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
75+
where
76+
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
77+
{
78+
let io = RustlsConnector::from(self.config.clone())
79+
.connect(self.domain.as_ref().to_owned(), io)
80+
.await?;
81+
82+
// Generally we require ALPN to be negotiated, but if the user has
83+
// explicitly set `assume_http2` to true, we'll allow it to be missing.
84+
let (_, session) = io.get_ref();
85+
let alpn_protocol = session.alpn_protocol();
86+
if alpn_protocol != Some(ALPN_H2) {
87+
if alpn_protocol.is_some() || !self.assume_http2 {
88+
return Err(TlsError::H2NotNegotiated.into());
89+
}
90+
}
91+
92+
Ok(BoxedIo::new(io))
93+
}
94+
}
95+
96+
impl fmt::Debug for TlsConnector {
97+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98+
f.debug_struct("TlsConnector").finish()
99+
}
100+
}

tonic/src/transport/channel/tls.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
use super::service::TlsConnector;
12
use crate::transport::{
2-
service::TlsConnector,
33
tls::{Certificate, Identity},
44
Error,
55
};

tonic/src/transport/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ struct ErrorImpl {
1515
#[derive(Debug)]
1616
pub(crate) enum Kind {
1717
Transport,
18+
#[cfg(feature = "channel")]
1819
InvalidUri,
20+
#[cfg(feature = "channel")]
1921
InvalidUserAgent,
2022
}
2123

@@ -35,18 +37,22 @@ impl Error {
3537
Error::new(Kind::Transport).with(source)
3638
}
3739

40+
#[cfg(feature = "channel")]
3841
pub(crate) fn new_invalid_uri() -> Self {
3942
Error::new(Kind::InvalidUri)
4043
}
4144

45+
#[cfg(feature = "channel")]
4246
pub(crate) fn new_invalid_user_agent() -> Self {
4347
Error::new(Kind::InvalidUserAgent)
4448
}
4549

4650
fn description(&self) -> &str {
4751
match &self.inner.kind {
4852
Kind::Transport => "transport error",
53+
#[cfg(feature = "channel")]
4954
Kind::InvalidUri => "invalid URI",
55+
#[cfg(feature = "channel")]
5056
Kind::InvalidUserAgent => "user agent is not a valid header value",
5157
}
5258
}

0 commit comments

Comments
 (0)