Skip to content

Commit

Permalink
add handshake validation
Browse files Browse the repository at this point in the history
  • Loading branch information
menghaoyu2002 committed May 26, 2024
1 parent c59ea02 commit 698e4e1
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 59 deletions.
149 changes: 92 additions & 57 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,62 @@
use std::{
fmt::Debug,
fmt::{Debug, Display},
io::{Read, Write},
net::TcpStream,
time::Duration,
};

use crate::tracker::{Peer, Tracker};

const HANDSHAKE_LEN: usize = 68;
const PSTR: &[u8; 19] = b"BitTorrent protocol";
const HANDSHAKE_LEN: usize = 49 + PSTR.len();

pub struct PeerConnectionError {
pub peer: Peer,
}

pub enum HandshakePhase {
Send,
Receive,
}

impl Display for HandshakePhase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HandshakePhase::Send => write!(f, "Send"),
HandshakePhase::Receive => write!(f, "Receive"),
}
}
}

pub struct HandshakeError {
peer: Peer,
handshake: Vec<u8>,
status: HandshakePhase,
message: String,
}

pub enum ClientError {
ValidateHandshakeError(String),
GetPeersError(String),
HandshakeError,
HandshakeError(HandshakeError),
}

impl Debug for ClientError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClientError::ValidateHandshakeError(e) => write!(f, "ValidateHandshakeError: {}", e),
ClientError::GetPeersError(e) => write!(f, "GetPeersError: {}", e),
ClientError::HandshakeError => write!(f, "HandshakeError"),
ClientError::HandshakeError(e) => write!(
f,
"HandshakeError: Peer: {}, Status: {}, Message: {} Handshake: {}",
e.peer,
e.status,
e.message,
e.handshake
.iter()
.map(|b| format!("{:02x}", b))
.collect::<String>()
),
}
}
}
Expand All @@ -45,16 +77,16 @@ impl Client {
fn get_handshake(&self) -> Result<Vec<u8>, ClientError> {
let mut handshake = Vec::new();

let pstr = b"BitTorrent protocol";
let info_hash = self
.tracker
.get_metainfo()
.get_info_hash()
.map_err(|_| ClientError::HandshakeError)?;
.map_err(|_| ClientError::GetPeersError(String::from("Failed to get info hash")))?;

let peer_id = self.tracker.peer_id();

handshake.push(pstr.len() as u8);
handshake.extend_from_slice(pstr);
handshake.push(PSTR.len() as u8);
handshake.extend_from_slice(PSTR);
handshake.extend_from_slice(&[0; 8]);
handshake.extend_from_slice(&info_hash);
handshake.extend_from_slice(peer_id.as_bytes());
Expand All @@ -63,7 +95,7 @@ impl Client {
Ok(handshake)
}

fn validate_handshake(&self, handshake: &[u8]) -> Result<String, ClientError> {
fn validate_handshake(handshake: &[u8], info_hash: &Vec<u8>) -> Result<String, ClientError> {
if handshake.len() != HANDSHAKE_LEN {
return Err(ClientError::ValidateHandshakeError(
"Invalid handshake length".to_string(),
Expand All @@ -83,26 +115,25 @@ impl Client {
));
}

if &handshake[20..28] != [0u8; 8] {
return Err(ClientError::ValidateHandshakeError(
"Invalid reserved bytes".to_string(),
));
}

let info_hash = self
.tracker
.get_metainfo()
.get_info_hash()
.map_err(|_| ClientError::HandshakeError)?;
// don't need to validate reserved bytes
// if &handshake[20..28] != [0; 8] {
// return Err(ClientError::ValidateHandshakeError(
// "Invalid reserved bytes".to_string(),
// ));
// }

if &handshake[28..48] != info_hash {
return Err(ClientError::ValidateHandshakeError(
"Invalid info hash".to_string(),
));
}

let peer_id = String::from_utf8(handshake[48..68].to_vec())
.map_err(|_| ClientError::HandshakeError)?;
let peer_id = String::from_utf8(handshake[48..68].to_vec()).map_err(|_| {
ClientError::ValidateHandshakeError(String::from(format!(
"Invalid peer id: {:?}",
String::from_utf8(handshake[48..68].to_vec())
)))
})?;

Ok(peer_id)
}
Expand All @@ -121,42 +152,41 @@ impl Client {
}

let handshake = self.get_handshake()?;
let info_hash = self.tracker.get_metainfo().get_info_hash().map_err(|_| {
ClientError::GetPeersError(String::from("Failed to get info hash"))
})?;

let handle = tokio::spawn(async move {
match TcpStream::connect_timeout(&peer.addr, Duration::new(5, 0)) {
Ok(mut stream) => {
println!("Connected to peer {}", peer.addr);
match stream.write_all(&handshake) {
Ok(_) => {
println!("Handshake sent to peer {}", peer.addr,);
}
Err(_) => {
return Err(format!(
"Failed to send handshake to peer {}",
peer.addr
))
}
}

// if self.validate_handshake(&handshake).is_err() {
// return Err(format!("Invalid handshake from peer {}", peer.addr));
// }

let mut buf = [0u8; HANDSHAKE_LEN];
match stream.read_exact(&mut buf) {
Ok(()) => {
println!("Received handshake from peer {}", peer.addr,);
}
Err(_) => {
return Err(format!(
"Failed to read handshake from peer {}",
peer.addr
))
}
}

Ok(stream)
stream.write_all(&handshake).map_err(|e| {
ClientError::HandshakeError(HandshakeError {
peer: peer.clone(),
handshake: handshake.clone(),
status: HandshakePhase::Send,
message: e.to_string(),
})
})?;

let mut handshake_response = [0u8; HANDSHAKE_LEN];
stream.read_exact(&mut handshake_response).map_err(|e| {
ClientError::HandshakeError(HandshakeError {
peer: peer.clone(),
handshake: handshake_response.to_vec(),
status: HandshakePhase::Receive,
message: e.to_string(),
})
})?;

let peer_id =
Client::validate_handshake(&handshake_response, &info_hash)?;

Ok((peer_id, stream))
}
Err(_) => Err(format!("Failed to connect to peer {}", peer.addr)),
Err(_) => Err(ClientError::GetPeersError(format!(
"Failed to connect to peer {}",
peer.addr
))),
}
});
handles.push(handle);
Expand All @@ -169,13 +199,18 @@ impl Client {

match handle
.await
.map_err(|e| ClientError::GetPeersError(e.to_string()))?
.map_err(|_| ClientError::GetPeersError(String::from("Failed to get peer")))?
{
Ok(stream) => {
Ok((peer_id, stream)) => {
println!(
"Connected to peer {} with id {}",
stream.peer_addr().unwrap(),
peer_id
);
self.connections.push(stream);
}
Err(e) => {
eprintln!("{}", e)
eprintln!("{:?}", e);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async fn main() {
let mut client = Client::new(tracker);

client
.connect_to_peers(1)
.connect_to_peers(10)
.await
.expect("Failed to connect to peers");
}
20 changes: 19 additions & 1 deletion src/tracker/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
fmt::Debug,
fmt::{Debug, Display},
net::{IpAddr, Ipv4Addr, SocketAddr},
str::FromStr,
};
Expand Down Expand Up @@ -61,6 +61,24 @@ pub struct Peer {
pub peer_id: Option<String>,
}

impl Clone for Peer {
fn clone(&self) -> Self {
Self {
addr: self.addr,
peer_id: self.peer_id.clone(),
}
}
}

impl Display for Peer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.peer_id {
Some(peer_id) => write!(f, "{}: {}", peer_id, self.addr),
None => write!(f, "{}", self.addr),
}
}
}

pub type Peers = Vec<Peer>;

#[derive(Debug)]
Expand Down

0 comments on commit 698e4e1

Please sign in to comment.