From a6fdec99aa195de886f3bc4e50404749e46bdb60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=87a=C4=9Fatay=20Yi=C4=9Fit=20=C5=9Eahin?= Date: Tue, 17 Sep 2024 13:47:32 +0300 Subject: [PATCH] refactor(virtq): remove intermediate virtqueue channels Rather than pushing used buffers from the queues to channels, return them directly to the drivers. This removes the dependency on `async-channel`. --- Cargo.lock | 19 ------ Cargo.toml | 1 - src/drivers/net/virtio/mod.rs | 56 +++--------------- src/drivers/virtio/virtqueue/mod.rs | 38 +++++------- src/drivers/virtio/virtqueue/packed.rs | 39 +++++-------- src/drivers/virtio/virtqueue/split.rs | 81 +++++++++++--------------- src/drivers/vsock/mod.rs | 68 ++++----------------- 7 files changed, 78 insertions(+), 224 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8a7b0a07ef..e51e53dc00 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -130,18 +130,6 @@ dependencies = [ "bitflags 2.6.0", ] -[[package]] -name = "async-channel" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a" -dependencies = [ - "concurrent-queue", - "event-listener-strategy", - "futures-core", - "pin-project-lite", -] - [[package]] name = "async-lock" version = "3.4.0" @@ -519,12 +507,6 @@ dependencies = [ "zerocopy-derive", ] -[[package]] -name = "futures-core" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" - [[package]] name = "generic_once_cell" version = "0.1.1" @@ -613,7 +595,6 @@ dependencies = [ "anstyle", "anyhow", "arm-gic", - "async-channel", "async-lock", "async-trait", "bit_field", diff --git a/Cargo.toml b/Cargo.toml index 4f43cde124..8ece97576c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,7 +76,6 @@ virtio = { package = "virtio-spec", version = "0.1", features = ["alloc", "mmio" ahash = { version = "0.8", default-features = false } align-address = "0.3" anstyle = { version = "1", default-features = false } -async-channel = { version = "2.3", default-features = false } async-lock = { version = "3.4.0", default-features = false } async-trait = "0.1.83" bit_field = "0.10" diff --git a/src/drivers/net/virtio/mod.rs b/src/drivers/net/virtio/mod.rs index 4c39d12b8b..b507e39c9b 100644 --- a/src/drivers/net/virtio/mod.rs +++ b/src/drivers/net/virtio/mod.rs @@ -56,15 +56,11 @@ impl CtrlQueue { pub struct RxQueues { vqs: Vec>, - poll_sender: async_channel::Sender, - poll_receiver: async_channel::Receiver, packet_size: u32, } impl RxQueues { pub fn new(vqs: Vec>, dev_cfg: &NetDevCfg) -> Self { - let (poll_sender, poll_receiver) = async_channel::unbounded(); - // See Virtio specification v1.1 - 5.1.6.3.1 // let packet_size = if dev_cfg.features.contains(virtio::net::F::MRG_RXBUF) { @@ -73,12 +69,7 @@ impl RxQueues { dev_cfg.raw.as_ptr().mtu().read().to_ne().into() }; - Self { - vqs, - poll_sender, - poll_receiver, - packet_size, - } + Self { vqs, packet_size } } /// Takes care of handling packets correctly which need some processing after being received. @@ -95,32 +86,12 @@ impl RxQueues { fn add(&mut self, mut vq: Box) { const BUFF_PER_PACKET: u16 = 2; let num_packets: u16 = u16::from(vq.size()) / BUFF_PER_PACKET; - fill_queue( - vq.as_mut(), - num_packets, - self.packet_size, - self.poll_sender.clone(), - ); + fill_queue(vq.as_mut(), num_packets, self.packet_size); self.vqs.push(vq); } fn get_next(&mut self) -> Option { - let transfer = self.poll_receiver.try_recv(); - - transfer - .or_else(|_| { - // Check if any not yet provided transfers are in the queue. - self.poll(); - - self.poll_receiver.try_recv() - }) - .ok() - } - - fn poll(&mut self) { - for vq in &mut self.vqs { - vq.poll(); - } + self.vqs[0].try_recv().ok() } fn enable_notifs(&mut self) { @@ -140,12 +111,7 @@ impl RxQueues { } } -fn fill_queue( - vq: &mut dyn Virtq, - num_packets: u16, - packet_size: u32, - poll_sender: async_channel::Sender, -) { +fn fill_queue(vq: &mut dyn Virtq, num_packets: u16, packet_size: u32) { for _ in 0..num_packets { let buff_tkn = match AvailBufferToken::new( vec![], @@ -167,12 +133,7 @@ fn fill_queue( // BufferTokens are directly provided to the queue // TransferTokens are directly dispatched // Transfers will be awaited at the queue - match vq.dispatch( - buff_tkn, - Some(poll_sender.clone()), - false, - BufferType::Direct, - ) { + match vq.dispatch(buff_tkn, false, BufferType::Direct) { Ok(_) => (), Err(err) => { error!("{:#?}", err); @@ -220,7 +181,9 @@ impl TxQueues { fn poll(&mut self) { for vq in &mut self.vqs { - vq.poll(); + // We don't do anything with the buffers but we need to receive them for the + // ring slots to be emptied and the memory from the previous transfers to be freed. + while vq.try_recv().is_ok() {} } } @@ -339,7 +302,7 @@ impl NetworkDriver for VirtioNetDriver { .unwrap(); self.send_vqs.vqs[0] - .dispatch(buff_tkn, None, false, BufferType::Direct) + .dispatch(buff_tkn, false, BufferType::Direct) .unwrap(); result @@ -373,7 +336,6 @@ impl NetworkDriver for VirtioNetDriver { self.recv_vqs.vqs[0].as_mut(), num_buffers, self.recv_vqs.packet_size, - self.recv_vqs.poll_sender.clone(), ); let vec_data = packets.into_iter().flatten().collect(); diff --git a/src/drivers/virtio/virtqueue/mod.rs b/src/drivers/virtio/virtqueue/mod.rs index 2ae238091d..5463839dcb 100644 --- a/src/drivers/virtio/virtqueue/mod.rs +++ b/src/drivers/virtio/virtqueue/mod.rs @@ -19,7 +19,6 @@ use core::any::Any; use core::mem::MaybeUninit; use core::{mem, ptr}; -use async_channel::TryRecvError; use virtio::{le32, le64, pvirtq, virtq}; use self::error::VirtqError; @@ -88,8 +87,6 @@ impl From for u16 { } } -type UsedBufferTokenSender = async_channel::Sender; - // Public interface of Virtq /// The Virtq trait unifies access to the two different Virtqueue types @@ -106,7 +103,6 @@ pub trait Virtq { fn dispatch( &mut self, tkn: AvailBufferToken, - sender: Option, notif: bool, buffer_type: BufferType, ) -> Result<(), VirtqError>; @@ -125,21 +121,20 @@ pub trait Virtq { tkn: AvailBufferToken, buffer_type: BufferType, ) -> Result { - let (sender, receiver) = async_channel::bounded(1); - self.dispatch(tkn, Some(sender), false, buffer_type)?; + self.dispatch(tkn, false, buffer_type)?; self.disable_notifs(); let result: UsedBufferToken; // Keep Spinning until the receive queue is filled loop { - match receiver.try_recv() { - Ok(buffer_tkn) => { - result = buffer_tkn; - break; - } - Err(TryRecvError::Closed) => return Err(VirtqError::General), - Err(TryRecvError::Empty) => self.poll(), + // TODO: normally, we should check if the used buffer in question is the one + // we just made available. However, this shouldn't be a problem as the queue this + // function is called on makes use of this blocking dispatch function exclusively + // and thus dispatches cannot be interleaved. + if let Ok(buffer_tkn) = self.try_recv() { + result = buffer_tkn; + break; } } @@ -156,10 +151,7 @@ pub trait Virtq { /// Checks if new used descriptors have been written by the device. /// This activates the queue and polls the descriptor ring of the queue. - /// - /// * `TransferTokens` which hold an `await_queue` will be placed into - /// these queues. - fn poll(&mut self); + fn try_recv(&mut self) -> Result; /// Dispatches a batch of [AvailBufferToken]s. The buffers are provided to the queue in /// sequence. After the last buffer has been written, the queue marks the first buffer as available and triggers @@ -189,7 +181,6 @@ pub trait Virtq { fn dispatch_batch_await( &mut self, tkns: Vec<(AvailBufferToken, BufferType)>, - await_queue: UsedBufferTokenSender, notif: bool, ) -> Result<(), VirtqError>; @@ -242,7 +233,6 @@ trait VirtqPrivate { /// After this call, the buffers are no longer writable. fn transfer_token_from_buffer_token( buff_tkn: AvailBufferToken, - await_queue: Option, buffer_type: BufferType, ) -> TransferToken { let ctrl_desc = match buffer_type { @@ -252,7 +242,6 @@ trait VirtqPrivate { TransferToken { buff_tkn, - await_queue, ctrl_desc, } } @@ -334,11 +323,6 @@ pub struct TransferToken { /// Must be some in order to prevent drop /// upon reuse. buff_tkn: AvailBufferToken, - /// Structure which allows to await Transfers - /// If Some, finished TransferTokens will be placed here - /// as finished `Transfers`. If None, only the state - /// of the Token will be changed. - await_queue: Option, // Contains the [MemDescr] for the indirect table if the transfer is indirect. ctrl_desc: Option>, } @@ -616,6 +600,7 @@ pub mod error { FeatureNotSupported(virtio::F), AllocationError, IncompleteWrite, + NoNewUsed, } impl core::fmt::Debug for VirtqError { @@ -645,6 +630,9 @@ pub mod error { VirtqError::IncompleteWrite => { write!(f, "A sized object was partially initialized.") } + VirtqError::NoNewUsed => { + write!(f, "The queue does not contain any new used buffers.") + } } } } diff --git a/src/drivers/virtio/virtqueue/packed.rs b/src/drivers/virtio/virtqueue/packed.rs index b1ea248eee..65b23fe6de 100644 --- a/src/drivers/virtio/virtqueue/packed.rs +++ b/src/drivers/virtio/virtqueue/packed.rs @@ -23,8 +23,8 @@ use super::super::transport::mmio::{ComCfg, NotifCfg, NotifCtrl}; use super::super::transport::pci::{ComCfg, NotifCfg, NotifCtrl}; use super::error::VirtqError; use super::{ - AvailBufferToken, BufferType, MemDescrId, MemPool, TransferToken, UsedBufferToken, - UsedBufferTokenSender, Virtq, VirtqPrivate, VqIndex, VqSize, + AvailBufferToken, BufferType, MemDescrId, MemPool, TransferToken, UsedBufferToken, Virtq, + VirtqPrivate, VqIndex, VqSize, }; use crate::arch::mm::paging::{BasePageSize, PageSize}; use crate::arch::mm::{paging, VirtAddr}; @@ -128,21 +128,14 @@ impl DescriptorRing { } /// Polls poll index and sets the state of any finished TransferTokens. - /// If [TransferToken::await_queue] is available, the [UsedBufferToken] will be moved to the queue. - fn poll(&mut self) { + fn try_recv(&mut self) -> Result { let mut ctrl = self.get_read_ctrler(); - if let Some((mut tkn, written_len)) = ctrl.poll_next() { - if let Some(queue) = tkn.await_queue.take() { - // Place the TransferToken in a Transfer, which will hold ownership of the token - queue - .try_send(UsedBufferToken::from_avail_buffer_token( - tkn.buff_tkn, - written_len, - )) - .unwrap(); - } - } + ctrl.poll_next() + .map(|(tkn, written_len)| { + UsedBufferToken::from_avail_buffer_token(tkn.buff_tkn, written_len) + }) + .ok_or(VirtqError::NoNewUsed) } fn push_batch( @@ -539,8 +532,8 @@ impl Virtq for PackedVq { self.drv_event.disable_notif(); } - fn poll(&mut self) { - self.descr_ring.poll(); + fn try_recv(&mut self) -> Result { + self.descr_ring.try_recv() } fn dispatch_batch( @@ -552,7 +545,7 @@ impl Virtq for PackedVq { assert!(!buffer_tkns.is_empty()); let transfer_tkns = buffer_tkns.into_iter().map(|(buffer_tkn, buffer_type)| { - Self::transfer_token_from_buffer_token(buffer_tkn, None, buffer_type) + Self::transfer_token_from_buffer_token(buffer_tkn, buffer_type) }); let next_idx = self.descr_ring.push_batch(transfer_tkns)?; @@ -581,18 +574,13 @@ impl Virtq for PackedVq { fn dispatch_batch_await( &mut self, buffer_tkns: Vec<(AvailBufferToken, BufferType)>, - await_queue: super::UsedBufferTokenSender, notif: bool, ) -> Result<(), VirtqError> { // Zero transfers are not allowed assert!(!buffer_tkns.is_empty()); let transfer_tkns = buffer_tkns.into_iter().map(|(buffer_tkn, buffer_type)| { - Self::transfer_token_from_buffer_token( - buffer_tkn, - Some(await_queue.clone()), - buffer_type, - ) + Self::transfer_token_from_buffer_token(buffer_tkn, buffer_type) }); let next_idx = self.descr_ring.push_batch(transfer_tkns)?; @@ -621,11 +609,10 @@ impl Virtq for PackedVq { fn dispatch( &mut self, buffer_tkn: AvailBufferToken, - sender: Option, notif: bool, buffer_type: BufferType, ) -> Result<(), VirtqError> { - let transfer_tkn = Self::transfer_token_from_buffer_token(buffer_tkn, sender, buffer_type); + let transfer_tkn = Self::transfer_token_from_buffer_token(buffer_tkn, buffer_type); let next_idx = self.descr_ring.push(transfer_tkn)?; if notif { diff --git a/src/drivers/virtio/virtqueue/split.rs b/src/drivers/virtio/virtqueue/split.rs index fb1d9c5908..95f4a31156 100644 --- a/src/drivers/virtio/virtqueue/split.rs +++ b/src/drivers/virtio/virtqueue/split.rs @@ -19,8 +19,8 @@ use super::super::transport::mmio::{ComCfg, NotifCfg, NotifCtrl}; use super::super::transport::pci::{ComCfg, NotifCfg, NotifCtrl}; use super::error::VirtqError; use super::{ - AvailBufferToken, BufferType, MemPool, TransferToken, UsedBufferToken, UsedBufferTokenSender, - Virtq, VirtqPrivate, VqIndex, VqSize, + AvailBufferToken, BufferType, MemPool, TransferToken, UsedBufferToken, Virtq, VirtqPrivate, + VqIndex, VqSize, }; use crate::arch::memory_barrier; use crate::arch::mm::{paging, VirtAddr}; @@ -98,51 +98,38 @@ impl DescrRing { Ok(next_idx) } - fn poll(&mut self) { - // We cannot use a simple while loop here because Rust cannot tell that [Self::used_ring_ref], - // [Self::read_idx] and [Self::token_ring] access separate fields of `self`. For this reason we - // need to move [Self::used_ring_ref] lines into a separate scope. - loop { - let used_elem; - { - if self.read_idx == self.used_ring().idx.to_ne() { - break; - } else { - let cur_ring_index = self.read_idx as usize % self.token_ring.len(); - used_elem = self.used_ring().ring()[cur_ring_index]; - } - } + fn try_recv(&mut self) -> Result { + if self.read_idx == self.used_ring().idx.to_ne() { + return Err(VirtqError::NoNewUsed); + } + let cur_ring_index = self.read_idx as usize % self.token_ring.len(); + let used_elem = self.used_ring().ring()[cur_ring_index]; - let mut tkn = self.token_ring[used_elem.id.to_ne() as usize] - .take() - .expect( - "The buff_id is incorrect or the reference to the TransferToken was misplaced.", - ); - - if let Some(queue) = tkn.await_queue.take() { - queue - .try_send(UsedBufferToken::from_avail_buffer_token( - tkn.buff_tkn, - used_elem.len.to_ne(), - )) - .unwrap() - } + let tkn = self.token_ring[used_elem.id.to_ne() as usize] + .take() + .expect( + "The buff_id is incorrect or the reference to the TransferToken was misplaced.", + ); - let mut id_ret_idx = u16::try_from(used_elem.id.to_ne()).unwrap(); - loop { - self.mem_pool.ret_id(super::MemDescrId(id_ret_idx)); - let cur_chain_elem = - unsafe { self.descr_table_mut()[usize::from(id_ret_idx)].assume_init() }; - if cur_chain_elem.flags.contains(virtq::DescF::NEXT) { - id_ret_idx = cur_chain_elem.next.to_ne(); - } else { - break; - } + // We return the indices of the now freed ring slots back to `mem_pool.` + let mut id_ret_idx = u16::try_from(used_elem.id.to_ne()).unwrap(); + loop { + self.mem_pool.ret_id(super::MemDescrId(id_ret_idx)); + let cur_chain_elem = + unsafe { self.descr_table_mut()[usize::from(id_ret_idx)].assume_init() }; + if cur_chain_elem.flags.contains(virtq::DescF::NEXT) { + id_ret_idx = cur_chain_elem.next.to_ne(); + } else { + break; } - - memory_barrier(); - self.read_idx = self.read_idx.wrapping_add(1); } + + memory_barrier(); + self.read_idx = self.read_idx.wrapping_add(1); + Ok(UsedBufferToken::from_avail_buffer_token( + tkn.buff_tkn, + used_elem.len.to_ne(), + )) } fn drv_enable_notif(&mut self) { @@ -180,8 +167,8 @@ impl Virtq for SplitVq { self.ring.drv_disable_notif(); } - fn poll(&mut self) { - self.ring.poll() + fn try_recv(&mut self) -> Result { + self.ring.try_recv() } fn dispatch_batch( @@ -195,7 +182,6 @@ impl Virtq for SplitVq { fn dispatch_batch_await( &mut self, _tkns: Vec<(AvailBufferToken, BufferType)>, - _await_queue: super::UsedBufferTokenSender, _notif: bool, ) -> Result<(), VirtqError> { unimplemented!() @@ -204,11 +190,10 @@ impl Virtq for SplitVq { fn dispatch( &mut self, buffer_tkn: AvailBufferToken, - sender: Option, notif: bool, buffer_type: BufferType, ) -> Result<(), VirtqError> { - let transfer_tkn = Self::transfer_token_from_buffer_token(buffer_tkn, sender, buffer_type); + let transfer_tkn = Self::transfer_token_from_buffer_token(buffer_tkn, buffer_type); let next_idx = self.ring.push(transfer_tkn)?; if notif { diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs index 7ab72e3008..5feb91a485 100644 --- a/src/drivers/vsock/mod.rs +++ b/src/drivers/vsock/mod.rs @@ -24,12 +24,7 @@ use crate::drivers::virtio::virtqueue::{ use crate::drivers::vsock::pci::VsockDevCfgRaw; use crate::mm::device_alloc::DeviceAlloc; -fn fill_queue( - vq: &mut dyn Virtq, - num_packets: u16, - packet_size: u32, - poll_sender: async_channel::Sender, -) { +fn fill_queue(vq: &mut dyn Virtq, num_packets: u16, packet_size: u32) { for _ in 0..num_packets { let buff_tkn = match AvailBufferToken::new( vec![], @@ -51,12 +46,7 @@ fn fill_queue( // BufferTokens are directly provided to the queue // TransferTokens are directly dispatched // Transfers will be awaited at the queue - match vq.dispatch( - buff_tkn, - Some(poll_sender.clone()), - false, - BufferType::Direct, - ) { + match vq.dispatch(buff_tkn, false, BufferType::Direct) { Ok(_) => (), Err(err) => { error!("{:#?}", err); @@ -68,19 +58,14 @@ fn fill_queue( pub(crate) struct RxQueue { vq: Option>, - poll_sender: async_channel::Sender, - poll_receiver: async_channel::Receiver, packet_size: u32, } impl RxQueue { pub fn new() -> Self { - let (poll_sender, poll_receiver) = async_channel::unbounded(); - Self { vq: None, - poll_sender, - poll_receiver, + packet_size: crate::VSOCK_PACKET_SIZE, } } @@ -89,12 +74,7 @@ impl RxQueue { const BUFF_PER_PACKET: u16 = 2; let num_packets: u16 = u16::from(vq.size()) / BUFF_PER_PACKET; info!("num_packets {}", num_packets); - fill_queue( - vq.as_mut(), - num_packets, - self.packet_size, - self.poll_sender.clone(), - ); + fill_queue(vq.as_mut(), num_packets, self.packet_size); self.vq = Some(vq); } @@ -112,22 +92,7 @@ impl RxQueue { } fn get_next(&mut self) -> Option { - let transfer = self.poll_receiver.try_recv(); - - transfer - .or_else(|_| { - // Check if any not yet provided transfers are in the queue. - self.poll(); - - self.poll_receiver.try_recv() - }) - .ok() - } - - fn poll(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.poll(); - } + self.vq.as_mut().unwrap().try_recv().ok() } pub fn process_packet(&mut self, mut f: F) @@ -144,7 +109,7 @@ impl RxQueue { if let Some(ref mut vq) = self.vq { f(&header, &packet[..]); - fill_queue(vq.as_mut(), 1, self.packet_size, self.poll_sender.clone()); + fill_queue(vq.as_mut(), 1, self.packet_size); } else { panic!("Invalid length of receive queue"); } @@ -185,7 +150,7 @@ impl TxQueue { fn poll(&mut self) { if let Some(ref mut vq) = self.vq { - vq.poll(); + while vq.try_recv().is_ok() {} } } @@ -198,9 +163,8 @@ impl TxQueue { { // We need to poll to get the queue to remove elements from the table and make space for // what we are about to add + self.poll(); if let Some(ref mut vq) = self.vq { - vq.poll(); - assert!(len < usize::try_from(self.packet_length).unwrap()); let mut packet = Vec::with_capacity_in(len, DeviceAlloc); let result = unsafe { @@ -213,8 +177,7 @@ impl TxQueue { let buff_tkn = AvailBufferToken::new(vec![BufferElem::Vector(packet)], vec![]).unwrap(); - vq.dispatch(buff_tkn, None, false, BufferType::Direct) - .unwrap(); + vq.dispatch(buff_tkn, false, BufferType::Direct).unwrap(); result } else { @@ -225,19 +188,13 @@ impl TxQueue { pub(crate) struct EventQueue { vq: Option>, - poll_sender: async_channel::Sender, - poll_receiver: async_channel::Receiver, packet_size: u32, } impl EventQueue { pub fn new() -> Self { - let (poll_sender, poll_receiver) = async_channel::unbounded(); - Self { vq: None, - poll_sender, - poll_receiver, packet_size: 128u32, } } @@ -248,12 +205,7 @@ impl EventQueue { fn add(&mut self, mut vq: Box) { const BUFF_PER_PACKET: u16 = 2; let num_packets: u16 = u16::from(vq.size()) / BUFF_PER_PACKET; - fill_queue( - vq.as_mut(), - num_packets, - self.packet_size, - self.poll_sender.clone(), - ); + fill_queue(vq.as_mut(), num_packets, self.packet_size); self.vq = Some(vq); }