From a9d1aa96467109bb9c95ce4a2e89b353093e1807 Mon Sep 17 00:00:00 2001 From: AJ Bagwell Date: Wed, 4 Sep 2024 13:26:19 +0100 Subject: [PATCH 1/8] add method for sending &str request to socket Add helper methods for reading a line from the socket, and for sending a request and waiting for the response. --- src/ap/mod.rs | 12 ++++-------- src/socket_handle.rs | 13 +++++++++++++ src/sta/mod.rs | 26 +++++++++++--------------- src/sta/types.rs | 18 ++++++++---------- 4 files changed, 36 insertions(+), 33 deletions(-) diff --git a/src/ap/mod.rs b/src/ap/mod.rs index 95e3765..f287446 100644 --- a/src/ap/mod.rs +++ b/src/ap/mod.rs @@ -122,8 +122,7 @@ impl WifiAp { match request { Request::Custom(custom, response_channel) => { let _n = socket_handle.socket.send(custom.as_bytes()).await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); + let data_str = socket_handle.recv_line().await?; debug!("Custom request response: {data_str}"); if response_channel.send(Ok(data_str.into())).is_err() { error!("Custom request response channel closed before response sent"); @@ -131,8 +130,7 @@ impl WifiAp { } Request::Status(response_channel) => { let _n = socket_handle.socket.send(b"STATUS").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); + let data_str = socket_handle.recv_line().await?; let status = Status::from_response(data_str)?; if response_channel.send(Ok(status)).is_err() { @@ -141,8 +139,7 @@ impl WifiAp { } Request::Config(response_channel) => { let _n = socket_handle.socket.send(b"GET_CONFIG").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); + let data_str = socket_handle.recv_line().await?; let config = Config::from_response(data_str)?; if response_channel.send(Ok(config)).is_err() { @@ -171,8 +168,7 @@ impl WifiAp { response_channel: oneshot::Sender, ) -> Result { let _n = socket_handle.socket.send(request).await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); + let data_str = socket_handle.recv_line().await?; let response = if data_str == "OK" { Ok(()) } else { diff --git a/src/socket_handle.rs b/src/socket_handle.rs index 69c7835..505120e 100644 --- a/src/socket_handle.rs +++ b/src/socket_handle.rs @@ -92,6 +92,19 @@ impl SocketHandle { self.expect_ok_with_default_timeout().await } + pub async fn request(&mut self, req: &str) -> Result<&str> { + let n = self.socket.send(req.as_bytes()).await?; + if n != req.len() { + return Err(error::Error::DidNotWriteAllBytes(n, req.len())); + } + self.recv_line().await + } + + pub async fn recv_line(&mut self) -> Result<&str> { + let n = self.socket.recv(&mut self.buffer).await?; + Ok(std::str::from_utf8(&self.buffer[..n])?.trim_end()) + } + async fn expect_ok(&mut self) -> Result { match self.socket.recv(&mut self.buffer).await { Ok(n) => { diff --git a/src/sta/mod.rs b/src/sta/mod.rs index 85f0fee..fcde43b 100644 --- a/src/sta/mod.rs +++ b/src/sta/mod.rs @@ -160,10 +160,15 @@ impl WifiStation { Ok(()) } + async fn add_network(socket_handle: &mut SocketHandle) -> Result { + let data_str = socket_handle.request("ADD_NETWORK").await?; + let network_id = usize::from_str(data_str)?; + debug!("wpa_ctrl created network {network_id}"); + Ok(network_id) + } + async fn get_status(socket_handle: &mut SocketHandle) -> Result { - let _n = socket_handle.socket.send(b"STATUS").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); + let data_str = socket_handle.request("STATUS").await?; parse_status(data_str) } @@ -197,11 +202,7 @@ impl WifiStation { } } Request::Networks(response_channel) => { - let _n = socket_handle.socket.send(b"LIST_NETWORKS").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - let network_list = - NetworkResult::vec_from_str(data_str, &mut socket_handle.socket).await?; + let network_list = NetworkResult::request_results(socket_handle).await?; if response_channel.send(Ok(network_list)).is_err() { error!("Scan request response channel closed before response sent"); } @@ -213,14 +214,9 @@ impl WifiStation { } } Request::AddNetwork(response_channel) => { - let _n = socket_handle.socket.send(b"ADD_NETWORK").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - let network_id = usize::from_str(data_str)?; - if response_channel.send(Ok(network_id)).is_err() { + let network_id = Self::add_network(socket_handle).await; + if response_channel.send(network_id).is_err() { error!("Scan request response channel closed before response sent"); - } else { - debug!("wpa_ctrl created network {network_id}"); } } Request::SetNetwork(id, param, response) => { diff --git a/src/sta/types.rs b/src/sta/types.rs index 333cf6a..3bc8776 100644 --- a/src/sta/types.rs +++ b/src/sta/types.rs @@ -1,10 +1,10 @@ +use super::SocketHandle; use super::{config, config::unprintf, error, warn, Result}; use serde::Serialize; use std::collections::HashMap; use std::fmt::Display; use std::str::FromStr; -use tokio::net::UnixDatagram; #[derive(Serialize, Debug, Clone)] /// The result from scanning for networks. @@ -69,21 +69,19 @@ pub struct NetworkResult { } impl NetworkResult { - pub async fn vec_from_str( - response: &str, - socket: &mut UnixDatagram, + pub async fn request_results( + socket_handle: &mut SocketHandle, ) -> Result> { - let mut buffer = [0; 256]; + let response = socket_handle.request("LIST_NETWORKS").await?.to_owned(); let mut results = Vec::new(); let split = response.split('\n').skip(1); for line in split { let mut line_split = line.split_whitespace(); if let Some(network_id) = line_split.next() { - let cmd = format!("GET_NETWORK {network_id} ssid"); - let bytes = cmd.into_bytes(); - socket.send(&bytes).await?; - let n = socket.recv(&mut buffer).await?; - let ssid = std::str::from_utf8(&buffer[..n])?.trim_matches('\"'); + let ssid = socket_handle + .request(&format!("GET_NETWORK {network_id} ssid")) + .await? + .trim_matches('\"'); let ssid = unprintf(ssid).map_err(|e| error::Error::ParsingWifiStatus { e, s: ssid.to_string(), From 77f5a9f63451a7511ac10f65acb0bd0d2af91d68 Mon Sep 17 00:00:00 2001 From: AJ Bagwell Date: Wed, 4 Sep 2024 13:28:23 +0100 Subject: [PATCH 2/8] add method for broadcasting a message It adds a debug log message if there is no listener for the boradcasts. --- src/sta/mod.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/sta/mod.rs b/src/sta/mod.rs index fcde43b..d480c1a 100644 --- a/src/sta/mod.rs +++ b/src/sta/mod.rs @@ -48,13 +48,19 @@ impl WifiStation { for request in deferred_requests { let _ = self.self_sender.send(request).await; } - self.broadcast_sender.send(Broadcast::Ready)?; + self.broadcast(Broadcast::Ready); tokio::select!( resp = unsolicited.run() => resp, resp = self.run_internal(unsolicited_receiver, socket_handle) => resp, ) } + fn broadcast(&self, event: Broadcast) { + if self.broadcast_sender.send(event).is_err() { + debug!("broadcast listener closed") + } + } + async fn run_internal( mut self, mut unsolicited_receiver: EventReceiver, @@ -82,12 +88,11 @@ impl WifiStation { EventOrRequest::Event(event) => match event { Some(unsolicited_msg) => { debug!("Unsolicited event: {unsolicited_msg:?}"); - Self::handle_event( + self.handle_event( &mut socket_handle, unsolicited_msg, &mut scan_requests, &mut select_request, - &mut self.broadcast_sender, ) .await? } @@ -111,11 +116,11 @@ impl WifiStation { } async fn handle_event( + &mut self, socket_handle: &mut SocketHandle, event: Event, scan_requests: &mut Vec>>>>, select_request: &mut Option, - broadcast_sender: &mut broadcast::Sender, ) -> Result { match event { Event::ScanComplete => { @@ -133,28 +138,28 @@ impl WifiStation { } } Event::Connected => { - broadcast_sender.send(Broadcast::Connected)?; + self.broadcast(Broadcast::Connected); if let Some(sender) = select_request.take() { sender.send(Ok(SelectResult::Success)); } } Event::Disconnected => { - broadcast_sender.send(Broadcast::Disconnected)?; + self.broadcast(Broadcast::Disconnected); } Event::NetworkNotFound => { - broadcast_sender.send(Broadcast::NetworkNotFound)?; + self.broadcast(Broadcast::NetworkNotFound); if let Some(sender) = select_request.take() { sender.send(Ok(SelectResult::NotFound)); } } Event::WrongPsk => { - broadcast_sender.send(Broadcast::WrongPsk)?; + self.broadcast(Broadcast::WrongPsk); if let Some(sender) = select_request.take() { sender.send(Ok(SelectResult::WrongPsk)); } } Event::Unknown(msg) => { - broadcast_sender.send(Broadcast::Unknown(msg))?; + self.broadcast(Broadcast::Unknown(msg)); } } Ok(()) From af25f04ef3a67b36ddd1acc002007ea533612e36 Mon Sep 17 00:00:00 2001 From: AJ Bagwell Date: Wed, 4 Sep 2024 13:30:04 +0100 Subject: [PATCH 3/8] Send API errors repsonses back to the requester This is to allow the runner to keep going if there is an issue with any single response from the wpa_supplicant. IO errors indicate that the socket is broken so still go to the worker so it can be reconnected. --- src/socket_handle.rs | 9 +++- src/sta/client.rs | 2 + src/sta/mod.rs | 120 +++++++++++++++++-------------------------- 3 files changed, 55 insertions(+), 76 deletions(-) diff --git a/src/socket_handle.rs b/src/socket_handle.rs index 505120e..bc4991e 100644 --- a/src/socket_handle.rs +++ b/src/socket_handle.rs @@ -1,4 +1,5 @@ use super::*; +use crate::error::Error::UnexpectedWifiApRepsonse; use std::io::ErrorKind; use tokio::net::UnixDatagram; @@ -84,12 +85,16 @@ impl SocketHandle { )) } - pub async fn command(&mut self, cmd: &[u8]) -> Result { + pub async fn command(&mut self, cmd: &[u8]) -> Result { let n = self.socket.send(cmd).await?; if n != cmd.len() { return Err(error::Error::DidNotWriteAllBytes(n, cmd.len())); } - self.expect_ok_with_default_timeout().await + match self.expect_ok_with_default_timeout().await { + Ok(()) => Ok(Ok(())), + Err(e @ UnexpectedWifiApRepsonse(_)) => Ok(Err(e)), + Err(e) => Err(e), + } } pub async fn request(&mut self, req: &str) -> Result<&str> { diff --git a/src/sta/client.rs b/src/sta/client.rs index 75345cc..cd3766f 100644 --- a/src/sta/client.rs +++ b/src/sta/client.rs @@ -47,6 +47,7 @@ pub(crate) enum Request { Status(oneshot::Sender>), Networks(oneshot::Sender>>), Scan(oneshot::Sender>), + ScanResults, AddNetwork(oneshot::Sender>), SetNetwork(usize, SetNetwork, oneshot::Sender), SaveConfig(oneshot::Sender), @@ -75,6 +76,7 @@ impl ShutdownSignal for Request { Request::Scan(response) => { let _ = response.send(Err(error::Error::StartupAborted)); } + Request::ScanResults => {} Request::AddNetwork(response) => { let _ = response.send(Err(error::Error::StartupAborted)); } diff --git a/src/sta/mod.rs b/src/sta/mod.rs index d480c1a..ffa2591 100644 --- a/src/sta/mod.rs +++ b/src/sta/mod.rs @@ -88,13 +88,8 @@ impl WifiStation { EventOrRequest::Event(event) => match event { Some(unsolicited_msg) => { debug!("Unsolicited event: {unsolicited_msg:?}"); - self.handle_event( - &mut socket_handle, - unsolicited_msg, - &mut scan_requests, - &mut select_request, - ) - .await? + self.handle_event(unsolicited_msg, &mut select_request) + .await } None => return Err(error::Error::WifiStationEventChannelClosed), }, @@ -115,27 +110,10 @@ impl WifiStation { } } - async fn handle_event( - &mut self, - socket_handle: &mut SocketHandle, - event: Event, - scan_requests: &mut Vec>>>>, - select_request: &mut Option, - ) -> Result { + async fn handle_event(&mut self, event: Event, select_request: &mut Option) { match event { Event::ScanComplete => { - let _n = socket_handle.socket.send(b"SCAN_RESULTS").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?; - let mut scan_results = ScanResult::vec_from_str(data_str); - scan_results.sort_by(|a, b| a.signal.cmp(&b.signal)); - - let results = Arc::new(scan_results); - while let Some(scan_request) = scan_requests.pop() { - if scan_request.send(Ok(results.clone())).is_err() { - error!("Scan request response channel closed before response sent"); - } - } + let _ = self.self_sender.send(Request::ScanResults).await; } Event::Connected => { self.broadcast(Broadcast::Connected); @@ -162,19 +140,13 @@ impl WifiStation { self.broadcast(Broadcast::Unknown(msg)); } } - Ok(()) - } - - async fn add_network(socket_handle: &mut SocketHandle) -> Result { - let data_str = socket_handle.request("ADD_NETWORK").await?; - let network_id = usize::from_str(data_str)?; - debug!("wpa_ctrl created network {network_id}"); - Ok(network_id) } - async fn get_status(socket_handle: &mut SocketHandle) -> Result { + async fn get_status( + socket_handle: &mut SocketHandle, + ) -> Result> { let data_str = socket_handle.request("STATUS").await?; - parse_status(data_str) + Ok(parse_status(data_str)) } async fn handle_request( @@ -189,11 +161,11 @@ impl WifiStation { Request::Custom(custom, response_channel) => { let _n = socket_handle.socket.send(custom.as_bytes()).await?; let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - debug!("Custom request response: {data_str}"); - if response_channel.send(Ok(data_str.into())).is_err() { - error!("Custom request response channel closed before response sent"); - } + let data_str = std::str::from_utf8(&socket_handle.buffer[..n]) + .map(|s| s.trim_end().to_owned()) + .map_err(Into::into); + debug!("Custom request response: {data_str:?}"); + let _ = response_channel.send(data_str); } Request::SelectTimeout => { if let Some(sender) = select_request.take() { @@ -201,28 +173,40 @@ impl WifiStation { } } Request::Scan(response_channel) => { - scan_requests.push(response_channel); - if let Err(e) = socket_handle.command(b"SCAN").await { - debug!("Error while requesting SCAN: {e}"); - } + match socket_handle.command(b"SCAN").await? { + Ok(_) => { + scan_requests.push(response_channel); + } + Err(e) => { + let _ = response_channel.send(Err(e)); + } + }; } + Request::ScanResults => match socket_handle.request("SCAN_RESULTS").await { + Ok(resp) => { + let scan_results = Arc::new(ScanResult::vec_from_str(resp)); + while let Some(scan_request) = scan_requests.pop() { + let _ = scan_request.send(Ok(scan_results.clone())); + } + } + Err(e) => { + scan_requests.clear(); + return Err(e); + } + }, Request::Networks(response_channel) => { let network_list = NetworkResult::request_results(socket_handle).await?; - if response_channel.send(Ok(network_list)).is_err() { - error!("Scan request response channel closed before response sent"); - } + let _ = response_channel.send(Ok(network_list)); } Request::Status(response_channel) => { - let status = Self::get_status(socket_handle).await; - if response_channel.send(status).is_err() { - error!("Scan request response channel closed before response sent"); - } + let status = Self::get_status(socket_handle).await?; + let _ = response_channel.send(status); } Request::AddNetwork(response_channel) => { - let network_id = Self::add_network(socket_handle).await; - if response_channel.send(network_id).is_err() { - error!("Scan request response channel closed before response sent"); - } + let data_str = socket_handle.request("ADD_NETWORK").await?; + let network_id = usize::from_str(data_str).map_err(Into::into); + debug!("wpa_ctrl created network {network_id:?}"); + let _ = response_channel.send(network_id); } Request::SetNetwork(id, param, response) => { let cmd = format!( @@ -236,24 +220,15 @@ impl WifiStation { ); debug!("wpa_ctrl \"{cmd}\""); let bytes = cmd.into_bytes(); - if let Err(e) = socket_handle.command(&bytes).await { - warn!("Error while setting network parameter: {e}"); - } - let _ = response.send(Ok(())); + let _ = response.send(socket_handle.command(&bytes).await?); } Request::SaveConfig(response) => { - if let Err(e) = socket_handle.command(b"SAVE_CONFIG").await { - warn!("Error while saving config: {e}"); - } debug!("wpa_ctrl config saved"); - let _ = response.send(Ok(())); + let _ = response.send(socket_handle.command(b"SAVE_CONFIG").await?); } Request::ReloadConfig(response) => { - if let Err(e) = socket_handle.command(b"RECONFIGURE").await { - warn!("Error while reloading config: {e}"); - } debug!("wpa_ctrl config reloaded"); - let _ = response.send(Ok(())); + let _ = response.send(socket_handle.command(b"RECONFIGURE").await?); } Request::RemoveNetwork(remove_network, response) => { let str = match remove_network { @@ -262,24 +237,21 @@ impl WifiStation { }; let cmd = format!("REMOVE_NETWORK {str}"); let bytes = cmd.into_bytes(); - if let Err(e) = socket_handle.command(&bytes).await { - warn!("Error while removing network {str}: {e}"); - } debug!("wpa_ctrl removed network {str}"); - let _ = response.send(Ok(())); + let _ = response.send(socket_handle.command(&bytes).await?); } Request::SelectNetwork(id, response_sender) => { let response_sender = match select_request { None => { let cmd = format!("SELECT_NETWORK {id}"); let bytes = cmd.into_bytes(); - if let Err(e) = socket_handle.command(&bytes).await { + if let Err(e) = socket_handle.command(&bytes).await? { warn!("Error while selecting network {id}: {e}"); let _ = response_sender.send(Ok(SelectResult::InvalidNetworkId)); None } else { debug!("wpa_ctrl selected network {id}"); - let status = Self::get_status(socket_handle).await?; + let status = Self::get_status(socket_handle).await?.unwrap_or_default(); if let Some(current_id) = status.get("id") { if current_id == &id.to_string() { let _ = From 6600d5791befe80993b5e59229f902f4d24ec9c7 Mon Sep 17 00:00:00 2001 From: AJ Bagwell Date: Thu, 6 Mar 2025 09:01:53 +0000 Subject: [PATCH 4/8] Handle scan failed. Sometime the scan request will return OK but the WiFi driver will fail to start a scan. This results in a scan failed event. The request then needs to get an error to aviod hanging forever. --- src/sta/event_socket.rs | 3 +++ src/sta/mod.rs | 15 +++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/sta/event_socket.rs b/src/sta/event_socket.rs index 3986e7c..a3234b5 100644 --- a/src/sta/event_socket.rs +++ b/src/sta/event_socket.rs @@ -9,6 +9,7 @@ pub(crate) struct EventSocket { #[derive(Debug)] pub(crate) enum Event { ScanComplete, + ScanFailed(String), Connected, Disconnected, NetworkNotFound, @@ -63,6 +64,8 @@ impl EventSocket { debug!("wpa_ctrl event: {data_str}"); if data_str.ends_with("CTRL-EVENT-SCAN-RESULTS") { self.send_event(Event::ScanComplete).await?; + } else if data_str.contains("CTRL-EVENT-SCAN-FAILED") { + self.send_event(Event::ScanFailed(data_str.into())).await?; } else if data_str.contains("CTRL-EVENT-CONNECTED") { self.send_event(Event::Connected).await?; } else if data_str.contains("CTRL-EVENT-DISCONNECTED") { diff --git a/src/sta/mod.rs b/src/sta/mod.rs index ffa2591..06d71b1 100644 --- a/src/sta/mod.rs +++ b/src/sta/mod.rs @@ -88,7 +88,7 @@ impl WifiStation { EventOrRequest::Event(event) => match event { Some(unsolicited_msg) => { debug!("Unsolicited event: {unsolicited_msg:?}"); - self.handle_event(unsolicited_msg, &mut select_request) + self.handle_event(unsolicited_msg, &mut scan_requests, &mut select_request) .await } None => return Err(error::Error::WifiStationEventChannelClosed), @@ -110,11 +110,22 @@ impl WifiStation { } } - async fn handle_event(&mut self, event: Event, select_request: &mut Option) { + async fn handle_event( + &mut self, + event: Event, + scan_requests: &mut Vec>>>>, + select_request: &mut Option, + ) { match event { Event::ScanComplete => { let _ = self.self_sender.send(Request::ScanResults).await; } + Event::ScanFailed(s) => { + while let Some(scan_request) = scan_requests.pop() { + let _ = + scan_request.send(Err(error::Error::UnexpectedWifiApRepsonse(s.clone()))); + } + } Event::Connected => { self.broadcast(Broadcast::Connected); if let Some(sender) = select_request.take() { From ab91cd3cb0a0f45404a77fa43a16c06814ef51b7 Mon Sep 17 00:00:00 2001 From: AJ Bagwell Date: Thu, 6 Mar 2025 11:32:39 +0000 Subject: [PATCH 5/8] Split Error into two enums This splits the error type into two: SocketError: for anything IO related Error: for anything user/protocol related Socket errors go stop the runners which can then reconnect or what ever other apropriate action is needed Other errors go to the client caller who can try and correct the malformed inputs. I also reorganised and simplified the errors. --- examples/wifi-ap.rs | 9 ++- examples/wifi-sta.rs | 6 +- src/ap/client.rs | 46 +++------------ src/ap/event_socket.rs | 49 ++++++---------- src/ap/mod.rs | 110 +++++++++++++----------------------- src/ap/setup.rs | 12 +++- src/ap/types.rs | 16 ++---- src/error.rs | 107 +++++++++++++++++++++++++---------- src/lib.rs | 5 +- src/socket_handle.rs | 100 +++++++++++++++++++------------- src/sta/client.rs | 122 ++++++++++++++-------------------------- src/sta/event_socket.rs | 60 ++++++++------------ src/sta/mod.rs | 59 +++++++++---------- src/sta/setup.rs | 12 +++- src/sta/types.rs | 58 ++++++++++--------- 15 files changed, 362 insertions(+), 409 deletions(-) diff --git a/examples/wifi-ap.rs b/examples/wifi-ap.rs index 1f7c5c9..2b61cdf 100644 --- a/examples/wifi-ap.rs +++ b/examples/wifi-ap.rs @@ -16,8 +16,11 @@ async fn main() -> Result { info!("[{:?}] {:?}", i, itf.name); } let user_input = read_until_break().await; - let index = user_input.trim().parse::()?; - let mut setup = ap::WifiSetup::new()?; + let index = user_input + .trim() + .parse::() + .expect("user should enter a number"); + let mut setup = ap::WifiSetup::new(); let proposed_path = format!("/var/run/hostapd/{}", network_interfaces[index].name); info!("Connect to \"{proposed_path}\"? Type full new path or just press enter to accept."); @@ -31,7 +34,7 @@ async fn main() -> Result { let broadcast = setup.get_broadcast_receiver(); let requester = setup.get_request_client(); - let runtime = setup.complete(); + let mut runtime = setup.complete(); let (_runtime, _app, _broadcast) = tokio::join!( async move { diff --git a/examples/wifi-sta.rs b/examples/wifi-sta.rs index 9519ed3..4728ac5 100644 --- a/examples/wifi-sta.rs +++ b/examples/wifi-sta.rs @@ -16,8 +16,8 @@ async fn main() -> Result { info!("[{:?}] {:?}", i, itf.name); } let user_input = read_until_break().await; - let index = user_input.trim().parse::()?; - let mut setup = sta::WifiSetup::new()?; + let index = user_input.trim().parse::().expect(""); + let mut setup = sta::WifiSetup::new(); let proposed_path = format!("/var/run/wpa_supplicant/{}", network_interfaces[index].name); info!("Connect to \"{proposed_path}\"? Type full new path or just press enter to accept."); @@ -31,7 +31,7 @@ async fn main() -> Result { let broadcast = setup.get_broadcast_receiver(); let requester = setup.get_request_client(); - let runtime = setup.complete(); + let mut runtime = setup.complete(); let (_runtime, _app, _broadcast) = tokio::join!( async move { diff --git a/src/ap/client.rs b/src/ap/client.rs index 5ed3010..3e2fdbd 100644 --- a/src/ap/client.rs +++ b/src/ap/client.rs @@ -15,26 +15,6 @@ impl ShutdownSignal for Request { fn is_shutdown(&self) -> bool { matches!(self, Request::Shutdown) } - fn inform_of_shutdown(self) { - match self { - Request::Custom(_, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Status(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Config(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Enable(response) | Request::Disable(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::SetValue(_, _, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Shutdown => {} - } - } } #[derive(Clone)] @@ -48,56 +28,46 @@ impl RequestClient { RequestClient { sender } } - async fn send_request(&self, request: Request) -> Result { - self.sender - .send(request) - .await - .map_err(|_| error::Error::WifiApRequestChannelClosed)?; - Ok(()) - } - pub async fn send_custom(&self, custom: String) -> Result { let (response, request) = oneshot::channel(); - self.sender - .send(Request::Custom(custom, response)) - .await - .map_err(|_| error::Error::WifiApRequestChannelClosed)?; + self.sender.send(Request::Custom(custom, response)).await?; request.await? } pub async fn get_status(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Status(response)).await?; + self.sender.send(Request::Status(response)).await?; request.await? } pub async fn get_config(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Config(response)).await?; + self.sender.send(Request::Config(response)).await?; request.await? } pub async fn enable(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Enable(response)).await?; + self.sender.send(Request::Enable(response)).await?; request.await? } pub async fn disable(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Disable(response)).await?; + self.sender.send(Request::Disable(response)).await?; request.await? } pub async fn set_value(&self, key: &str, value: &str) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SetValue(key.into(), value.into(), response)) + self.sender + .send(Request::SetValue(key.into(), value.into(), response)) .await?; request.await? } pub async fn shutdown(&self) -> Result { - self.send_request(Request::Shutdown).await + Ok(self.sender.send(Request::Shutdown).await?) } } diff --git a/src/ap/event_socket.rs b/src/ap/event_socket.rs index 001562a..6ab13f1 100644 --- a/src/ap/event_socket.rs +++ b/src/ap/event_socket.rs @@ -21,7 +21,7 @@ impl EventSocket { socket: P, request_receiver: &mut mpsc::Receiver, attach_options: &[String], - ) -> Result<(EventReceiver, Vec, Self)> + ) -> SocketResult<(EventReceiver, Vec, Self)> where P: AsRef + std::fmt::Debug, { @@ -41,15 +41,15 @@ impl EventSocket { )) } - async fn send_event(&self, event: Event) -> Result { + async fn send_event(&self, event: Event) -> SocketResult { self.sender .send(event) .await - .map_err(|_| error::Error::WifiApEventChannelClosed)?; + .map_err(|_| error::SocketError::EventChannelClosed)?; Ok(()) } - pub(crate) async fn run(mut self) -> Result { + pub(crate) async fn run(mut self) -> SocketResult { let mut command = "ATTACH".to_string(); for o in &self.attach_options { command.push(' '); @@ -69,33 +69,20 @@ impl EventSocket { info!("hostapd event stream registered"); loop { - match self - .socket_handle - .socket - .recv(&mut self.socket_handle.buffer) - .await - { - Ok(n) => { - let data_str = std::str::from_utf8(&self.socket_handle.buffer[..n])?.trim_end(); - if let Some(n) = data_str.find("AP-STA-DISCONNECTED") { - let index = n + "AP-STA-DISCONNECTED".len(); - let mac = &data_str[index..]; - self.send_event(Event::ApStaDisconnected(mac.to_string())) - .await?; - } else if let Some(n) = data_str.find("AP-STA-CONNECTED") { - let index = n + "AP-STA-CONNECTED".len(); - let mac = &data_str[index..]; - self.send_event(Event::ApStaConnected(mac.to_string())) - .await?; - } else { - self.send_event(Event::Unknown(data_str.to_string())) - .await?; - } - } - Err(e) => { - return Err(error::Error::UnsolicitedIoError(e)); - } - } + let bytes = self.socket_handle.recv().await?; + let data_str = String::from_utf8_lossy(bytes); + let event = if let Some(n) = data_str.find("AP-STA-DISCONNECTED") { + let index = n + "AP-STA-DISCONNECTED".len(); + let mac = &data_str[index..].trim(); + Event::ApStaDisconnected(mac.to_string()) + } else if let Some(n) = data_str.find("AP-STA-CONNECTED") { + let index = n + "AP-STA-CONNECTED".len(); + let mac = &data_str[index..].trim(); + Event::ApStaConnected(mac.to_string()) + } else { + Event::Unknown(data_str.to_string()) + }; + self.send_event(event).await?; } } } diff --git a/src/ap/mod.rs b/src/ap/mod.rs index f287446..82a75fe 100644 --- a/src/ap/mod.rs +++ b/src/ap/mod.rs @@ -30,7 +30,7 @@ pub struct WifiAp { } impl WifiAp { - pub async fn run(mut self) -> Result { + pub async fn run(&mut self) -> SocketResult { info!("Starting Wifi AP process"); let (event_receiver, mut deferred_requests, event_socket) = EventSocket::new( &self.socket_path, @@ -48,20 +48,29 @@ impl WifiAp { .await?; deferred_requests.extend(next_deferred_requests); for request in deferred_requests { - let _ = self.self_sender.send(request).await; + self.self_sender + .send(request) + .await + .expect("self_sender should never close as same struct owns both ends"); } - self.broadcast_sender.send(Broadcast::Ready)?; + self.broadcast(Broadcast::Ready); tokio::select!( resp = event_socket.run() => resp, resp = self.run_internal(event_receiver, socket_handle) => resp, ) } + fn broadcast(&self, event: Broadcast) { + if self.broadcast_sender.send(event).is_err() { + debug!("broadcast listener closed") + } + } + async fn run_internal( - mut self, + &mut self, mut event_receiver: EventReceiver, mut socket_handle: SocketHandle<2048>, - ) -> Result { + ) -> SocketResult { enum EventOrRequest { Event(Option), Request(Option), @@ -74,110 +83,67 @@ impl WifiAp { ); match event_or_request { EventOrRequest::Event(event) => match event { - Some(event) => { - Self::handle_event(&mut socket_handle, &self.broadcast_sender, event) - .await? - } - None => return Err(error::Error::WifiApEventChannelClosed), + Some(event) => self.handle_event(&mut socket_handle, event).await, + None => return Err(error::SocketError::EventChannelClosed), }, EventOrRequest::Request(request) => match request { Some(Request::Shutdown) => return Ok(()), Some(request) => Self::handle_request(&mut socket_handle, request).await?, - None => return Err(error::Error::WifiApRequestChannelClosed), + None => return Err(error::SocketError::ClientChannelClosed), }, } } } async fn handle_event( + &self, _socket_handle: &mut SocketHandle, - broadcast_sender: &broadcast::Sender, event_msg: Event, - ) -> Result { + ) { match event_msg { - Event::ApStaConnected(mac) => { - if let Err(e) = broadcast_sender.send(Broadcast::Connected(mac)) { - warn!("error broadcasting: {e}"); - } - } - Event::ApStaDisconnected(mac) => { - if let Err(e) = broadcast_sender.send(Broadcast::Disconnected(mac)) { - warn!("error broadcasting: {e}"); - } - } - Event::Unknown(msg) => { - if let Err(e) = broadcast_sender.send(Broadcast::UnknownEvent(msg)) { - warn!("error broadcasting: {e}"); - } - } + Event::ApStaConnected(mac) => self.broadcast(Broadcast::Connected(mac)), + Event::ApStaDisconnected(mac) => self.broadcast(Broadcast::Disconnected(mac)), + Event::Unknown(msg) => self.broadcast(Broadcast::UnknownEvent(msg)), }; - Ok(()) } async fn handle_request( socket_handle: &mut SocketHandle, request: Request, - ) -> Result { + ) -> SocketResult { debug!("Handling request: {request:?}"); match request { Request::Custom(custom, response_channel) => { - let _n = socket_handle.socket.send(custom.as_bytes()).await?; - let data_str = socket_handle.recv_line().await?; - debug!("Custom request response: {data_str}"); - if response_channel.send(Ok(data_str.into())).is_err() { - error!("Custom request response channel closed before response sent"); - } + let data_str = socket_handle.request(&custom, TryInto::try_into).await?; + debug!("Custom request response: {data_str:?}"); + let _ = response_channel.send(data_str); } Request::Status(response_channel) => { - let _n = socket_handle.socket.send(b"STATUS").await?; - let data_str = socket_handle.recv_line().await?; - let status = Status::from_response(data_str)?; + let status = socket_handle + .request("STATUS", Status::from_response) + .await?; - if response_channel.send(Ok(status)).is_err() { - error!("Status request response channel closed before response sent"); - } + let _ = response_channel.send(status); } Request::Config(response_channel) => { - let _n = socket_handle.socket.send(b"GET_CONFIG").await?; - let data_str = socket_handle.recv_line().await?; - let config = Config::from_response(data_str)?; - - if response_channel.send(Ok(config)).is_err() { - error!("Config request response channel closed before response sent"); - } + let config = socket_handle + .request("GET_CONFIG", Config::from_response) + .await?; + let _ = response_channel.send(config); } Request::Enable(response_channel) => { - Self::ok_fail_request(socket_handle, b"ENABLE", response_channel).await? + let _ = response_channel.send(socket_handle.command(b"ENABLE").await?); } Request::Disable(response_channel) => { - Self::ok_fail_request(socket_handle, b"DISABLE", response_channel).await? + let _ = response_channel.send(socket_handle.command(b"DISABLE").await?); } Request::SetValue(key, value, response_channel) => { let request_string = format!("SET {key} {value}"); - Self::ok_fail_request(socket_handle, request_string.as_bytes(), response_channel) - .await? + let _ = + response_channel.send(socket_handle.command(request_string.as_bytes()).await?); } Request::Shutdown => (), //shutdown is handled at the scope above } Ok(()) } - - async fn ok_fail_request( - socket_handle: &mut SocketHandle, - request: &[u8], - response_channel: oneshot::Sender, - ) -> Result { - let _n = socket_handle.socket.send(request).await?; - let data_str = socket_handle.recv_line().await?; - let response = if data_str == "OK" { - Ok(()) - } else { - Err(error::Error::UnexpectedWifiApRepsonse(data_str.into())) - }; - - if response_channel.send(response).is_err() { - error!("Config request response channel closed before response sent"); - } - Ok(()) - } } diff --git a/src/ap/setup.rs b/src/ap/setup.rs index f3a3790..aa00349 100644 --- a/src/ap/setup.rs +++ b/src/ap/setup.rs @@ -16,14 +16,14 @@ pub struct WifiSetupGeneric { } impl WifiSetupGeneric { - pub fn new() -> Result { + pub fn new() -> Self { // setup the channel for client requests let (self_sender, request_receiver) = mpsc::channel(C); let request_client = RequestClient::new(self_sender.clone()); // setup the channel for broadcasts let (broadcast_sender, broadcast_receiver) = broadcast::channel(B); - Ok(Self { + Self { wifi: WifiAp { socket_path: PATH_DEFAULT_SERVER.into(), attach_options: vec![], @@ -33,7 +33,7 @@ impl WifiSetupGeneric { }, request_client, broadcast_receiver, - }) + } } pub fn set_socket_path>(&mut self, path: S) { @@ -57,3 +57,9 @@ impl WifiSetupGeneric { self.wifi } } + +impl Default for WifiSetupGeneric { + fn default() -> Self { + Self::new() + } +} diff --git a/src/ap/types.rs b/src/ap/types.rs index ceef15f..03fc981 100644 --- a/src/ap/types.rs +++ b/src/ap/types.rs @@ -1,4 +1,4 @@ -use super::{error, Result}; +use super::config::ConfigError; use serde::{Deserialize, Serialize}; /// Status of the WiFi Station @@ -84,11 +84,8 @@ impl Status { /// assert_eq!(status.ssid, vec![r"WiFi-SSID", r#"¯\_(ツ)_/¯"#]); /// assert_eq!(status.num_sta, vec![0, 1]); /// ``` - pub fn from_response(response: &str) -> Result { - crate::config::from_str(response).map_err(|e| error::Error::ParsingWifiStatus { - e, - s: response.into(), - }) + pub fn from_response(response: &str) -> std::result::Result { + crate::config::from_str(response) } } @@ -126,11 +123,8 @@ impl Config { /// assert_eq!(config.wpa, 2); /// assert_eq!(config.ssid, "WiFi-SSID"); /// ``` - pub fn from_response(response: &str) -> Result { - crate::config::from_str(response).map_err(|e| error::Error::ParsingWifiConfig { - e, - s: response.into(), - }) + pub fn from_response(response: &str) -> std::result::Result { + crate::config::from_str(response) } } diff --git a/src/error.rs b/src/error.rs index f9a7470..f5c3ce6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,44 +1,93 @@ use super::*; +use std::convert::Infallible; use thiserror::Error; - +use tokio::sync::mpsc::error::SendError; +use tokio::sync::oneshot::error::RecvError; #[derive(Error, Debug)] -pub enum Error { + +/// Error returned by [access point](crate::ap::WifiAp::run) and [station](crate::sta::WifiStation::run) runners if there is +/// problem with the control socket. e.g. if `wpa_suplicant` is restarted +pub enum SocketError { + /// IO error from control socket #[error("io error: {0}")] Io(#[from] std::io::Error), + /// Client asked runner to shutdown #[error("start-up aborted")] StartupAborted, - #[error("error parsing wifi status {e}: \n{s}")] - ParsingWifiStatus { e: config::ConfigError, s: String }, - #[error("error parsing wifi config {e}: \n{s}")] - ParsingWifiConfig { e: config::ConfigError, s: String }, - #[error("unexpected wifi ap response: {0}")] - UnexpectedWifiApRepsonse(String), + /// Only one of the two control sockets is open, this should not happen + #[error("internal event channel unexpectedly closed")] + EventChannelClosed, + /// `RequestClient` dropped without shutting down runner + #[error("internal client channel unexpectedly closed")] + ClientChannelClosed, + /// Timeout trying to open the control socket even after retrying + #[error("timeout opening socket {0}")] + TimeoutOpeningSocket(String), + /// Permission denied opening control socket + #[error("permission denied opening socket {0}")] + PermissionDeniedOpeningSocket(String), +} + +/// Error returned by [access point](crate::ap::RequestClient) and [station](crate::sta::RequestClient) clients if there is +/// problem with the request e.g. asking to select a network you have not created a config for +#[derive(Error, Debug, Clone)] +pub enum ClientError { + /// Request failed e.g. asking to select a network you have not created a config for + #[error("Supplicant reported request failed")] + Failed, + /// Error parsing the reponse from the socket. This is probably a bug in the [`wifi_ctrl`](crate) code. + #[error("error {error} parsing response: \n{failed_response}")] + ParsingResponse { + #[source] + error: ParseError, + failed_response: String, + }, + /// Timeout waiting for response to request on control socket #[error("timeout waiting for response")] Timeout, - #[error("did not write all bytes {0}/{1}")] + /// Request was too big to fit in a datagram, mostlikely seen on bad custom requests + #[error("Request was too big only sent {0} of {1} bytes")] DidNotWriteAllBytes(usize, usize), + /// The control socket is not connected at the moment, reconnect and try again + #[error("Runner task not runnning")] + RunnerNotRunning, +} + +/// A sub error of [`ClientError`] returned when there is a problem parsing the reponse from +/// the socket. This is probably a bug in the [`wifi_ctrl`](crate) code. +#[derive(Error, Debug, Clone)] +pub enum ParseError { + #[error("Didn't get expected literal \"OK\" response")] + NotOK, + #[error("Too few columns in scan response")] + ScanResult, + #[error("error parsing config: {0}")] + ParseConfig(#[from] config::ConfigError), #[error("error parsing int: {0}")] ParseInt(#[from] std::num::ParseIntError), #[error("utf8 error: {0}")] Utf8Parse(#[from] std::str::Utf8Error), - #[error("recv error: {0}")] - Recv(#[from] oneshot::error::RecvError), - #[error("unsolicited socket io error: {0}")] - UnsolicitedIoError(std::io::Error), - #[error("wifi_ctrl::station internal request channel unexpectedly closed")] - WifiStationRequestChannelClosed, - #[error("wifi_ctrl::station internal event channel unexpectedly closed")] - WifiStationEventChannelClosed, - #[error("wifi_ctrl::ap internal request channel unexpectedly closed")] - WifiApRequestChannelClosed, - #[error("wifi_ctrl::ap internal event channel unexpectedly closed")] - WifiApEventChannelClosed, - #[error("wifi ap broadcast: {0}")] - WifiApBroadcast(#[from] broadcast::error::SendError), - #[error("wifi::sta broadcast: {0}")] - WifiStaBroadcast(#[from] broadcast::error::SendError), - #[error("timeout opening socket {0}")] - TimeoutOpeningSocket(String), - #[error("permission denied opening socket {0}")] - PermissionDeniedOpeningSocket(String), +} + +// Needed to make TryFrom happy when it can't fail +impl From for ParseError { + fn from(_: Infallible) -> Self { + unreachable!() + } +} + +// Happens when the runner half of a request channel gets dropped +// e.g. if it is asked to shut down, or the socket dies +impl From> for ClientError { + fn from(_: SendError) -> Self { + ClientError::RunnerNotRunning + } +} + +// Happens when the runner half of a repsonse channel gets dropped +// e.g. if it is asked to shut down, or the socket dies +impl From for ClientError { + fn from(_: RecvError) -> Self { + ClientError::RunnerNotRunning + } } diff --git a/src/lib.rs b/src/lib.rs index c21e272..0648c4e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,11 +27,12 @@ pub(crate) mod config; pub(crate) mod socket_handle; use socket_handle::SocketHandle; -pub type Result = std::result::Result; +pub type Result = std::result::Result; +pub type SocketResult = std::result::Result; +pub type ParseResult = std::result::Result; use log::{debug, error, info, warn}; pub(crate) trait ShutdownSignal { fn is_shutdown(&self) -> bool; - fn inform_of_shutdown(self); } diff --git a/src/socket_handle.rs b/src/socket_handle.rs index bc4991e..d0c0087 100644 --- a/src/socket_handle.rs +++ b/src/socket_handle.rs @@ -1,5 +1,5 @@ use super::*; -use crate::error::Error::UnexpectedWifiApRepsonse; +use error::{ClientError, ParseError}; use std::io::ErrorKind; use tokio::net::UnixDatagram; @@ -19,7 +19,7 @@ impl SocketHandle { path: P, label: &str, request_channel: &mut mpsc::Receiver, - ) -> Result<(Self, Vec)> + ) -> SocketResult<(Self, Vec)> where P: AsRef + std::fmt::Debug, S: ShutdownSignal, @@ -34,13 +34,13 @@ impl SocketHandle { let socket = tokio::select!( resp = async move { let mut loop_count = 0; - let s: Result = loop { + let s: SocketResult = loop { match socket.connect(&path) { Ok(()) => break Ok(socket), Err(e) => { // if socket is there but permission denied, fail fast if e.kind() == ErrorKind::PermissionDenied { - break Err(error::Error::PermissionDeniedOpeningSocket(socket_debug.to_string())); + break Err(error::SocketError::PermissionDeniedOpeningSocket(socket_debug.to_string())); } if loop_count % 60 == 0 { info!("Failed to connect to {socket_debug}, retrying for {} more minutes", RETRY_MINUTES-(loop_count+1)/60); @@ -54,7 +54,7 @@ impl SocketHandle { } => resp, _ = async move { tokio::time::sleep(tokio::time::Duration::from_secs(60*RETRY_MINUTES)).await; - } => Err(error::Error::TimeoutOpeningSocket(socket_debug.to_string())), + } => Err(error::SocketError::TimeoutOpeningSocket(socket_debug.to_string())), _ = async move { loop { if let Some(request) = request_channel.recv().await { @@ -65,16 +65,9 @@ impl SocketHandle { } } } - } => Err(error::Error::StartupAborted), + } => Err(error::SocketError::StartupAborted), ); - if let Err(error::Error::StartupAborted) = socket { - for request in deferred_requests { - request.inform_of_shutdown(); - } - return Err(error::Error::StartupAborted); - } - Ok(( Self { tmp_dir, @@ -85,55 +78,82 @@ impl SocketHandle { )) } - pub async fn command(&mut self, cmd: &[u8]) -> Result { + pub async fn recv(&mut self) -> SocketResult<&[u8]> { + let n = self.socket.recv(&mut self.buffer).await?; + Ok(&self.buffer[..n]) + } + + pub async fn command(&mut self, cmd: &[u8]) -> SocketResult { let n = self.socket.send(cmd).await?; if n != cmd.len() { - return Err(error::Error::DidNotWriteAllBytes(n, cmd.len())); - } - match self.expect_ok_with_default_timeout().await { - Ok(()) => Ok(Ok(())), - Err(e @ UnexpectedWifiApRepsonse(_)) => Ok(Err(e)), - Err(e) => Err(e), + return Ok(Err(error::ClientError::DidNotWriteAllBytes(n, cmd.len()))); } + self.expect_ok_with_default_timeout().await } - pub async fn request(&mut self, req: &str) -> Result<&str> { + pub(crate) async fn request<'a, T, E, F>( + &'a mut self, + req: &str, + parse: F, + ) -> SocketResult> + where + ParseError: From, + F: FnOnce(&'a str) -> std::result::Result, + { let n = self.socket.send(req.as_bytes()).await?; if n != req.len() { - return Err(error::Error::DidNotWriteAllBytes(n, req.len())); + return Ok(Err(error::ClientError::DidNotWriteAllBytes(n, req.len()))); } - self.recv_line().await + self.parse_resp(parse).await } - pub async fn recv_line(&mut self) -> Result<&str> { - let n = self.socket.recv(&mut self.buffer).await?; - Ok(std::str::from_utf8(&self.buffer[..n])?.trim_end()) - } - - async fn expect_ok(&mut self) -> Result { - match self.socket.recv(&mut self.buffer).await { - Ok(n) => { - let data_str = std::str::from_utf8(&self.buffer[..n])?.trim_end(); - if data_str.trim() == "OK" { - Ok(()) + async fn parse_resp<'a, T, E, F>(&'a mut self, parse: F) -> SocketResult> + where + ParseError: From, + F: FnOnce(&'a str) -> std::result::Result, + { + let bytes = self.recv().await?; + let str = std::str::from_utf8(bytes).map(str::trim_end); + Ok(str + .map_err(Into::::into) + .and_then(|s| parse(s).map_err(Into::::into)) + .map_err(|error| { + if str == Ok("FAIL") { + ClientError::Failed } else { - Err(error::Error::UnexpectedWifiApRepsonse(data_str.into())) + ClientError::ParsingResponse { + error, + failed_response: String::from_utf8_lossy(bytes).to_string(), + } } + })) + } + + async fn expect_ok(&mut self) -> SocketResult { + self.parse_resp(|data| { + // Scan (and only scan) return FAIL-BUSY when already scanning + if data == "OK" || data == "FAIL-BUSY" { + Ok(()) + } else { + Err(error::ParseError::NotOK) } - Err(e) => Err(error::Error::UnsolicitedIoError(e)), - } + }) + .await } - async fn expect_ok_with_default_timeout(&mut self) -> Result { + async fn expect_ok_with_default_timeout(&mut self) -> SocketResult { self.expect_ok_with_timeout(tokio::time::Duration::from_secs(1)) .await } - pub async fn expect_ok_with_timeout(&mut self, timeout: tokio::time::Duration) -> Result { + pub async fn expect_ok_with_timeout( + &mut self, + timeout: tokio::time::Duration, + ) -> SocketResult { tokio::select!( resp = self.expect_ok() => resp, _ = - tokio::time::sleep(timeout) => Err(error::Error::Timeout) + tokio::time::sleep(timeout) => Ok(Err(error::ClientError::Timeout)) ) } } diff --git a/src/sta/client.rs b/src/sta/client.rs index cd3766f..5273bc7 100644 --- a/src/sta/client.rs +++ b/src/sta/client.rs @@ -62,43 +62,6 @@ impl ShutdownSignal for Request { fn is_shutdown(&self) -> bool { matches!(self, Request::Shutdown) } - fn inform_of_shutdown(self) { - match self { - Request::Custom(_, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Status(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Networks(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Scan(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::ScanResults => {} - Request::AddNetwork(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::SetNetwork(_, _, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::SaveConfig(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::ReloadConfig(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::RemoveNetwork(_, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::SelectNetwork(_, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Shutdown => {} - Request::SelectTimeout => {} - } - } } #[derive(Debug)] @@ -120,123 +83,122 @@ impl RequestClient { RequestClient { sender } } - async fn send_request(&self, request: Request) -> Result { - self.sender - .send(request) - .await - .map_err(|_| error::Error::WifiStationRequestChannelClosed)?; - Ok(()) - } - pub async fn send_custom(&self, custom: String) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Custom(custom, response)).await?; + self.sender.send(Request::Custom(custom, response)).await?; request.await? } pub async fn get_scan(&self) -> Result>> { let (response, request) = oneshot::channel(); - self.send_request(Request::Scan(response)).await?; + self.sender.send(Request::Scan(response)).await?; request.await? } pub async fn get_networks(&self) -> Result> { let (response, request) = oneshot::channel(); - self.send_request(Request::Networks(response)).await?; + self.sender.send(Request::Networks(response)).await?; request.await? } pub async fn get_status(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Status(response)).await?; + self.sender.send(Request::Status(response)).await?; request.await? } pub async fn add_network(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::AddNetwork(response)).await?; + self.sender.send(Request::AddNetwork(response)).await?; request.await? } pub async fn set_network_psk(&self, network_id: usize, psk: String) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SetNetwork( - network_id, - SetNetwork::Psk(psk), - response, - )) - .await?; + self.sender + .send(Request::SetNetwork( + network_id, + SetNetwork::Psk(psk), + response, + )) + .await?; request.await? } pub async fn set_network_ssid(&self, network_id: usize, ssid: String) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SetNetwork( - network_id, - SetNetwork::Ssid(ssid), - response, - )) - .await?; + self.sender + .send(Request::SetNetwork( + network_id, + SetNetwork::Ssid(ssid), + response, + )) + .await?; request.await? } pub async fn set_network_bssid(&self, network_id: usize, bssid: String) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SetNetwork( - network_id, - SetNetwork::Bssid(bssid), - response, - )) - .await?; + self.sender + .send(Request::SetNetwork( + network_id, + SetNetwork::Bssid(bssid), + response, + )) + .await?; request.await? } pub async fn set_network_keymgmt(&self, network_id: usize, mgmt: KeyMgmt) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SetNetwork( - network_id, - SetNetwork::KeyMgmt(mgmt), - response, - )) - .await?; + self.sender + .send(Request::SetNetwork( + network_id, + SetNetwork::KeyMgmt(mgmt), + response, + )) + .await?; request.await? } pub async fn save_config(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SaveConfig(response)).await?; + self.sender.send(Request::SaveConfig(response)).await?; request.await? } pub async fn reload_config(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::ReloadConfig(response)).await?; + self.sender.send(Request::ReloadConfig(response)).await?; request.await? } pub async fn remove_network(&self, id: usize) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::RemoveNetwork(RemoveNetwork::Id(id), response)) + self.sender + .send(Request::RemoveNetwork(RemoveNetwork::Id(id), response)) .await?; request.await? } pub async fn remove_all_networks(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::RemoveNetwork(RemoveNetwork::All, response)) + self.sender + .send(Request::RemoveNetwork(RemoveNetwork::All, response)) .await?; request.await? } pub async fn select_network(&self, network_id: usize) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SelectNetwork(network_id, response)) + self.sender + .send(Request::SelectNetwork(network_id, response)) .await?; request.await? } pub async fn shutdown(&self) -> Result { - self.send_request(Request::Shutdown).await?; + self.sender.send(Request::Shutdown).await?; Ok(()) } } diff --git a/src/sta/event_socket.rs b/src/sta/event_socket.rs index a3234b5..fa2a985 100644 --- a/src/sta/event_socket.rs +++ b/src/sta/event_socket.rs @@ -9,7 +9,7 @@ pub(crate) struct EventSocket { #[derive(Debug)] pub(crate) enum Event { ScanComplete, - ScanFailed(String), + ScanFailed, Connected, Disconnected, NetworkNotFound, @@ -23,7 +23,7 @@ impl EventSocket { pub(crate) async fn new

( socket: P, request_receiver: &mut mpsc::Receiver, - ) -> Result<(EventReceiver, Vec, Self)> + ) -> SocketResult<(EventReceiver, Vec, Self)> where P: AsRef + std::fmt::Debug, { @@ -41,49 +41,39 @@ impl EventSocket { )) } - async fn send_event(&self, event: Event) -> Result { + async fn send_event(&self, event: Event) -> SocketResult { self.sender .send(event) .await - .map_err(|_| error::Error::WifiStationEventChannelClosed)?; + .map_err(|_| error::SocketError::EventChannelClosed)?; Ok(()) } - pub(crate) async fn run(mut self) -> Result { + pub(crate) async fn run(mut self) -> SocketResult { info!("wpa_ctrl attempting attach"); self.socket_handle.socket.send(b"ATTACH").await?; loop { - match self - .socket_handle - .socket - .recv(&mut self.socket_handle.buffer) - .await + let bytes = self.socket_handle.recv().await?; + let data_str = String::from_utf8_lossy(bytes); + debug!("wpa_ctrl event: {data_str}"); + let event = if data_str.trim_end().ends_with("CTRL-EVENT-SCAN-RESULTS") { + Event::ScanComplete + } else if data_str.contains("CTRL-EVENT-SCAN-FAILED") { + Event::ScanFailed + } else if data_str.contains("CTRL-EVENT-CONNECTED") { + Event::Connected + } else if data_str.contains("CTRL-EVENT-DISCONNECTED") { + Event::Disconnected + } else if data_str.contains("CTRL-EVENT-NETWORK-NOT-FOUND") { + Event::NetworkNotFound + } else if data_str.contains("CTRL-EVENT-SSID-TEMP-DISABLED") + && data_str.contains("reason=WRONG_KEY") { - Ok(n) => { - let data_str = std::str::from_utf8(&self.socket_handle.buffer[..n])?.trim_end(); - debug!("wpa_ctrl event: {data_str}"); - if data_str.ends_with("CTRL-EVENT-SCAN-RESULTS") { - self.send_event(Event::ScanComplete).await?; - } else if data_str.contains("CTRL-EVENT-SCAN-FAILED") { - self.send_event(Event::ScanFailed(data_str.into())).await?; - } else if data_str.contains("CTRL-EVENT-CONNECTED") { - self.send_event(Event::Connected).await?; - } else if data_str.contains("CTRL-EVENT-DISCONNECTED") { - self.send_event(Event::Disconnected).await?; - } else if data_str.contains("CTRL-EVENT-NETWORK-NOT-FOUND") { - self.send_event(Event::NetworkNotFound).await?; - } else if data_str.contains("CTRL-EVENT-SSID-TEMP-DISABLED") - && data_str.contains("reason=WRONG_KEY") - { - self.send_event(Event::WrongPsk).await?; - } else { - self.send_event(Event::Unknown(data_str.into())).await?; - } - } - Err(e) => { - return Err(error::Error::UnsolicitedIoError(e)); - } - } + Event::WrongPsk + } else { + Event::Unknown(data_str.into()) + }; + self.send_event(event).await?; } } } diff --git a/src/sta/mod.rs b/src/sta/mod.rs index 06d71b1..77e008a 100644 --- a/src/sta/mod.rs +++ b/src/sta/mod.rs @@ -1,3 +1,5 @@ +use core::str; + use super::*; use tokio::time::Duration; @@ -32,7 +34,7 @@ pub struct WifiStation { } impl WifiStation { - pub async fn run(mut self) -> Result { + pub async fn run(&mut self) -> SocketResult { info!("Starting Wifi Station process"); let (socket_handle, mut deferred_requests) = SocketHandle::open( &self.socket_path, @@ -62,10 +64,10 @@ impl WifiStation { } async fn run_internal( - mut self, + &mut self, mut unsolicited_receiver: EventReceiver, mut socket_handle: SocketHandle<10240>, - ) -> Result { + ) -> SocketResult { // We will collect scan requests and batch respond to them when results are ready let mut scan_requests = Vec::new(); let mut select_request = None; @@ -91,7 +93,7 @@ impl WifiStation { self.handle_event(unsolicited_msg, &mut scan_requests, &mut select_request) .await } - None => return Err(error::Error::WifiStationEventChannelClosed), + None => return Err(error::SocketError::EventChannelClosed), }, EventOrRequest::Request(request) => match request { Some(Request::Shutdown) => return Ok(()), @@ -104,7 +106,7 @@ impl WifiStation { ) .await?; } - None => return Err(error::Error::WifiStationRequestChannelClosed), + None => return Err(error::SocketError::ClientChannelClosed), }, } } @@ -120,10 +122,9 @@ impl WifiStation { Event::ScanComplete => { let _ = self.self_sender.send(Request::ScanResults).await; } - Event::ScanFailed(s) => { + Event::ScanFailed => { while let Some(scan_request) = scan_requests.pop() { - let _ = - scan_request.send(Err(error::Error::UnexpectedWifiApRepsonse(s.clone()))); + let _ = scan_request.send(Err(error::ClientError::Failed)); } } Event::Connected => { @@ -155,9 +156,8 @@ impl WifiStation { async fn get_status( socket_handle: &mut SocketHandle, - ) -> Result> { - let data_str = socket_handle.request("STATUS").await?; - Ok(parse_status(data_str)) + ) -> SocketResult> { + socket_handle.request("STATUS", parse_status).await } async fn handle_request( @@ -166,15 +166,11 @@ impl WifiStation { request: Request, scan_requests: &mut Vec>>>>, select_request: &mut Option, - ) -> Result { + ) -> SocketResult { debug!("Handling request: {request:?}"); match request { Request::Custom(custom, response_channel) => { - let _n = socket_handle.socket.send(custom.as_bytes()).await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n]) - .map(|s| s.trim_end().to_owned()) - .map_err(Into::into); + let data_str = socket_handle.request(&custom, TryInto::try_into).await?; debug!("Custom request response: {data_str:?}"); let _ = response_channel.send(data_str); } @@ -193,29 +189,26 @@ impl WifiStation { } }; } - Request::ScanResults => match socket_handle.request("SCAN_RESULTS").await { - Ok(resp) => { - let scan_results = Arc::new(ScanResult::vec_from_str(resp)); - while let Some(scan_request) = scan_requests.pop() { - let _ = scan_request.send(Ok(scan_results.clone())); - } - } - Err(e) => { - scan_requests.clear(); - return Err(e); + Request::ScanResults => { + let scan_results = socket_handle + .request("SCAN_RESULTS", ScanResult::vec_from_str) + .await?; + while let Some(scan_request) = scan_requests.pop() { + let _ = scan_request.send(scan_results.clone()); } - }, + } Request::Networks(response_channel) => { let network_list = NetworkResult::request_results(socket_handle).await?; - let _ = response_channel.send(Ok(network_list)); + let _ = response_channel.send(network_list); } Request::Status(response_channel) => { let status = Self::get_status(socket_handle).await?; let _ = response_channel.send(status); } Request::AddNetwork(response_channel) => { - let data_str = socket_handle.request("ADD_NETWORK").await?; - let network_id = usize::from_str(data_str).map_err(Into::into); + let network_id = socket_handle + .request("ADD_NETWORK", usize::from_str) + .await?; debug!("wpa_ctrl created network {network_id:?}"); let _ = response_channel.send(network_id); } @@ -224,12 +217,12 @@ impl WifiStation { "SET_NETWORK {id} {}", match param { SetNetwork::Ssid(ssid) => format!("ssid {}", conf_escape(&ssid)), - SetNetwork::Bssid(bssid) => format!("bssid \"{bssid}\""), + SetNetwork::Bssid(bssid) => format!("bssid {}", conf_escape(&bssid)), SetNetwork::Psk(psk) => format!("psk {}", conf_escape(&psk)), SetNetwork::KeyMgmt(mgmt) => format!("key_mgmt {}", mgmt), } ); - debug!("wpa_ctrl \"{cmd}\""); + debug!("wpa_ctrl {cmd:?}"); let bytes = cmd.into_bytes(); let _ = response.send(socket_handle.command(&bytes).await?); } diff --git a/src/sta/setup.rs b/src/sta/setup.rs index 2571008..4deccf4 100644 --- a/src/sta/setup.rs +++ b/src/sta/setup.rs @@ -16,14 +16,14 @@ pub struct WifiSetupGeneric { } impl WifiSetupGeneric { - pub fn new() -> Result { + pub fn new() -> Self { // setup the channel for client requests let (self_sender, request_receiver) = mpsc::channel(C); let request_client = RequestClient::new(self_sender.clone()); // setup the channel for broadcasts let (broadcast_sender, broadcast_receiver) = broadcast::channel(B); - Ok(Self { + Self { wifi: WifiStation { socket_path: PATH_DEFAULT_SERVER.into(), request_receiver, @@ -33,7 +33,7 @@ impl WifiSetupGeneric { }, request_client, broadcast_receiver, - }) + } } pub fn set_socket_path>(&mut self, path: S) { @@ -55,3 +55,9 @@ impl WifiSetupGeneric { self.wifi } } + +impl Default for WifiSetupGeneric { + fn default() -> Self { + Self::new() + } +} diff --git a/src/sta/types.rs b/src/sta/types.rs index 3bc8776..ca98ac5 100644 --- a/src/sta/types.rs +++ b/src/sta/types.rs @@ -1,10 +1,12 @@ -use super::SocketHandle; -use super::{config, config::unprintf, error, warn, Result}; +use super::error::ParseError; +use super::{config, config::unprintf, warn, Result, SocketHandle}; +use super::{ParseResult, SocketResult}; use serde::Serialize; use std::collections::HashMap; use std::fmt::Display; use std::str::FromStr; +use std::sync::Arc; #[derive(Serialize, Debug, Clone)] /// The result from scanning for networks. @@ -41,22 +43,19 @@ impl ScanResult { ///let results = ScanResult::vec_from_str(r#"bssid / frequency / signal level / flags / ssid ///00:5f:67:90:da:64 2417 -35 [WPA-PSK-CCMP][WPA2-PSK-CCMP][ESS] TP-Link DA64 ///e0:91:f5:7d:11:c0 2462 -33 [WPA2-PSK-CCMP][WPS][ESS] ¯\\_(\xe3\x83\x84)_/¯ - ///"#); + ///"#).unwrap(); ///assert_eq!(results[0].mac, "00:5f:67:90:da:64"); ///assert_eq!(results[0].name, "TP-Link DA64"); ///assert_eq!(results[1].signal, -33); ///assert_eq!(results[1].name, r#"¯\_(ツ)_/¯"#); ///``` - pub fn vec_from_str(response: &str) -> Vec { + pub fn vec_from_str(response: &str) -> ParseResult>> { let mut results = Vec::new(); for line in response.lines().skip(1) { - if let Some(scan_result) = ScanResult::from_line(line) { - results.push(scan_result); - } else { - warn!("Invalid result from scan: {line}"); - } + results.push(ScanResult::from_line(line).ok_or(ParseError::ScanResult)?); } - results + results.sort_by(|a, b| a.signal.cmp(&b.signal)); + Ok(Arc::new(results)) } } @@ -68,25 +67,35 @@ pub struct NetworkResult { pub flags: String, } +fn parse_get_network(resp: &str) -> ParseResult { + let escaped = resp.trim_matches('\"'); + Ok(unprintf(escaped)?) +} + impl NetworkResult { pub async fn request_results( socket_handle: &mut SocketHandle, - ) -> Result> { - let response = socket_handle.request("LIST_NETWORKS").await?.to_owned(); + ) -> SocketResult>> { + let response: String = match socket_handle + .request("LIST_NETWORKS", TryInto::try_into) + .await? + { + Ok(x) => x, + Err(e) => return Ok(Err(e)), + }; let mut results = Vec::new(); let split = response.split('\n').skip(1); for line in split { let mut line_split = line.split_whitespace(); if let Some(network_id) = line_split.next() { - let ssid = socket_handle - .request(&format!("GET_NETWORK {network_id} ssid")) - .await? - .trim_matches('\"'); - let ssid = unprintf(ssid).map_err(|e| error::Error::ParsingWifiStatus { - e, - s: ssid.to_string(), - })?; if let Ok(network_id) = usize::from_str(network_id) { + let ssid = match socket_handle + .request(&format!("GET_NETWORK {network_id} ssid"), parse_get_network) + .await? + { + Ok(x) => x, + Err(e) => return Ok(Err(e)), + }; if let Some(flags) = line_split.last() { results.push(NetworkResult { flags: flags.into(), @@ -99,18 +108,15 @@ impl NetworkResult { } } } - Ok(results) + Ok(Ok(results)) } } /// A HashMap of what is returned when running `wpa_cli status`. pub type Status = HashMap; -pub(crate) fn parse_status(response: &str) -> Result { - config::from_str(response).map_err(|e| error::Error::ParsingWifiStatus { - e, - s: response.into(), - }) +pub(crate) fn parse_status(response: &str) -> ParseResult { + Ok(config::from_str(response)?) } #[derive(Debug)] From c12e170081eb34818b6f728ed1d9c7fb5e562dee Mon Sep 17 00:00:00 2001 From: AJ Bagwell Date: Thu, 6 Mar 2025 14:16:54 +0000 Subject: [PATCH 6/8] Remove channel for events The channel was complicating error handling and not adding much as both sides of the channel were going into the same select!() call. The channel for broadcast requests remains. --- src/ap/event_socket.rs | 73 +++++++++++++---------------------------- src/ap/mod.rs | 18 ++++------ src/error.rs | 3 -- src/sta/event_socket.rs | 46 +++++++------------------- src/sta/mod.rs | 28 +++++++--------- 5 files changed, 52 insertions(+), 116 deletions(-) diff --git a/src/ap/event_socket.rs b/src/ap/event_socket.rs index 6ab13f1..25fcf65 100644 --- a/src/ap/event_socket.rs +++ b/src/ap/event_socket.rs @@ -2,9 +2,6 @@ use super::*; pub(crate) struct EventSocket { socket_handle: SocketHandle<1024>, - attach_options: Vec, - /// Sends messages to client - sender: mpsc::Sender, } #[derive(Debug)] @@ -14,75 +11,51 @@ pub(crate) enum Event { Unknown(String), } -pub(crate) type EventReceiver = mpsc::Receiver; - impl EventSocket { pub(crate) async fn new

( socket: P, request_receiver: &mut mpsc::Receiver, attach_options: &[String], - ) -> SocketResult<(EventReceiver, Vec, Self)> + ) -> SocketResult<(Vec, Self)> where P: AsRef + std::fmt::Debug, { - let (socket_handle, deferred_requests) = + let (mut socket_handle, deferred_requests) = SocketHandle::open(socket, "hostapd_async.sock", request_receiver).await?; - // setup the channel for client requests - let (sender, receiver) = mpsc::channel(32); - Ok(( - receiver, - deferred_requests, - Self { - socket_handle, - sender, - attach_options: attach_options.to_vec(), - }, - )) - } - - async fn send_event(&self, event: Event) -> SocketResult { - self.sender - .send(event) - .await - .map_err(|_| error::SocketError::EventChannelClosed)?; - Ok(()) - } - - pub(crate) async fn run(mut self) -> SocketResult { let mut command = "ATTACH".to_string(); - for o in &self.attach_options { + for o in attach_options { command.push(' '); command.push_str(o); } - let mut attach = self.socket_handle.command(command.as_bytes()).await; + let mut attach = socket_handle.command(command.as_bytes()).await?; while attach.is_err() { tokio::time::sleep(tokio::time::Duration::from_millis(250)).await; - attach = self.socket_handle.command(command.as_bytes()).await; + attach = socket_handle.command(command.as_bytes()).await?; } - let mut log_level = self.socket_handle.command(b"LOG_LEVEL DEBUG").await; + let mut log_level = socket_handle.command(b"LOG_LEVEL DEBUG").await?; while log_level.is_err() { tokio::time::sleep(tokio::time::Duration::from_millis(250)).await; - log_level = self.socket_handle.command(b"LOG_LEVEL DEBUG").await; + log_level = socket_handle.command(b"LOG_LEVEL DEBUG").await?; } info!("hostapd event stream registered"); + Ok((deferred_requests, Self { socket_handle })) + } - loop { - let bytes = self.socket_handle.recv().await?; - let data_str = String::from_utf8_lossy(bytes); - let event = if let Some(n) = data_str.find("AP-STA-DISCONNECTED") { - let index = n + "AP-STA-DISCONNECTED".len(); - let mac = &data_str[index..].trim(); - Event::ApStaDisconnected(mac.to_string()) - } else if let Some(n) = data_str.find("AP-STA-CONNECTED") { - let index = n + "AP-STA-CONNECTED".len(); - let mac = &data_str[index..].trim(); - Event::ApStaConnected(mac.to_string()) - } else { - Event::Unknown(data_str.to_string()) - }; - self.send_event(event).await?; - } + pub(crate) async fn recv(&mut self) -> SocketResult { + let bytes = self.socket_handle.recv().await?; + let data_str = String::from_utf8_lossy(bytes); + Ok(if let Some(n) = data_str.find("AP-STA-DISCONNECTED") { + let index = n + "AP-STA-DISCONNECTED".len(); + let mac = data_str[index..].trim(); + Event::ApStaDisconnected(mac.to_string()) + } else if let Some(n) = data_str.find("AP-STA-CONNECTED") { + let index = n + "AP-STA-CONNECTED".len(); + let mac = &data_str[index..].trim(); + Event::ApStaConnected(mac.to_string()) + } else { + Event::Unknown(data_str.to_string()) + }) } } diff --git a/src/ap/mod.rs b/src/ap/mod.rs index 82a75fe..d8d452a 100644 --- a/src/ap/mod.rs +++ b/src/ap/mod.rs @@ -32,7 +32,7 @@ pub struct WifiAp { impl WifiAp { pub async fn run(&mut self) -> SocketResult { info!("Starting Wifi AP process"); - let (event_receiver, mut deferred_requests, event_socket) = EventSocket::new( + let (mut deferred_requests, event_socket) = EventSocket::new( &self.socket_path, &mut self.request_receiver, &self.attach_options, @@ -54,10 +54,7 @@ impl WifiAp { .expect("self_sender should never close as same struct owns both ends"); } self.broadcast(Broadcast::Ready); - tokio::select!( - resp = event_socket.run() => resp, - resp = self.run_internal(event_receiver, socket_handle) => resp, - ) + self.run_internal(event_socket, socket_handle).await } fn broadcast(&self, event: Broadcast) { @@ -68,24 +65,21 @@ impl WifiAp { async fn run_internal( &mut self, - mut event_receiver: EventReceiver, + mut event_socket: EventSocket, mut socket_handle: SocketHandle<2048>, ) -> SocketResult { enum EventOrRequest { - Event(Option), + Event(Event), Request(Option), } loop { let event_or_request = tokio::select!( - event = event_receiver.recv() => EventOrRequest::Event(event), + event = event_socket.recv() => EventOrRequest::Event(event?), request = self.request_receiver.recv() => EventOrRequest::Request(request), ); match event_or_request { - EventOrRequest::Event(event) => match event { - Some(event) => self.handle_event(&mut socket_handle, event).await, - None => return Err(error::SocketError::EventChannelClosed), - }, + EventOrRequest::Event(event) => self.handle_event(&mut socket_handle, event).await, EventOrRequest::Request(request) => match request { Some(Request::Shutdown) => return Ok(()), Some(request) => Self::handle_request(&mut socket_handle, request).await?, diff --git a/src/error.rs b/src/error.rs index f5c3ce6..d42af03 100644 --- a/src/error.rs +++ b/src/error.rs @@ -14,9 +14,6 @@ pub enum SocketError { /// Client asked runner to shutdown #[error("start-up aborted")] StartupAborted, - /// Only one of the two control sockets is open, this should not happen - #[error("internal event channel unexpectedly closed")] - EventChannelClosed, /// `RequestClient` dropped without shutting down runner #[error("internal client channel unexpectedly closed")] ClientChannelClosed, diff --git a/src/sta/event_socket.rs b/src/sta/event_socket.rs index fa2a985..98ebe05 100644 --- a/src/sta/event_socket.rs +++ b/src/sta/event_socket.rs @@ -2,8 +2,6 @@ use super::*; pub(crate) struct EventSocket { socket_handle: SocketHandle<1024>, - /// Sends messages to client - sender: mpsc::Sender, } #[derive(Debug)] @@ -17,46 +15,27 @@ pub(crate) enum Event { Unknown(String), } -pub(crate) type EventReceiver = mpsc::Receiver; - impl EventSocket { pub(crate) async fn new

( socket: P, request_receiver: &mut mpsc::Receiver, - ) -> SocketResult<(EventReceiver, Vec, Self)> + ) -> SocketResult<(Vec, Self)> where P: AsRef + std::fmt::Debug, { let (socket_handle, deferred_requests) = SocketHandle::open(socket, "wpa_ctrl_async.sock", request_receiver).await?; - // setup the channel for client requests - let (sender, receiver) = mpsc::channel(32); - Ok(( - receiver, - deferred_requests, - Self { - socket_handle, - sender, - }, - )) - } - - async fn send_event(&self, event: Event) -> SocketResult { - self.sender - .send(event) - .await - .map_err(|_| error::SocketError::EventChannelClosed)?; - Ok(()) + info!("wpa_ctrl attempting attach"); + socket_handle.socket.send(b"ATTACH").await?; + Ok((deferred_requests, Self { socket_handle })) } - pub(crate) async fn run(mut self) -> SocketResult { - info!("wpa_ctrl attempting attach"); - self.socket_handle.socket.send(b"ATTACH").await?; - loop { - let bytes = self.socket_handle.recv().await?; - let data_str = String::from_utf8_lossy(bytes); - debug!("wpa_ctrl event: {data_str}"); - let event = if data_str.trim_end().ends_with("CTRL-EVENT-SCAN-RESULTS") { + pub(crate) async fn recv(&mut self) -> SocketResult { + let bytes = self.socket_handle.recv().await?; + let data_str = String::from_utf8_lossy(bytes); + debug!("wpa_ctrl event: {data_str}"); + Ok( + if data_str.trim_end().ends_with("CTRL-EVENT-SCAN-RESULTS") { Event::ScanComplete } else if data_str.contains("CTRL-EVENT-SCAN-FAILED") { Event::ScanFailed @@ -72,8 +51,7 @@ impl EventSocket { Event::WrongPsk } else { Event::Unknown(data_str.into()) - }; - self.send_event(event).await?; - } + }, + ) } } diff --git a/src/sta/mod.rs b/src/sta/mod.rs index 77e008a..277df3e 100644 --- a/src/sta/mod.rs +++ b/src/sta/mod.rs @@ -44,17 +44,14 @@ impl WifiStation { .await?; // We start up a separate socket for receiving the "unexpected" events that // gets forwarded to us via the unsolicited_receiver - let (unsolicited_receiver, next_deferred_requests, unsolicited) = + let (next_deferred_requests, unsolicited) = EventSocket::new(&self.socket_path, &mut self.request_receiver).await?; deferred_requests.extend(next_deferred_requests); for request in deferred_requests { let _ = self.self_sender.send(request).await; } self.broadcast(Broadcast::Ready); - tokio::select!( - resp = unsolicited.run() => resp, - resp = self.run_internal(unsolicited_receiver, socket_handle) => resp, - ) + self.run_internal(unsolicited, socket_handle).await } fn broadcast(&self, event: Broadcast) { @@ -65,7 +62,7 @@ impl WifiStation { async fn run_internal( &mut self, - mut unsolicited_receiver: EventReceiver, + mut unsolicited: EventSocket, mut socket_handle: SocketHandle<10240>, ) -> SocketResult { // We will collect scan requests and batch respond to them when results are ready @@ -73,13 +70,13 @@ impl WifiStation { let mut select_request = None; loop { enum EventOrRequest { - Event(Option), + Event(Event), Request(Option), } let event_or_request = tokio::select!( - unsolicited_msg = unsolicited_receiver.recv() => { - EventOrRequest::Event(unsolicited_msg) + unsolicited_msg = unsolicited.recv() => { + EventOrRequest::Event(unsolicited_msg?) }, request = self.request_receiver.recv() => { EventOrRequest::Request(request) @@ -87,14 +84,11 @@ impl WifiStation { ); match event_or_request { - EventOrRequest::Event(event) => match event { - Some(unsolicited_msg) => { - debug!("Unsolicited event: {unsolicited_msg:?}"); - self.handle_event(unsolicited_msg, &mut scan_requests, &mut select_request) - .await - } - None => return Err(error::SocketError::EventChannelClosed), - }, + EventOrRequest::Event(unsolicited_msg) => { + debug!("Unsolicited event: {unsolicited_msg:?}"); + self.handle_event(unsolicited_msg, &mut scan_requests, &mut select_request) + .await + } EventOrRequest::Request(request) => match request { Some(Request::Shutdown) => return Ok(()), Some(request) => { From de456910d24c3fd34ce4f405e0f399ffee2f73ae Mon Sep 17 00:00:00 2001 From: AJ Bagwell Date: Mon, 24 Mar 2025 10:55:54 +0000 Subject: [PATCH 7/8] Move some errors out of select result --- src/error.rs | 2 ++ src/sta/client.rs | 6 ------ src/sta/mod.rs | 35 +++++++++++++---------------------- 3 files changed, 15 insertions(+), 28 deletions(-) diff --git a/src/error.rs b/src/error.rs index d42af03..80f1c56 100644 --- a/src/error.rs +++ b/src/error.rs @@ -48,6 +48,8 @@ pub enum ClientError { /// The control socket is not connected at the moment, reconnect and try again #[error("Runner task not runnning")] RunnerNotRunning, + #[error("Select already pending")] + PendingSelect, } /// A sub error of [`ClientError`] returned when there is a problem parsing the reponse from diff --git a/src/sta/client.rs b/src/sta/client.rs index 5273bc7..92baece 100644 --- a/src/sta/client.rs +++ b/src/sta/client.rs @@ -12,9 +12,6 @@ pub enum SelectResult { Success, WrongPsk, NotFound, - PendingSelect, - InvalidNetworkId, - Timeout, AlreadyConnected, } @@ -26,9 +23,6 @@ impl fmt::Display for SelectResult { SelectResult::Success => "success", SelectResult::WrongPsk => "wrong_psk", SelectResult::NotFound => "network_not_found", - SelectResult::PendingSelect => "select_already_pending", - SelectResult::InvalidNetworkId => "invalid_network_id", - SelectResult::Timeout => "select_timeout", SelectResult::AlreadyConnected => "already_connected", }; write!(f, "{s}") diff --git a/src/sta/mod.rs b/src/sta/mod.rs index 277df3e..0f9eeb4 100644 --- a/src/sta/mod.rs +++ b/src/sta/mod.rs @@ -1,5 +1,7 @@ use core::str; +use crate::error::ClientError; + use super::*; use tokio::time::Duration; @@ -170,7 +172,7 @@ impl WifiStation { } Request::SelectTimeout => { if let Some(sender) = select_request.take() { - sender.send(Ok(SelectResult::Timeout)); + sender.send(Err(ClientError::Timeout)); } } Request::Scan(response_channel) => { @@ -239,44 +241,33 @@ impl WifiStation { let _ = response.send(socket_handle.command(&bytes).await?); } Request::SelectNetwork(id, response_sender) => { - let response_sender = match select_request { + match select_request { None => { let cmd = format!("SELECT_NETWORK {id}"); let bytes = cmd.into_bytes(); if let Err(e) = socket_handle.command(&bytes).await? { warn!("Error while selecting network {id}: {e}"); - let _ = response_sender.send(Ok(SelectResult::InvalidNetworkId)); - None + let _ = response_sender.send(Err(e)); } else { debug!("wpa_ctrl selected network {id}"); let status = Self::get_status(socket_handle).await?.unwrap_or_default(); - if let Some(current_id) = status.get("id") { - if current_id == &id.to_string() { - let _ = - response_sender.send(Ok(SelectResult::AlreadyConnected)); - None - } else { - Some(response_sender) - } + if status.get("id") == Some(&id.to_string()) { + let _ = response_sender.send(Ok(SelectResult::AlreadyConnected)); } else { - Some(response_sender) + *select_request = Some(SelectRequest::new( + self.self_sender.clone(), + response_sender, + self.select_timeout, + )); } } } Some(_) => { warn!("Select request already pending! Dropping this one."); - let _ = response_sender.send(Ok(SelectResult::PendingSelect)); + let _ = response_sender.send(Err(ClientError::PendingSelect)); debug!("wpa_ctrl removed network {id}"); - None } }; - if let Some(response_sender) = response_sender { - *select_request = Some(SelectRequest::new( - self.self_sender.clone(), - response_sender, - self.select_timeout, - )); - } } Request::Shutdown => (), //shutdown is handled at the scope above } From feb28d2d88aab41b976f4595445f14c1df56e1fa Mon Sep 17 00:00:00 2001 From: AJ Bagwell Date: Tue, 26 Aug 2025 09:48:52 +0100 Subject: [PATCH 8/8] don't strip tailing tab off scan results Only strip the trailing new line on reponses before parsing The scan results are tab separated and if the ssid is empty it can end with a tab which breaks the parsing if removed. --- src/socket_handle.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/socket_handle.rs b/src/socket_handle.rs index d0c0087..a8ef41b 100644 --- a/src/socket_handle.rs +++ b/src/socket_handle.rs @@ -113,7 +113,7 @@ impl SocketHandle { F: FnOnce(&'a str) -> std::result::Result, { let bytes = self.recv().await?; - let str = std::str::from_utf8(bytes).map(str::trim_end); + let str = std::str::from_utf8(bytes).map(|r| r.trim_end_matches('\n')); Ok(str .map_err(Into::::into) .and_then(|s| parse(s).map_err(Into::::into))