diff --git a/examples/wifi-ap.rs b/examples/wifi-ap.rs index 1f7c5c9..2b61cdf 100644 --- a/examples/wifi-ap.rs +++ b/examples/wifi-ap.rs @@ -16,8 +16,11 @@ async fn main() -> Result { info!("[{:?}] {:?}", i, itf.name); } let user_input = read_until_break().await; - let index = user_input.trim().parse::()?; - let mut setup = ap::WifiSetup::new()?; + let index = user_input + .trim() + .parse::() + .expect("user should enter a number"); + let mut setup = ap::WifiSetup::new(); let proposed_path = format!("/var/run/hostapd/{}", network_interfaces[index].name); info!("Connect to \"{proposed_path}\"? Type full new path or just press enter to accept."); @@ -31,7 +34,7 @@ async fn main() -> Result { let broadcast = setup.get_broadcast_receiver(); let requester = setup.get_request_client(); - let runtime = setup.complete(); + let mut runtime = setup.complete(); let (_runtime, _app, _broadcast) = tokio::join!( async move { diff --git a/examples/wifi-sta.rs b/examples/wifi-sta.rs index 9519ed3..4728ac5 100644 --- a/examples/wifi-sta.rs +++ b/examples/wifi-sta.rs @@ -16,8 +16,8 @@ async fn main() -> Result { info!("[{:?}] {:?}", i, itf.name); } let user_input = read_until_break().await; - let index = user_input.trim().parse::()?; - let mut setup = sta::WifiSetup::new()?; + let index = user_input.trim().parse::().expect(""); + let mut setup = sta::WifiSetup::new(); let proposed_path = format!("/var/run/wpa_supplicant/{}", network_interfaces[index].name); info!("Connect to \"{proposed_path}\"? Type full new path or just press enter to accept."); @@ -31,7 +31,7 @@ async fn main() -> Result { let broadcast = setup.get_broadcast_receiver(); let requester = setup.get_request_client(); - let runtime = setup.complete(); + let mut runtime = setup.complete(); let (_runtime, _app, _broadcast) = tokio::join!( async move { diff --git a/src/ap/client.rs b/src/ap/client.rs index 5ed3010..3e2fdbd 100644 --- a/src/ap/client.rs +++ b/src/ap/client.rs @@ -15,26 +15,6 @@ impl ShutdownSignal for Request { fn is_shutdown(&self) -> bool { matches!(self, Request::Shutdown) } - fn inform_of_shutdown(self) { - match self { - Request::Custom(_, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Status(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Config(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Enable(response) | Request::Disable(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::SetValue(_, _, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Shutdown => {} - } - } } #[derive(Clone)] @@ -48,56 +28,46 @@ impl RequestClient { RequestClient { sender } } - async fn send_request(&self, request: Request) -> Result { - self.sender - .send(request) - .await - .map_err(|_| error::Error::WifiApRequestChannelClosed)?; - Ok(()) - } - pub async fn send_custom(&self, custom: String) -> Result { let (response, request) = oneshot::channel(); - self.sender - .send(Request::Custom(custom, response)) - .await - .map_err(|_| error::Error::WifiApRequestChannelClosed)?; + self.sender.send(Request::Custom(custom, response)).await?; request.await? } pub async fn get_status(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Status(response)).await?; + self.sender.send(Request::Status(response)).await?; request.await? } pub async fn get_config(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Config(response)).await?; + self.sender.send(Request::Config(response)).await?; request.await? } pub async fn enable(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Enable(response)).await?; + self.sender.send(Request::Enable(response)).await?; request.await? } pub async fn disable(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Disable(response)).await?; + self.sender.send(Request::Disable(response)).await?; request.await? } pub async fn set_value(&self, key: &str, value: &str) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SetValue(key.into(), value.into(), response)) + self.sender + .send(Request::SetValue(key.into(), value.into(), response)) .await?; request.await? } pub async fn shutdown(&self) -> Result { - self.send_request(Request::Shutdown).await + Ok(self.sender.send(Request::Shutdown).await?) } } diff --git a/src/ap/event_socket.rs b/src/ap/event_socket.rs index 001562a..25fcf65 100644 --- a/src/ap/event_socket.rs +++ b/src/ap/event_socket.rs @@ -2,9 +2,6 @@ use super::*; pub(crate) struct EventSocket { socket_handle: SocketHandle<1024>, - attach_options: Vec, - /// Sends messages to client - sender: mpsc::Sender, } #[derive(Debug)] @@ -14,88 +11,51 @@ pub(crate) enum Event { Unknown(String), } -pub(crate) type EventReceiver = mpsc::Receiver; - impl EventSocket { pub(crate) async fn new

( socket: P, request_receiver: &mut mpsc::Receiver, attach_options: &[String], - ) -> Result<(EventReceiver, Vec, Self)> + ) -> SocketResult<(Vec, Self)> where P: AsRef + std::fmt::Debug, { - let (socket_handle, deferred_requests) = + let (mut socket_handle, deferred_requests) = SocketHandle::open(socket, "hostapd_async.sock", request_receiver).await?; - // setup the channel for client requests - let (sender, receiver) = mpsc::channel(32); - Ok(( - receiver, - deferred_requests, - Self { - socket_handle, - sender, - attach_options: attach_options.to_vec(), - }, - )) - } - - async fn send_event(&self, event: Event) -> Result { - self.sender - .send(event) - .await - .map_err(|_| error::Error::WifiApEventChannelClosed)?; - Ok(()) - } - - pub(crate) async fn run(mut self) -> Result { let mut command = "ATTACH".to_string(); - for o in &self.attach_options { + for o in attach_options { command.push(' '); command.push_str(o); } - let mut attach = self.socket_handle.command(command.as_bytes()).await; + let mut attach = socket_handle.command(command.as_bytes()).await?; while attach.is_err() { tokio::time::sleep(tokio::time::Duration::from_millis(250)).await; - attach = self.socket_handle.command(command.as_bytes()).await; + attach = socket_handle.command(command.as_bytes()).await?; } - let mut log_level = self.socket_handle.command(b"LOG_LEVEL DEBUG").await; + let mut log_level = socket_handle.command(b"LOG_LEVEL DEBUG").await?; while log_level.is_err() { tokio::time::sleep(tokio::time::Duration::from_millis(250)).await; - log_level = self.socket_handle.command(b"LOG_LEVEL DEBUG").await; + log_level = socket_handle.command(b"LOG_LEVEL DEBUG").await?; } info!("hostapd event stream registered"); + Ok((deferred_requests, Self { socket_handle })) + } - loop { - match self - .socket_handle - .socket - .recv(&mut self.socket_handle.buffer) - .await - { - Ok(n) => { - let data_str = std::str::from_utf8(&self.socket_handle.buffer[..n])?.trim_end(); - if let Some(n) = data_str.find("AP-STA-DISCONNECTED") { - let index = n + "AP-STA-DISCONNECTED".len(); - let mac = &data_str[index..]; - self.send_event(Event::ApStaDisconnected(mac.to_string())) - .await?; - } else if let Some(n) = data_str.find("AP-STA-CONNECTED") { - let index = n + "AP-STA-CONNECTED".len(); - let mac = &data_str[index..]; - self.send_event(Event::ApStaConnected(mac.to_string())) - .await?; - } else { - self.send_event(Event::Unknown(data_str.to_string())) - .await?; - } - } - Err(e) => { - return Err(error::Error::UnsolicitedIoError(e)); - } - } - } + pub(crate) async fn recv(&mut self) -> SocketResult { + let bytes = self.socket_handle.recv().await?; + let data_str = String::from_utf8_lossy(bytes); + Ok(if let Some(n) = data_str.find("AP-STA-DISCONNECTED") { + let index = n + "AP-STA-DISCONNECTED".len(); + let mac = data_str[index..].trim(); + Event::ApStaDisconnected(mac.to_string()) + } else if let Some(n) = data_str.find("AP-STA-CONNECTED") { + let index = n + "AP-STA-CONNECTED".len(); + let mac = &data_str[index..].trim(); + Event::ApStaConnected(mac.to_string()) + } else { + Event::Unknown(data_str.to_string()) + }) } } diff --git a/src/ap/mod.rs b/src/ap/mod.rs index 95e3765..d8d452a 100644 --- a/src/ap/mod.rs +++ b/src/ap/mod.rs @@ -30,9 +30,9 @@ pub struct WifiAp { } impl WifiAp { - pub async fn run(mut self) -> Result { + pub async fn run(&mut self) -> SocketResult { info!("Starting Wifi AP process"); - let (event_receiver, mut deferred_requests, event_socket) = EventSocket::new( + let (mut deferred_requests, event_socket) = EventSocket::new( &self.socket_path, &mut self.request_receiver, &self.attach_options, @@ -48,140 +48,96 @@ impl WifiAp { .await?; deferred_requests.extend(next_deferred_requests); for request in deferred_requests { - let _ = self.self_sender.send(request).await; + self.self_sender + .send(request) + .await + .expect("self_sender should never close as same struct owns both ends"); + } + self.broadcast(Broadcast::Ready); + self.run_internal(event_socket, socket_handle).await + } + + fn broadcast(&self, event: Broadcast) { + if self.broadcast_sender.send(event).is_err() { + debug!("broadcast listener closed") } - self.broadcast_sender.send(Broadcast::Ready)?; - tokio::select!( - resp = event_socket.run() => resp, - resp = self.run_internal(event_receiver, socket_handle) => resp, - ) } async fn run_internal( - mut self, - mut event_receiver: EventReceiver, + &mut self, + mut event_socket: EventSocket, mut socket_handle: SocketHandle<2048>, - ) -> Result { + ) -> SocketResult { enum EventOrRequest { - Event(Option), + Event(Event), Request(Option), } loop { let event_or_request = tokio::select!( - event = event_receiver.recv() => EventOrRequest::Event(event), + event = event_socket.recv() => EventOrRequest::Event(event?), request = self.request_receiver.recv() => EventOrRequest::Request(request), ); match event_or_request { - EventOrRequest::Event(event) => match event { - Some(event) => { - Self::handle_event(&mut socket_handle, &self.broadcast_sender, event) - .await? - } - None => return Err(error::Error::WifiApEventChannelClosed), - }, + EventOrRequest::Event(event) => self.handle_event(&mut socket_handle, event).await, EventOrRequest::Request(request) => match request { Some(Request::Shutdown) => return Ok(()), Some(request) => Self::handle_request(&mut socket_handle, request).await?, - None => return Err(error::Error::WifiApRequestChannelClosed), + None => return Err(error::SocketError::ClientChannelClosed), }, } } } async fn handle_event( + &self, _socket_handle: &mut SocketHandle, - broadcast_sender: &broadcast::Sender, event_msg: Event, - ) -> Result { + ) { match event_msg { - Event::ApStaConnected(mac) => { - if let Err(e) = broadcast_sender.send(Broadcast::Connected(mac)) { - warn!("error broadcasting: {e}"); - } - } - Event::ApStaDisconnected(mac) => { - if let Err(e) = broadcast_sender.send(Broadcast::Disconnected(mac)) { - warn!("error broadcasting: {e}"); - } - } - Event::Unknown(msg) => { - if let Err(e) = broadcast_sender.send(Broadcast::UnknownEvent(msg)) { - warn!("error broadcasting: {e}"); - } - } + Event::ApStaConnected(mac) => self.broadcast(Broadcast::Connected(mac)), + Event::ApStaDisconnected(mac) => self.broadcast(Broadcast::Disconnected(mac)), + Event::Unknown(msg) => self.broadcast(Broadcast::UnknownEvent(msg)), }; - Ok(()) } async fn handle_request( socket_handle: &mut SocketHandle, request: Request, - ) -> Result { + ) -> SocketResult { debug!("Handling request: {request:?}"); match request { Request::Custom(custom, response_channel) => { - let _n = socket_handle.socket.send(custom.as_bytes()).await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - debug!("Custom request response: {data_str}"); - if response_channel.send(Ok(data_str.into())).is_err() { - error!("Custom request response channel closed before response sent"); - } + let data_str = socket_handle.request(&custom, TryInto::try_into).await?; + debug!("Custom request response: {data_str:?}"); + let _ = response_channel.send(data_str); } Request::Status(response_channel) => { - let _n = socket_handle.socket.send(b"STATUS").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - let status = Status::from_response(data_str)?; + let status = socket_handle + .request("STATUS", Status::from_response) + .await?; - if response_channel.send(Ok(status)).is_err() { - error!("Status request response channel closed before response sent"); - } + let _ = response_channel.send(status); } Request::Config(response_channel) => { - let _n = socket_handle.socket.send(b"GET_CONFIG").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - let config = Config::from_response(data_str)?; - - if response_channel.send(Ok(config)).is_err() { - error!("Config request response channel closed before response sent"); - } + let config = socket_handle + .request("GET_CONFIG", Config::from_response) + .await?; + let _ = response_channel.send(config); } Request::Enable(response_channel) => { - Self::ok_fail_request(socket_handle, b"ENABLE", response_channel).await? + let _ = response_channel.send(socket_handle.command(b"ENABLE").await?); } Request::Disable(response_channel) => { - Self::ok_fail_request(socket_handle, b"DISABLE", response_channel).await? + let _ = response_channel.send(socket_handle.command(b"DISABLE").await?); } Request::SetValue(key, value, response_channel) => { let request_string = format!("SET {key} {value}"); - Self::ok_fail_request(socket_handle, request_string.as_bytes(), response_channel) - .await? + let _ = + response_channel.send(socket_handle.command(request_string.as_bytes()).await?); } Request::Shutdown => (), //shutdown is handled at the scope above } Ok(()) } - - async fn ok_fail_request( - socket_handle: &mut SocketHandle, - request: &[u8], - response_channel: oneshot::Sender, - ) -> Result { - let _n = socket_handle.socket.send(request).await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - let response = if data_str == "OK" { - Ok(()) - } else { - Err(error::Error::UnexpectedWifiApRepsonse(data_str.into())) - }; - - if response_channel.send(response).is_err() { - error!("Config request response channel closed before response sent"); - } - Ok(()) - } } diff --git a/src/ap/setup.rs b/src/ap/setup.rs index f3a3790..aa00349 100644 --- a/src/ap/setup.rs +++ b/src/ap/setup.rs @@ -16,14 +16,14 @@ pub struct WifiSetupGeneric { } impl WifiSetupGeneric { - pub fn new() -> Result { + pub fn new() -> Self { // setup the channel for client requests let (self_sender, request_receiver) = mpsc::channel(C); let request_client = RequestClient::new(self_sender.clone()); // setup the channel for broadcasts let (broadcast_sender, broadcast_receiver) = broadcast::channel(B); - Ok(Self { + Self { wifi: WifiAp { socket_path: PATH_DEFAULT_SERVER.into(), attach_options: vec![], @@ -33,7 +33,7 @@ impl WifiSetupGeneric { }, request_client, broadcast_receiver, - }) + } } pub fn set_socket_path>(&mut self, path: S) { @@ -57,3 +57,9 @@ impl WifiSetupGeneric { self.wifi } } + +impl Default for WifiSetupGeneric { + fn default() -> Self { + Self::new() + } +} diff --git a/src/ap/types.rs b/src/ap/types.rs index ceef15f..03fc981 100644 --- a/src/ap/types.rs +++ b/src/ap/types.rs @@ -1,4 +1,4 @@ -use super::{error, Result}; +use super::config::ConfigError; use serde::{Deserialize, Serialize}; /// Status of the WiFi Station @@ -84,11 +84,8 @@ impl Status { /// assert_eq!(status.ssid, vec![r"WiFi-SSID", r#"¯\_(ツ)_/¯"#]); /// assert_eq!(status.num_sta, vec![0, 1]); /// ``` - pub fn from_response(response: &str) -> Result { - crate::config::from_str(response).map_err(|e| error::Error::ParsingWifiStatus { - e, - s: response.into(), - }) + pub fn from_response(response: &str) -> std::result::Result { + crate::config::from_str(response) } } @@ -126,11 +123,8 @@ impl Config { /// assert_eq!(config.wpa, 2); /// assert_eq!(config.ssid, "WiFi-SSID"); /// ``` - pub fn from_response(response: &str) -> Result { - crate::config::from_str(response).map_err(|e| error::Error::ParsingWifiConfig { - e, - s: response.into(), - }) + pub fn from_response(response: &str) -> std::result::Result { + crate::config::from_str(response) } } diff --git a/src/error.rs b/src/error.rs index f9a7470..80f1c56 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,44 +1,92 @@ use super::*; +use std::convert::Infallible; use thiserror::Error; - +use tokio::sync::mpsc::error::SendError; +use tokio::sync::oneshot::error::RecvError; #[derive(Error, Debug)] -pub enum Error { + +/// Error returned by [access point](crate::ap::WifiAp::run) and [station](crate::sta::WifiStation::run) runners if there is +/// problem with the control socket. e.g. if `wpa_suplicant` is restarted +pub enum SocketError { + /// IO error from control socket #[error("io error: {0}")] Io(#[from] std::io::Error), + /// Client asked runner to shutdown #[error("start-up aborted")] StartupAborted, - #[error("error parsing wifi status {e}: \n{s}")] - ParsingWifiStatus { e: config::ConfigError, s: String }, - #[error("error parsing wifi config {e}: \n{s}")] - ParsingWifiConfig { e: config::ConfigError, s: String }, - #[error("unexpected wifi ap response: {0}")] - UnexpectedWifiApRepsonse(String), + /// `RequestClient` dropped without shutting down runner + #[error("internal client channel unexpectedly closed")] + ClientChannelClosed, + /// Timeout trying to open the control socket even after retrying + #[error("timeout opening socket {0}")] + TimeoutOpeningSocket(String), + /// Permission denied opening control socket + #[error("permission denied opening socket {0}")] + PermissionDeniedOpeningSocket(String), +} + +/// Error returned by [access point](crate::ap::RequestClient) and [station](crate::sta::RequestClient) clients if there is +/// problem with the request e.g. asking to select a network you have not created a config for +#[derive(Error, Debug, Clone)] +pub enum ClientError { + /// Request failed e.g. asking to select a network you have not created a config for + #[error("Supplicant reported request failed")] + Failed, + /// Error parsing the reponse from the socket. This is probably a bug in the [`wifi_ctrl`](crate) code. + #[error("error {error} parsing response: \n{failed_response}")] + ParsingResponse { + #[source] + error: ParseError, + failed_response: String, + }, + /// Timeout waiting for response to request on control socket #[error("timeout waiting for response")] Timeout, - #[error("did not write all bytes {0}/{1}")] + /// Request was too big to fit in a datagram, mostlikely seen on bad custom requests + #[error("Request was too big only sent {0} of {1} bytes")] DidNotWriteAllBytes(usize, usize), + /// The control socket is not connected at the moment, reconnect and try again + #[error("Runner task not runnning")] + RunnerNotRunning, + #[error("Select already pending")] + PendingSelect, +} + +/// A sub error of [`ClientError`] returned when there is a problem parsing the reponse from +/// the socket. This is probably a bug in the [`wifi_ctrl`](crate) code. +#[derive(Error, Debug, Clone)] +pub enum ParseError { + #[error("Didn't get expected literal \"OK\" response")] + NotOK, + #[error("Too few columns in scan response")] + ScanResult, + #[error("error parsing config: {0}")] + ParseConfig(#[from] config::ConfigError), #[error("error parsing int: {0}")] ParseInt(#[from] std::num::ParseIntError), #[error("utf8 error: {0}")] Utf8Parse(#[from] std::str::Utf8Error), - #[error("recv error: {0}")] - Recv(#[from] oneshot::error::RecvError), - #[error("unsolicited socket io error: {0}")] - UnsolicitedIoError(std::io::Error), - #[error("wifi_ctrl::station internal request channel unexpectedly closed")] - WifiStationRequestChannelClosed, - #[error("wifi_ctrl::station internal event channel unexpectedly closed")] - WifiStationEventChannelClosed, - #[error("wifi_ctrl::ap internal request channel unexpectedly closed")] - WifiApRequestChannelClosed, - #[error("wifi_ctrl::ap internal event channel unexpectedly closed")] - WifiApEventChannelClosed, - #[error("wifi ap broadcast: {0}")] - WifiApBroadcast(#[from] broadcast::error::SendError), - #[error("wifi::sta broadcast: {0}")] - WifiStaBroadcast(#[from] broadcast::error::SendError), - #[error("timeout opening socket {0}")] - TimeoutOpeningSocket(String), - #[error("permission denied opening socket {0}")] - PermissionDeniedOpeningSocket(String), +} + +// Needed to make TryFrom happy when it can't fail +impl From for ParseError { + fn from(_: Infallible) -> Self { + unreachable!() + } +} + +// Happens when the runner half of a request channel gets dropped +// e.g. if it is asked to shut down, or the socket dies +impl From> for ClientError { + fn from(_: SendError) -> Self { + ClientError::RunnerNotRunning + } +} + +// Happens when the runner half of a repsonse channel gets dropped +// e.g. if it is asked to shut down, or the socket dies +impl From for ClientError { + fn from(_: RecvError) -> Self { + ClientError::RunnerNotRunning + } } diff --git a/src/lib.rs b/src/lib.rs index c21e272..0648c4e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,11 +27,12 @@ pub(crate) mod config; pub(crate) mod socket_handle; use socket_handle::SocketHandle; -pub type Result = std::result::Result; +pub type Result = std::result::Result; +pub type SocketResult = std::result::Result; +pub type ParseResult = std::result::Result; use log::{debug, error, info, warn}; pub(crate) trait ShutdownSignal { fn is_shutdown(&self) -> bool; - fn inform_of_shutdown(self); } diff --git a/src/socket_handle.rs b/src/socket_handle.rs index 69c7835..a8ef41b 100644 --- a/src/socket_handle.rs +++ b/src/socket_handle.rs @@ -1,4 +1,5 @@ use super::*; +use error::{ClientError, ParseError}; use std::io::ErrorKind; use tokio::net::UnixDatagram; @@ -18,7 +19,7 @@ impl SocketHandle { path: P, label: &str, request_channel: &mut mpsc::Receiver, - ) -> Result<(Self, Vec)> + ) -> SocketResult<(Self, Vec)> where P: AsRef + std::fmt::Debug, S: ShutdownSignal, @@ -33,13 +34,13 @@ impl SocketHandle { let socket = tokio::select!( resp = async move { let mut loop_count = 0; - let s: Result = loop { + let s: SocketResult = loop { match socket.connect(&path) { Ok(()) => break Ok(socket), Err(e) => { // if socket is there but permission denied, fail fast if e.kind() == ErrorKind::PermissionDenied { - break Err(error::Error::PermissionDeniedOpeningSocket(socket_debug.to_string())); + break Err(error::SocketError::PermissionDeniedOpeningSocket(socket_debug.to_string())); } if loop_count % 60 == 0 { info!("Failed to connect to {socket_debug}, retrying for {} more minutes", RETRY_MINUTES-(loop_count+1)/60); @@ -53,7 +54,7 @@ impl SocketHandle { } => resp, _ = async move { tokio::time::sleep(tokio::time::Duration::from_secs(60*RETRY_MINUTES)).await; - } => Err(error::Error::TimeoutOpeningSocket(socket_debug.to_string())), + } => Err(error::SocketError::TimeoutOpeningSocket(socket_debug.to_string())), _ = async move { loop { if let Some(request) = request_channel.recv().await { @@ -64,16 +65,9 @@ impl SocketHandle { } } } - } => Err(error::Error::StartupAborted), + } => Err(error::SocketError::StartupAborted), ); - if let Err(error::Error::StartupAborted) = socket { - for request in deferred_requests { - request.inform_of_shutdown(); - } - return Err(error::Error::StartupAborted); - } - Ok(( Self { tmp_dir, @@ -84,38 +78,82 @@ impl SocketHandle { )) } - pub async fn command(&mut self, cmd: &[u8]) -> Result { + pub async fn recv(&mut self) -> SocketResult<&[u8]> { + let n = self.socket.recv(&mut self.buffer).await?; + Ok(&self.buffer[..n]) + } + + pub async fn command(&mut self, cmd: &[u8]) -> SocketResult { let n = self.socket.send(cmd).await?; if n != cmd.len() { - return Err(error::Error::DidNotWriteAllBytes(n, cmd.len())); + return Ok(Err(error::ClientError::DidNotWriteAllBytes(n, cmd.len()))); } self.expect_ok_with_default_timeout().await } - async fn expect_ok(&mut self) -> Result { - match self.socket.recv(&mut self.buffer).await { - Ok(n) => { - let data_str = std::str::from_utf8(&self.buffer[..n])?.trim_end(); - if data_str.trim() == "OK" { - Ok(()) + pub(crate) async fn request<'a, T, E, F>( + &'a mut self, + req: &str, + parse: F, + ) -> SocketResult> + where + ParseError: From, + F: FnOnce(&'a str) -> std::result::Result, + { + let n = self.socket.send(req.as_bytes()).await?; + if n != req.len() { + return Ok(Err(error::ClientError::DidNotWriteAllBytes(n, req.len()))); + } + self.parse_resp(parse).await + } + + async fn parse_resp<'a, T, E, F>(&'a mut self, parse: F) -> SocketResult> + where + ParseError: From, + F: FnOnce(&'a str) -> std::result::Result, + { + let bytes = self.recv().await?; + let str = std::str::from_utf8(bytes).map(|r| r.trim_end_matches('\n')); + Ok(str + .map_err(Into::::into) + .and_then(|s| parse(s).map_err(Into::::into)) + .map_err(|error| { + if str == Ok("FAIL") { + ClientError::Failed } else { - Err(error::Error::UnexpectedWifiApRepsonse(data_str.into())) + ClientError::ParsingResponse { + error, + failed_response: String::from_utf8_lossy(bytes).to_string(), + } } + })) + } + + async fn expect_ok(&mut self) -> SocketResult { + self.parse_resp(|data| { + // Scan (and only scan) return FAIL-BUSY when already scanning + if data == "OK" || data == "FAIL-BUSY" { + Ok(()) + } else { + Err(error::ParseError::NotOK) } - Err(e) => Err(error::Error::UnsolicitedIoError(e)), - } + }) + .await } - async fn expect_ok_with_default_timeout(&mut self) -> Result { + async fn expect_ok_with_default_timeout(&mut self) -> SocketResult { self.expect_ok_with_timeout(tokio::time::Duration::from_secs(1)) .await } - pub async fn expect_ok_with_timeout(&mut self, timeout: tokio::time::Duration) -> Result { + pub async fn expect_ok_with_timeout( + &mut self, + timeout: tokio::time::Duration, + ) -> SocketResult { tokio::select!( resp = self.expect_ok() => resp, _ = - tokio::time::sleep(timeout) => Err(error::Error::Timeout) + tokio::time::sleep(timeout) => Ok(Err(error::ClientError::Timeout)) ) } } diff --git a/src/sta/client.rs b/src/sta/client.rs index 75345cc..92baece 100644 --- a/src/sta/client.rs +++ b/src/sta/client.rs @@ -12,9 +12,6 @@ pub enum SelectResult { Success, WrongPsk, NotFound, - PendingSelect, - InvalidNetworkId, - Timeout, AlreadyConnected, } @@ -26,9 +23,6 @@ impl fmt::Display for SelectResult { SelectResult::Success => "success", SelectResult::WrongPsk => "wrong_psk", SelectResult::NotFound => "network_not_found", - SelectResult::PendingSelect => "select_already_pending", - SelectResult::InvalidNetworkId => "invalid_network_id", - SelectResult::Timeout => "select_timeout", SelectResult::AlreadyConnected => "already_connected", }; write!(f, "{s}") @@ -47,6 +41,7 @@ pub(crate) enum Request { Status(oneshot::Sender>), Networks(oneshot::Sender>>), Scan(oneshot::Sender>), + ScanResults, AddNetwork(oneshot::Sender>), SetNetwork(usize, SetNetwork, oneshot::Sender), SaveConfig(oneshot::Sender), @@ -61,42 +56,6 @@ impl ShutdownSignal for Request { fn is_shutdown(&self) -> bool { matches!(self, Request::Shutdown) } - fn inform_of_shutdown(self) { - match self { - Request::Custom(_, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Status(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Networks(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Scan(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::AddNetwork(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::SetNetwork(_, _, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::SaveConfig(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::ReloadConfig(response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::RemoveNetwork(_, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::SelectNetwork(_, response) => { - let _ = response.send(Err(error::Error::StartupAborted)); - } - Request::Shutdown => {} - Request::SelectTimeout => {} - } - } } #[derive(Debug)] @@ -118,123 +77,122 @@ impl RequestClient { RequestClient { sender } } - async fn send_request(&self, request: Request) -> Result { - self.sender - .send(request) - .await - .map_err(|_| error::Error::WifiStationRequestChannelClosed)?; - Ok(()) - } - pub async fn send_custom(&self, custom: String) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Custom(custom, response)).await?; + self.sender.send(Request::Custom(custom, response)).await?; request.await? } pub async fn get_scan(&self) -> Result>> { let (response, request) = oneshot::channel(); - self.send_request(Request::Scan(response)).await?; + self.sender.send(Request::Scan(response)).await?; request.await? } pub async fn get_networks(&self) -> Result> { let (response, request) = oneshot::channel(); - self.send_request(Request::Networks(response)).await?; + self.sender.send(Request::Networks(response)).await?; request.await? } pub async fn get_status(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::Status(response)).await?; + self.sender.send(Request::Status(response)).await?; request.await? } pub async fn add_network(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::AddNetwork(response)).await?; + self.sender.send(Request::AddNetwork(response)).await?; request.await? } pub async fn set_network_psk(&self, network_id: usize, psk: String) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SetNetwork( - network_id, - SetNetwork::Psk(psk), - response, - )) - .await?; + self.sender + .send(Request::SetNetwork( + network_id, + SetNetwork::Psk(psk), + response, + )) + .await?; request.await? } pub async fn set_network_ssid(&self, network_id: usize, ssid: String) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SetNetwork( - network_id, - SetNetwork::Ssid(ssid), - response, - )) - .await?; + self.sender + .send(Request::SetNetwork( + network_id, + SetNetwork::Ssid(ssid), + response, + )) + .await?; request.await? } pub async fn set_network_bssid(&self, network_id: usize, bssid: String) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SetNetwork( - network_id, - SetNetwork::Bssid(bssid), - response, - )) - .await?; + self.sender + .send(Request::SetNetwork( + network_id, + SetNetwork::Bssid(bssid), + response, + )) + .await?; request.await? } pub async fn set_network_keymgmt(&self, network_id: usize, mgmt: KeyMgmt) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SetNetwork( - network_id, - SetNetwork::KeyMgmt(mgmt), - response, - )) - .await?; + self.sender + .send(Request::SetNetwork( + network_id, + SetNetwork::KeyMgmt(mgmt), + response, + )) + .await?; request.await? } pub async fn save_config(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SaveConfig(response)).await?; + self.sender.send(Request::SaveConfig(response)).await?; request.await? } pub async fn reload_config(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::ReloadConfig(response)).await?; + self.sender.send(Request::ReloadConfig(response)).await?; request.await? } pub async fn remove_network(&self, id: usize) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::RemoveNetwork(RemoveNetwork::Id(id), response)) + self.sender + .send(Request::RemoveNetwork(RemoveNetwork::Id(id), response)) .await?; request.await? } pub async fn remove_all_networks(&self) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::RemoveNetwork(RemoveNetwork::All, response)) + self.sender + .send(Request::RemoveNetwork(RemoveNetwork::All, response)) .await?; request.await? } pub async fn select_network(&self, network_id: usize) -> Result { let (response, request) = oneshot::channel(); - self.send_request(Request::SelectNetwork(network_id, response)) + self.sender + .send(Request::SelectNetwork(network_id, response)) .await?; request.await? } pub async fn shutdown(&self) -> Result { - self.send_request(Request::Shutdown).await?; + self.sender.send(Request::Shutdown).await?; Ok(()) } } diff --git a/src/sta/event_socket.rs b/src/sta/event_socket.rs index 3986e7c..98ebe05 100644 --- a/src/sta/event_socket.rs +++ b/src/sta/event_socket.rs @@ -2,13 +2,12 @@ use super::*; pub(crate) struct EventSocket { socket_handle: SocketHandle<1024>, - /// Sends messages to client - sender: mpsc::Sender, } #[derive(Debug)] pub(crate) enum Event { ScanComplete, + ScanFailed, Connected, Disconnected, NetworkNotFound, @@ -16,71 +15,43 @@ pub(crate) enum Event { Unknown(String), } -pub(crate) type EventReceiver = mpsc::Receiver; - impl EventSocket { pub(crate) async fn new

( socket: P, request_receiver: &mut mpsc::Receiver, - ) -> Result<(EventReceiver, Vec, Self)> + ) -> SocketResult<(Vec, Self)> where P: AsRef + std::fmt::Debug, { let (socket_handle, deferred_requests) = SocketHandle::open(socket, "wpa_ctrl_async.sock", request_receiver).await?; - // setup the channel for client requests - let (sender, receiver) = mpsc::channel(32); - Ok(( - receiver, - deferred_requests, - Self { - socket_handle, - sender, - }, - )) - } - - async fn send_event(&self, event: Event) -> Result { - self.sender - .send(event) - .await - .map_err(|_| error::Error::WifiStationEventChannelClosed)?; - Ok(()) + info!("wpa_ctrl attempting attach"); + socket_handle.socket.send(b"ATTACH").await?; + Ok((deferred_requests, Self { socket_handle })) } - pub(crate) async fn run(mut self) -> Result { - info!("wpa_ctrl attempting attach"); - self.socket_handle.socket.send(b"ATTACH").await?; - loop { - match self - .socket_handle - .socket - .recv(&mut self.socket_handle.buffer) - .await + pub(crate) async fn recv(&mut self) -> SocketResult { + let bytes = self.socket_handle.recv().await?; + let data_str = String::from_utf8_lossy(bytes); + debug!("wpa_ctrl event: {data_str}"); + Ok( + if data_str.trim_end().ends_with("CTRL-EVENT-SCAN-RESULTS") { + Event::ScanComplete + } else if data_str.contains("CTRL-EVENT-SCAN-FAILED") { + Event::ScanFailed + } else if data_str.contains("CTRL-EVENT-CONNECTED") { + Event::Connected + } else if data_str.contains("CTRL-EVENT-DISCONNECTED") { + Event::Disconnected + } else if data_str.contains("CTRL-EVENT-NETWORK-NOT-FOUND") { + Event::NetworkNotFound + } else if data_str.contains("CTRL-EVENT-SSID-TEMP-DISABLED") + && data_str.contains("reason=WRONG_KEY") { - Ok(n) => { - let data_str = std::str::from_utf8(&self.socket_handle.buffer[..n])?.trim_end(); - debug!("wpa_ctrl event: {data_str}"); - if data_str.ends_with("CTRL-EVENT-SCAN-RESULTS") { - self.send_event(Event::ScanComplete).await?; - } else if data_str.contains("CTRL-EVENT-CONNECTED") { - self.send_event(Event::Connected).await?; - } else if data_str.contains("CTRL-EVENT-DISCONNECTED") { - self.send_event(Event::Disconnected).await?; - } else if data_str.contains("CTRL-EVENT-NETWORK-NOT-FOUND") { - self.send_event(Event::NetworkNotFound).await?; - } else if data_str.contains("CTRL-EVENT-SSID-TEMP-DISABLED") - && data_str.contains("reason=WRONG_KEY") - { - self.send_event(Event::WrongPsk).await?; - } else { - self.send_event(Event::Unknown(data_str.into())).await?; - } - } - Err(e) => { - return Err(error::Error::UnsolicitedIoError(e)); - } - } - } + Event::WrongPsk + } else { + Event::Unknown(data_str.into()) + }, + ) } } diff --git a/src/sta/mod.rs b/src/sta/mod.rs index 85f0fee..0f9eeb4 100644 --- a/src/sta/mod.rs +++ b/src/sta/mod.rs @@ -1,3 +1,7 @@ +use core::str; + +use crate::error::ClientError; + use super::*; use tokio::time::Duration; @@ -32,7 +36,7 @@ pub struct WifiStation { } impl WifiStation { - pub async fn run(mut self) -> Result { + pub async fn run(&mut self) -> SocketResult { info!("Starting Wifi Station process"); let (socket_handle, mut deferred_requests) = SocketHandle::open( &self.socket_path, @@ -42,36 +46,39 @@ impl WifiStation { .await?; // We start up a separate socket for receiving the "unexpected" events that // gets forwarded to us via the unsolicited_receiver - let (unsolicited_receiver, next_deferred_requests, unsolicited) = + let (next_deferred_requests, unsolicited) = EventSocket::new(&self.socket_path, &mut self.request_receiver).await?; deferred_requests.extend(next_deferred_requests); for request in deferred_requests { let _ = self.self_sender.send(request).await; } - self.broadcast_sender.send(Broadcast::Ready)?; - tokio::select!( - resp = unsolicited.run() => resp, - resp = self.run_internal(unsolicited_receiver, socket_handle) => resp, - ) + self.broadcast(Broadcast::Ready); + self.run_internal(unsolicited, socket_handle).await + } + + fn broadcast(&self, event: Broadcast) { + if self.broadcast_sender.send(event).is_err() { + debug!("broadcast listener closed") + } } async fn run_internal( - mut self, - mut unsolicited_receiver: EventReceiver, + &mut self, + mut unsolicited: EventSocket, mut socket_handle: SocketHandle<10240>, - ) -> Result { + ) -> SocketResult { // We will collect scan requests and batch respond to them when results are ready let mut scan_requests = Vec::new(); let mut select_request = None; loop { enum EventOrRequest { - Event(Option), + Event(Event), Request(Option), } let event_or_request = tokio::select!( - unsolicited_msg = unsolicited_receiver.recv() => { - EventOrRequest::Event(unsolicited_msg) + unsolicited_msg = unsolicited.recv() => { + EventOrRequest::Event(unsolicited_msg?) }, request = self.request_receiver.recv() => { EventOrRequest::Request(request) @@ -79,20 +86,11 @@ impl WifiStation { ); match event_or_request { - EventOrRequest::Event(event) => match event { - Some(unsolicited_msg) => { - debug!("Unsolicited event: {unsolicited_msg:?}"); - Self::handle_event( - &mut socket_handle, - unsolicited_msg, - &mut scan_requests, - &mut select_request, - &mut self.broadcast_sender, - ) - .await? - } - None => return Err(error::Error::WifiStationEventChannelClosed), - }, + EventOrRequest::Event(unsolicited_msg) => { + debug!("Unsolicited event: {unsolicited_msg:?}"); + self.handle_event(unsolicited_msg, &mut scan_requests, &mut select_request) + .await + } EventOrRequest::Request(request) => match request { Some(Request::Shutdown) => return Ok(()), Some(request) => { @@ -104,67 +102,58 @@ impl WifiStation { ) .await?; } - None => return Err(error::Error::WifiStationRequestChannelClosed), + None => return Err(error::SocketError::ClientChannelClosed), }, } } } - async fn handle_event( - socket_handle: &mut SocketHandle, + async fn handle_event( + &mut self, event: Event, scan_requests: &mut Vec>>>>, select_request: &mut Option, - broadcast_sender: &mut broadcast::Sender, - ) -> Result { + ) { match event { Event::ScanComplete => { - let _n = socket_handle.socket.send(b"SCAN_RESULTS").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?; - let mut scan_results = ScanResult::vec_from_str(data_str); - scan_results.sort_by(|a, b| a.signal.cmp(&b.signal)); - - let results = Arc::new(scan_results); + let _ = self.self_sender.send(Request::ScanResults).await; + } + Event::ScanFailed => { while let Some(scan_request) = scan_requests.pop() { - if scan_request.send(Ok(results.clone())).is_err() { - error!("Scan request response channel closed before response sent"); - } + let _ = scan_request.send(Err(error::ClientError::Failed)); } } Event::Connected => { - broadcast_sender.send(Broadcast::Connected)?; + self.broadcast(Broadcast::Connected); if let Some(sender) = select_request.take() { sender.send(Ok(SelectResult::Success)); } } Event::Disconnected => { - broadcast_sender.send(Broadcast::Disconnected)?; + self.broadcast(Broadcast::Disconnected); } Event::NetworkNotFound => { - broadcast_sender.send(Broadcast::NetworkNotFound)?; + self.broadcast(Broadcast::NetworkNotFound); if let Some(sender) = select_request.take() { sender.send(Ok(SelectResult::NotFound)); } } Event::WrongPsk => { - broadcast_sender.send(Broadcast::WrongPsk)?; + self.broadcast(Broadcast::WrongPsk); if let Some(sender) = select_request.take() { sender.send(Ok(SelectResult::WrongPsk)); } } Event::Unknown(msg) => { - broadcast_sender.send(Broadcast::Unknown(msg))?; + self.broadcast(Broadcast::Unknown(msg)); } } - Ok(()) } - async fn get_status(socket_handle: &mut SocketHandle) -> Result { - let _n = socket_handle.socket.send(b"STATUS").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - parse_status(data_str) + async fn get_status( + socket_handle: &mut SocketHandle, + ) -> SocketResult> { + socket_handle.request("STATUS", parse_status).await } async fn handle_request( @@ -173,86 +162,73 @@ impl WifiStation { request: Request, scan_requests: &mut Vec>>>>, select_request: &mut Option, - ) -> Result { + ) -> SocketResult { debug!("Handling request: {request:?}"); match request { Request::Custom(custom, response_channel) => { - let _n = socket_handle.socket.send(custom.as_bytes()).await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - debug!("Custom request response: {data_str}"); - if response_channel.send(Ok(data_str.into())).is_err() { - error!("Custom request response channel closed before response sent"); - } + let data_str = socket_handle.request(&custom, TryInto::try_into).await?; + debug!("Custom request response: {data_str:?}"); + let _ = response_channel.send(data_str); } Request::SelectTimeout => { if let Some(sender) = select_request.take() { - sender.send(Ok(SelectResult::Timeout)); + sender.send(Err(ClientError::Timeout)); } } Request::Scan(response_channel) => { - scan_requests.push(response_channel); - if let Err(e) = socket_handle.command(b"SCAN").await { - debug!("Error while requesting SCAN: {e}"); + match socket_handle.command(b"SCAN").await? { + Ok(_) => { + scan_requests.push(response_channel); + } + Err(e) => { + let _ = response_channel.send(Err(e)); + } + }; + } + Request::ScanResults => { + let scan_results = socket_handle + .request("SCAN_RESULTS", ScanResult::vec_from_str) + .await?; + while let Some(scan_request) = scan_requests.pop() { + let _ = scan_request.send(scan_results.clone()); } } Request::Networks(response_channel) => { - let _n = socket_handle.socket.send(b"LIST_NETWORKS").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - let network_list = - NetworkResult::vec_from_str(data_str, &mut socket_handle.socket).await?; - if response_channel.send(Ok(network_list)).is_err() { - error!("Scan request response channel closed before response sent"); - } + let network_list = NetworkResult::request_results(socket_handle).await?; + let _ = response_channel.send(network_list); } Request::Status(response_channel) => { - let status = Self::get_status(socket_handle).await; - if response_channel.send(status).is_err() { - error!("Scan request response channel closed before response sent"); - } + let status = Self::get_status(socket_handle).await?; + let _ = response_channel.send(status); } Request::AddNetwork(response_channel) => { - let _n = socket_handle.socket.send(b"ADD_NETWORK").await?; - let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; - let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end(); - let network_id = usize::from_str(data_str)?; - if response_channel.send(Ok(network_id)).is_err() { - error!("Scan request response channel closed before response sent"); - } else { - debug!("wpa_ctrl created network {network_id}"); - } + let network_id = socket_handle + .request("ADD_NETWORK", usize::from_str) + .await?; + debug!("wpa_ctrl created network {network_id:?}"); + let _ = response_channel.send(network_id); } Request::SetNetwork(id, param, response) => { let cmd = format!( "SET_NETWORK {id} {}", match param { SetNetwork::Ssid(ssid) => format!("ssid {}", conf_escape(&ssid)), - SetNetwork::Bssid(bssid) => format!("bssid \"{bssid}\""), + SetNetwork::Bssid(bssid) => format!("bssid {}", conf_escape(&bssid)), SetNetwork::Psk(psk) => format!("psk {}", conf_escape(&psk)), SetNetwork::KeyMgmt(mgmt) => format!("key_mgmt {}", mgmt), } ); - debug!("wpa_ctrl \"{cmd}\""); + debug!("wpa_ctrl {cmd:?}"); let bytes = cmd.into_bytes(); - if let Err(e) = socket_handle.command(&bytes).await { - warn!("Error while setting network parameter: {e}"); - } - let _ = response.send(Ok(())); + let _ = response.send(socket_handle.command(&bytes).await?); } Request::SaveConfig(response) => { - if let Err(e) = socket_handle.command(b"SAVE_CONFIG").await { - warn!("Error while saving config: {e}"); - } debug!("wpa_ctrl config saved"); - let _ = response.send(Ok(())); + let _ = response.send(socket_handle.command(b"SAVE_CONFIG").await?); } Request::ReloadConfig(response) => { - if let Err(e) = socket_handle.command(b"RECONFIGURE").await { - warn!("Error while reloading config: {e}"); - } debug!("wpa_ctrl config reloaded"); - let _ = response.send(Ok(())); + let _ = response.send(socket_handle.command(b"RECONFIGURE").await?); } Request::RemoveNetwork(remove_network, response) => { let str = match remove_network { @@ -261,51 +237,37 @@ impl WifiStation { }; let cmd = format!("REMOVE_NETWORK {str}"); let bytes = cmd.into_bytes(); - if let Err(e) = socket_handle.command(&bytes).await { - warn!("Error while removing network {str}: {e}"); - } debug!("wpa_ctrl removed network {str}"); - let _ = response.send(Ok(())); + let _ = response.send(socket_handle.command(&bytes).await?); } Request::SelectNetwork(id, response_sender) => { - let response_sender = match select_request { + match select_request { None => { let cmd = format!("SELECT_NETWORK {id}"); let bytes = cmd.into_bytes(); - if let Err(e) = socket_handle.command(&bytes).await { + if let Err(e) = socket_handle.command(&bytes).await? { warn!("Error while selecting network {id}: {e}"); - let _ = response_sender.send(Ok(SelectResult::InvalidNetworkId)); - None + let _ = response_sender.send(Err(e)); } else { debug!("wpa_ctrl selected network {id}"); - let status = Self::get_status(socket_handle).await?; - if let Some(current_id) = status.get("id") { - if current_id == &id.to_string() { - let _ = - response_sender.send(Ok(SelectResult::AlreadyConnected)); - None - } else { - Some(response_sender) - } + let status = Self::get_status(socket_handle).await?.unwrap_or_default(); + if status.get("id") == Some(&id.to_string()) { + let _ = response_sender.send(Ok(SelectResult::AlreadyConnected)); } else { - Some(response_sender) + *select_request = Some(SelectRequest::new( + self.self_sender.clone(), + response_sender, + self.select_timeout, + )); } } } Some(_) => { warn!("Select request already pending! Dropping this one."); - let _ = response_sender.send(Ok(SelectResult::PendingSelect)); + let _ = response_sender.send(Err(ClientError::PendingSelect)); debug!("wpa_ctrl removed network {id}"); - None } }; - if let Some(response_sender) = response_sender { - *select_request = Some(SelectRequest::new( - self.self_sender.clone(), - response_sender, - self.select_timeout, - )); - } } Request::Shutdown => (), //shutdown is handled at the scope above } diff --git a/src/sta/setup.rs b/src/sta/setup.rs index 2571008..4deccf4 100644 --- a/src/sta/setup.rs +++ b/src/sta/setup.rs @@ -16,14 +16,14 @@ pub struct WifiSetupGeneric { } impl WifiSetupGeneric { - pub fn new() -> Result { + pub fn new() -> Self { // setup the channel for client requests let (self_sender, request_receiver) = mpsc::channel(C); let request_client = RequestClient::new(self_sender.clone()); // setup the channel for broadcasts let (broadcast_sender, broadcast_receiver) = broadcast::channel(B); - Ok(Self { + Self { wifi: WifiStation { socket_path: PATH_DEFAULT_SERVER.into(), request_receiver, @@ -33,7 +33,7 @@ impl WifiSetupGeneric { }, request_client, broadcast_receiver, - }) + } } pub fn set_socket_path>(&mut self, path: S) { @@ -55,3 +55,9 @@ impl WifiSetupGeneric { self.wifi } } + +impl Default for WifiSetupGeneric { + fn default() -> Self { + Self::new() + } +} diff --git a/src/sta/types.rs b/src/sta/types.rs index 333cf6a..ca98ac5 100644 --- a/src/sta/types.rs +++ b/src/sta/types.rs @@ -1,10 +1,12 @@ -use super::{config, config::unprintf, error, warn, Result}; +use super::error::ParseError; +use super::{config, config::unprintf, warn, Result, SocketHandle}; +use super::{ParseResult, SocketResult}; use serde::Serialize; use std::collections::HashMap; use std::fmt::Display; use std::str::FromStr; -use tokio::net::UnixDatagram; +use std::sync::Arc; #[derive(Serialize, Debug, Clone)] /// The result from scanning for networks. @@ -41,22 +43,19 @@ impl ScanResult { ///let results = ScanResult::vec_from_str(r#"bssid / frequency / signal level / flags / ssid ///00:5f:67:90:da:64 2417 -35 [WPA-PSK-CCMP][WPA2-PSK-CCMP][ESS] TP-Link DA64 ///e0:91:f5:7d:11:c0 2462 -33 [WPA2-PSK-CCMP][WPS][ESS] ¯\\_(\xe3\x83\x84)_/¯ - ///"#); + ///"#).unwrap(); ///assert_eq!(results[0].mac, "00:5f:67:90:da:64"); ///assert_eq!(results[0].name, "TP-Link DA64"); ///assert_eq!(results[1].signal, -33); ///assert_eq!(results[1].name, r#"¯\_(ツ)_/¯"#); ///``` - pub fn vec_from_str(response: &str) -> Vec { + pub fn vec_from_str(response: &str) -> ParseResult>> { let mut results = Vec::new(); for line in response.lines().skip(1) { - if let Some(scan_result) = ScanResult::from_line(line) { - results.push(scan_result); - } else { - warn!("Invalid result from scan: {line}"); - } + results.push(ScanResult::from_line(line).ok_or(ParseError::ScanResult)?); } - results + results.sort_by(|a, b| a.signal.cmp(&b.signal)); + Ok(Arc::new(results)) } } @@ -68,27 +67,35 @@ pub struct NetworkResult { pub flags: String, } +fn parse_get_network(resp: &str) -> ParseResult { + let escaped = resp.trim_matches('\"'); + Ok(unprintf(escaped)?) +} + impl NetworkResult { - pub async fn vec_from_str( - response: &str, - socket: &mut UnixDatagram, - ) -> Result> { - let mut buffer = [0; 256]; + pub async fn request_results( + socket_handle: &mut SocketHandle, + ) -> SocketResult>> { + let response: String = match socket_handle + .request("LIST_NETWORKS", TryInto::try_into) + .await? + { + Ok(x) => x, + Err(e) => return Ok(Err(e)), + }; let mut results = Vec::new(); let split = response.split('\n').skip(1); for line in split { let mut line_split = line.split_whitespace(); if let Some(network_id) = line_split.next() { - let cmd = format!("GET_NETWORK {network_id} ssid"); - let bytes = cmd.into_bytes(); - socket.send(&bytes).await?; - let n = socket.recv(&mut buffer).await?; - let ssid = std::str::from_utf8(&buffer[..n])?.trim_matches('\"'); - let ssid = unprintf(ssid).map_err(|e| error::Error::ParsingWifiStatus { - e, - s: ssid.to_string(), - })?; if let Ok(network_id) = usize::from_str(network_id) { + let ssid = match socket_handle + .request(&format!("GET_NETWORK {network_id} ssid"), parse_get_network) + .await? + { + Ok(x) => x, + Err(e) => return Ok(Err(e)), + }; if let Some(flags) = line_split.last() { results.push(NetworkResult { flags: flags.into(), @@ -101,18 +108,15 @@ impl NetworkResult { } } } - Ok(results) + Ok(Ok(results)) } } /// A HashMap of what is returned when running `wpa_cli status`. pub type Status = HashMap; -pub(crate) fn parse_status(response: &str) -> Result { - config::from_str(response).map_err(|e| error::Error::ParsingWifiStatus { - e, - s: response.into(), - }) +pub(crate) fn parse_status(response: &str) -> ParseResult { + Ok(config::from_str(response)?) } #[derive(Debug)]