Skip to content

Commit

Permalink
refactor(transport): Move channel feature to channel module
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Feb 21, 2024
1 parent 4aa7354 commit 4d0d3ea
Show file tree
Hide file tree
Showing 17 changed files with 213 additions and 202 deletions.
15 changes: 8 additions & 7 deletions tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use super::super::service;
use super::service::Connector;
#[cfg(feature = "tls")]
use super::service::TlsConnector;
use super::service::{Executor, SharedExec};
use super::Channel;
#[cfg(feature = "tls")]
use super::ClientTlsConfig;
#[cfg(feature = "tls")]
use crate::transport::service::TlsConnector;
use crate::transport::{service::SharedExec, Error, Executor};
use crate::transport::Error;
use bytes::Bytes;
use http::{uri::Uri, HeaderValue};
use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration};
Expand Down Expand Up @@ -301,12 +302,12 @@ impl Endpoint {
self
}

pub(crate) fn connector<C>(&self, c: C) -> service::Connector<C> {
pub(crate) fn connector<C>(&self, c: C) -> Connector<C> {
#[cfg(feature = "tls")]
let connector = service::Connector::new(c, self.tls.clone());
let connector = Connector::new(c, self.tls.clone());

#[cfg(not(feature = "tls"))]
let connector = service::Connector::new(c);
let connector = Connector::new(c);

connector
}
Expand Down
3 changes: 2 additions & 1 deletion tonic/src/transport/channel/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Client implementation and builder.
mod endpoint;
pub(crate) mod service;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
Expand All @@ -9,7 +10,7 @@ pub use endpoint::Endpoint;
#[cfg(feature = "tls")]
pub use tls::ClientTlsConfig;

use super::service::{Connection, DynamicServiceStream, SharedExec};
use self::service::{Connection, DynamicServiceStream, SharedExec};
use crate::body::BoxBody;
use crate::transport::Executor;
use bytes::Bytes;
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent};
use super::{reconnect::Reconnect, AddOrigin, UserAgent};
use crate::transport::service::GrpcTimeout;
use crate::{
body::BoxBody,
transport::{BoxFuture, Endpoint},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use super::super::BoxFuture;
use super::io::BoxedIo;
#[cfg(feature = "tls")]
use super::tls::TlsConnector;
use http::Uri;
use std::fmt;
use std::task::{Context, Poll};
use tower::make::MakeConnection;
use tower_service::Service;

use super::io::BoxedIo;
#[cfg(feature = "tls")]
use super::tls::TlsConnector;
use crate::transport::BoxFuture;

pub(crate) struct Connector<C> {
inner: C,
#[cfg(feature = "tls")]
Expand Down
File renamed without changes.
File renamed without changes.
69 changes: 69 additions & 0 deletions tonic/src/transport/channel/service/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::io::{self, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};

use hyper::client::connect::{Connected as HyperConnected, Connection};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

pub(in crate::transport) trait Io:
AsyncRead + AsyncWrite + Send + 'static
{
}

impl<T> Io for T where T: AsyncRead + AsyncWrite + Send + 'static {}

pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);

impl BoxedIo {
pub(in crate::transport) fn new<I: Io>(io: I) -> Self {
BoxedIo(Box::pin(io))
}
}

impl Connection for BoxedIo {
fn connected(&self) -> HyperConnected {
HyperConnected::new()
}
}

#[cfg(feature = "channel")]
impl AsyncRead for BoxedIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

#[cfg(feature = "channel")]
impl AsyncWrite for BoxedIo {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.0.is_write_vectored()
}
}
26 changes: 26 additions & 0 deletions tonic/src/transport/channel/service/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
mod add_origin;
pub(crate) use self::add_origin::AddOrigin;

mod connector;
pub(crate) use self::connector::Connector;

mod connection;
pub(crate) use self::connection::Connection;

mod discover;
pub(crate) use self::discover::DynamicServiceStream;

pub(crate) mod executor;
pub(crate) use self::executor::{Executor, SharedExec};

pub(crate) mod io;

mod reconnect;

mod user_agent;
pub(crate) use self::user_agent::UserAgent;

#[cfg(feature = "tls")]
mod tls;
#[cfg(feature = "tls")]
pub(crate) use self::tls::TlsConnector;
File renamed without changes.
92 changes: 92 additions & 0 deletions tonic/src/transport/channel/service/tls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use std::fmt;
use std::io::Cursor;
use std::sync::Arc;

use rustls_pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::rustls::RootCertStore;
use tokio_rustls::{rustls::ClientConfig, TlsConnector as RustlsConnector};

use super::io::BoxedIo;
use crate::transport::service::tls::{add_certs_from_pem, load_identity, ALPN_H2};
use crate::transport::tls::{Certificate, Identity};

#[derive(Debug)]
enum TlsError {
H2NotNegotiated,
}

impl fmt::Display for TlsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
}
}
}

impl std::error::Error for TlsError {}

#[derive(Clone)]
pub(crate) struct TlsConnector {
config: Arc<ClientConfig>,
domain: Arc<ServerName<'static>>,
}

impl TlsConnector {
pub(crate) fn new(
ca_cert: Option<Certificate>,
identity: Option<Identity>,
domain: &str,
) -> Result<Self, crate::Error> {
let builder = ClientConfig::builder();
let mut roots = RootCertStore::empty();

#[cfg(feature = "tls-roots")]
roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);

#[cfg(feature = "tls-webpki-roots")]
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());

if let Some(cert) = ca_cert {
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
}

let builder = builder.with_root_certificates(roots);
let mut config = match identity {
Some(identity) => {
let (client_cert, client_key) = load_identity(identity)?;
builder.with_client_auth_cert(client_cert, client_key)?
}
None => builder.with_no_client_auth(),
};

config.alpn_protocols.push(ALPN_H2.into());
Ok(Self {
config: Arc::new(config),
domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
})
}

pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let io = RustlsConnector::from(self.config.clone())
.connect(self.domain.as_ref().to_owned(), io)
.await?;

let (_, session) = io.get_ref();
if session.alpn_protocol() != Some(ALPN_H2) {
return Err(TlsError::H2NotNegotiated)?;
}

Ok(BoxedIo::new(io))
}
}

#[cfg(feature = "channel")]
impl fmt::Debug for TlsConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnector").finish()
}
}
File renamed without changes.
2 changes: 1 addition & 1 deletion tonic/src/transport/channel/tls.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::service::TlsConnector;
use crate::transport::{
service::TlsConnector,
tls::{Certificate, Identity},
Error,
};
Expand Down
2 changes: 1 addition & 1 deletion tonic/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter};
pub use hyper::{Body, Uri};

#[cfg(feature = "channel")]
pub(crate) use self::service::executor::Executor;
pub(crate) use self::channel::service::executor::Executor;

#[cfg(all(feature = "channel", feature = "tls"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "channel", feature = "tls"))))]
Expand Down
81 changes: 0 additions & 81 deletions tonic/src/transport/service/io.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use crate::transport::server::Connected;
#[cfg(feature = "channel")]
use hyper::client::connect::{Connected as HyperConnected, Connection};
use std::io;
use std::io::IoSlice;
use std::pin::Pin;
Expand All @@ -9,85 +7,6 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[cfg(feature = "tls")]
use tokio_rustls::server::TlsStream;

pub(in crate::transport) trait Io:
AsyncRead + AsyncWrite + Send + 'static
{
}

impl<T> Io for T where T: AsyncRead + AsyncWrite + Send + 'static {}

#[cfg(feature = "channel")]
pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);

#[cfg(feature = "channel")]
impl BoxedIo {
pub(in crate::transport) fn new<I: Io>(io: I) -> Self {
BoxedIo(Box::pin(io))
}
}

#[cfg(feature = "channel")]
impl Connection for BoxedIo {
fn connected(&self) -> HyperConnected {
HyperConnected::new()
}
}

#[cfg(feature = "channel")]
impl Connected for BoxedIo {
type ConnectInfo = NoneConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
NoneConnectInfo
}
}

#[cfg(feature = "channel")]
#[derive(Copy, Clone)]
pub(crate) struct NoneConnectInfo;

#[cfg(feature = "channel")]
impl AsyncRead for BoxedIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

#[cfg(feature = "channel")]
impl AsyncWrite for BoxedIo {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.0.is_write_vectored()
}
}

pub(crate) enum ServerIo<IO> {
Io(IO),
#[cfg(feature = "tls")]
Expand Down
Loading

0 comments on commit 4d0d3ea

Please sign in to comment.