diff --git a/Cargo.toml b/Cargo.toml index 937e714..22feabe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ readme = "README.md" keywords = ["hostapd", "wpa-supplicant", "wpa_supplicant", "wpa-cli", "wifi"] [dependencies] -config = {version="0", default-features = false, features = ["ini"]} +hex = "0.4" log = { version = "0" } serde = {version = "1", features = ["derive"] } thiserror = "1" diff --git a/src/ap/types.rs b/src/ap/types.rs index af81a01..ceef15f 100644 --- a/src/ap/types.rs +++ b/src/ap/types.rs @@ -1,53 +1,94 @@ use super::{error, Result}; -use serde::{de, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; /// Status of the WiFi Station #[derive(Serialize, Deserialize, Debug)] pub struct Status { pub state: String, pub phy: String, - pub freq: String, - pub num_sta_non_erp: String, - pub num_sta_no_short_slot_time: String, - pub num_sta_no_short_preamble: String, - pub olbc: String, - pub num_sta_ht_no_gf: String, - pub num_sta_no_ht: String, - pub num_sta_ht_20_mhz: String, - pub num_sta_ht40_intolerant: String, - pub olbc_ht: String, + pub freq: u32, + pub num_sta_non_erp: u64, + pub num_sta_no_short_slot_time: u64, + pub num_sta_no_short_preamble: u64, + pub olbc: u64, + pub num_sta_ht_no_gf: u64, + pub num_sta_no_ht: u64, + pub num_sta_ht_20_mhz: u64, + pub num_sta_ht40_intolerant: u64, + pub olbc_ht: u64, pub ht_op_mode: String, - pub cac_time_seconds: String, - pub cac_time_left_seconds: String, - pub channel: String, - pub secondary_channel: String, - pub ieee80211n: String, - pub ieee80211ac: String, - pub ieee80211ax: String, - pub beacon_int: String, - pub dtim_period: String, - pub ht_caps_info: String, - pub ht_mcs_bitmask: String, + pub cac_time_seconds: u64, + pub cac_time_left_seconds: Option, + pub channel: u64, + pub secondary_channel: u64, + pub ieee80211n: u64, + pub ieee80211ac: u64, + pub ieee80211ax: u64, + pub beacon_int: u64, + pub dtim_period: u64, + // missing if not not ieee80211n + pub ht_caps_info: Option, + pub ht_mcs_bitmask: Option, + #[serde(default)] // missing if there are no rates pub supported_rates: String, - pub max_txpower: String, + pub max_txpower: u64, pub bss: Vec, pub bssid: Vec, pub ssid: Vec, - pub num_sta: Vec, + pub num_sta: Vec, } impl Status { + /// Decode from the response sent from the hostapd + /// ``` + /// # use wifi_ctrl::ap::Status; + /// let resp = r#" + ///state=ENABLED + ///phy=phy0 + ///freq=2437 + ///num_sta_non_erp=0 + ///num_sta_no_short_slot_time=0 + ///num_sta_no_short_preamble=0 + ///olbc=0 + ///num_sta_ht_no_gf=0 + ///num_sta_no_ht=0 + ///num_sta_ht_20_mhz=0 + ///num_sta_ht40_intolerant=0 + ///olbc_ht=0 + ///ht_op_mode=0x0 + ///cac_time_seconds=0 + ///cac_time_left_seconds=N/A + ///channel=6 + ///edmg_enable=0 + ///edmg_channel=0 + ///secondary_channel=0 + ///ieee80211n=0 + ///ieee80211ac=0 + ///ieee80211ax=0 + ///beacon_int=100 + ///dtim_period=2 + ///supported_rates=02 04 0b 16 0c 12 18 24 30 48 60 6c + ///max_txpower=20 + ///bss[0]=wlan0 + ///bssid[0]=cc:7b:5c:1a:d2:21 + ///ssid[0]=WiFi-SSID + ///num_sta[0]=0 + ///bss[1]=wlan1 + ///bssid[1]=cc:7b:5c:4d:ff:5c + ///ssid[1]=¯\\_(\xe3\x83\x84)_/¯ + ///num_sta[1]=1 + ///"#; + /// let status = Status::from_response(resp).unwrap(); + /// assert_eq!(status.state, "ENABLED"); + /// assert_eq!(status.freq, 2437); + /// assert_eq!(status.ssid, vec![r"WiFi-SSID", r#"¯\_(ツ)_/¯"#]); + /// assert_eq!(status.num_sta, vec![0, 1]); + /// ``` pub fn from_response(response: &str) -> Result { - use config::{Config, File, FileFormat}; - let config = Config::builder() - .add_source(File::from_str(response, FileFormat::Ini)) - .build() - .map_err(|e| error::Error::ParsingWifiStatus { - e, - s: response.into(), - })?; - - Ok(config.try_deserialize::().unwrap()) + crate::config::from_str(response).map_err(|e| error::Error::ParsingWifiStatus { + e, + s: response.into(), + }) } } @@ -56,39 +97,111 @@ impl Status { pub struct Config { pub bssid: String, pub ssid: String, - #[serde(deserialize_with = "deserialize_enabled_bool")] - pub wps_state: bool, + pub wps_state: String, + #[serde(default)] // missing if zero pub wpa: i32, - pub ket_mgmt: String, - pub group_cipher: String, - pub rsn_pairwise_cipher: String, - pub wpa_pairwise_cipher: String, + // missing if WPA is not enabled + pub key_mgmt: Option, + pub group_cipher: Option, + pub rsn_pairwise_cipher: Option, + pub wpa_pairwise_cipher: Option, } impl Config { + /// Decode from the response sent from the hostapd + /// ``` + /// # use wifi_ctrl::ap::Config; + /// let resp = r#" + ///bssid=cc:7b:5c:1a:d2:21 + ///ssid=WiFi-SSID + ///wps_state=disabled + ///wpa=2 + ///key_mgmt=WPA-PSK + ///group_cipher=CCMP + ///rsn_pairwise_cipher=CCMP + ///wpa_pairwise_cipher=CCMP + ///"#; + /// let config = Config::from_response(resp).unwrap(); + /// assert_eq!(config.wps_state, "disabled"); + /// assert_eq!(config.wpa, 2); + /// assert_eq!(config.ssid, "WiFi-SSID"); + /// ``` pub fn from_response(response: &str) -> Result { - use config::{File, FileFormat}; - let config = config::Config::builder() - .add_source(File::from_str(response, FileFormat::Ini)) - .build() - .map_err(|e| error::Error::ParsingWifiConfig { - e, - s: response.into(), - })?; - - Ok(config.try_deserialize::().unwrap()) + crate::config::from_str(response).map_err(|e| error::Error::ParsingWifiConfig { + e, + s: response.into(), + }) } } -fn deserialize_enabled_bool<'de, D>(deserializer: D) -> std::result::Result -where - D: de::Deserializer<'de>, -{ - let s: &str = de::Deserialize::deserialize(deserializer)?; +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_wpa_psk() { + let resp = r#" +bssid=cc:7b:5c:1a:d2:21 +ssid=\xc2\xaf\\_(\xe3\x83\x84)_/\xc2\xaf +wps_state=disabled +wpa=2 +key_mgmt=WPA-PSK +group_cipher=CCMP +rsn_pairwise_cipher=CCMP + "#; + let config = Config::from_response(resp).unwrap(); + assert_eq!(config.wpa, 2); + assert_eq!(config.wps_state, "disabled"); + assert_eq!(config.ssid, r#"¯\_(ツ)_/¯"#); + } + + #[test] + fn test_config_wsp_1() { + let resp = r#" +bssid=cc:7b:5c:1a:d2:21 +ssid=MY_SSID +wps_state=not configured +passphrase=MY_PASSPHRASE +psk=8dbbe42cb44f21088fbb9cfbf24dc9b39787d6026d436b01b3ac7d34afb4416d +wpa=2 +key_mgmt=WPA-PSK +group_cipher=CCMP +rsn_pairwise_cipher=CCMP + "#; + let config = Config::from_response(resp).unwrap(); + assert_eq!(config.wpa, 2); + assert_eq!(config.wps_state, "not configured"); + assert_eq!(config.ssid, "MY_SSID"); + } + + #[test] + fn test_config_wsp_2() { + let resp = r#" +bssid=cc:7b:5c:1a:d2:21 +ssid=MY_SSID +wps_state=configured +passphrase=MY_PASSPHRASE +psk=8dbbe42cb44f21088fbb9cfbf24dc9b39787d6026d436b01b3ac7d34afb4416d +wpa=2 +key_mgmt=WPA-PSK +group_cipher=CCMP +rsn_pairwise_cipher=CCMP + "#; + let config = Config::from_response(resp).unwrap(); + assert_eq!(config.wpa, 2); + assert_eq!(config.wps_state, "configured"); + assert_eq!(config.ssid, "MY_SSID"); + } - match s { - "enabled" => Ok(true), - "disabled" => Ok(false), - _ => Err(de::Error::unknown_variant(s, &["enabled", "disabled"])), + #[test] + fn test_config_open() { + let resp = r#" +bssid=cc:7b:5c:1a:d2:21 +ssid=Wi-Fi +wps_state=disabled + "#; + let config = Config::from_response(resp).unwrap(); + assert_eq!(config.wpa, 0); + assert_eq!(config.ssid, "Wi-Fi"); } } diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..5f8237a --- /dev/null +++ b/src/config.rs @@ -0,0 +1,211 @@ +use std::collections::HashMap; +use std::fmt::Display; + +use serde::de::value::MapDeserializer; +use serde::de::{self, Error, IntoDeserializer, Visitor}; +use serde::{forward_to_deserialize_any, Deserialize}; + +type Result = std::result::Result; + +#[derive(Debug, thiserror::Error, PartialEq, Eq, Clone)] +pub enum ConfigError { + #[error("Missing '=' delimiter in config line")] + MissingDelimterEqual, + #[error("escape code is not made up of valid hex code")] + InvalidEscape, + #[error("escape code is incomplete")] + IncompleteEscape, + #[error("escaped value is not valid uft8 after unescaping")] + NonUtf8Escape, + #[error("Value could not be decoded")] + SerdeError(String), +} + +impl Error for ConfigError { + fn custom(msg: T) -> Self + where + T: Display, + { + Self::SerdeError(msg.to_string()) + } +} + +#[derive(Default)] +pub struct Deserializer<'de> { + input: Vec<&'de str>, +} + +impl<'de> Deserializer<'de> { + fn only(&self) -> Result<&'de str> { + if self.input.len() == 1 { + Ok(self.input[0]) + } else { + Err(ConfigError::SerdeError("did not expect seq".to_owned())) + } + } +} + +impl<'de> IntoDeserializer<'de, ConfigError> for Deserializer<'de> { + type Deserializer = Self; + + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +pub fn from_str<'a, T>(s: &'a str) -> Result +where + T: Deserialize<'a>, +{ + let mut map: HashMap<&str, Deserializer<'_>> = HashMap::new(); + for line in s.trim().lines() { + let (k, v) = line + .split_once('=') + .ok_or(ConfigError::MissingDelimterEqual)?; + let (k, i) = if let Some((k, i)) = k.split_once('[') { + if let Some((i, "")) = i.rsplit_once(']') { + (k, i.parse().map_err(ConfigError::custom)?) + } else { + return Err(ConfigError::custom("invalid key")); + } + } else { + (k, 0) + }; + let values = &mut map.entry(k.trim()).or_default().input; + if values.len() != i { + return Err(ConfigError::custom("Duplicate key")); + } + values.push(v); + } + T::deserialize(MapDeserializer::new(map.into_iter())) +} + +macro_rules! forward_to_from_str { + ($func:ident $method:ident) => { + #[inline] + fn $func(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.$method(self.only()?.parse().map_err(ConfigError::custom)?) + } + }; +} + +impl<'de> de::Deserializer<'de> for Deserializer<'de> { + type Error = ConfigError; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_string(visitor) + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self.only()? { + "true" | "TRUE" | "enabled" | "ENABLED" => visitor.visit_bool(true), + "false" | "FALSE" | "disabled" | "DISABLED" => visitor.visit_bool(false), + s => Err(ConfigError::SerdeError(format!("Invalid bool {}", s))), + } + } + + forward_to_from_str!(deserialize_i8 visit_i8); + forward_to_from_str!(deserialize_i16 visit_i16); + forward_to_from_str!(deserialize_i32 visit_i32); + forward_to_from_str!(deserialize_i64 visit_i64); + + forward_to_from_str!(deserialize_u8 visit_u8); + forward_to_from_str!(deserialize_u16 visit_u16); + forward_to_from_str!(deserialize_u32 visit_u32); + forward_to_from_str!(deserialize_u64 visit_u64); + + forward_to_from_str!(deserialize_f32 visit_f32); + forward_to_from_str!(deserialize_f64 visit_f64); + + forward_to_from_str!(deserialize_char visit_char); + + // these are not really supported (nor used) as deserialize_any will always deserialize to a String + forward_to_deserialize_any! {str unit unit_struct bytes byte_buf map struct newtype_struct enum tuple tuple_struct identifier ignored_any} + + fn deserialize_string(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_string(unprintf(self.only()?)?) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.only()?.is_empty() || self.only()? == "N/A" { + visitor.visit_none() + } else { + visitor.visit_some(self) + } + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq( + self.input + .into_iter() + .map(|s| Deserializer { input: vec![s] }) + .collect::>() + .into_deserializer(), + ) + } +} + +pub(crate) fn unprintf(escaped: &str) -> std::result::Result { + let mut bytes = escaped.as_bytes().iter().copied(); + let mut unescaped = vec![]; + // undo "printf_encode" + loop { + unescaped.push(match bytes.next() { + Some(b'\\') => match bytes.next().ok_or(ConfigError::IncompleteEscape)? { + b'n' => b'\n', + b'r' => b'\r', + b't' => b'\t', + b'e' => b'\x1b', + b'x' => { + let hex = [ + bytes.next().ok_or(ConfigError::IncompleteEscape)?, + bytes.next().ok_or(ConfigError::IncompleteEscape)?, + ]; + u8::from_str_radix( + std::str::from_utf8(&hex).or(Err(ConfigError::InvalidEscape))?, + 16, + ) + .or(Err(ConfigError::InvalidEscape))? + } + c => c, + }, + Some(c) => c, + None => break, + }) + } + String::from_utf8(unescaped).or(Err(ConfigError::NonUtf8Escape)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_deserializer() { + let resp = r#" + state=ENABLED + shrug=¯\\_(\xe3\x83\x84)_/¯ + "#; + let status: HashMap = from_str(resp).unwrap(); + assert_eq!(status.get("state").unwrap(), "ENABLED"); + assert_eq!(status.get("shrug").unwrap(), r#"¯\_(ツ)_/¯"#); + } +} diff --git a/src/lib.rs b/src/lib.rs index 9ae1ced..c21e272 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,7 @@ pub mod error; /// WiFi Station (network client) runtime and types pub mod sta; +pub(crate) mod config; pub(crate) mod socket_handle; use socket_handle::SocketHandle; diff --git a/src/sta/client.rs b/src/sta/client.rs index eb9938f..75345cc 100644 --- a/src/sta/client.rs +++ b/src/sta/client.rs @@ -50,6 +50,7 @@ pub(crate) enum Request { AddNetwork(oneshot::Sender>), SetNetwork(usize, SetNetwork, oneshot::Sender), SaveConfig(oneshot::Sender), + ReloadConfig(oneshot::Sender), RemoveNetwork(RemoveNetwork, oneshot::Sender), SelectNetwork(usize, oneshot::Sender>), Shutdown, @@ -83,6 +84,9 @@ impl ShutdownSignal for Request { Request::SaveConfig(response) => { let _ = response.send(Err(error::Error::StartupAborted)); } + Request::ReloadConfig(response) => { + let _ = response.send(Err(error::Error::StartupAborted)); + } Request::RemoveNetwork(_, response) => { let _ = response.send(Err(error::Error::StartupAborted)); } @@ -202,6 +206,12 @@ impl RequestClient { request.await? } + pub async fn reload_config(&self) -> Result { + let (response, request) = oneshot::channel(); + self.send_request(Request::ReloadConfig(response)).await?; + request.await? + } + pub async fn remove_network(&self, id: usize) -> Result { let (response, request) = oneshot::channel(); self.send_request(Request::RemoveNetwork(RemoveNetwork::Id(id), response)) diff --git a/src/sta/mod.rs b/src/sta/mod.rs index 6cbff1f..85f0fee 100644 --- a/src/sta/mod.rs +++ b/src/sta/mod.rs @@ -122,7 +122,7 @@ impl WifiStation { let _n = socket_handle.socket.send(b"SCAN_RESULTS").await?; let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?; let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?; - let mut scan_results = ScanResult::vec_from_str(data_str)?; + let mut scan_results = ScanResult::vec_from_str(data_str); scan_results.sort_by(|a, b| a.signal.cmp(&b.signal)); let results = Arc::new(scan_results); @@ -227,9 +227,9 @@ impl WifiStation { let cmd = format!( "SET_NETWORK {id} {}", match param { - SetNetwork::Ssid(ssid) => format!("ssid \"{ssid}\""), + SetNetwork::Ssid(ssid) => format!("ssid {}", conf_escape(&ssid)), SetNetwork::Bssid(bssid) => format!("bssid \"{bssid}\""), - SetNetwork::Psk(psk) => format!("psk \"{psk}\""), + SetNetwork::Psk(psk) => format!("psk {}", conf_escape(&psk)), SetNetwork::KeyMgmt(mgmt) => format!("key_mgmt {}", mgmt), } ); @@ -247,6 +247,13 @@ impl WifiStation { debug!("wpa_ctrl config saved"); let _ = response.send(Ok(())); } + Request::ReloadConfig(response) => { + if let Err(e) = socket_handle.command(b"RECONFIGURE").await { + warn!("Error while reloading config: {e}"); + } + debug!("wpa_ctrl config reloaded"); + let _ = response.send(Ok(())); + } Request::RemoveNetwork(remove_network, response) => { let str = match remove_network { RemoveNetwork::All => "all".to_string(), @@ -306,6 +313,16 @@ impl WifiStation { } } +/// convert to wpa config format idealy a "quoted string" +/// in case of new-lines, quotes or emoji fall back to hex encoding the whole thing +fn conf_escape(raw: &str) -> String { + if raw.bytes().all(|b| b.is_ascii_graphic() && b != b'"') { + format!("\"{raw}\"") + } else { + hex::encode(raw) + } +} + struct SelectRequest { response: oneshot::Sender>, timeout: tokio::task::JoinHandle<()>, diff --git a/src/sta/types.rs b/src/sta/types.rs index c55cdfe..333cf6a 100644 --- a/src/sta/types.rs +++ b/src/sta/types.rs @@ -1,4 +1,5 @@ -use super::{error, warn, Result}; +use super::{config, config::unprintf, error, warn, Result}; + use serde::Serialize; use std::collections::HashMap; use std::fmt::Display; @@ -16,46 +17,46 @@ pub struct ScanResult { } impl ScanResult { - pub fn vec_from_str(response: &str) -> Result> { + fn from_line(line: &str) -> Option { + let (mac, rest) = line.split_once('\t')?; + let (frequency, rest) = rest.split_once('\t')?; + let (signal, rest) = rest.split_once('\t')?; + let signal = isize::from_str(signal).ok()?; + let (flags, escaped_name) = rest.split_once('\t')?; + let name = unprintf(escaped_name).ok()?; + Some(ScanResult { + mac: mac.to_string(), + frequency: frequency.to_string(), + signal, + flags: flags.to_string(), + name, + }) + } + + // Overide to allow tabs in the raw string to avoid double escaping everything + #[allow(clippy::tabs_in_doc_comments)] + /// Parses lines from a scan result + ///``` + ///use wifi_ctrl::sta::ScanResult; + ///let results = ScanResult::vec_from_str(r#"bssid / frequency / signal level / flags / ssid + ///00:5f:67:90:da:64 2417 -35 [WPA-PSK-CCMP][WPA2-PSK-CCMP][ESS] TP-Link DA64 + ///e0:91:f5:7d:11:c0 2462 -33 [WPA2-PSK-CCMP][WPS][ESS] ¯\\_(\xe3\x83\x84)_/¯ + ///"#); + ///assert_eq!(results[0].mac, "00:5f:67:90:da:64"); + ///assert_eq!(results[0].name, "TP-Link DA64"); + ///assert_eq!(results[1].signal, -33); + ///assert_eq!(results[1].name, r#"¯\_(ツ)_/¯"#); + ///``` + pub fn vec_from_str(response: &str) -> Vec { let mut results = Vec::new(); - let split = response.split('\n').skip(1); - for line in split { - let mut line_split = line.split_whitespace(); - if let (Some(mac), Some(frequency), Some(signal), Some(flags)) = ( - line_split.next(), - line_split.next(), - line_split.next(), - line_split.next(), - ) { - let mut name: Option = None; - for text in line_split { - match &mut name { - Some(started) => { - started.push(' '); - started.push_str(text); - } - None => { - name = Some(text.to_string()); - } - } - } - if let Some(name) = name { - if let Ok(signal) = isize::from_str(signal) { - let scan_result = ScanResult { - mac: mac.to_string(), - frequency: frequency.to_string(), - signal, - flags: flags.to_string(), - name, - }; - results.push(scan_result); - } else { - warn!("Invalid string for signal: {signal}"); - } - } + for line in response.lines().skip(1) { + if let Some(scan_result) = ScanResult::from_line(line) { + results.push(scan_result); + } else { + warn!("Invalid result from scan: {line}"); } } - Ok(results) + results } } @@ -83,11 +84,15 @@ impl NetworkResult { socket.send(&bytes).await?; let n = socket.recv(&mut buffer).await?; let ssid = std::str::from_utf8(&buffer[..n])?.trim_matches('\"'); + let ssid = unprintf(ssid).map_err(|e| error::Error::ParsingWifiStatus { + e, + s: ssid.to_string(), + })?; if let Ok(network_id) = usize::from_str(network_id) { if let Some(flags) = line_split.last() { results.push(NetworkResult { flags: flags.into(), - ssid: ssid.into(), + ssid, network_id, }) } @@ -104,15 +109,10 @@ impl NetworkResult { pub type Status = HashMap; pub(crate) fn parse_status(response: &str) -> Result { - use config::{Config, File, FileFormat}; - let config = Config::builder() - .add_source(File::from_str(response, FileFormat::Ini)) - .build() - .map_err(|e| error::Error::ParsingWifiStatus { - e, - s: response.into(), - })?; - Ok(config.try_deserialize::>().unwrap()) + config::from_str(response).map_err(|e| error::Error::ParsingWifiStatus { + e, + s: response.into(), + }) } #[derive(Debug)]