diff --git a/sys/build.rs b/sys/build.rs index 1659991..84b394b 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -1,4 +1,5 @@ use cmake::Config; +use glob::glob; use std::env; use std::path::{Path, PathBuf}; use std::process::Command; @@ -24,6 +25,60 @@ fn copy_folder(src: &Path, dst: &Path) { } } +fn extract_lib_names(out_dir: &Path) -> Vec { + // Construct the pattern based on the target platform + let lib_suffix = if cfg!(windows) { "*.lib" } else { "*.a" }; + let pattern = out_dir.join(format!("build/lib/{}", lib_suffix)); + + let mut lib_names = Vec::new(); + + // Process the libraries based on the pattern + for entry in glob(pattern.to_str().unwrap()).unwrap() { + match entry { + Ok(path) => { + let stem = path.file_stem().unwrap(); + let stem_str = stem.to_str().unwrap(); + + // Remove the "lib" prefix if present + let lib_name = if stem_str.starts_with("lib") { + stem_str.strip_prefix("lib").unwrap_or(stem_str) + } else { + stem_str + }; + + lib_names.push(lib_name.to_string()); + } + Err(e) => println!("cargo:warning=error={}", e), + } + } + + lib_names +} + +fn extract_lib_assets(out_dir: &Path) -> Vec { + let shared_lib_suffix = if cfg!(windows) { + ".dll" + } else if cfg!(target_os = "macos") { + ".dylib" + } else { + ".so" + }; + + let pattern = out_dir.join(format!("lib/{}", shared_lib_suffix)); + let mut files = Vec::new(); + + for entry in glob(pattern.to_str().unwrap()).unwrap() { + match entry { + Ok(path) => { + files.push(path); + } + Err(e) => eprintln!("cargo:warning=error={}", e), + } + } + + files +} + fn main() { let target = env::var("TARGET").unwrap(); let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); @@ -38,47 +93,6 @@ fn main() { "Release" }; - let shared_lib_suffix = if cfg!(windows) { - ".dll" - } else if cfg!(target_os = "macos") { - ".dylib" - } else { - ".so" - }; - let sherpa_libs_kind = if build_shared_libs { "dylib" } else { "static" }; - let sherpa_libs: &[&str] = if build_shared_libs { - // shared - &["sherpa-onnx-c-api", "onnxruntime"] - } else if cfg!(feature = "tts") { - // static with tts - &[ - "sherpa-onnx-c-api", - "sherpa-onnx-core", - "kaldi-decoder-core", - "sherpa-onnx-kaldifst-core", - "sherpa-onnx-fstfar", - "sherpa-onnx-fst", - "kaldi-native-fbank-core", - "piper_phonemize", - "espeak-ng", - "ucd", - "onnxruntime", - "ssentencepiece_core", - ] - } else { - // static without tts - &[ - "sherpa-onnx-c-api", - "sherpa-onnx-core", - "kaldi-decoder-core", - "sherpa-onnx-kaldifst-core", - "sherpa-onnx-fst", - "kaldi-native-fbank-core", - "onnxruntime", - "ssentencepiece_core", - ] - }; - // Prepare sherpa-onnx source if !sherpa_dst.exists() { copy_folder(&sherpa_src, &sherpa_dst); @@ -155,29 +169,15 @@ fn main() { // Search paths println!("cargo:rustc-link-search={}", out_dir.join("lib").display()); - println!("cargo:rustc-link-search=native={}", bindings_dir.display()); - - // Cuda - if cfg!(feature = "cuda") { - println!( - "cargo:rustc-link-search={}", - out_dir.join(format!("build/lib/{}", profile)).display() - ); - if cfg!(windows) { - println!( - "cargo:rustc-link-search=native={}", - out_dir.join("build/_deps/onnxruntime-src/lib").display() - ); - } - if cfg!(target_os = "linux") { - println!( - "cargo:rustc-link-search=native={}", - out_dir.join("build/lib").display() - ); - } - } + println!( + "cargo:rustc-link-search={}", + out_dir.join("build/lib").display() + ); + println!("cargo:rustc-link-search={}", bindings_dir.display()); // Link libraries + let sherpa_libs_kind = if build_shared_libs { "dylib" } else { "static" }; + let sherpa_libs = extract_lib_names(&out_dir); for lib in sherpa_libs { println!( "{}", @@ -190,13 +190,6 @@ fn main() { println!("cargo:rustc-link-lib=dylib=msvcrtd"); } - // Cuda - if cfg!(feature = "cuda") && cfg!(windows) { - println!("cargo:rustc-link-lib=static=onnxruntime_providers_cuda"); - println!("cargo:rustc-link-lib=static=onnxruntime_providers_shared"); - println!("cargo:rustc-link-lib=static=onnxruntime_providers_tensorrt"); - } - // macOS if cfg!(target_os = "macos") { println!("cargo:rustc-link-lib=framework=Foundation"); @@ -221,16 +214,13 @@ fn main() { // copy DLLs to target if build_shared_libs { - for entry in glob::glob(&format!( - "{}/*{}", - out_dir.join("lib").to_str().unwrap(), - shared_lib_suffix - )) - .unwrap() - .flatten() - { - let dst = target_dir.join(entry.file_name().unwrap()); - std::fs::copy(entry, dst).unwrap(); + let libs_assets = extract_lib_assets(&out_dir); + for asset in libs_assets { + let asset_clone = asset.clone(); + let filename = asset_clone.file_name().unwrap(); + let filename = filename.to_str().unwrap(); + let dst = target_dir.join(filename); + std::fs::copy(asset.clone(), dst).unwrap(); } } }