diff --git a/src/add_service.rs b/src/add_service.rs index f5245ba..be62ebe 100644 --- a/src/add_service.rs +++ b/src/add_service.rs @@ -10,7 +10,7 @@ use crate::config::create_owned_dir; use crate::helpers::download_and_extract_safenode; use crate::node::{Node, NodeRegistry, NodeStatus}; use crate::service::{ServiceConfig, ServiceControl}; -use color_eyre::{eyre::eyre, Result}; +use color_eyre::{eyre::eyre, Help, Result}; use colored::Colorize; use libp2p::Multiaddr; use sn_releases::SafeReleaseRepositoryInterface; @@ -22,6 +22,8 @@ pub struct AddServiceOptions { pub service_data_dir_path: PathBuf, pub service_log_dir_path: PathBuf, pub peers: Vec, + pub port: Option, + pub rpc_port: Option, pub url: Option, pub user: String, pub version: Option, @@ -39,6 +41,32 @@ pub async fn add( service_control: &dyn ServiceControl, release_repo: Box, ) -> Result<()> { + if install_options.count.is_some() + && (install_options.port.is_some() || install_options.rpc_port.is_some()) + { + let count = install_options.count.unwrap(); + if count > 1 { + return Err(eyre!( + "Custom ports can only be used when adding a single service" + )); + } + } + + if install_options.port.is_some() { + let port = install_options.port.unwrap(); + if !service_control.is_port_free(port) { + return Err(eyre!("Port {port} is already in use") + .suggestion("Please try again with an available port")); + } + } + if install_options.rpc_port.is_some() { + let rpc_port = install_options.rpc_port.unwrap(); + if !service_control.is_port_free(rpc_port) { + return Err(eyre!("Port {rpc_port} is already in use") + .suggestion("Please try again with an available port")); + } + } + let (safenode_download_path, version) = download_and_extract_safenode(install_options.url, install_options.version, release_repo) .await?; @@ -53,8 +81,16 @@ pub async fn add( let target_node_count = current_node_count + install_options.count.unwrap_or(1); let mut node_number = current_node_count + 1; while node_number <= target_node_count { - let node_port = service_control.get_available_port()?; - let rpc_port = service_control.get_available_port()?; + let node_port = if let Some(port) = install_options.port { + port + } else { + service_control.get_available_port()? + }; + let rpc_port = if let Some(port) = install_options.rpc_port { + port + } else { + service_control.get_available_port()? + }; let service_name = format!("safenode{node_number}"); let service_data_dir_path = install_options @@ -276,6 +312,8 @@ mod tests { service_data_dir_path: node_data_dir.to_path_buf(), service_log_dir_path: node_logs_dir.to_path_buf(), peers: vec![], + port: None, + rpc_port: None, url: None, user: get_username(), version: None, @@ -463,6 +501,8 @@ mod tests { AddServiceOptions { count: Some(3), peers: vec![], + port: None, + rpc_port: None, safenode_dir_path: temp_dir.to_path_buf(), service_data_dir_path: node_data_dir.to_path_buf(), service_log_dir_path: node_logs_dir.to_path_buf(), @@ -610,6 +650,8 @@ mod tests { AddServiceOptions { count: None, peers: vec![], + port: None, + rpc_port: None, safenode_dir_path: temp_dir.to_path_buf(), service_data_dir_path: node_data_dir.to_path_buf(), service_log_dir_path: node_logs_dir.to_path_buf(), @@ -749,6 +791,8 @@ mod tests { AddServiceOptions { count: None, peers: vec![], + port: None, + rpc_port: None, safenode_dir_path: temp_dir.to_path_buf(), service_data_dir_path: node_data_dir.to_path_buf(), service_log_dir_path: node_logs_dir.to_path_buf(), @@ -865,6 +909,8 @@ mod tests { service_data_dir_path: node_data_dir.to_path_buf(), service_log_dir_path: node_logs_dir.to_path_buf(), peers: vec![], + port: None, + rpc_port: None, url: Some(url.to_string()), user: get_username(), version: None, @@ -898,4 +944,294 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn add_node_should_use_custom_ports_for_one_service() -> Result<()> { + let mut mock_service_control = MockServiceControl::new(); + let mut mock_release_repo = MockSafeReleaseRepository::new(); + + let mut node_registry = NodeRegistry { nodes: vec![] }; + let latest_version = "0.96.4"; + let temp_dir = assert_fs::TempDir::new()?; + let node_data_dir = temp_dir.child("data"); + node_data_dir.create_dir_all()?; + let node_logs_dir = temp_dir.child("logs"); + node_logs_dir.create_dir_all()?; + let safenode_download_path = temp_dir.child(SAFENODE_FILE_NAME); + safenode_download_path.write_binary(b"fake safenode bin")?; + + let custom_port = 12000; + let custom_rpc_port = 12001; + + let mut seq = Sequence::new(); + mock_service_control + .expect_is_port_free() + .with(eq(custom_port)) + .times(1) + .returning(|_| true) + .in_sequence(&mut seq); + mock_service_control + .expect_is_port_free() + .with(eq(custom_rpc_port)) + .times(1) + .returning(|_| true) + .in_sequence(&mut seq); + + mock_release_repo + .expect_get_latest_version() + .times(1) + .returning(|_| Ok(latest_version.to_string())) + .in_sequence(&mut seq); + + mock_release_repo + .expect_download_release_from_s3() + .with( + eq(&ReleaseType::Safenode), + eq(latest_version), + always(), // Varies per platform + eq(&ArchiveType::TarGz), + always(), // Temporary directory which doesn't really matter + always(), // Callback for progress bar which also doesn't matter + ) + .times(1) + .returning(move |_, _, _, _, _, _| { + Ok(PathBuf::from(format!( + "/tmp/safenode-{}-x86_64-unknown-linux-musl.tar.gz", + latest_version + ))) + }) + .in_sequence(&mut seq); + + let safenode_download_path_clone = safenode_download_path.to_path_buf().clone(); + mock_release_repo + .expect_extract_release_archive() + .with( + eq(PathBuf::from(format!( + "/tmp/safenode-{}-x86_64-unknown-linux-musl.tar.gz", + latest_version + ))), + always(), // We will extract to a temporary directory + ) + .times(1) + .returning(move |_, _| Ok(safenode_download_path_clone.clone())) + .in_sequence(&mut seq); + + mock_service_control + .expect_install() + .times(1) + .with(eq(ServiceConfig { + name: "safenode1".to_string(), + safenode_path: node_data_dir + .to_path_buf() + .join("safenode1") + .join(SAFENODE_FILE_NAME), + node_port: custom_port, + rpc_port: custom_rpc_port, + service_user: get_username(), + log_dir_path: node_logs_dir.to_path_buf().join("safenode1"), + data_dir_path: node_data_dir.to_path_buf().join("safenode1"), + peers: vec![], + })) + .returning(|_| Ok(())) + .in_sequence(&mut seq); + + add( + AddServiceOptions { + count: None, + safenode_dir_path: temp_dir.to_path_buf(), + service_data_dir_path: node_data_dir.to_path_buf(), + service_log_dir_path: node_logs_dir.to_path_buf(), + peers: vec![], + port: Some(custom_port), + rpc_port: Some(custom_rpc_port), + url: None, + user: get_username(), + version: None, + }, + &mut node_registry, + &mock_service_control, + Box::new(mock_release_repo), + ) + .await?; + + safenode_download_path.assert(predicate::path::missing()); + node_data_dir.assert(predicate::path::is_dir()); + node_logs_dir.assert(predicate::path::is_dir()); + + assert_eq!(node_registry.nodes.len(), 1); + assert_eq!(node_registry.nodes[0].version, latest_version); + assert_eq!(node_registry.nodes[0].service_name, "safenode1"); + assert_eq!(node_registry.nodes[0].user, get_username()); + assert_eq!(node_registry.nodes[0].number, 1); + assert_eq!(node_registry.nodes[0].port, custom_port); + assert_eq!(node_registry.nodes[0].rpc_port, custom_rpc_port); + assert_eq!( + node_registry.nodes[0].log_dir_path, + Some(node_logs_dir.to_path_buf().join("safenode1")) + ); + assert_eq!( + node_registry.nodes[0].data_dir_path, + Some(node_data_dir.to_path_buf().join("safenode1")) + ); + assert_matches!(node_registry.nodes[0].status, NodeStatus::Added); + + Ok(()) + } + + #[tokio::test] + async fn add_node_should_return_error_if_custom_port_is_in_use() -> Result<()> { + let mut mock_service_control = MockServiceControl::new(); + + let mut node_registry = NodeRegistry { nodes: vec![] }; + let temp_dir = assert_fs::TempDir::new()?; + let node_data_dir = temp_dir.child("data"); + node_data_dir.create_dir_all()?; + let node_logs_dir = temp_dir.child("logs"); + node_logs_dir.create_dir_all()?; + + let custom_port = 12000; + let custom_rpc_port = 12001; + + mock_service_control + .expect_is_port_free() + .with(eq(custom_port)) + .times(1) + .returning(|_| false); + + let result = add( + AddServiceOptions { + count: None, + safenode_dir_path: temp_dir.to_path_buf(), + service_data_dir_path: node_data_dir.to_path_buf(), + service_log_dir_path: node_logs_dir.to_path_buf(), + peers: vec![], + port: Some(custom_port), + rpc_port: Some(custom_rpc_port), + url: None, + user: get_username(), + version: None, + }, + &mut node_registry, + &mock_service_control, + Box::new(MockSafeReleaseRepository::new()), + ) + .await; + + match result { + Ok(_) => panic!("This test should result in an error"), + Err(e) => { + assert_eq!( + format!("Port {custom_port} is already in use"), + e.to_string() + ) + } + } + + Ok(()) + } + + #[tokio::test] + async fn add_node_should_return_error_if_custom_rpc_port_is_in_use() -> Result<()> { + let mut mock_service_control = MockServiceControl::new(); + + let mut node_registry = NodeRegistry { nodes: vec![] }; + let temp_dir = assert_fs::TempDir::new()?; + let node_data_dir = temp_dir.child("data"); + node_data_dir.create_dir_all()?; + let node_logs_dir = temp_dir.child("logs"); + node_logs_dir.create_dir_all()?; + + let custom_port = 12000; + let custom_rpc_port = 12001; + + let mut seq = Sequence::new(); + mock_service_control + .expect_is_port_free() + .with(eq(custom_port)) + .times(1) + .returning(|_| true) + .in_sequence(&mut seq); + mock_service_control + .expect_is_port_free() + .with(eq(custom_rpc_port)) + .times(1) + .returning(|_| false) + .in_sequence(&mut seq); + + let result = add( + AddServiceOptions { + count: None, + safenode_dir_path: temp_dir.to_path_buf(), + service_data_dir_path: node_data_dir.to_path_buf(), + service_log_dir_path: node_logs_dir.to_path_buf(), + peers: vec![], + port: Some(custom_port), + rpc_port: Some(custom_rpc_port), + url: None, + user: get_username(), + version: None, + }, + &mut node_registry, + &mock_service_control, + Box::new(MockSafeReleaseRepository::new()), + ) + .await; + + match result { + Ok(_) => panic!("This test should result in an error"), + Err(e) => { + assert_eq!( + format!("Port {custom_rpc_port} is already in use"), + e.to_string() + ) + } + } + + Ok(()) + } + + #[tokio::test] + async fn add_node_should_return_error_if_custom_port_is_used_and_more_than_one_service_is_used( + ) -> Result<()> { + let mut node_registry = NodeRegistry { nodes: vec![] }; + let temp_dir = assert_fs::TempDir::new()?; + let node_data_dir = temp_dir.child("data"); + node_data_dir.create_dir_all()?; + let node_logs_dir = temp_dir.child("logs"); + node_logs_dir.create_dir_all()?; + + let custom_port = 12000; + let custom_rpc_port = 12001; + + let result = add( + AddServiceOptions { + count: Some(3), + safenode_dir_path: temp_dir.to_path_buf(), + service_data_dir_path: node_data_dir.to_path_buf(), + service_log_dir_path: node_logs_dir.to_path_buf(), + peers: vec![], + port: Some(custom_port), + rpc_port: Some(custom_rpc_port), + url: None, + user: get_username(), + version: None, + }, + &mut node_registry, + &MockServiceControl::new(), + Box::new(MockSafeReleaseRepository::new()), + ) + .await; + + match result { + Ok(_) => panic!("This test should result in an error"), + Err(e) => { + assert_eq!( + format!("Custom ports can only be used when adding a single service"), + e.to_string() + ) + } + } + + Ok(()) + } } diff --git a/src/main.rs b/src/main.rs index 2f889d0..be6960e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -70,6 +70,20 @@ pub enum SubCmd { log_dir_path: Option, #[command(flatten)] peers: PeersArgs, + /// Specify a port for the node to run on. + /// + /// If not used, a port will be selected at random. + /// + /// This option only applies when a single service is being added. + #[clap(long)] + port: Option, + /// Specify a port for the node's RPC service to run on. + /// + /// If not used, a port will be selected at random. + /// + /// This option only applies when a single service is being added. + #[clap(long)] + rpc_port: Option, /// Provide a safenode binary using a URL. /// /// The binary must be inside a zip or gzipped tar archive. @@ -168,6 +182,8 @@ async fn main() -> Result<()> { data_dir_path, log_dir_path, peers, + port, + rpc_port, url, user, version, @@ -203,6 +219,8 @@ async fn main() -> Result<()> { AddServiceOptions { count, peers: parse_peers_args(peers).await.unwrap_or(vec![]), + port, + rpc_port, safenode_dir_path: service_data_dir_path.clone(), service_data_dir_path, service_log_dir_path, diff --git a/src/service.rs b/src/service.rs index 3a076ed..5ed44a9 100644 --- a/src/service.rs +++ b/src/service.rs @@ -43,6 +43,7 @@ pub trait ServiceControl { fn create_service_user(&self, username: &str) -> Result<()>; fn get_available_port(&self) -> Result; fn install(&self, config: ServiceConfig) -> Result<()>; + fn is_port_free(&self, port: u16) -> bool; fn is_service_process_running(&self, pid: u32) -> bool; fn start(&self, service_name: &str) -> Result<()>; fn stop(&self, service_name: &str) -> Result<()>; @@ -142,6 +143,10 @@ impl ServiceControl for NodeServiceManager { Ok(()) } + fn is_port_free(&self, port: u16) -> bool { + TcpListener::bind(("127.0.0.1", port)).is_ok() + } + fn is_service_process_running(&self, pid: u32) -> bool { let mut system = System::new_all(); system.refresh_all();