diff --git a/src/client/mod.rs b/src/client/mod.rs index 86e8fbf..bc7a84b 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,9 +1,10 @@ -use std::{ - collections::HashMap, - fmt::Display, - io::{Read, Write}, +use std::{collections::HashMap, fmt::Display, time::Duration}; + +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream, - time::Duration, + task::JoinSet, + time::timeout, }; use crate::tracker::{Peer, Tracker}; @@ -141,53 +142,36 @@ impl Clone for Message { } } +struct PeerState { + peer_id: Vec, + connection: TcpStream, + bitfield: Vec, + + am_choking: bool, + am_interested: bool, + peer_choking: bool, + peer_interested: bool, +} + pub struct Client { tracker: Tracker, - connections: HashMap, TcpStream>, - bitfield: Vec, + peers: HashMap, PeerState>, + bitfield: Vec, } impl Client { pub fn new(tracker: Tracker) -> Self { // divide the number of pieces by 8 to get the number of bytes needed to represent the bitfield - let bitfield = vec![0u8; tracker.get_metainfo().get_peices().len().div_ceil(8)]; + let bitfield = vec![false; tracker.get_metainfo().get_peices().len().div_ceil(8)]; Self { tracker, - connections: HashMap::new(), + peers: HashMap::new(), bitfield, } } pub async fn download(&mut self) -> Result<(), ClientError> { - println!("Starting download..."); - // self.connect_to_peers(30).await?; - self.connect_to_peers(1).await?; - self.send_message(Message::new(MessageId::Bitfield, &self.bitfield))?; - self.send_message(Message::new(MessageId::Interested, &Vec::new()))?; - Ok(()) - } - - fn send_message(&mut self, message: Message) -> Result<(), ClientError> { - let serialized_message = message.serialize(); - println!("Sending message: {:?}", serialized_message); - for (_, stream) in self.connections.iter_mut() { - stream.write_all(&serialized_message).map_err(|e| { - ClientError::SendMessageError(SendMessageError { - peer: Peer { - peer_id: None, - addr: stream.peer_addr().unwrap(), - }, - message: message.clone(), - error: e.to_string(), - }) - })?; - - let mut response = vec![0u8; serialized_message.len()]; - stream.read_exact(&mut response).map_err(|e| { - ClientError::GetPeersError(format!("Failed to read response: {}", e)) - })?; - println!("Response: {:?}", response); - } + self.connect_to_peers(30).await?; Ok(()) } @@ -244,89 +228,120 @@ impl Client { Ok(peer_id) } + async fn initiate_handshake( + stream: &mut TcpStream, + handshake: &Vec, + info_hash: &Vec, + peer: &Peer, + ) -> Result, ClientError> { + stream.write_all(handshake).await.map_err(|e| { + ClientError::HandshakeError(HandshakeError { + peer: peer.clone(), + handshake: handshake.to_vec(), + status: HandshakePhase::Send, + message: format!("Failed to send handshake: {}", e), + }) + })?; + + let mut response = vec![0u8; HANDSHAKE_LEN]; + stream.read_exact(&mut response).await.map_err(|e| { + ClientError::HandshakeError(HandshakeError { + peer: peer.clone(), + handshake: handshake.to_vec(), + status: HandshakePhase::Receive, + message: format!("Failed to receive handshake: {}", e), + }) + })?; + + Self::validate_handshake(&response, info_hash) + } + async fn connect_to_peers(&mut self, min_connections: usize) -> Result<(), ClientError> { println!("Connecting to peers..."); - while self.connections.len() < min_connections { - let mut handles = Vec::new(); + + while self.peers.len() < min_connections { + let mut handles = JoinSet::new(); for peer in self .tracker .get_peers() .await .map_err(|_| ClientError::GetPeersError(String::from("Failed to get peers")))? { - if self.connections.len() >= min_connections { - return Ok(()); - } - let handshake = self.get_handshake()?; let info_hash = self.tracker.get_metainfo().get_info_hash().map_err(|_| { ClientError::GetPeersError(String::from("Failed to get info hash")) })?; - let handle = tokio::spawn(async move { - match TcpStream::connect_timeout(&peer.addr, Duration::new(5, 0)) { - Ok(mut stream) => { - stream.write_all(&handshake).map_err(|e| { - ClientError::HandshakeError(HandshakeError { - peer: peer.clone(), - handshake: handshake.clone(), - status: HandshakePhase::Send, - message: e.to_string(), - }) - })?; - - let mut handshake_response = [0u8; HANDSHAKE_LEN]; - stream.read_exact(&mut handshake_response).map_err(|e| { - ClientError::HandshakeError(HandshakeError { - peer: peer.clone(), - handshake: handshake_response.to_vec(), - status: HandshakePhase::Receive, - message: e.to_string(), - }) - })?; - - let peer_id = - Client::validate_handshake(&handshake_response, &info_hash)?; - - Ok((peer_id, stream)) + handles.spawn(async move { + let mut stream = match timeout( + Duration::from_secs(5), + TcpStream::connect(peer.addr), + ) + .await + { + Ok(Ok(stream)) => stream, + Ok(Err(e)) => { + return Err(ClientError::GetPeersError(format!( + "Failed to connect to peer: {}", + e + ))) } - Err(_) => Err(ClientError::GetPeersError(format!( - "Failed to connect to peer {}", - peer.addr - ))), - } + Err(_) => { + return Err(ClientError::GetPeersError(format!( + "Failed to connect to peer: {} - timed out", + peer.addr + ))) + } + }; + + let peer_id = + Self::initiate_handshake(&mut stream, &handshake, &info_hash, &peer) + .await?; + + Ok((peer_id, stream)) }); - handles.push(handle); } - for handle in handles { - if self.connections.len() >= min_connections { - handle.abort(); - } else { - match handle - .await - .map_err(|e| ClientError::GetPeersError(String::from(e.to_string())))? - { - Ok((peer_id, stream)) => { - println!( - "Connected to peer: {} at {}", - String::from_utf8_lossy(&peer_id), - stream - .peer_addr() - .map(|addr| addr.to_string()) - .unwrap_or("Unknown".to_string()) - ); - self.connections.insert(peer_id, stream); - } - Err(e) => { - #[cfg(debug_assertions)] - eprintln!("{}", e); - } + while let Some(handle) = handles.join_next().await { + let conection_result = + handle.map_err(|e| ClientError::GetPeersError(format!("{}", e)))?; + + match conection_result { + Ok((peer_id, stream)) => { + println!( + "Connected to peer: {:?}", + stream.peer_addr().map_err(|e| { + ClientError::GetPeersError(format!( + "Failed to get peer address: {}", + e + )) + })? + ); + self.peers.insert( + peer_id.clone(), + PeerState { + peer_id, + connection: stream, + bitfield: vec![ + false; + self.tracker.get_metainfo().get_peices().len() + ], + am_choking: true, + am_interested: false, + peer_choking: true, + peer_interested: false, + }, + ); + } + Err(e) => { + #[cfg(debug_assertions)] + eprintln!("{}", e); } } } } + println!("Connected to {} peers", self.peers.len()); Ok(()) } }