Skip to content

Commit

Permalink
Merge pull request #59 from Fishrock123/async-h1-pooling
Browse files Browse the repository at this point in the history
feat: h1 connection pooling
  • Loading branch information
Fishrock123 authored Feb 12, 2021
2 parents 06994bb + 93bb1db commit 2eed344
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 32 deletions.
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,25 @@ rustdoc-args = ["--cfg", "feature=\"docs\""]
[features]
default = ["h1_client"]
docs = ["h1_client", "curl_client", "wasm_client", "hyper_client"]
h1_client = ["async-h1", "async-std", "async-native-tls"]
h1_client_rustls = ["async-h1", "async-std", "async-tls"]
h1_client = ["async-h1", "async-std", "async-native-tls", "deadpool", "futures"]
h1_client_rustls = ["async-h1", "async-std", "async-tls", "deadpool", "futures"]
native_client = ["curl_client", "wasm_client"]
curl_client = ["isahc", "async-std"]
wasm_client = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-futures", "futures"]
hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util"]

[dependencies]
async-trait = "0.1.37"
dashmap = "4.0.2"
http-types = "2.3.0"
log = "0.4.7"

# h1_client
async-h1 = { version = "2.0.0", optional = true }
async-std = { version = "1.6.0", default-features = false, optional = true }
async-native-tls = { version = "0.3.1", optional = true }
deadpool = { version = "0.7.0", optional = true }
futures = { version = "0.3.8", optional = true }

# h1_client_rustls
async-tls = { version = "0.10.0", optional = true }
Expand Down
123 changes: 93 additions & 30 deletions src/h1.rs → src/h1/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
//! http-client implementation for async-h1.
//! http-client implementation for async-h1, with connecton pooling ("Keep-Alive").
use super::{async_trait, Error, HttpClient, Request, Response};
use std::fmt::Debug;
use std::net::SocketAddr;

use async_h1::client;
use async_std::net::TcpStream;
use dashmap::DashMap;
use deadpool::managed::Pool;
use http_types::StatusCode;

/// Async-h1 based HTTP Client.
#[derive(Debug)]
#[cfg(not(feature = "h1_client_rustls"))]
use async_native_tls::TlsStream;
#[cfg(feature = "h1_client_rustls")]
use async_tls::client::TlsStream;

use super::{async_trait, Error, HttpClient, Request, Response};

mod tcp;
mod tls;

use tcp::{TcpConnWrapper, TcpConnection};
use tls::{TlsConnWrapper, TlsConnection};

// This number is based on a few random benchmarks and see whatever gave decent perf vs resource use.
const DEFAULT_MAX_CONCURRENT_CONNECTIONS: usize = 50;

type HttpPool = DashMap<SocketAddr, Pool<TcpStream, std::io::Error>>;
type HttpsPool = DashMap<SocketAddr, Pool<TlsStream<TcpStream>, Error>>;

/// Async-h1 based HTTP Client, with connecton pooling ("Keep-Alive").
pub struct H1Client {
_priv: (),
http_pools: HttpPool,
https_pools: HttpsPool,
max_concurrent_connections: usize,
}

impl Debug for H1Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("H1Client")
}
}

impl Default for H1Client {
Expand All @@ -20,13 +50,28 @@ impl Default for H1Client {
impl H1Client {
/// Create a new instance.
pub fn new() -> Self {
Self { _priv: () }
Self {
http_pools: DashMap::new(),
https_pools: DashMap::new(),
max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS,
}
}

/// Create a new instance.
pub fn with_max_connections(max: usize) -> Self {
Self {
http_pools: DashMap::new(),
https_pools: DashMap::new(),
max_concurrent_connections: max,
}
}
}

#[async_trait]
impl HttpClient for H1Client {
async fn send(&self, mut req: Request) -> Result<Response, Error> {
req.insert_header("Connection", "keep-alive");

// Insert host
let host = req
.url()
Expand Down Expand Up @@ -57,40 +102,58 @@ impl HttpClient for H1Client {

match scheme {
"http" => {
let stream = async_std::net::TcpStream::connect(addr).await?;
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
pool_ref
} else {
let manager = TcpConnection::new(addr);
let pool = Pool::<TcpStream, std::io::Error>::new(
manager,
self.max_concurrent_connections,
);
self.http_pools.insert(addr, pool);
self.http_pools.get(&addr).unwrap()
};

// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
let pool = pool_ref.clone();
std::mem::drop(pool_ref);

let stream = pool.get().await?;
req.set_peer_addr(stream.peer_addr().ok());
req.set_local_addr(stream.local_addr().ok());
client::connect(stream, req).await
client::connect(TcpConnWrapper::new(stream), req).await
}
"https" => {
let raw_stream = async_std::net::TcpStream::connect(addr).await?;
req.set_peer_addr(raw_stream.peer_addr().ok());
req.set_local_addr(raw_stream.local_addr().ok());
let tls_stream = add_tls(host, raw_stream).await?;
client::connect(tls_stream, req).await
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
pool_ref
} else {
let manager = TlsConnection::new(host.clone(), addr);
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
manager,
self.max_concurrent_connections,
);
self.https_pools.insert(addr, pool);
self.https_pools.get(&addr).unwrap()
};

// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
let pool = pool_ref.clone();
std::mem::drop(pool_ref);

let stream = pool
.get()
.await
.map_err(|e| Error::from_str(400, e.to_string()))?;
req.set_peer_addr(stream.get_ref().peer_addr().ok());
req.set_local_addr(stream.get_ref().local_addr().ok());

client::connect(TlsConnWrapper::new(stream), req).await
}
_ => unreachable!(),
}
}
}

#[cfg(not(feature = "h1_client_rustls"))]
async fn add_tls(
host: String,
stream: async_std::net::TcpStream,
) -> Result<async_native_tls::TlsStream<async_std::net::TcpStream>, async_native_tls::Error> {
async_native_tls::connect(host, stream).await
}

#[cfg(feature = "h1_client_rustls")]
async fn add_tls(
host: String,
stream: async_std::net::TcpStream,
) -> std::io::Result<async_tls::client::TlsStream<async_std::net::TcpStream>> {
let connector = async_tls::TlsConnector::default();
connector.connect(host, stream).await
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
67 changes: 67 additions & 0 deletions src/h1/tcp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::fmt::Debug;
use std::net::SocketAddr;
use std::pin::Pin;

use async_std::net::TcpStream;
use async_trait::async_trait;
use deadpool::managed::{Manager, Object, RecycleResult};
use futures::io::{AsyncRead, AsyncWrite};
use futures::task::{Context, Poll};

#[derive(Clone, Debug)]
pub(crate) struct TcpConnection {
addr: SocketAddr,
}
impl TcpConnection {
pub(crate) fn new(addr: SocketAddr) -> Self {
Self { addr }
}
}

pub(crate) struct TcpConnWrapper {
conn: Object<TcpStream, std::io::Error>,
}
impl TcpConnWrapper {
pub(crate) fn new(conn: Object<TcpStream, std::io::Error>) -> Self {
Self { conn }
}
}

impl AsyncRead for TcpConnWrapper {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut *self.conn).poll_read(cx, buf)
}
}

impl AsyncWrite for TcpConnWrapper {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut *self.conn).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.conn).poll_flush(cx)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.conn).poll_close(cx)
}
}

#[async_trait]
impl Manager<TcpStream, std::io::Error> for TcpConnection {
async fn create(&self) -> Result<TcpStream, std::io::Error> {
Ok(TcpStream::connect(self.addr).await?)
}

async fn recycle(&self, _conn: &mut TcpStream) -> RecycleResult<std::io::Error> {
Ok(())
}
}
91 changes: 91 additions & 0 deletions src/h1/tls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use std::fmt::Debug;
use std::net::SocketAddr;
use std::pin::Pin;

use async_std::net::TcpStream;
use async_trait::async_trait;
use deadpool::managed::{Manager, Object, RecycleResult};
use futures::io::{AsyncRead, AsyncWrite};
use futures::task::{Context, Poll};

#[cfg(not(feature = "h1_client_rustls"))]
use async_native_tls::TlsStream;
#[cfg(feature = "h1_client_rustls")]
use async_tls::client::TlsStream;

use crate::Error;

#[derive(Clone, Debug)]
pub(crate) struct TlsConnection {
host: String,
addr: SocketAddr,
}
impl TlsConnection {
pub(crate) fn new(host: String, addr: SocketAddr) -> Self {
Self { host, addr }
}
}

pub(crate) struct TlsConnWrapper {
conn: Object<TlsStream<TcpStream>, Error>,
}
impl TlsConnWrapper {
pub(crate) fn new(conn: Object<TlsStream<TcpStream>, Error>) -> Self {
Self { conn }
}
}

impl AsyncRead for TlsConnWrapper {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut *self.conn).poll_read(cx, buf)
}
}

impl AsyncWrite for TlsConnWrapper {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut *self.conn).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.conn).poll_flush(cx)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.conn).poll_close(cx)
}
}

#[async_trait]
impl Manager<TlsStream<TcpStream>, Error> for TlsConnection {
async fn create(&self) -> Result<TlsStream<TcpStream>, Error> {
let raw_stream = async_std::net::TcpStream::connect(self.addr).await?;
let tls_stream = add_tls(&self.host, raw_stream).await?;
Ok(tls_stream)
}

async fn recycle(&self, _conn: &mut TlsStream<TcpStream>) -> RecycleResult<Error> {
Ok(())
}
}

#[cfg(not(feature = "h1_client_rustls"))]
async fn add_tls(
host: &str,
stream: TcpStream,
) -> Result<async_native_tls::TlsStream<TcpStream>, async_native_tls::Error> {
async_native_tls::connect(host, stream).await
}

#[cfg(feature = "h1_client_rustls")]
async fn add_tls(host: &str, stream: TcpStream) -> Result<TlsStream<TcpStream>, std::io::Error> {
let connector = async_tls::TlsConnector::default();
connector.connect(host, stream).await
}

0 comments on commit 2eed344

Please sign in to comment.