Skip to content

Commit ec73a0a

Browse files
authored
Merge pull request #14 from joshtriplett/custom-tls-acceptor
Add support for custom TLS acceptors
2 parents 635ecf0 + 68594fd commit ec73a0a

File tree

5 files changed

+83
-16
lines changed

5 files changed

+83
-16
lines changed

src/custom_tls_acceptor.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use async_rustls::server::TlsStream;
2+
use async_std::net::TcpStream;
3+
4+
/// The CustomTlsAcceptor trait provides a custom implementation of accepting
5+
/// TLS connections from a [`TcpStream`]. tide-rustls will call the
6+
/// [`CustomTlsAcceptor::accept`] function for each new [`TcpStream`] it
7+
/// accepts, to obtain a [`TlsStream`]).
8+
///
9+
/// Implementing this trait gives you control over the TLS negotiation process,
10+
/// and allows you to process some TLS connections internally without passing
11+
/// them through to tide, such as for multiplexing or custom ALPN negotiation.
12+
#[tide::utils::async_trait]
13+
pub trait CustomTlsAcceptor: Send + Sync {
14+
/// Accept a [`TlsStream`] from a [`TcpStream`].
15+
///
16+
/// If TLS negotiation succeeds, but does not result in a stream that tide
17+
/// should process HTTP connections from, return `Ok(None)`.
18+
async fn accept(&self, stream: TcpStream) -> std::io::Result<Option<TlsStream<TcpStream>>>;
19+
}
20+
21+
/// Crate-private adapter to make `async_rustls::TlsAcceptor` implement
22+
/// `CustomTlsAcceptor`, without creating a conflict between the two `accept`
23+
/// methods.
24+
pub(crate) struct StandardTlsAcceptor(pub(crate) async_rustls::TlsAcceptor);
25+
26+
#[tide::utils::async_trait]
27+
impl CustomTlsAcceptor for StandardTlsAcceptor {
28+
async fn accept(&self, stream: TcpStream) -> std::io::Result<Option<TlsStream<TcpStream>>> {
29+
self.0.accept(stream).await.map(Some)
30+
}
31+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
unused_qualifications
2929
)]
3030

31+
mod custom_tls_acceptor;
3132
mod tcp_connection;
3233
mod tls_listener;
3334
mod tls_listener_builder;
@@ -38,6 +39,7 @@ pub(crate) use tcp_connection::TcpConnection;
3839
pub(crate) use tls_listener_config::TlsListenerConfig;
3940
pub(crate) use tls_stream_wrapper::TlsStreamWrapper;
4041

42+
pub use custom_tls_acceptor::CustomTlsAcceptor;
4143
pub use tls_listener::TlsListener;
4244
pub use tls_listener_builder::TlsListenerBuilder;
4345

src/tls_listener.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::{TcpConnection, TlsListenerBuilder, TlsListenerConfig, TlsStreamWrapper};
1+
use crate::custom_tls_acceptor::StandardTlsAcceptor;
2+
use crate::{
3+
CustomTlsAcceptor, TcpConnection, TlsListenerBuilder, TlsListenerConfig, TlsStreamWrapper,
4+
};
25

36
use tide::listener::ListenInfo;
47
use tide::listener::{Listener, ToListener};
@@ -79,12 +82,14 @@ impl<State> TlsListener<State> {
7982
.set_single_cert(certs, keys.remove(0))
8083
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
8184

82-
TlsListenerConfig::Acceptor(TlsAcceptor::from(Arc::new(config)))
85+
TlsListenerConfig::Acceptor(Arc::new(StandardTlsAcceptor(TlsAcceptor::from(
86+
Arc::new(config),
87+
))))
8388
}
8489

85-
TlsListenerConfig::ServerConfig(config) => {
86-
TlsListenerConfig::Acceptor(TlsAcceptor::from(Arc::new(config)))
87-
}
90+
TlsListenerConfig::ServerConfig(config) => TlsListenerConfig::Acceptor(Arc::new(
91+
StandardTlsAcceptor(TlsAcceptor::from(Arc::new(config))),
92+
)),
8893

8994
other @ TlsListenerConfig::Acceptor(_) => other,
9095

@@ -99,7 +104,7 @@ impl<State> TlsListener<State> {
99104
Ok(())
100105
}
101106

102-
fn acceptor(&self) -> Option<&TlsAcceptor> {
107+
fn acceptor(&self) -> Option<&Arc<dyn CustomTlsAcceptor>> {
103108
match self.config {
104109
TlsListenerConfig::Acceptor(ref a) => Some(a),
105110
_ => None,
@@ -125,14 +130,16 @@ impl<State> TlsListener<State> {
125130
fn handle_tls<State: Clone + Send + Sync + 'static>(
126131
app: Server<State>,
127132
stream: TcpStream,
128-
acceptor: TlsAcceptor,
133+
acceptor: Arc<dyn CustomTlsAcceptor>,
129134
) {
130135
task::spawn(async move {
131136
let local_addr = stream.local_addr().ok();
132137
let peer_addr = stream.peer_addr().ok();
133138

134139
match acceptor.accept(stream).await {
135-
Ok(tls_stream) => {
140+
Ok(None) => {}
141+
142+
Ok(Some(tls_stream)) => {
136143
let stream = TlsStreamWrapper::new(tls_stream);
137144
let fut = async_h1::accept(stream, |mut req| async {
138145
if req.url_mut().set_scheme("https").is_err() {

src/tls_listener_builder.rs

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ use async_std::net::TcpListener;
33

44
use rustls::ServerConfig;
55

6-
use super::{TcpConnection, TlsListener, TlsListenerConfig};
6+
use super::{CustomTlsAcceptor, TcpConnection, TlsListener, TlsListenerConfig};
77

88
use std::marker::PhantomData;
99
use std::net::{SocketAddr, ToSocketAddrs};
1010
use std::path::{Path, PathBuf};
11+
use std::sync::Arc;
1112

1213
/// # A builder for TlsListeners
1314
///
@@ -38,6 +39,7 @@ pub struct TlsListenerBuilder<State> {
3839
key: Option<PathBuf>,
3940
cert: Option<PathBuf>,
4041
config: Option<ServerConfig>,
42+
tls_acceptor: Option<Arc<dyn CustomTlsAcceptor>>,
4143
tcp: Option<TcpListener>,
4244
addrs: Option<Vec<SocketAddr>>,
4345
_state: PhantomData<State>,
@@ -49,6 +51,7 @@ impl<State> Default for TlsListenerBuilder<State> {
4951
key: None,
5052
cert: None,
5153
config: None,
54+
tls_acceptor: None,
5255
tcp: None,
5356
addrs: None,
5457
_state: PhantomData,
@@ -69,6 +72,14 @@ impl<State> std::fmt::Debug for TlsListenerBuilder<State> {
6972
"None"
7073
},
7174
)
75+
.field(
76+
"tls_acceptor",
77+
&if self.tls_acceptor.is_some() {
78+
"Some(_)"
79+
} else {
80+
"None"
81+
},
82+
)
7283
.field("tcp", &self.tcp)
7384
.field("addrs", &self.addrs)
7485
.finish()
@@ -108,6 +119,17 @@ impl<State> TlsListenerBuilder<State> {
108119
self
109120
}
110121

122+
/// Provides a custom acceptor for TLS connections. This is mutually
123+
/// exclusive with any of [`TlsListenerBuilder::key`],
124+
/// [`TlsListenerBuilder::cert`], and [`TlsListenerBuilder::config`], but
125+
/// gives total control over accepting TLS connections, including
126+
/// multiplexing other streams or ALPN negotiations on the same TLS
127+
/// connection that tide should ignore.
128+
pub fn tls_acceptor(mut self, acceptor: Arc<dyn CustomTlsAcceptor>) -> Self {
129+
self.tls_acceptor = Some(acceptor);
130+
self
131+
}
132+
111133
/// Provides a bound tcp listener (either async-std or std) to
112134
/// build this tls listener on. This is mutually exclusive with
113135
/// [`TlsListenerBuilder::addrs`], but one of them is mandatory.
@@ -134,26 +156,29 @@ impl<State> TlsListenerBuilder<State> {
134156
/// * either of these is provided, but not both
135157
/// * [`TlsListenerBuilder::tcp`]
136158
/// * [`TlsListenerBuilder::addrs`]
137-
/// * either of these is provided, but not both
159+
/// * exactly one of these is provided
138160
/// * both [`TlsListenerBuilder::cert`] AND [`TlsListenerBuilder::key`]
139161
/// * [`TlsListenerBuilder::config`]
162+
/// * [`TlsListenerBuilder::tls_acceptor`]
140163
pub fn finish(self) -> io::Result<TlsListener<State>> {
141164
let Self {
142165
key,
143166
cert,
144167
config,
168+
tls_acceptor,
145169
tcp,
146170
addrs,
147171
..
148172
} = self;
149173

150-
let config = match (key, cert, config) {
151-
(Some(key), Some(cert), None) => TlsListenerConfig::Paths { key, cert },
152-
(None, None, Some(config)) => TlsListenerConfig::ServerConfig(config),
174+
let config = match (key, cert, config, tls_acceptor) {
175+
(Some(key), Some(cert), None, None) => TlsListenerConfig::Paths { key, cert },
176+
(None, None, Some(config), None) => TlsListenerConfig::ServerConfig(config),
177+
(None, None, None, Some(tls_acceptor)) => TlsListenerConfig::Acceptor(tls_acceptor),
153178
_ => {
154179
return Err(io::Error::new(
155180
io::ErrorKind::InvalidInput,
156-
"either cert + key are required or a ServerConfig",
181+
"need exactly one of cert + key, ServerConfig, or TLS acceptor",
157182
))
158183
}
159184
};

src/tls_listener_config.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use std::fmt::{self, Debug, Formatter};
22

3-
use async_rustls::TlsAcceptor;
43
use rustls::ServerConfig;
54

5+
use super::CustomTlsAcceptor;
6+
67
use std::path::PathBuf;
8+
use std::sync::Arc;
79

810
impl Default for TlsListenerConfig {
911
fn default() -> Self {
@@ -12,7 +14,7 @@ impl Default for TlsListenerConfig {
1214
}
1315
pub(crate) enum TlsListenerConfig {
1416
Unconfigured,
15-
Acceptor(TlsAcceptor),
17+
Acceptor(Arc<dyn CustomTlsAcceptor>),
1618
ServerConfig(ServerConfig),
1719
Paths { cert: PathBuf, key: PathBuf },
1820
}

0 commit comments

Comments
 (0)