Skip to content

Commit

Permalink
fix deadlock kinda
Browse files Browse the repository at this point in the history
  • Loading branch information
menghaoyu2002 committed Jun 1, 2024
1 parent 9e03223 commit aedbbe4
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 46 deletions.
94 changes: 51 additions & 43 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
time::Duration,
};

use chrono::{DateTime, Utc};
use futures::future::join_all;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
Expand Down Expand Up @@ -102,8 +103,9 @@ struct PeerState {
peer_id: Vec<u8>,
stream: Arc<Mutex<TcpStream>>,
bitfield: Option<Bitfield>,
send_queue: Arc<Mutex<VecDeque<(Vec<u8>, Message)>>>,
receive_queue: Arc<Mutex<VecDeque<(Vec<u8>, Message)>>>,
send_queue: Arc<Mutex<VecDeque<Message>>>,
receive_queue: Arc<Mutex<VecDeque<Message>>>,
last_sent: DateTime<Utc>,

am_choking: bool,
am_interested: bool,
Expand All @@ -118,6 +120,8 @@ impl PeerState {
stream: Arc::new(Mutex::new(stream)),
send_queue: Arc::new(Mutex::new(VecDeque::new())),
receive_queue: Arc::new(Mutex::new(VecDeque::new())),
last_sent: Utc::now(),

bitfield: None,
am_choking: true,
am_interested: false,
Expand All @@ -144,35 +148,36 @@ impl Client {
}

pub async fn download(&mut self) -> Result<(), ClientError> {
let peer_ids = self.connect_to_peers(8).await?;
let peer_ids = self.connect_to_peers(30).await?;

let mut handles = Vec::new();
peer_ids.iter().for_each(|id| {
handles.push(self.send_messages(id));
handles.push(self.retrieve_messages(id));
});
for id in peer_ids {
handles.push(self.send_messages(&id).await);
handles.push(self.retrieve_messages(&id));
}

join_all(handles).await;

Ok(())
}

fn retrieve_messages(&self, peer_id: &Vec<u8>) -> JoinHandle<()> {
fn retrieve_messages(&self, peer_id: &Vec<u8>) -> JoinHandle<Vec<u8>> {
let peers = Arc::clone(&self.peers);
let id = peer_id.clone();

tokio::spawn(async move {
let id_to_peer = peers.read().await;
let peer = id_to_peer.get(&id).unwrap();
loop {
let id_to_peer = peers.read().await;
let Some(peer) = id_to_peer.get(&id) else {
break;
};
let mut stream = peer.stream.lock().await;
let Ok(message) = receive_message(&mut stream).await else {
#[cfg(debug_assertions)]
eprintln!(
"Failed to receive message from peer: {}",
String::from_utf8_lossy(&id)
);
peers.write().await.remove(&id);
break;
};

Expand All @@ -182,52 +187,53 @@ impl Client {
String::from_utf8_lossy(&id)
);

peer.receive_queue
.lock()
.await
.push_back((peer.peer_id.clone(), message));
peer.receive_queue.lock().await.push_back(message);
}
// peers.write().await.remove(&id);
id
})
}

fn send_messages(&self, peer_id: &Vec<u8>) -> JoinHandle<()> {
async fn send_messages(&self, peer_id: &Vec<u8>) -> JoinHandle<Vec<u8>> {
let peers = Arc::clone(&self.peers);
let id = peer_id.clone();

tokio::spawn(async move {
let id_to_peer = peers.read().await;
let peer = id_to_peer.get(&id).unwrap();
loop {
let Some((peer_id, message)) = peer.send_queue.lock().await.pop_front() else {
yield_now().await;
continue;
let id_to_peer = peers.read().await;
let Some(peer) = id_to_peer.get(&id) else {
break;
};
let message = match peer.send_queue.lock().await.pop_front() {
Some(m) => m,
None => {
if (peer.last_sent - Utc::now()).num_seconds() > 120 {
Message::new(MessageId::KeepAlive, &Vec::new())
} else {
yield_now().await;
continue;
}
}
};

println!(
"Sending message {} to peer: {}",
message.get_id(),
String::from_utf8_lossy(&peer_id)
String::from_utf8_lossy(&id)
);
let id_to_peer = peers.read().await;
let Some(peer) = id_to_peer.get(&peer_id) else {
#[cfg(debug_assertions)]
eprintln!(
"Failed to get peer with id: {:?}",
String::from_utf8_lossy(&peer_id)
);
break;
};

let mut stream = peer.stream.lock().await;
if let Err(e) = send_message(&mut stream, message).await {
#[cfg(debug_assertions)]
eprintln!(
"Failed to send message to peer {:?}: {}",
String::from_utf8_lossy(&peer_id),
String::from_utf8_lossy(&id),
e.to_string()
);
break;
}
}
// peers.write().await.remove(&id);
id
})
}

Expand Down Expand Up @@ -320,12 +326,14 @@ impl Client {
let mut new_peers = Vec::new();
while self.peers.read().await.len() < min_connections {
let mut handles = JoinSet::new();
for peer in self
.tracker
.get_peers()
.await
.map_err(|_| ClientError::GetPeersError(String::from("Failed to get peers")))?
for peer in
self.tracker.get_peers().await.map_err(|e| {
ClientError::GetPeersError(format!("Failed to get peers: {}", e))
})?
{
if self.peers.read().await.len() >= min_connections {
break;
}
let handshake = self.get_handshake()?;
let info_hash = self.tracker.get_metainfo().get_info_hash().map_err(|_| {
ClientError::GetPeersError(String::from("Failed to get info hash"))
Expand Down Expand Up @@ -367,10 +375,11 @@ impl Client {
}

let peer_state = PeerState::new(&peer_id, stream);
peer_state.send_queue.lock().await.push_back((
peer_id.clone(),
Message::new(MessageId::Bitfield, &bitfield),
));
peer_state
.send_queue
.lock()
.await
.push_back(Message::new(MessageId::Bitfield, &bitfield));
peers.write().await.insert(peer_id.clone(), peer_state);

println!("Connected to peer: {:?}", peer.addr);
Expand All @@ -386,7 +395,6 @@ impl Client {
match conection_result {
Ok(peer_id) => {
new_peers.push(peer_id);
// self.retrieve_messages(&peer_id);
}
Err(e) => {
// #[cfg(debug_assertions)]
Expand Down
33 changes: 30 additions & 3 deletions src/tracker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ use std::{
fmt::{Debug, Display},
net::{IpAddr, Ipv4Addr, SocketAddr},
str::FromStr,
time::Duration,
};

use chrono::{DateTime, Utc};
use rand::Rng;
use tokio::time::sleep;

use crate::{
bencode::{BencodeString, BencodeValue},
Expand All @@ -27,6 +30,7 @@ impl Debug for InvalidResponseError {
}
}

#[derive(Debug)]
pub enum TrackerError {
InvalidMetainfo,
InvalidInfoHash,
Expand All @@ -36,7 +40,7 @@ pub enum TrackerError {
ResponseParseError(String),
}

impl Debug for TrackerError {
impl Display for TrackerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TrackerError::InvalidMetainfo => write!(f, "InvalidMetainfo"),
Expand All @@ -53,6 +57,9 @@ impl Debug for TrackerError {
pub struct Tracker {
metainfo: Metainfo,
peer_id: Vec<u8>,

last_announce: Option<DateTime<Utc>>,
last_interval: Option<i64>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -109,6 +116,8 @@ impl Tracker {
Ok(Self {
metainfo,
peer_id: Tracker::get_peer_id(),
last_announce: None,
last_interval: None,
})
}

Expand All @@ -120,17 +129,34 @@ impl Tracker {
self.peer_id.clone()
}

pub async fn get_peers(&self) -> Result<Peers, TrackerError> {
pub async fn get_peers(&mut self) -> Result<Peers, TrackerError> {
// if let Some(last_announce) = self.last_announce {
// if let Some(last_interval) = self.last_interval {
// let elapsed = Utc::now()
// .signed_duration_since(last_announce)
// .num_seconds();
// println!("{}, {}", last_interval, elapsed);
// if elapsed < last_interval {
// sleep(Duration::from_secs((last_interval - elapsed) as u64)).await;
// }
// }
// }

let response = self.get_announce().await?;
let peers = match response {
TrackerResponse::Success(success_response) => success_response.peers,
TrackerResponse::Success(success_response) => {
self.last_interval = Some(success_response.interval);
success_response.peers
}
TrackerResponse::Failure(failure_response) => {
return Err(TrackerError::GetPeersFailure(
failure_response.failure_reason,
))
}
};

self.last_announce = Some(Utc::now());

Ok(peers)
}

Expand Down Expand Up @@ -333,6 +359,7 @@ impl Tracker {
.as_str(),
);
url.push_str("&port=6881");
url.push_str("&numwant=100");

println!("GET {}", &url);
let response = reqwest::get(&url)
Expand Down

0 comments on commit aedbbe4

Please sign in to comment.