Skip to content

Commit

Permalink
fix concurrency bottlenecks
Browse files Browse the repository at this point in the history
  • Loading branch information
menghaoyu2002 committed May 30, 2024
1 parent 09f32eb commit e27b6bb
Showing 1 changed file with 114 additions and 99 deletions.
213 changes: 114 additions & 99 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::{
collections::HashMap,
fmt::Display,
io::{Read, Write},
use std::{collections::HashMap, fmt::Display, time::Duration};

use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
time::Duration,
task::JoinSet,
time::timeout,
};

use crate::tracker::{Peer, Tracker};
Expand Down Expand Up @@ -141,53 +142,36 @@ impl Clone for Message {
}
}

struct PeerState {
peer_id: Vec<u8>,
connection: TcpStream,
bitfield: Vec<bool>,

am_choking: bool,
am_interested: bool,
peer_choking: bool,
peer_interested: bool,
}

pub struct Client {
tracker: Tracker,
connections: HashMap<Vec<u8>, TcpStream>,
bitfield: Vec<u8>,
peers: HashMap<Vec<u8>, PeerState>,
bitfield: Vec<bool>,
}

impl Client {
pub fn new(tracker: Tracker) -> Self {
// divide the number of pieces by 8 to get the number of bytes needed to represent the bitfield
let bitfield = vec![0u8; tracker.get_metainfo().get_peices().len().div_ceil(8)];
let bitfield = vec![false; tracker.get_metainfo().get_peices().len().div_ceil(8)];
Self {
tracker,
connections: HashMap::new(),
peers: HashMap::new(),
bitfield,
}
}

pub async fn download(&mut self) -> Result<(), ClientError> {
println!("Starting download...");
// self.connect_to_peers(30).await?;
self.connect_to_peers(1).await?;
self.send_message(Message::new(MessageId::Bitfield, &self.bitfield))?;
self.send_message(Message::new(MessageId::Interested, &Vec::new()))?;
Ok(())
}

fn send_message(&mut self, message: Message) -> Result<(), ClientError> {
let serialized_message = message.serialize();
println!("Sending message: {:?}", serialized_message);
for (_, stream) in self.connections.iter_mut() {
stream.write_all(&serialized_message).map_err(|e| {
ClientError::SendMessageError(SendMessageError {
peer: Peer {
peer_id: None,
addr: stream.peer_addr().unwrap(),
},
message: message.clone(),
error: e.to_string(),
})
})?;

let mut response = vec![0u8; serialized_message.len()];
stream.read_exact(&mut response).map_err(|e| {
ClientError::GetPeersError(format!("Failed to read response: {}", e))
})?;
println!("Response: {:?}", response);
}
self.connect_to_peers(30).await?;
Ok(())
}

Expand Down Expand Up @@ -244,89 +228,120 @@ impl Client {
Ok(peer_id)
}

async fn initiate_handshake(
stream: &mut TcpStream,
handshake: &Vec<u8>,
info_hash: &Vec<u8>,
peer: &Peer,
) -> Result<Vec<u8>, ClientError> {
stream.write_all(handshake).await.map_err(|e| {
ClientError::HandshakeError(HandshakeError {
peer: peer.clone(),
handshake: handshake.to_vec(),
status: HandshakePhase::Send,
message: format!("Failed to send handshake: {}", e),
})
})?;

let mut response = vec![0u8; HANDSHAKE_LEN];
stream.read_exact(&mut response).await.map_err(|e| {
ClientError::HandshakeError(HandshakeError {
peer: peer.clone(),
handshake: handshake.to_vec(),
status: HandshakePhase::Receive,
message: format!("Failed to receive handshake: {}", e),
})
})?;

Self::validate_handshake(&response, info_hash)
}

async fn connect_to_peers(&mut self, min_connections: usize) -> Result<(), ClientError> {
println!("Connecting to peers...");
while self.connections.len() < min_connections {
let mut handles = Vec::new();

while self.peers.len() < min_connections {
let mut handles = JoinSet::new();
for peer in self
.tracker
.get_peers()
.await
.map_err(|_| ClientError::GetPeersError(String::from("Failed to get peers")))?
{
if self.connections.len() >= min_connections {
return Ok(());
}

let handshake = self.get_handshake()?;
let info_hash = self.tracker.get_metainfo().get_info_hash().map_err(|_| {
ClientError::GetPeersError(String::from("Failed to get info hash"))
})?;

let handle = tokio::spawn(async move {
match TcpStream::connect_timeout(&peer.addr, Duration::new(5, 0)) {
Ok(mut stream) => {
stream.write_all(&handshake).map_err(|e| {
ClientError::HandshakeError(HandshakeError {
peer: peer.clone(),
handshake: handshake.clone(),
status: HandshakePhase::Send,
message: e.to_string(),
})
})?;

let mut handshake_response = [0u8; HANDSHAKE_LEN];
stream.read_exact(&mut handshake_response).map_err(|e| {
ClientError::HandshakeError(HandshakeError {
peer: peer.clone(),
handshake: handshake_response.to_vec(),
status: HandshakePhase::Receive,
message: e.to_string(),
})
})?;

let peer_id =
Client::validate_handshake(&handshake_response, &info_hash)?;

Ok((peer_id, stream))
handles.spawn(async move {
let mut stream = match timeout(
Duration::from_secs(5),
TcpStream::connect(peer.addr),
)
.await
{
Ok(Ok(stream)) => stream,
Ok(Err(e)) => {
return Err(ClientError::GetPeersError(format!(
"Failed to connect to peer: {}",
e
)))
}
Err(_) => Err(ClientError::GetPeersError(format!(
"Failed to connect to peer {}",
peer.addr
))),
}
Err(_) => {
return Err(ClientError::GetPeersError(format!(
"Failed to connect to peer: {} - timed out",
peer.addr
)))
}
};

let peer_id =
Self::initiate_handshake(&mut stream, &handshake, &info_hash, &peer)
.await?;

Ok((peer_id, stream))
});
handles.push(handle);
}

for handle in handles {
if self.connections.len() >= min_connections {
handle.abort();
} else {
match handle
.await
.map_err(|e| ClientError::GetPeersError(String::from(e.to_string())))?
{
Ok((peer_id, stream)) => {
println!(
"Connected to peer: {} at {}",
String::from_utf8_lossy(&peer_id),
stream
.peer_addr()
.map(|addr| addr.to_string())
.unwrap_or("Unknown".to_string())
);
self.connections.insert(peer_id, stream);
}
Err(e) => {
#[cfg(debug_assertions)]
eprintln!("{}", e);
}
while let Some(handle) = handles.join_next().await {
let conection_result =
handle.map_err(|e| ClientError::GetPeersError(format!("{}", e)))?;

match conection_result {
Ok((peer_id, stream)) => {
println!(
"Connected to peer: {:?}",
stream.peer_addr().map_err(|e| {
ClientError::GetPeersError(format!(
"Failed to get peer address: {}",
e
))
})?
);
self.peers.insert(
peer_id.clone(),
PeerState {
peer_id,
connection: stream,
bitfield: vec![
false;
self.tracker.get_metainfo().get_peices().len()
],
am_choking: true,
am_interested: false,
peer_choking: true,
peer_interested: false,
},
);
}
Err(e) => {
#[cfg(debug_assertions)]
eprintln!("{}", e);
}
}
}
}

println!("Connected to {} peers", self.peers.len());
Ok(())
}
}

0 comments on commit e27b6bb

Please sign in to comment.