Skip to content

Commit

Permalink
add peices schedular
Browse files Browse the repository at this point in the history
  • Loading branch information
menghaoyu2002 committed Jun 5, 2024
1 parent 33cea86 commit e6916a4
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 16 deletions.
4 changes: 4 additions & 0 deletions src/client/bitfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ impl Bitfield {
self.bitfield.len()
}

pub fn iter(&self) -> std::slice::Iter<bool> {
self.bitfield.iter()
}

pub fn set(&mut self, index: usize, value: bool) -> Result<(), OutOfBoundsError> {
if index >= self.bitfield.len() {
return Err(OutOfBoundsError {
Expand Down
38 changes: 24 additions & 14 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{
};

use chrono::{DateTime, Utc};
use pieces::Pieces;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
Expand All @@ -16,6 +17,7 @@ use tokio::{

mod bitfield;
mod message;
mod pieces;

use crate::{
client::message::{receive_message, send_message},
Expand All @@ -29,7 +31,6 @@ use self::{

const PSTR: &[u8; 19] = b"BitTorrent protocol";
const HANDSHAKE_LEN: usize = 49 + PSTR.len();
const REQUEST_LEN: u32 = 2 << 14;

pub struct PeerConnectionError {
pub peer: Peer,
Expand Down Expand Up @@ -129,43 +130,46 @@ impl PeerState {
pub struct Client {
tracker: Tracker,
peers: Arc<RwLock<HashMap<Vec<u8>, Arc<RwLock<PeerState>>>>>,
bitfield: Bitfield,
pieces: Arc<RwLock<Pieces>>,
send_queue: Arc<Mutex<VecDeque<(Vec<u8>, Message)>>>,
receive_queue: Arc<Mutex<VecDeque<(Vec<u8>, Message)>>>,
}

impl Client {
pub fn new(tracker: Tracker) -> Self {
let bitfield = Bitfield::new(tracker.get_metainfo().get_peices().len());
let pieces = Pieces::new(&tracker.get_metainfo().info);
Self {
tracker,
peers: Arc::new(RwLock::new(HashMap::new())),
bitfield,
pieces: Arc::new(RwLock::new(pieces)),
send_queue: Arc::new(Mutex::new(VecDeque::new())),
receive_queue: Arc::new(Mutex::new(VecDeque::new())),
}
}

pub async fn download(&mut self) -> Result<(), ClientError> {
self.connect_to_peers(30).await?;
self.connect_to_peers(10).await?;

let _ = tokio::join!(
self.send_messages(),
self.retrieve_messages(),
self.keep_alive(),
self.process_messages(),
self.keep_alive(),
);

Ok(())
}

fn process_messages(&self) -> JoinHandle<()> {
async fn process_messages(&self) -> JoinHandle<()> {
let peers = Arc::clone(&self.peers);
let receive_queue = Arc::clone(&self.receive_queue);
let bitfield_len = self.bitfield.len();
let pieces = Arc::clone(&self.pieces);
let num_pieces = self.pieces.read().await.len();

tokio::spawn(async move {
loop {
let Some((peer_id, message)) = receive_queue.lock().await.pop_front() else {
yield_now().await;
continue;
};

Expand Down Expand Up @@ -200,7 +204,7 @@ impl Client {
let payload = message.get_payload();
let piece_index = u32::from_be_bytes(payload[0..4].try_into().unwrap());
if peer.write().await.bitfield.is_none() {
peer.write().await.bitfield = Some(Bitfield::new(bitfield_len));
peer.write().await.bitfield = Some(Bitfield::new(num_pieces));
};

should_remove = peer
Expand All @@ -211,13 +215,20 @@ impl Client {
.unwrap()
.set(piece_index as usize, true)
.is_err();

pieces
.write()
.await
.add_peer_have(&peer_id, piece_index as usize);
}
MessageId::Bitfield => {
let payload = message.get_payload();
if payload.len() * 8 < bitfield_len {
if payload.len() * 8 < num_pieces {
println!("Invalid bitfield length, disconnecting peer...");
should_remove = true;
} else {
let bitfield = Bitfield::from_bytes(payload, bitfield_len);
let bitfield = Bitfield::from_bytes(payload, num_pieces);
pieces.write().await.add_peer_count(&peer_id, &bitfield);
peer.write().await.bitfield = Some(bitfield);
}
}
Expand Down Expand Up @@ -275,6 +286,7 @@ impl Client {
.push_back((peer_id.clone(), message));
}
Err(ReceiveError::WouldBlock) => {
yield_now().await;
continue;
}
Err(e) => {
Expand Down Expand Up @@ -351,8 +363,6 @@ impl Client {
}
}
}

// yield_now().await;
}
})
}
Expand Down Expand Up @@ -451,7 +461,7 @@ impl Client {
let info_hash = self.tracker.get_metainfo().get_info_hash().map_err(|_| {
ClientError::GetPeersError(String::from("Failed to get info hash"))
})?;
let bitfield = self.bitfield.to_bytes();
let bitfield = self.pieces.read().await.to_bitfield().to_bytes();

let peers = Arc::clone(&mut self.peers);
let send_queue = Arc::clone(&self.send_queue);
Expand Down
167 changes: 167 additions & 0 deletions src/client/pieces.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
use std::collections::HashSet;

use crate::metainfo::Info;

use super::bitfield::Bitfield;

pub const BLOCK_SIZE: u32 = 2 << 13; // 16KB

#[derive(Debug)]
pub struct Block {
begin: u32,
length: u32,
data: Vec<u8>,
requested: bool,
}

#[derive(Debug)]
pub struct Piece {
index: usize,
blocks: Vec<Block>,
hash: Vec<u8>,
completed: bool,
peers: HashSet<Vec<u8>>,
}

#[derive(Debug)]
pub struct Pieces {
pieces: Vec<Piece>,
any_complete: bool,
}

impl Pieces {
pub fn new(info_dict: &Info) -> Self {
let (piece_hashes, piece_length, file_size) = match info_dict {
Info::SingleFile(info) => (
info.base_info.pieces.clone(),
info.base_info.piece_length,
info.length,
),
Info::MultiFile(info) => (
info.base_info.pieces.clone(),
info.base_info.piece_length,
info.files.iter().map(|f| f.length).sum(),
),
};

assert!(
piece_length as u32 % BLOCK_SIZE == 0,
"piece length must be a multiple of the block size"
);

let mut remaining_size = file_size as u32;
let mut pieces = Vec::new();
for (i, hash) in piece_hashes.iter().enumerate() {
let mut blocks = Vec::new();
let mut offset: u32 = 0;
while offset < (piece_length as u32).min(remaining_size) {
let length = if remaining_size < BLOCK_SIZE {
remaining_size
} else {
BLOCK_SIZE
};
let block = Block {
begin: offset,
length,
data: Vec::new(),
requested: false,
};
blocks.push(block);

remaining_size -= length;
offset += length;
}

let piece = Piece {
index: i,
blocks,
hash: hash.to_vec(),
completed: false,
peers: HashSet::new(),
};
pieces.push(piece);
}

Self {
pieces,
any_complete: false,
}
}

pub fn len(&self) -> usize {
self.pieces.len()
}

pub fn to_bitfield(&self) -> Bitfield {
let mut bitfield = Bitfield::new(self.len());
for piece in &self.pieces {
bitfield.set(piece.index, piece.completed).unwrap();
}
bitfield
}

fn get_rarest_noncompleted_piece(&self) -> &Piece {
self.pieces
.iter()
.filter(|p| !p.completed && p.blocks.iter().any(|b| !b.requested))
.min_by_key(|p| p.peers.len())
.unwrap()
}

pub fn set_requested(&mut self, index: usize, begin: u32) {
let piece = &mut self.pieces[index];

let block_bucket: usize = begin.div_ceil(BLOCK_SIZE).try_into().unwrap();
let block = &mut piece.blocks[block_bucket];
block.requested = true;
}

pub fn set_block(&mut self, index: usize, begin: u32, data: Vec<u8>) {
let piece = &mut self.pieces[index];

let block_bucket: usize = begin.div_ceil(BLOCK_SIZE).try_into().unwrap();
let block = &mut piece.blocks[block_bucket];
block.data = data;
}

pub fn add_peer_count(&mut self, peer_id: &Vec<u8>, bitfield: &Bitfield) {
for (i, bit) in bitfield.iter().enumerate() {
if *bit {
self.pieces[i].peers.insert(peer_id.clone());
}
}
}

pub fn add_peer_have(&mut self, peer_id: &Vec<u8>, i: usize) {
self.pieces[i].peers.insert(peer_id.clone());
}

pub fn remove_peer_count(&mut self, peer_id: &Vec<u8>, bitfield: &Bitfield) {
for (i, bit) in bitfield.iter().enumerate() {
if *bit {
self.pieces[i].peers.remove(peer_id);
}
}
}

pub fn get_block_to_download(&mut self) -> (usize, u32, u32) {
let piece = if !self.any_complete {
let pieces = self
.pieces
.iter()
.filter(|p| !p.completed && p.blocks.iter().any(|b| !b.requested))
.collect::<Vec<&Piece>>();
pieces[rand::random::<usize>() % pieces.len()]
} else {
self.get_rarest_noncompleted_piece()
};

let block = piece
.blocks
.iter()
.find(|b| !b.requested && b.data.is_empty())
.unwrap();

(piece.index, block.begin, block.length)
}
}
2 changes: 0 additions & 2 deletions src/tracker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ use std::{
fmt::{Debug, Display},
net::{IpAddr, Ipv4Addr, SocketAddr},
str::FromStr,
time::Duration,
};

use chrono::{DateTime, Utc};
use rand::Rng;
use tokio::time::sleep;

use crate::{
bencode::{BencodeString, BencodeValue},
Expand Down

0 comments on commit e6916a4

Please sign in to comment.