diff --git a/src/client/file_manager.rs b/src/client/file_manager.rs index 950d43f..104a70e 100644 --- a/src/client/file_manager.rs +++ b/src/client/file_manager.rs @@ -3,6 +3,8 @@ use std::{ os::unix::fs::FileExt, }; +use sha1::Digest; + use crate::metainfo::Info; #[derive(Debug)] @@ -60,4 +62,22 @@ impl FileManager { accumulated_size += *file_size; } } + + pub fn verify_piece(&self, piece_index: usize, hash: &[u8]) -> bool { + let offset = self.piece_length * piece_index as u64; + let mut file_index = 0; + let mut accumulated_size = 0; + while offset >= self.files[file_index].1 + accumulated_size { + accumulated_size += self.files[file_index].1; + file_index += 1; + } + let file = &self.files[file_index].0; + let mut buf = vec![0; self.piece_length as usize]; + file.read_at(&mut buf, offset).unwrap(); + + let mut hasher = sha1::Sha1::new(); + hasher.update(&buf); + let result = hasher.finalize().to_vec(); + hash == result.as_slice() + } } diff --git a/src/client/message.rs b/src/client/message.rs index a2fec8e..40747e5 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -224,7 +224,6 @@ pub async fn receive_message(stream: &TcpStream) -> Result Result { bytes_read += n; } - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + yield_now().await; + } Err(e) => { return Err(ReceiveError::ReceiveError(ReceiveMessageError { error: format!("Failed to read message: {}", e), })); } } - yield_now().await; } let id = message[0]; let payload = message[1..].to_vec(); diff --git a/src/client/mod.rs b/src/client/mod.rs index cb7e938..a2cd5df 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -153,12 +153,13 @@ impl Client { } } - pub async fn download(&mut self) -> Result<(), ClientError> { - self.connect_to_peers(30).await?; + pub async fn download(&mut self, num_peers: u32) -> Result<(), ClientError> { + self.connect_to_peers(num_peers).await?; let _ = tokio::join!( self.send_messages(), self.retrieve_messages(), + self.retrieve_messages(), self.process_messages(), self.keep_alive(), ); @@ -179,6 +180,7 @@ impl Client { tokio::spawn(async move { while *total_downloaded.lock().await < total_length { let Some((peer_id, message)) = receive_queue.lock().await.pop_front() else { + yield_now().await; continue; }; @@ -314,33 +316,24 @@ impl Client { as f64, ); - let peer = peer.lock().await; - if piece_scheduler - .read() - .await - .is_interested(peer.bitfield.as_ref().unwrap()) - { - if peer.peer_choking { + if peer.lock().await.peer_choking { + send_queue.lock().await.push_back(( + peer_id.clone(), + Message::new(MessageId::Interested, &Vec::new()), + )); + } else { + if let Some((index, begin, length)) = + piece_scheduler.write().await.schedule_piece(&peer_id) + { + let mut payload = Vec::new(); + payload.extend_from_slice(&index.to_be_bytes()); + payload.extend_from_slice(&begin.to_be_bytes()); + payload.extend_from_slice(&length.to_be_bytes()); send_queue.lock().await.push_back(( peer_id.clone(), - Message::new(MessageId::Interested, &Vec::new()), + Message::new(MessageId::Request, &payload), )); } else { - if let Some((index, begin, length)) = - piece_scheduler.write().await.schedule_piece(&peer_id) - { - let mut payload = Vec::new(); - payload.extend_from_slice(&index.to_be_bytes()); - payload.extend_from_slice(&begin.to_be_bytes()); - payload.extend_from_slice(&length.to_be_bytes()); - send_queue.lock().await.push_back(( - peer_id.clone(), - Message::new(MessageId::Request, &payload), - )); - } - } - } else { - if !peer.peer_choking { send_queue.lock().await.push_back(( peer_id.clone(), Message::new(MessageId::NotInterested, &Vec::new()), @@ -403,6 +396,7 @@ impl Client { .push_back((peer_id.clone(), message)); } Err(ReceiveError::WouldBlock) => { + yield_now().await; continue; } Err(e) => { @@ -441,6 +435,7 @@ impl Client { tokio::spawn(async move { while *total_downloaded.lock().await < total_length { let Some((peer_id, message)) = send_queue.lock().await.pop_front() else { + yield_now().await; continue; }; @@ -568,9 +563,9 @@ impl Client { Self::validate_handshake(&response, info_hash) } - async fn connect_to_peers(&mut self, min_connections: usize) -> Result<(), ClientError> { + async fn connect_to_peers(&mut self, min_connections: u32) -> Result<(), ClientError> { println!("Connecting to peers..."); - while self.peers.read().await.len() < min_connections { + while self.peers.read().await.len() < min_connections as usize { let mut handles = JoinSet::new(); for peer in self.tracker.get_peers().await.map_err(|e| { @@ -612,7 +607,7 @@ impl Client { Self::initiate_handshake(&mut stream, &handshake, &info_hash, &peer) .await?; - if peers.read().await.len() >= min_connections { + if peers.read().await.len() >= min_connections as usize { return Err(ClientError::GetPeersError(String::from( "Already connected to minimum number of peers", ))); diff --git a/src/client/pieces.rs b/src/client/pieces.rs index 4146c56..5843ee2 100644 --- a/src/client/pieces.rs +++ b/src/client/pieces.rs @@ -55,12 +55,8 @@ impl PieceScheduler { for (i, hash) in piece_hashes.iter().enumerate() { let mut blocks = Vec::new(); let mut offset: u32 = 0; - while offset < (piece_length as u32).min(remaining_size) { - let length = if remaining_size < BLOCK_SIZE { - remaining_size - } else { - BLOCK_SIZE - }; + while offset < piece_length as u32 && remaining_size > 0 { + let length = BLOCK_SIZE.min(remaining_size); let block = Block { begin: offset, length, @@ -128,6 +124,19 @@ impl PieceScheduler { let block = &mut piece.blocks[block_bucket]; self.file_manager.save_block(index, begin, data); block.completed = true; + if piece.blocks.iter().all(|b| b.completed) { + println!("Piece {} completed", piece.index); + piece.completed = true; + self.any_complete = true; + + // if !self.file_manager.verify_piece(index, &piece.hash) { + // println!("Piece {} failed verification", piece.index); + // for block in &mut piece.blocks { + // block.completed = false; + // } + // piece.completed = false; + // } + } } pub fn add_peer_count(&mut self, peer_id: &Vec, bitfield: &Bitfield) { @@ -178,8 +187,8 @@ impl PieceScheduler { (piece.index as u32, block.begin, block.length) }); - if let Some((piece_index, block_begin, _)) = &request { - self.set_requested(*piece_index as usize, *block_begin); + if let Some((piece_index, block_begin, _)) = request { + self.set_requested(piece_index as usize, block_begin); } request diff --git a/src/main.rs b/src/main.rs index 5f9d0f1..e58c61d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,9 @@ struct Args { #[arg(short, long)] output_dir: String, + + #[arg(short, long, default_value_t = 30)] + num_peers: u32, } fn read_file(filename: &str) -> Result, std::io::Error> { @@ -43,7 +46,7 @@ async fn main() { let tracker = Tracker::new(bencode_value).expect("Failed to create tracker"); let mut client = Client::new(tracker, args.output_dir); - match client.download().await { + match client.download(args.num_peers).await { Ok(()) => println!("Download completed"), Err(e) => eprintln!("Error downloading: {}", e), }