diff --git a/Cargo.lock b/Cargo.lock index 93ad605..ca4d7e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -190,6 +190,46 @@ dependencies = [ "windows-link", ] +[[package]] +name = "clap" +version = "4.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" + [[package]] name = "colorchoice" version = "1.0.4" @@ -505,6 +545,12 @@ version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "http" version = "1.3.1" @@ -1310,6 +1356,17 @@ dependencies = [ "libc", ] +[[package]] +name = "self-replace" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03ec815b5eab420ab893f63393878d89c90fdd94c0bcc44c07abb8ad95552fb7" +dependencies = [ + "fastrand", + "tempfile", + "windows-sys 0.52.0", +] + [[package]] name = "serde" version = "1.0.219" @@ -1421,6 +1478,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -1716,17 +1779,21 @@ version = "0.1.0" dependencies = [ "anyhow", "chrono", + "clap", "embed-manifest", "env_logger", "futures", "inquire", "log", "reqwest", + "self-replace", "serde", "serde_json", + "tempfile", "thiserror", "tokio", "uuid", + "windows", "windows-registry 0.5.3", "windows-result", "winres", @@ -1878,11 +1945,33 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core", + "windows-future", + "windows-link", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core", +] + [[package]] name = "windows-core" -version = "0.61.1" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46ec44dc15085cea82cf9c78f85a9114c463a369786585ad2882d1ff0b0acf40" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement", "windows-interface", @@ -1891,6 +1980,17 @@ dependencies = [ "windows-strings 0.4.2", ] +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core", + "windows-link", + "windows-threading", +] + [[package]] name = "windows-implement" version = "0.60.0" @@ -1919,6 +2019,16 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core", + "windows-link", +] + [[package]] name = "windows-registry" version = "0.4.0" @@ -2042,6 +2152,15 @@ dependencies = [ "windows_x86_64_msvc 0.53.0", ] +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" diff --git a/Cargo.toml b/Cargo.toml index 5a971fe..5fd48c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,16 +6,23 @@ edition = "2024" [dependencies] anyhow = "1.0.98" chrono = { version = "0.4.41", features = ["serde"] } +clap = { version = "4.5.40", features = ["derive"] } env_logger = "0.11.8" futures = "0.3.31" inquire = "0.7.5" log = "0.4.27" reqwest = { version = "0.12.15", features = ["json", "stream"] } +self-replace = "1.5.0" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" +tempfile = "3.20.0" thiserror = "2.0.12" tokio = { version = "1.45.0", features = ["full"] } uuid = { version = "1.16.0", features = ["serde"] } +windows = { version = "0.61.3", features = [ + "Win32_System_Console", + "Win32_System_Services", +] } windows-registry = "0.5.3" windows-result = "0.3.4" diff --git a/README.md b/README.md new file mode 100644 index 0000000..400cfc5 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# Valthrun Loader + +The Valthrun Loader automatically maps the [Valthrun Kernel Driver](https://github.com/valthrun/valthrun-driver-kernel) using [KDMapper](https://github.com/TheCruZ/kdmapper) and starts the [Valthrun Overlay](https://github.com/valthrun/valthrun). diff --git a/build.rs b/build.rs index d8c216f..831965b 100644 --- a/build.rs +++ b/build.rs @@ -9,7 +9,7 @@ use winres::WindowsResource; fn main() -> Result<(), Box> { { - let git_hash = if Path::new("../.git").exists() { + let git_hash = if Path::new(".git").exists() { match { Command::new("git").args(&["rev-parse", "HEAD"]).output() } { Ok(output) => String::from_utf8(output.stdout).expect("the git hash to be utf-8"), Err(error) => { diff --git a/src/api.rs b/src/api.rs index 872af69..bbd7e3a 100644 --- a/src/api.rs +++ b/src/api.rs @@ -7,7 +7,8 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - util::{self}, + components, + utils::{self}, version::{self}, }; @@ -110,9 +111,18 @@ pub async fn get_latest_artifact_version( http: &Client, artifact_slug: &str, ) -> anyhow::Result { - let artifact = get_artifact(http, &artifact_slug).await?.artifact; - let track_response = - get_track(http, &artifact_slug, &artifact.default_track.to_string()).await?; + let artifact = get_artifact(http, artifact_slug).await?.artifact; + + get_latest_artifact_track_version(http, artifact_slug, &artifact.default_track.to_string()) + .await +} + +pub async fn get_latest_artifact_track_version( + http: &Client, + artifact_slug: &str, + track_slug: &str, +) -> anyhow::Result { + let track_response = get_track(http, &artifact_slug, &track_slug).await?; let latest_version = track_response .versions @@ -126,40 +136,62 @@ pub async fn get_latest_artifact_version( pub async fn download_latest_artifact_version( http: &Client, - artifact_slug: &str, - output_name: &str, + artifact: &components::Artifact, ) -> anyhow::Result { - let latest_version = get_latest_artifact_version(http, artifact_slug) + let latest_version = get_latest_artifact_version(http, artifact.slug()) .await .context("get latest artifact version")?; - let stored_hash = version::get_stored_version_hash(artifact_slug) + let stored_hash = version::get_stored_version_hash(artifact.slug()) .await .context("get stored version hash")?; - let output_path = util::get_downloads_path() + let output_path = utils::get_downloads_path() .context("get downloads path")? - .join(output_name); + .join(artifact.file_name()); let should_download = !output_path.is_file() || stored_hash .is_none_or(|hash| !version::compare_hashes(&hash, &latest_version.version_hash)); if should_download { - util::download_file( + if output_path.is_file() { + log::info!( + "{} is outdated. Downloading new version {} ({}).", + artifact.name(), + latest_version.version, + latest_version.version_hash + ); + } else { + log::info!( + "{} not found locally. Downloading version {} ({}).", + artifact.name(), + latest_version.version, + latest_version.version_hash + ); + } + + utils::download_file( http, format!( "https://valth.run/api/artifacts/{}/{}/{}/download", - artifact_slug, latest_version.track, latest_version.id + artifact.slug(), + latest_version.track, + latest_version.id ), &output_path, ) .await .context("download file")?; - version::set_stored_version_hash(artifact_slug, &latest_version.version_hash) + version::set_stored_version_hash(artifact.slug(), &latest_version.version_hash) .await .context("set stored version hash")?; + } else { + log::info!( + "Latest version of {} found locally. Skipping download.", + artifact.name() + ); } Ok(output_path) diff --git a/src/commands/launch.rs b/src/commands/launch.rs new file mode 100644 index 0000000..c1a4ea1 --- /dev/null +++ b/src/commands/launch.rs @@ -0,0 +1,46 @@ +use anyhow::Context; + +use crate::{api, components, game, utils}; + +pub async fn launch(http: &reqwest::Client, enhancer: components::Enhancer) -> anyhow::Result<()> { + for artifact in enhancer.required_artifacts() { + api::download_latest_artifact_version(http, &artifact) + .await + .context("failed to download {}")?; + } + + // TODO: Make it game-independent to also allow PUBG, for example + if game::is_running() + .await + .context("failed to check if game is running")? + { + log::info!("Counter-Strike 2 is already running."); + } else { + log::info!("Counter-Strike 2 is not running."); + + if utils::confirm_default("Do you want to launch the game?", true)? { + log::info!("Waiting for Counter-Strike 2 to start"); + game::launch_and_wait() + .await + .context("failed to wait for cs2 to launch")?; + } + } + + utils::invoke_ps_command(&format!( + "Start-Process -FilePath '{}' -WorkingDirectory '{}'", + utils::get_downloads_path()? + .join(enhancer.artifact_to_execute().file_name()) + .display(), + std::env::current_exe() + .context("get current exe")? + .parent() + .context("get parent path")? + .display() + )) + .await + .context("failed to start overlay")?; + + log::info!("Valthrun will now load. Have fun!"); + + Ok(()) +} diff --git a/src/commands/map_driver.rs b/src/commands/map_driver.rs new file mode 100644 index 0000000..28a2916 --- /dev/null +++ b/src/commands/map_driver.rs @@ -0,0 +1,44 @@ +use anyhow::Context; + +use crate::{api, components, driver, fixes, utils}; + +pub async fn map_driver(http: &reqwest::Client) -> anyhow::Result<()> { + log::info!("Checking for interfering services"); + + for service in [c"faceit", c"vgc", c"vgk", c"ESEADriver2"] { + if fixes::is_service_running(service).context("check service running")? { + log::error!( + "The service '{}' will cause the driver mapping to fail. In order to proceed, you need to stop this service.", + service.to_str()? + ); + + if utils::confirm_default("Do you want to stop this service?", true)? { + fixes::stop_service(service.to_str()?) + .await + .context("stop service")?; + } + } + } + + api::download_latest_artifact_version(http, &components::Artifact::KernelDriver) + .await + .context("failed to download kernel driver")?; + + log::info!("Downloading KDMapper"); + + utils::download_file( + &http, + "https://github.com/sinjs/kdmapper/releases/latest/download/kdmapper.exe", + &utils::get_downloads_path()?.join("kdmapper.exe"), + ) + .await + .context("failed to download kdmapper")?; + + driver::ui_map_driver(&http) + .await + .context("failed to map driver")?; + + log::info!("Driver successfully mapped"); + + Ok(()) +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs new file mode 100644 index 0000000..e386167 --- /dev/null +++ b/src/commands/mod.rs @@ -0,0 +1,5 @@ +mod map_driver; +pub use map_driver::*; + +mod launch; +pub use launch::*; diff --git a/src/components.rs b/src/components.rs new file mode 100644 index 0000000..6bf744e --- /dev/null +++ b/src/components.rs @@ -0,0 +1,65 @@ +use clap::ValueEnum; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum Artifact { + Cs2Overlay, + Cs2RadarClient, + DriverInterfaceKernel, + KernelDriver, +} + +impl Artifact { + pub const fn name(&self) -> &'static str { + match self { + Artifact::Cs2Overlay => "CS2 Overlay", + Artifact::Cs2RadarClient => "CS2 Radar Client", + Artifact::DriverInterfaceKernel => "Driver Interface Kernel", + Artifact::KernelDriver => "Kernel Driver", + } + } + + pub const fn slug(&self) -> &'static str { + match self { + Artifact::Cs2Overlay => "cs2-overlay", + Artifact::Cs2RadarClient => "cs2-radar-client", + Artifact::DriverInterfaceKernel => "driver-interface-kernel", + Artifact::KernelDriver => "kernel-driver", + } + } + + pub const fn file_name(&self) -> &'static str { + match self { + Artifact::Cs2Overlay => "cs2_overlay.exe", + Artifact::Cs2RadarClient => "cs2_radar_client.exe", + Artifact::DriverInterfaceKernel => "driver_interface_kernel.dll", + Artifact::KernelDriver => "kernel_driver.sys", + } + } +} + +#[derive(ValueEnum, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[clap(rename_all = "kebab-case")] +pub enum Enhancer { + Cs2Overlay, + Cs2StandaloneRadar, +} + +impl Enhancer { + pub const fn required_artifacts(&self) -> &'static [&'static Artifact] { + match self { + Enhancer::Cs2Overlay => &[&Artifact::Cs2Overlay, &Artifact::DriverInterfaceKernel], + Enhancer::Cs2StandaloneRadar => { + &[&Artifact::Cs2RadarClient, &Artifact::DriverInterfaceKernel] + } + } + } + + pub const fn artifact_to_execute(&self) -> &'static Artifact { + match self { + Enhancer::Cs2Overlay => &Artifact::Cs2Overlay, + Enhancer::Cs2StandaloneRadar => &Artifact::Cs2RadarClient, + } + } +} diff --git a/src/driver.rs b/src/driver.rs index 2cef86b..1b2432e 100644 --- a/src/driver.rs +++ b/src/driver.rs @@ -2,7 +2,7 @@ use anyhow::Context; use thiserror::Error; use tokio::process::Command; -use crate::{fixes, util}; +use crate::{fixes, utils}; #[derive(Debug, Error)] pub enum MapDriverError { @@ -27,7 +27,7 @@ pub enum MapDriverError { } pub async fn map_driver() -> Result { - let downloads_path = util::get_downloads_path() + let downloads_path = utils::get_downloads_path() .context("get downloads path") .unwrap(); let kdmapper_path = downloads_path.join("kdmapper.exe"); @@ -37,11 +37,7 @@ pub async fn map_driver() -> Result { log::warn!("Failed to add exclusion for Windows Defender: {:#}", e); }; - for service in ["faceit", "vgc", "vgk", "ESEADriver2"] { - let _ = fixes::disable_service(service).await; - } - - let output = util::invoke_command(Command::new(kdmapper_path).arg(driver_path)).await?; + let output = utils::invoke_command(Command::new(kdmapper_path).arg(driver_path)).await?; let stdout = String::from_utf8_lossy(&output.stdout); match stdout.as_ref() { @@ -55,7 +51,27 @@ pub async fn map_driver() -> Result { } } -pub async fn map_driver_handled(http: &reqwest::Client) -> anyhow::Result<()> { +pub async fn ui_map_driver(http: &reqwest::Client) -> anyhow::Result<()> { + let downloads_path = utils::get_downloads_path() + .context("get downloads path") + .unwrap(); + let kdmapper_path = downloads_path.join("kdmapper.exe"); + + if fixes::is_defender_enabled() + .await + .context("check is defender enabled")? + && !fixes::has_defender_exclusion(&kdmapper_path) + .await + .context("check defender exclusion")? + { + log::warn!("Windows Defender is enabled and there is no exclusion for the driver mapper."); + if utils::confirm_default("Do you want to add an exclusion?", true)? { + fixes::add_defender_exclusion(&kdmapper_path) + .await + .context("failed to add defender exclusion")? + } + } + if let Err(e) = map_driver().await { match e { MapDriverError::DeviceNalInUse => { @@ -65,22 +81,38 @@ pub async fn map_driver_handled(http: &reqwest::Client) -> anyhow::Result<()> { map_driver().await?; } MapDriverError::DriverBlocklist => { - if let Err(e) = fixes::set_driver_blocklist(false) { - log::warn!("Failed to disable vulnerable driver blocklist: {:#}", e); - } - if let Err(e) = fixes::set_hvci(false) { - log::warn!("Failed to disable HVCI: {:#}", e); - } + log::error!( + "Failed to load the driver due to the Vulnerable Driver Blocklist or HVCI being enabled." + ); - log::warn!("The system must restart to continue changing system settings."); - let should_restart = inquire::prompt_confirmation("Do you want to restart now?") - .context("prompt for restart")?; + if utils::confirm_default( + "Do you want to disable these Windows security features?", + true, + )? { + if let Err(e) = fixes::set_driver_blocklist(false) { + log::error!("Failed to disable vulnerable driver blocklist: {:#}", e); + } + if let Err(e) = fixes::set_hvci(false) { + log::error!("Failed to disable HVCI: {:#}", e); + } - if should_restart { - util::schedule_restart().await.context("schedule restart")?; - } + log::info!("The system must restart to apply changes to the system settings."); + let should_restart = + utils::confirm_default("Do you want to restart now?", true) + .context("prompt for restart")?; - std::process::exit(0); + if should_restart { + log::info!("Restarting system"); + + utils::schedule_restart() + .await + .context("schedule restart")?; + + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + } + + anyhow::bail!("Please restart the system yourself and try again."); + } } e => anyhow::bail!(e), } diff --git a/src/fixes.rs b/src/fixes.rs index a4e76c4..64e1065 100644 --- a/src/fixes.rs +++ b/src/fixes.rs @@ -1,23 +1,33 @@ -use std::path::Path; +use std::{ffi::CStr, path::Path}; use anyhow::Context; use tokio::process::Command; +use windows::{ + Win32::{ + Foundation::ERROR_SERVICE_DOES_NOT_EXIST, + System::Services::{ + OpenSCManagerA, OpenServiceA, QueryServiceStatus, SC_MANAGER_CONNECT, + SERVICE_QUERY_STATUS, SERVICE_RUNNING, SERVICE_STATUS, + }, + }, + core::PCSTR, +}; use windows_registry::LOCAL_MACHINE; -use crate::util::{self}; +use crate::utils::{self}; pub async fn execute_nal_fix(http: &reqwest::Client) -> anyhow::Result<()> { - let path = util::get_downloads_path()?.join("nalfix.exe"); + let path = utils::get_downloads_path()?.join("nalfix.exe"); - util::download_file( + utils::download_file( http, - "https://github.com/VollRagm/NalFix/releases/latest/download/NalFix.exe", + "https://github.com/sinjs/NalFix/releases/latest/download/NalFix.exe", &path, ) .await .context("download file")?; - util::invoke_command(&mut Command::new(path)) + utils::invoke_command(&mut Command::new(path)) .await .context("execute command")?; @@ -36,14 +46,82 @@ pub fn set_driver_blocklist(enabled: bool) -> windows_registry::Result<()> { Ok(()) } -pub async fn disable_service(name: &str) -> anyhow::Result<()> { - util::invoke_command(Command::new("sc").args(["stop", name])).await?; +pub fn is_service_running(name: &CStr) -> anyhow::Result { + let running = unsafe { + let hsc_manager = OpenSCManagerA(None, None, SC_MANAGER_CONNECT | SERVICE_QUERY_STATUS) + .context("OpenSCManagerA")?; + + let service = match OpenServiceA( + hsc_manager, + PCSTR(name.as_ptr() as *const u8), + SERVICE_QUERY_STATUS, + ) { + Ok(handle) => handle, + Err(error) if error.code() == ERROR_SERVICE_DOES_NOT_EXIST.to_hresult() => { + return Ok(false); + } + Err(error) => { + anyhow::bail!( + "failed to open service '{}': {}", + name.to_string_lossy(), + error + ) + } + }; + + let mut status = SERVICE_STATUS::default(); + QueryServiceStatus(service, &mut status).context("QueryServiceStatus")?; + + status.dwCurrentState == SERVICE_RUNNING + }; + + Ok(running) +} + +pub async fn stop_service(name: &str) -> anyhow::Result<()> { + utils::invoke_command(Command::new("sc").args(["stop", name])).await?; Ok(()) } +fn parse_powershell_boolean(output: impl AsRef) -> anyhow::Result { + let output = output.as_ref(); + if output.contains("True") { + Ok(true) + } else if output.contains("False") { + Ok(false) + } else { + anyhow::bail!( + "failed to parse command output: (expected powershell boolean, got: '{}')", + output + ) + } +} + +pub async fn is_defender_enabled() -> anyhow::Result { + let output = + utils::invoke_ps_command(&format!("(Get-MpComputerStatus).RealTimeProtectionEnabled")) + .await?; + + let output = String::from_utf8_lossy(&output.stdout); + + parse_powershell_boolean(output) +} + +pub async fn has_defender_exclusion(path: &Path) -> anyhow::Result { + let output = utils::invoke_ps_command(&format!( + "(Get-MpPreference).ExclusionPath -contains '{}'", + path.display() + )) + .await?; + + let output = String::from_utf8_lossy(&output.stdout); + + parse_powershell_boolean(output) +} + pub async fn add_defender_exclusion(path: &Path) -> anyhow::Result<()> { - util::invoke_ps_command(&format!( + utils::invoke_ps_command(&format!( "Add-MpPreference -ExclusionPath '{}' -ErrorAction SilentlyContinue", path.display() )) diff --git a/src/game.rs b/src/game.rs index add3f96..46240d3 100644 --- a/src/game.rs +++ b/src/game.rs @@ -1,17 +1,13 @@ -use std::path::Path; - -use anyhow::Context; - -use crate::util; +use crate::utils; pub async fn is_running() -> anyhow::Result { let output = - util::invoke_ps_command("Get-Process -Name cs2 -ErrorAction SilentlyContinue").await?; + utils::invoke_ps_command("Get-Process -Name cs2 -ErrorAction SilentlyContinue").await?; Ok(output.status.success()) } pub async fn launch_and_wait() -> anyhow::Result<()> { - util::invoke_ps_command("Start-Process 'steam://run/730'").await?; + utils::invoke_ps_command("Start-Process 'steam://run/730'").await?; while !is_running().await? { tokio::time::sleep(std::time::Duration::from_secs(1)).await; @@ -21,31 +17,3 @@ pub async fn launch_and_wait() -> anyhow::Result<()> { Ok(()) } - -// FIXME: Using PowerShell here is very much easier than working with the Windows API or some library. Might want to implement my own at some point.. -// If any issues regarding launching as administrator arise, this function is ready to use -#[allow(dead_code)] -pub async fn create_and_run_task(name: &str, path: &Path) -> anyhow::Result<()> { - let path = tokio::fs::canonicalize(path) - .await - .context("canonicalize path")?; - - let path = path.to_string_lossy(); - - let script = format!( - r#"$taskName = '{name}'; -$trigger = New-ScheduledTaskTrigger -Once -At (Get-Date).Date.AddMinutes(1); -$action = New-ScheduledTaskAction -Execute $taskPath -WorkingDirectory '{path}'; - -Register-ScheduledTask -TaskName $taskName -Trigger $trigger -Action $action -User "$env:COMPUTERNAME\$env:USERNAME" -RunLevel Highest -Force -ErrorAction Stop | Out-Null; -Start-ScheduledTask -TaskName $taskName; -Start-Sleep -Seconds 2; - -Unregister-ScheduledTask -TaskName $taskName -Confirm:$false -ErrorAction SilentlyContinue | Out-Null; -"# - ); - - util::invoke_ps_command(&script).await?; - - todo!() -} diff --git a/src/main.rs b/src/main.rs index 17219b6..2f0c654 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,102 +1,116 @@ -use std::collections::HashMap; +use std::process::ExitCode; -use crate::api::download_latest_artifact_version; use anyhow::{Context, Result}; +use clap::{Parser, Subcommand}; mod api; +mod commands; +mod components; mod driver; mod fixes; mod game; -mod util; +mod ui; +mod updater; +mod utils; mod version; -async fn real_main() -> Result<()> { +#[derive(Parser, Debug)] +pub struct AppArgs { + /// Enable verbose logging ($env:RUST_LOG="trace") + #[clap(short, long)] + verbose: bool, + + #[command(subcommand)] + command: Option, +} + +#[derive(Subcommand, Debug, Clone)] +pub enum AppCommand { + /// Quickly launch Valthrun with all the default settings and commands + QuickStart, + + /// Download and map the driver + MapDriver, + + /// Download and launch a enhancer + Launch { enhancer: components::Enhancer }, + + /// Display the version + Version, +} + +async fn real_main(args: AppArgs) -> Result { let http = reqwest::Client::new(); - log::info!( - "Valthrun Loader v{} ({})", - env!("CARGO_PKG_VERSION"), - env!("GIT_HASH") - ); - log::info!("Current executable was built on {}", env!("BUILD_TIME")); - - // Download all artifacts from the Valthrun Portal - log::info!("Starting download process..."); - let artifact_file_names = HashMap::from([ - ("cs2-overlay", "cs2_overlay.exe"), - ("driver-interface-kernel", "driver_interface_kernel.dll"), - ("kernel-driver", "kernel_driver.sys"), - ]); - for (artifact_slug, file_name) in artifact_file_names.iter() { - download_latest_artifact_version(&http, artifact_slug, file_name) - .await - .with_context(|| { - format!( - "failed to download latest artifact version for '{}'", - artifact_slug - ) - })?; - } + updater::ui_updater(&http).await?; + + let command = args.command.map(Ok).unwrap_or_else(ui::app_menu)?; + + match command { + AppCommand::QuickStart => { + commands::map_driver(&http) + .await + .context("execute map driver command")?; - // Download kdmapper - log::info!("Downloading additional components..."); - util::download_file( - &http, - "https://github.com/sinjs/kdmapper/releases/latest/download/kdmapper.exe", - &util::get_downloads_path()?.join("kdmapper.exe"), - ) - .await - .context("failed to download kdmapper")?; - log::info!("All files downloaded and processed successfully."); - - // Map the driver - log::info!("Mapping driver..."); - driver::map_driver_handled(&http) - .await - .context("failed to map driver with error handling")?; - - // Launch the game - if game::is_running() - .await - .context("failed to check if game is running")? - { - log::info!("Counter-Strike 2 is already running."); - } else { - log::info!("Waiting for Counter-Strike 2 to start..."); - game::launch_and_wait() - .await - .context("failed to wait for cs2 to launch")?; + commands::launch(&http, components::Enhancer::Cs2Overlay) + .await + .context("execute launch enhancer command")?; + } + AppCommand::Launch { enhancer } => { + commands::launch(&http, enhancer) + .await + .context("execute launch enhancer command")?; + } + AppCommand::MapDriver => { + commands::map_driver(&http) + .await + .context("execute map driver command")?; + } + AppCommand::Version => { + log::info!("Valthrun Loader"); + log::info!(" Version: v{}", env!("CARGO_PKG_VERSION")); + log::info!(" Build: {} ({})", env!("GIT_HASH"), env!("BUILD_TIME")) + } } - // Launch the overlay - log::info!("Valthrun will now load. Have fun!"); - util::invoke_ps_command(&format!( - "Start-Process -FilePath '{}' -WorkingDirectory '{}'", - util::get_downloads_path()? - .join("cs2_overlay.exe") - .display(), - std::env::current_exe() - .context("get current exe")? - .parent() - .context("get parent path")? - .display() - )) - .await - .context("failed to start overlay")?; - - Ok(()) + Ok(ExitCode::SUCCESS) } #[tokio::main] -async fn main() { +async fn main() -> ExitCode { + let args = match AppArgs::try_parse() { + Ok(args) => args, + Err(e) => { + eprintln!("Failed to parse arguments:\n{:#}", e); + + if !utils::is_console_invoked() { + utils::console_pause(); + } + return ExitCode::FAILURE; + } + }; + env_logger::builder() - .filter_level(log::LevelFilter::Info) + .filter_level(if args.verbose { + log::LevelFilter::Trace + } else { + log::LevelFilter::Info + }) + .format_target(args.verbose || cfg!(debug_assertions)) .parse_default_env() .init(); - if let Err(e) = real_main().await { - log::error!("{:#}", e); + let status = match real_main(args).await { + Ok(status) => status, + Err(e) => { + log::error!("{:#}", e); + ExitCode::FAILURE + } + }; + + if !utils::is_console_invoked() { + utils::console_pause(); } - inquire::prompt_text("Press enter to continue...").expect("failed to prompt user"); + status } diff --git a/src/ui.rs b/src/ui.rs new file mode 100644 index 0000000..3197494 --- /dev/null +++ b/src/ui.rs @@ -0,0 +1,43 @@ +use crate::{AppCommand, components}; + +const MENU_OPTIONS: &[(&'static str, AppCommand)] = &[ + ( + "Launch Valthrun with default settings", + AppCommand::QuickStart, + ), + ("Map Driver", AppCommand::MapDriver), + ( + "Launch Overlay", + AppCommand::Launch { + enhancer: components::Enhancer::Cs2Overlay, + }, + ), + ( + "Launch Standalone Radar", + AppCommand::Launch { + enhancer: components::Enhancer::Cs2StandaloneRadar, + }, + ), + ("Show Version", AppCommand::Version), +]; + +pub fn app_menu() -> anyhow::Result { + log::info!( + "Welcome to the Valthrun Loader v{} ({})", + env!("CARGO_PKG_VERSION"), + env!("GIT_HASH") + ); + + let choice = inquire::Select::new( + "Please select the command you want to execute:\n", + MENU_OPTIONS + .iter() + .map(|(name, _value)| *name) + .collect::>(), + ) + .with_help_message("↑↓ to move, enter to select") + .without_filtering() + .raw_prompt()?; + + Ok(MENU_OPTIONS[choice.index].1.clone()) +} diff --git a/src/updater.rs b/src/updater.rs new file mode 100644 index 0000000..d0a56f0 --- /dev/null +++ b/src/updater.rs @@ -0,0 +1,145 @@ +use anyhow::Context; +use futures::StreamExt; + +use crate::{api, utils}; + +#[derive(Debug, Clone)] +struct Update(api::Version); + +impl Update { + pub fn download_url(&self) -> String { + format!( + "https://valth.run/api/artifacts/{}/{}/{}/download", + self.0.artifact, self.0.track, self.0.id + ) + } + + pub async fn download_and_install(&self, http: &reqwest::Client) -> anyhow::Result<()> { + let mut stream = http + .get(self.download_url()) + .send() + .await + .context("send request")? + .error_for_status()? + .bytes_stream(); + + let file = tempfile::NamedTempFile::new().context("create tempfile")?; + let mut buf = std::io::BufWriter::new(&file); + + while let Some(item) = stream.next().await { + std::io::copy(&mut item?.as_ref(), &mut buf).context("copy data")?; + } + + log::debug!("Downloaded update to {}", file.path().display()); + + self_replace::self_replace(file.path()).context("replace self")?; + + Ok(()) + } +} + +async fn check_for_updates(http: &reqwest::Client) -> anyhow::Result> { + log::debug!("Checking for updates"); + + if cfg!(debug_assertions) { + log::debug!("Running in debug version, skipping update check"); + return Ok(None); + } + + let latest_version = api::get_latest_artifact_track_version(http, "valthrun-loader", "win32") + .await + .context("failed to get latest version")?; + + let has_update = env!("GIT_HASH") != latest_version.version_hash; + + log::debug!( + "Has update: {has_update} (Latest: {}, Current: {})", + latest_version.version_hash, + env!("GIT_HASH") + ); + + Ok(if has_update { + Some(Update(latest_version)) + } else { + None + }) +} + +pub async fn ui_updater(http: &reqwest::Client) -> anyhow::Result<()> { + let Some(update) = check_for_updates(http).await.context("check for updates")? else { + return Ok(()); + }; + + log::info!("A new update for the loader is available."); + log::info!( + " Installed version: {} ({})", + env!("CARGO_PKG_VERSION"), + env!("GIT_HASH") + ); + log::info!( + " Available version: {} ({})", + update.0.version, + update.0.version_hash + ); + + if !utils::confirm_default( + "Do you want to download and install the latest version?", + true, + )? { + return Ok(()); + } + + update + .download_and_install(http) + .await + .context("download and install update")?; + + log::debug!("Update installed successfully. Restarting process"); + + restart().await; +} + +async fn restart() -> ! { + async fn restart_internal() -> anyhow::Result<()> { + let current_exe = std::env::current_exe()?; + + if utils::is_console_invoked() { + // If the loader is invoked from the command line, just spawn the process with the stdio inherited + // and wait for it to exit. + + let exit = std::process::Command::new(current_exe) + .args(std::env::args_os().skip(1)) + .spawn()? + .wait()? + .code() + .unwrap_or(1); + + std::process::exit(exit); + } else { + // If the loader is invoked normally, use Start-Process since that will not break is_console_invoked(). + // Arguments do not matter in this case, and the current process exits after spawning the new one. + + utils::invoke_ps_command(&format!( + "Start-Process -FilePath '{}'", + current_exe.display(), + )) + .await?; + + std::process::exit(0) + } + } + + if let Err(e) = restart_internal() + .await + .context("Failed to restart the loader") + { + log::error!("{:#}", e); + log::error!("Please restart the loader manually."); + + if utils::is_console_invoked() { + utils::console_pause(); + } + } + + std::process::exit(0); +} diff --git a/src/util.rs b/src/utils.rs similarity index 78% rename from src/util.rs rename to src/utils.rs index 0a707da..55d82f6 100644 --- a/src/util.rs +++ b/src/utils.rs @@ -6,6 +6,7 @@ use std::{ use anyhow::Context; use futures::StreamExt; use tokio::process::Command; +use windows::Win32::System::Console::GetConsoleProcessList; pub async fn invoke_ps_command(command: &str) -> tokio::io::Result { self::invoke_command(Command::new("powershell").args(&["-Command", &command])).await @@ -43,7 +44,7 @@ pub fn get_data_path() -> anyhow::Result { .context("get current exe")? .parent() .context("get parent path")? - .join(".vthl"); + .join(".vtl"); std::fs::create_dir_all(&path)?; @@ -71,6 +72,12 @@ pub async fn download_file( url: impl reqwest::IntoUrl, path: &Path, ) -> anyhow::Result<()> { + log::debug!( + "Downloading file from {} to {}", + url.as_str(), + path.display() + ); + let mut stream = http .get(url) .send() @@ -98,3 +105,21 @@ pub async fn schedule_restart() -> anyhow::Result<()> { std::process::exit(1); } + +pub fn is_console_invoked() -> bool { + let mut result: [u32; 128] = [0u32; 128]; + + let console_count = unsafe { GetConsoleProcessList(&mut result) }; + console_count > 1 +} + +pub fn console_pause() { + inquire::prompt_text("Press enter to continue...").expect("failed to prompt user"); +} + +pub fn confirm_default(message: impl AsRef, default: bool) -> anyhow::Result { + inquire::Confirm::new(message.as_ref()) + .with_default(default) + .prompt() + .context("prompt user") +} diff --git a/src/version.rs b/src/version.rs index 4602676..3bc32ad 100644 --- a/src/version.rs +++ b/src/version.rs @@ -1,6 +1,6 @@ use anyhow::Context; -use crate::util; +use crate::utils; pub fn compare_hashes(first: &str, second: &str) -> bool { let first = normalize_hash(first); @@ -14,7 +14,7 @@ pub fn normalize_hash(hash: &str) -> String { } pub async fn get_stored_version_hash(artifact_slug: &str) -> anyhow::Result> { - let path = util::get_versions_path() + let path = utils::get_versions_path() .context("get versions path")? .join(artifact_slug); @@ -30,7 +30,7 @@ pub async fn get_stored_version_hash(artifact_slug: &str) -> anyhow::Result anyhow::Result<()> { - let path = util::get_versions_path() + let path = utils::get_versions_path() .context("get versions path")? .join(artifact_slug);