Skip to content

Commit

Permalink
refactor build
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle committed Aug 3, 2024
1 parent 143087b commit 8588097
Showing 1 changed file with 69 additions and 79 deletions.
148 changes: 69 additions & 79 deletions sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use cmake::Config;
use glob::glob;
use std::env;
use std::path::{Path, PathBuf};
use std::process::Command;
Expand All @@ -24,6 +25,60 @@ fn copy_folder(src: &Path, dst: &Path) {
}
}

fn extract_lib_names(out_dir: &Path) -> Vec<String> {
// Construct the pattern based on the target platform
let lib_suffix = if cfg!(windows) { "*.lib" } else { "*.a" };
let pattern = out_dir.join(format!("build/lib/{}", lib_suffix));

let mut lib_names = Vec::new();

// Process the libraries based on the pattern
for entry in glob(pattern.to_str().unwrap()).unwrap() {
match entry {
Ok(path) => {
let stem = path.file_stem().unwrap();
let stem_str = stem.to_str().unwrap();

// Remove the "lib" prefix if present
let lib_name = if stem_str.starts_with("lib") {
stem_str.strip_prefix("lib").unwrap_or(stem_str)
} else {
stem_str
};

lib_names.push(lib_name.to_string());
}
Err(e) => println!("cargo:warning=error={}", e),
}
}

lib_names
}

fn extract_lib_assets(out_dir: &Path) -> Vec<PathBuf> {
let shared_lib_suffix = if cfg!(windows) {
".dll"
} else if cfg!(target_os = "macos") {
".dylib"
} else {
".so"
};

let pattern = out_dir.join(format!("lib/{}", shared_lib_suffix));
let mut files = Vec::new();

for entry in glob(pattern.to_str().unwrap()).unwrap() {
match entry {
Ok(path) => {
files.push(path);
}
Err(e) => eprintln!("cargo:warning=error={}", e),
}
}

files
}

fn main() {
let target = env::var("TARGET").unwrap();
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
Expand All @@ -38,47 +93,6 @@ fn main() {
"Release"
};

let shared_lib_suffix = if cfg!(windows) {
".dll"
} else if cfg!(target_os = "macos") {
".dylib"
} else {
".so"
};
let sherpa_libs_kind = if build_shared_libs { "dylib" } else { "static" };
let sherpa_libs: &[&str] = if build_shared_libs {
// shared
&["sherpa-onnx-c-api", "onnxruntime"]
} else if cfg!(feature = "tts") {
// static with tts
&[
"sherpa-onnx-c-api",
"sherpa-onnx-core",
"kaldi-decoder-core",
"sherpa-onnx-kaldifst-core",
"sherpa-onnx-fstfar",
"sherpa-onnx-fst",
"kaldi-native-fbank-core",
"piper_phonemize",
"espeak-ng",
"ucd",
"onnxruntime",
"ssentencepiece_core",
]
} else {
// static without tts
&[
"sherpa-onnx-c-api",
"sherpa-onnx-core",
"kaldi-decoder-core",
"sherpa-onnx-kaldifst-core",
"sherpa-onnx-fst",
"kaldi-native-fbank-core",
"onnxruntime",
"ssentencepiece_core",
]
};

// Prepare sherpa-onnx source
if !sherpa_dst.exists() {
copy_folder(&sherpa_src, &sherpa_dst);
Expand Down Expand Up @@ -155,29 +169,15 @@ fn main() {

// Search paths
println!("cargo:rustc-link-search={}", out_dir.join("lib").display());
println!("cargo:rustc-link-search=native={}", bindings_dir.display());

// Cuda
if cfg!(feature = "cuda") {
println!(
"cargo:rustc-link-search={}",
out_dir.join(format!("build/lib/{}", profile)).display()
);
if cfg!(windows) {
println!(
"cargo:rustc-link-search=native={}",
out_dir.join("build/_deps/onnxruntime-src/lib").display()
);
}
if cfg!(target_os = "linux") {
println!(
"cargo:rustc-link-search=native={}",
out_dir.join("build/lib").display()
);
}
}
println!(
"cargo:rustc-link-search={}",
out_dir.join("build/lib").display()
);
println!("cargo:rustc-link-search={}", bindings_dir.display());

// Link libraries
let sherpa_libs_kind = if build_shared_libs { "dylib" } else { "static" };
let sherpa_libs = extract_lib_names(&out_dir);
for lib in sherpa_libs {
println!(
"{}",
Expand All @@ -190,13 +190,6 @@ fn main() {
println!("cargo:rustc-link-lib=dylib=msvcrtd");
}

// Cuda
if cfg!(feature = "cuda") && cfg!(windows) {
println!("cargo:rustc-link-lib=static=onnxruntime_providers_cuda");
println!("cargo:rustc-link-lib=static=onnxruntime_providers_shared");
println!("cargo:rustc-link-lib=static=onnxruntime_providers_tensorrt");
}

// macOS
if cfg!(target_os = "macos") {
println!("cargo:rustc-link-lib=framework=Foundation");
Expand All @@ -221,16 +214,13 @@ fn main() {

// copy DLLs to target
if build_shared_libs {
for entry in glob::glob(&format!(
"{}/*{}",
out_dir.join("lib").to_str().unwrap(),
shared_lib_suffix
))
.unwrap()
.flatten()
{
let dst = target_dir.join(entry.file_name().unwrap());
std::fs::copy(entry, dst).unwrap();
let libs_assets = extract_lib_assets(&out_dir);
for asset in libs_assets {
let asset_clone = asset.clone();
let filename = asset_clone.file_name().unwrap();
let filename = filename.to_str().unwrap();
let dst = target_dir.join(filename);
std::fs::copy(asset.clone(), dst).unwrap();
}
}
}
Expand Down

0 comments on commit 8588097

Please sign in to comment.