Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(iroh): remove flume from iroh-cli and iroh #2543

Merged
merged 11 commits into from
Aug 2, 2024
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions iroh-blobs/src/downloader/progress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use parking_lot::Mutex;

use crate::{
get::{db::DownloadProgress, progress::TransferState},
util::progress::{FlumeProgressSender, IdGenerator, ProgressSendError, ProgressSender},
util::progress::{AsyncChannelProgressSender, IdGenerator, ProgressSendError, ProgressSender},
};

use super::DownloadKind;

/// The channel that can be used to subscribe to progress updates.
pub type ProgressSubscriber = FlumeProgressSender<DownloadProgress>;
pub type ProgressSubscriber = AsyncChannelProgressSender<DownloadProgress>;

/// Track the progress of downloads.
///
Expand Down
26 changes: 13 additions & 13 deletions iroh-blobs/src/downloader/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
get::{db::BlobId, progress::TransferState},
util::{
local_pool::LocalPool,
progress::{FlumeProgressSender, IdGenerator},
progress::{AsyncChannelProgressSender, IdGenerator},
},
};

Expand Down Expand Up @@ -276,13 +276,13 @@ async fn concurrent_progress() {
let hash = Hash::new([0u8; 32]);
let kind_1 = HashAndFormat::raw(hash);

let (prog_a_tx, prog_a_rx) = flume::bounded(64);
let prog_a_tx = FlumeProgressSender::new(prog_a_tx);
let (prog_a_tx, prog_a_rx) = async_channel::bounded(64);
let prog_a_tx = AsyncChannelProgressSender::new(prog_a_tx);
let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_a_tx);
let handle_a = downloader.queue(req).await;

let (prog_b_tx, prog_b_rx) = flume::bounded(64);
let prog_b_tx = FlumeProgressSender::new(prog_b_tx);
let (prog_b_tx, prog_b_rx) = async_channel::bounded(64);
let prog_b_tx = AsyncChannelProgressSender::new(prog_b_tx);
let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_b_tx);
let handle_b = downloader.queue(req).await;

Expand All @@ -292,21 +292,21 @@ async fn concurrent_progress() {
let mut state_b = TransferState::new(hash);
let mut state_c = TransferState::new(hash);

let prog1_a = prog_a_rx.recv_async().await.unwrap();
let prog1_b = prog_b_rx.recv_async().await.unwrap();
let prog1_a = prog_a_rx.recv().await.unwrap();
let prog1_b = prog_b_rx.recv().await.unwrap();
assert!(matches!(prog1_a, DownloadProgress::Found { hash, size: 100, ..} if hash == hash));
assert!(matches!(prog1_b, DownloadProgress::Found { hash, size: 100, ..} if hash == hash));

state_a.on_progress(prog1_a);
state_b.on_progress(prog1_b);
assert_eq!(state_a, state_b);

let (prog_c_tx, prog_c_rx) = flume::bounded(64);
let prog_c_tx = FlumeProgressSender::new(prog_c_tx);
let (prog_c_tx, prog_c_rx) = async_channel::bounded(64);
let prog_c_tx = AsyncChannelProgressSender::new(prog_c_tx);
let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_c_tx);
let handle_c = downloader.queue(req).await;

let prog1_c = prog_c_rx.recv_async().await.unwrap();
let prog1_c = prog_c_rx.recv().await.unwrap();
assert!(matches!(&prog1_c, DownloadProgress::InitialState(state) if state == &state_a));
state_c.on_progress(prog1_c);

Expand All @@ -317,9 +317,9 @@ async fn concurrent_progress() {
res_b.unwrap();
res_c.unwrap();

let prog_a: Vec<_> = prog_a_rx.into_stream().collect().await;
let prog_b: Vec<_> = prog_b_rx.into_stream().collect().await;
let prog_c: Vec<_> = prog_c_rx.into_stream().collect().await;
let prog_a: Vec<_> = prog_a_rx.collect().await;
let prog_b: Vec<_> = prog_b_rx.collect().await;
let prog_c: Vec<_> = prog_c_rx.collect().await;

assert_eq!(prog_a.len(), 1);
assert_eq!(prog_b.len(), 1);
Expand Down
128 changes: 128 additions & 0 deletions iroh-blobs/src/util/progress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,98 @@ impl<T: Send + Sync + 'static> ProgressSender for FlumeProgressSender<T> {
}
}

/// A progress sender that uses an async channel.
pub struct AsyncChannelProgressSender<T> {
sender: async_channel::Sender<T>,
id: std::sync::Arc<std::sync::atomic::AtomicU64>,
}

impl<T> std::fmt::Debug for AsyncChannelProgressSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncChannelProgressSender")
.field("id", &self.id)
.field("sender", &self.sender)
.finish()
}
}

impl<T> Clone for AsyncChannelProgressSender<T> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
id: self.id.clone(),
}
}
}

impl<T> AsyncChannelProgressSender<T> {
/// Create a new progress sender from an async channel sender.
pub fn new(sender: async_channel::Sender<T>) -> Self {
Self {
sender,
id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}

/// Returns true if `other` sends on the same `async_channel` channel as `self`.
pub fn same_channel(&self, other: &AsyncChannelProgressSender<T>) -> bool {
same_channel(&self.sender, &other.sender)
}
}

/// Given a value that is aligned and sized like a pointer, return the value of
/// the pointer as a usize.
fn get_as_ptr<T>(value: &T) -> Option<usize> {
rklaehn marked this conversation as resolved.
Show resolved Hide resolved
use std::mem;
if mem::size_of::<T>() == std::mem::size_of::<usize>()
&& mem::align_of::<T>() == mem::align_of::<usize>()
{
// SAFETY: size and alignment requirements are checked and met
unsafe { Some(mem::transmute_copy(value)) }
} else {
None
}
}

fn same_channel<T>(a: &async_channel::Sender<T>, b: &async_channel::Sender<T>) -> bool {
// This relies on async_channel::Sender being just a newtype wrapper around
// an Arc<Channel<T>>, so if two senders point to the same channel, the
// pointers will be the same.
get_as_ptr(a).unwrap() == get_as_ptr(b).unwrap()
}

impl<T> IdGenerator for AsyncChannelProgressSender<T> {
fn new_id(&self) -> u64 {
self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
}

impl<T: Send + Sync + 'static> ProgressSender for AsyncChannelProgressSender<T> {
type Msg = T;

async fn send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
self.sender
.send(msg)
.await
.map_err(|_| ProgressSendError::ReceiverDropped)
}

fn try_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
match self.sender.try_send(msg) {
Ok(_) => Ok(()),
Err(async_channel::TrySendError::Full(_)) => Ok(()),
Err(async_channel::TrySendError::Closed(_)) => Err(ProgressSendError::ReceiverDropped),
}
}

fn blocking_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
match self.sender.send_blocking(msg) {
Ok(_) => Ok(()),
Err(_) => Err(ProgressSendError::ReceiverDropped),
}
}
}

/// An error that can occur when sending progress messages.
///
/// Really the only error that can occur is if the receiver is dropped.
Expand Down Expand Up @@ -628,3 +720,39 @@ impl<W: AsyncSliceWriter + 'static, F: Fn(u64, usize) -> io::Result<()> + 'stati
self.0.set_len(size).await
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use super::*;

#[test]
fn get_as_ptr_works() {
struct Wrapper(Arc<u64>);
let x = Wrapper(Arc::new(1u64));
assert_eq!(
get_as_ptr(&x).unwrap(),
Arc::as_ptr(&x.0) as usize - 2 * std::mem::size_of::<usize>()
);
}

#[test]
fn get_as_ptr_wrong_use() {
struct Wrapper(#[allow(dead_code)] u8);
let x = Wrapper(1);
assert!(get_as_ptr(&x).is_none());
}

#[test]
fn test_sender_is_ptr() {
assert_eq!(
std::mem::size_of::<usize>(),
std::mem::size_of::<async_channel::Sender<u8>>()
);
assert_eq!(
std::mem::align_of::<usize>(),
std::mem::align_of::<async_channel::Sender<u8>>()
);
}
}
2 changes: 1 addition & 1 deletion iroh-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ doc = false

[dependencies]
anyhow = "1.0.81"
async-channel = "2.3.1"
bao-tree = "0.13"
bytes = "1.5.0"
clap = { version = "4", features = ["derive"] }
Expand All @@ -33,7 +34,6 @@ crossterm = "0.27.0"
derive_more = { version = "1.0.0-beta.1", features = ["display"] }
dialoguer = { version = "0.11.0", default-features = false }
dirs-next = "2.0.0"
flume = "0.11.0"
futures-buffered = "0.2.4"
futures-lite = "2.3"
futures-util = { version = "0.3.30", features = ["futures-sink"] }
Expand Down
14 changes: 7 additions & 7 deletions iroh-cli/src/commands/doctor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use iroh::{
base::ticket::{BlobTicket, Ticket},
blobs::{
store::{ReadableStore, Store as _},
util::progress::{FlumeProgressSender, ProgressSender},
util::progress::{AsyncChannelProgressSender, ProgressSender},
},
docs::{Capability, DocTicket},
net::{
Expand Down Expand Up @@ -1145,28 +1145,28 @@ pub async fn run(command: Commands, config: &NodeConfig) -> anyhow::Result<()> {
Commands::TicketInspect { ticket, zbase32 } => inspect_ticket(&ticket, zbase32),
Commands::BlobConsistencyCheck { path, repair } => {
let blob_store = iroh::blobs::store::fs::Store::load(path).await?;
let (send, recv) = flume::bounded(1);
let (send, recv) = async_channel::bounded(1);
let task = tokio::spawn(async move {
while let Ok(msg) = recv.recv_async().await {
while let Ok(msg) = recv.recv().await {
println!("{:?}", msg);
}
});
blob_store
.consistency_check(repair, FlumeProgressSender::new(send).boxed())
.consistency_check(repair, AsyncChannelProgressSender::new(send).boxed())
.await?;
task.await?;
Ok(())
}
Commands::BlobValidate { path, repair } => {
let blob_store = iroh::blobs::store::fs::Store::load(path).await?;
let (send, recv) = flume::bounded(1);
let (send, recv) = async_channel::bounded(1);
let task = tokio::spawn(async move {
while let Ok(msg) = recv.recv_async().await {
while let Ok(msg) = recv.recv().await {
println!("{:?}", msg);
}
});
blob_store
.validate(repair, FlumeProgressSender::new(send).boxed())
.validate(repair, AsyncChannelProgressSender::new(send).boxed())
.await?;
task.await?;
Ok(())
Expand Down
Loading
Loading