Skip to content

Commit

Permalink
add messages and bitfield
Browse files Browse the repository at this point in the history
  • Loading branch information
menghaoyu2002 committed May 31, 2024
1 parent e27b6bb commit d730e43
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 10 deletions.
132 changes: 132 additions & 0 deletions src/client/bitfield.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
use core::fmt;
use std::fmt::{Debug, Display, Formatter};

#[derive(Debug)]
pub struct Bitfield {
bitfield: Vec<bool>,
}

#[derive(Debug, PartialEq, Eq)]
pub struct OutOfBoundsError {
index: usize,
len: usize,
}

impl Display for OutOfBoundsError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"Index out of bounds: index = {}, len = {}",
self.index, self.len
)
}
}

impl Bitfield {
pub fn new(size: usize) -> Self {
let bitfield = vec![false; size];
Self { bitfield }
}

pub fn len(&self) -> usize {
self.bitfield.len()
}

pub fn set(&mut self, index: usize, value: bool) -> Result<(), OutOfBoundsError> {
if index >= self.bitfield.len() {
return Err(OutOfBoundsError {
index,
len: self.len(),
});
}
self.bitfield[index] = value;
Ok(())
}

pub fn is_set(&self, index: usize) -> Result<bool, OutOfBoundsError> {
if index >= self.bitfield.len() {
return Err(OutOfBoundsError {
index,
len: self.len(),
});
}

Ok(self.bitfield[index])
}

pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
for chunk in self.bitfield.chunks(8) {
let mut byte = 0;
for (i, &bit) in chunk.iter().enumerate() {
if bit {
byte |= 1 << (7 - i);
}
}
bytes.push(byte);
}
bytes
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_bitfield() {
let mut bitfield = Bitfield::new(10);
assert_eq!(bitfield.len(), 10);
assert_eq!(bitfield.is_set(0).unwrap(), false);
assert_eq!(bitfield.is_set(1).unwrap(), false);
assert_eq!(bitfield.is_set(2).unwrap(), false);
assert_eq!(bitfield.is_set(3).unwrap(), false);
assert_eq!(bitfield.is_set(4).unwrap(), false);
assert_eq!(bitfield.is_set(5).unwrap(), false);
assert_eq!(bitfield.is_set(6).unwrap(), false);
assert_eq!(bitfield.is_set(7).unwrap(), false);
assert_eq!(bitfield.is_set(8).unwrap(), false);
assert_eq!(bitfield.is_set(9).unwrap(), false);

bitfield.set(0, true).unwrap();
bitfield.set(1, true).unwrap();
bitfield.set(2, true).unwrap();
bitfield.set(3, true).unwrap();
bitfield.set(4, true).unwrap();
bitfield.set(5, true).unwrap();
bitfield.set(6, true).unwrap();
bitfield.set(7, true).unwrap();
bitfield.set(8, true).unwrap();
bitfield.set(9, true).unwrap();

assert_eq!(bitfield.is_set(0).unwrap(), true);
assert_eq!(bitfield.is_set(1).unwrap(), true);
assert_eq!(bitfield.is_set(2).unwrap(), true);
assert_eq!(bitfield.is_set(3).unwrap(), true);
assert_eq!(bitfield.is_set(4).unwrap(), true);
assert_eq!(bitfield.is_set(5).unwrap(), true);
assert_eq!(bitfield.is_set(6).unwrap(), true);
assert_eq!(bitfield.is_set(7).unwrap(), true);
assert_eq!(bitfield.is_set(8).unwrap(), true);
assert_eq!(bitfield.is_set(9).unwrap(), true);

let bytes = bitfield.to_bytes();
assert_eq!(bytes, vec![0b11111111, 0b11000000]);

bitfield.set(7, false).unwrap();
bitfield.set(3, false).unwrap();

assert_eq!(
bitfield.set(23, false),
Err(OutOfBoundsError { index: 23, len: 10 })
);

assert_eq!(
bitfield.is_set(23),
Err(OutOfBoundsError { index: 23, len: 10 })
);

let bytes = bitfield.to_bytes();
assert_eq!(bytes, vec![0b11101110, 0b11000000]);
}
}
96 changes: 86 additions & 10 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ use tokio::{
time::timeout,
};

mod bitfield;

use crate::tracker::{Peer, Tracker};

use self::bitfield::Bitfield;

const PSTR: &[u8; 19] = b"BitTorrent protocol";
const HANDSHAKE_LEN: usize = 49 + PSTR.len();
const REQUEST_LEN: u32 = 2 << 14;
Expand Down Expand Up @@ -69,7 +73,7 @@ pub struct HandshakeError {
}

pub struct SendMessageError {
peer: Peer,
peer_id: Vec<u8>,
message: Message,
error: String,
}
Expand Down Expand Up @@ -100,7 +104,7 @@ impl Display for ClientError {
ClientError::SendMessageError(e) => write!(
f,
"SendMessageError: Peer: {}, Message: {:02x?}, Error: {}",
e.peer,
String::from_utf8_lossy(&e.peer_id),
e.message.serialize(),
e.error
),
Expand Down Expand Up @@ -145,7 +149,7 @@ impl Clone for Message {
struct PeerState {
peer_id: Vec<u8>,
connection: TcpStream,
bitfield: Vec<bool>,
bitfield: Bitfield,

am_choking: bool,
am_interested: bool,
Expand All @@ -156,13 +160,12 @@ struct PeerState {
pub struct Client {
tracker: Tracker,
peers: HashMap<Vec<u8>, PeerState>,
bitfield: Vec<bool>,
bitfield: Bitfield,
}

impl Client {
pub fn new(tracker: Tracker) -> Self {
// divide the number of pieces by 8 to get the number of bytes needed to represent the bitfield
let bitfield = vec![false; tracker.get_metainfo().get_peices().len().div_ceil(8)];
let bitfield = Bitfield::new(tracker.get_metainfo().get_peices().len());
Self {
tracker,
peers: HashMap::new(),
Expand All @@ -175,6 +178,80 @@ impl Client {
Ok(())
}

async fn initialize_bitfields(&mut self, peer_id: &Vec<u8>) -> Result<(), ClientError> {
self.send_message(
peer_id,
Message::new(MessageId::Bitfield, &self.bitfield.to_bytes()),
)
.await?;

Ok(())
}

async fn send_message(
&mut self,
peer_id: &Vec<u8>,
message: Message,
) -> Result<(), ClientError> {
let peer = self.peers.get_mut(peer_id).ok_or_else(|| {
ClientError::SendMessageError(SendMessageError {
peer_id: peer_id.clone(),
message: message.clone(),
error: "Peer not found".to_string(),
})
})?;

peer.connection
.write_all(&message.serialize())
.await
.map_err(|e| {
ClientError::SendMessageError(SendMessageError {
peer_id: peer_id.clone(),
message: message.clone(),
error: format!("Failed to send message: {}", e),
})
})?;

Ok(())
}

async fn receive_message(&mut self, peer_id: &Vec<u8>) -> Result<Message, ClientError> {
let peer = self.peers.get_mut(peer_id).ok_or_else(|| {
ClientError::SendMessageError(SendMessageError {
peer_id: peer_id.clone(),
message: Message::new(MessageId::Choke, &Vec::new()),
error: "Peer not found".to_string(),
})
})?;

let mut len = [0u8; 4];
peer.connection.read_exact(&mut len).await.map_err(|e| {
ClientError::SendMessageError(SendMessageError {
peer_id: peer_id.clone(),
message: Message::new(MessageId::Choke, &Vec::new()),
error: format!("Failed to read message length: {}", e),
})
})?;

let len = u32::from_be_bytes(len);
let mut message = vec![0u8; len as usize];
peer.connection
.read_exact(&mut message)
.await
.map_err(|e| {
ClientError::SendMessageError(SendMessageError {
peer_id: peer_id.clone(),
message: Message::new(MessageId::Choke, &Vec::new()),
error: format!("Failed to read message: {}", e),
})
})?;

let id = message[0];
let payload = message[1..].to_vec();

Ok(Message { len, id, payload })
}

fn get_handshake(&self) -> Result<Vec<u8>, ClientError> {
let mut handshake = Vec::new();

Expand Down Expand Up @@ -322,10 +399,9 @@ impl Client {
PeerState {
peer_id,
connection: stream,
bitfield: vec![
false;
self.tracker.get_metainfo().get_peices().len()
],
bitfield: Bitfield::new(
self.tracker.get_metainfo().get_peices().len(),
),
am_choking: true,
am_interested: false,
peer_choking: true,
Expand Down

0 comments on commit d730e43

Please sign in to comment.