Skip to content

Commit

Permalink
Merge pull request #80 from OSSystems/issue-79
Browse files Browse the repository at this point in the history
h1: Fix connection with multiple IPs for a hostname
  • Loading branch information
Fishrock123 authored Mar 12, 2021
2 parents e7375af + e5bbc27 commit bcbc2b2
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 61 deletions.
134 changes: 73 additions & 61 deletions src/h1/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,72 +134,84 @@ impl HttpClient for H1Client {
));
}

let addr = req
.url()
.socket_addrs(|| match req.url().scheme() {
"http" => Some(80),
#[cfg(any(feature = "native-tls", feature = "rustls"))]
"https" => Some(443),
_ => None,
})?
.into_iter()
.next()
.ok_or_else(|| Error::from_str(StatusCode::BadRequest, "missing valid address"))?;
let addrs = req.url().socket_addrs(|| match req.url().scheme() {
"http" => Some(80),
#[cfg(any(feature = "native-tls", feature = "rustls"))]
"https" => Some(443),
_ => None,
})?;

log::trace!("> Scheme: {}", scheme);

match scheme {
"http" => {
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
pool_ref
} else {
let manager = TcpConnection::new(addr);
let pool = Pool::<TcpStream, std::io::Error>::new(
manager,
self.max_concurrent_connections,
);
self.http_pools.insert(addr, pool);
self.http_pools.get(&addr).unwrap()
};

// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
let pool = pool_ref.clone();
std::mem::drop(pool_ref);

let stream = pool.get().await?;
req.set_peer_addr(stream.peer_addr().ok());
req.set_local_addr(stream.local_addr().ok());
client::connect(TcpConnWrapper::new(stream), req).await
}
#[cfg(any(feature = "native-tls", feature = "rustls"))]
"https" => {
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
pool_ref
} else {
let manager = TlsConnection::new(host.clone(), addr);
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
manager,
self.max_concurrent_connections,
);
self.https_pools.insert(addr, pool);
self.https_pools.get(&addr).unwrap()
};

// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
let pool = pool_ref.clone();
std::mem::drop(pool_ref);

let stream = pool
.get()
.await
.map_err(|e| Error::from_str(400, e.to_string()))?;
req.set_peer_addr(stream.get_ref().peer_addr().ok());
req.set_local_addr(stream.get_ref().local_addr().ok());

client::connect(TlsConnWrapper::new(stream), req).await
let max_addrs_idx = addrs.len() - 1;
for (idx, addr) in addrs.into_iter().enumerate() {
let has_another_addr = idx != max_addrs_idx;

match scheme {
"http" => {
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
pool_ref
} else {
let manager = TcpConnection::new(addr);
let pool = Pool::<TcpStream, std::io::Error>::new(
manager,
self.max_concurrent_connections,
);
self.http_pools.insert(addr, pool);
self.http_pools.get(&addr).unwrap()
};

// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
let pool = pool_ref.clone();
std::mem::drop(pool_ref);

let stream = match pool.get().await {
Ok(s) => s,
Err(_) if has_another_addr => continue,
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
};

req.set_peer_addr(stream.peer_addr().ok());
req.set_local_addr(stream.local_addr().ok());
return client::connect(TcpConnWrapper::new(stream), req).await;
}
#[cfg(any(feature = "native-tls", feature = "rustls"))]
"https" => {
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
pool_ref
} else {
let manager = TlsConnection::new(host.clone(), addr);
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
manager,
self.max_concurrent_connections,
);
self.https_pools.insert(addr, pool);
self.https_pools.get(&addr).unwrap()
};

// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
let pool = pool_ref.clone();
std::mem::drop(pool_ref);

let stream = match pool.get().await {
Ok(s) => s,
Err(_) if has_another_addr => continue,
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
};

req.set_peer_addr(stream.get_ref().peer_addr().ok());
req.set_local_addr(stream.get_ref().local_addr().ok());

return client::connect(TlsConnWrapper::new(stream), req).await;
}
_ => unreachable!(),
}
_ => unreachable!(),
}

Err(Error::from_str(
StatusCode::BadRequest,
"missing valid address",
))
}
}

Expand Down
16 changes: 16 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,19 @@ async fn keep_alive() {
client.send(req.clone()).await.unwrap();
client.send(req.clone()).await.unwrap();
}

#[atest]
async fn fallback_to_ipv4() {
let client = DefaultClient::new();
let _mock_guard = mock("GET", "/")
.with_status(200)
.expect_at_least(2)
.create();

// Kips the initial "http://127.0.0.1:" to get only the port number
let mock_port = &mockito::server_url()[17..];

let url = &format!("http://localhost:{}", mock_port);
let req = Request::new(http_types::Method::Get, Url::parse(url).unwrap());
client.send(req.clone()).await.unwrap();
}

0 comments on commit bcbc2b2

Please sign in to comment.