Skip to content

Commit

Permalink
add ubuntu cuda and refactor build
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle committed Jul 13, 2024
1 parent 84179d1 commit ee5f9c0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 43 deletions.
25 changes: 18 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,32 @@ jobs:
- platform: "windows-latest"
options: '--features "cuda"'
cuda-version: "12.5.0"
- platform: "ubuntu-22.04"
options: '--features "cuda"'
cuda-version: "12.5.0"

steps:
- uses: actions/checkout@v3
with:
submodules: "true"

- name: Setup cuda
- name: Setup cuda for Windows
run: scripts/setup_cuda.ps1
env:
INPUT_CUDA_VERSION: ${{ matrix.cuda-version }}
if: contains(matrix.options, 'cuda')
if: matrix.platform == 'windows-latest' && contains(matrix.options, 'cuda')

- name: Setup cuda for Ubuntu
uses: Jimver/cuda-toolkit@master
with:
cuda: "${{ matrix.cuda-version }}"
if: contains(matrix.platform, 'ubuntu') && contains(matrix.options, 'cuda')

- uses: Swatinem/rust-cache@v2
- name: Cache Rust
uses: Swatinem/rust-cache@v2

- uses: actions-rs/toolchain@v1
- name: Setup Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true
Expand All @@ -44,8 +56,7 @@ jobs:
- name: Build
run: |
cargo build ${{ matrix.options }}
continue-on-error: true
- name: Find ONNX Runtime Libraries
- name: Find Runtime Libraries
if: matrix.platform == 'windows-latest'
run: |
C:\msys64\usr\bin\find -name "onnxruntime*.lib"
C:\msys64\usr\bin\find -name "*.lib"
73 changes: 37 additions & 36 deletions sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ fn main() {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("Failed to get CARGO_MANIFEST_DIR");
let sherpa_src = Path::new(&manifest_dir).join("sherpa-onnx");

// Prepare sherpa-onnx source
if !sherpa_dst.exists() {
std::fs::create_dir_all(&sherpa_dst).expect("Failed to create sherpa-onnx directory");

// There's some invalid files. better to use cp
#[cfg(unix)]
{
if cfg!(unix) {
std::process::Command::new("cp")
.arg("-rf")
.arg(sherpa_src.clone())
Expand All @@ -24,8 +24,7 @@ fn main() {
.expect("Failed to execute cp command");
}

#[cfg(windows)]
{
if cfg!(windows) {
std::process::Command::new("robocopy.exe")
.args(&[
"/e",
Expand All @@ -46,7 +45,8 @@ fn main() {
.to_string(),
);

// Set up bindgen builder
// Bindings

let bindings = bindgen::Builder::default()
.header("wrapper.h")
.clang_arg(format!("-I{}", sherpa_dst.display()))
Expand All @@ -63,6 +63,8 @@ fn main() {
println!("cargo:rerun-if-changed=wrapper.h");
println!("cargo:rerun-if-changed=./sherpa-onnx");

// Build with Cmake

let mut config = Config::new(&sherpa_dst);

config
Expand All @@ -80,63 +82,62 @@ fn main() {

// Cuda
// https://k2-fsa.github.io/k2/installation/cuda-cudnn.html
#[cfg(feature = "cuda")]
{
if cfg!(feature = "cuda") {
config.define("SHERPA_ONNX_ENABLE_GPU", "ON");
config.define("BUILD_SHARED_LIBS", "ON");
}

#[cfg(any(windows, target_os = "linux"))]
{
if cfg!(any(windows, target_os = "linux")) {
config.define("SHERPA_ONNX_ENABLE_PORTAUDIO", "ON");
}

let bindings_dir = config.very_verbose(true).build();

// Common
// Search paths
println!("cargo:rustc-link-search={}", out_dir.join("lib").display());
println!("cargo:rustc-link-search=native={}", bindings_dir.display());

if cfg!(feature = "cuda") && cfg!(windows) {
println!(
"cargo:rustc-link-search=native={}",
out_dir.join("build\\_deps\\onnxruntime-src\\lib").display()
);
}

// Link libraries

println!("cargo:rustc-link-lib=static=onnxruntime");

// Sherpa API
println!("cargo:rustc-link-lib=static=kaldi-native-fbank-core");
println!("cargo:rustc-link-lib=static=sherpa-onnx-core");
println!("cargo:rustc-link-lib=static=sherpa-onnx-c-api");
println!("cargo:rustc-link-lib=static=kaldi-decoder-core");
println!("cargo:rustc-link-lib=static=sherpa-onnx-kaldifst-core");
println!("cargo:rustc-link-lib=static=sherpa-onnx-fstfar");
println!("cargo:rustc-link-lib=static=ssentencepiece_core");

// Cuda
if cfg!(feature = "cuda") && cfg!(windows) {
println!("cargo:rustc-link-lib=static=onnxruntime_providers_cuda");
println!("cargo:rustc-link-lib=static=onnxruntime_providers_shared");
println!("cargo:rustc-link-lib=static=onnxruntime_providers_tensorrt");
}

// macOS
#[cfg(target_os = "macos")]
{
if cfg!(target_os = "macos") {
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=c++");
println!("cargo:rustc-link-lib=static=sherpa-onnx-kaldifst-core");
println!("cargo:rustc-link-lib=static=kaldi-decoder-core");
println!("cargo:rustc-link-lib=static=sherpa-onnx-fst");
println!("cargo:rustc-link-lib=static=sherpa-onnx-fstfar");
println!("cargo:rustc-link-lib=static=ssentencepiece_core");
}

// Linux
#[cfg(target_os = "linux")]
{
if cfg!(target_os = "linux") {
println!("cargo:rustc-link-lib=dylib=stdc++");
}

// Linux and Windows
#[cfg(any(target_os = "linux", windows))]
{
println!("cargo:rustc-link-lib=static=kaldi-decoder-core");
println!("cargo:rustc-link-lib=static=sherpa-onnx-kaldifst-core");
println!("cargo:rustc-link-lib=static=sherpa-onnx-fst");
println!("cargo:rustc-link-lib=static=sherpa-onnx-fstfar");
println!("cargo:rustc-link-lib=static=ssentencepiece_core");
}

// TTS
#[cfg(feature = "tts")]
{
if cfg!(feature = "tts") {
println!("cargo:rustc-link-lib=static=espeak-ng");
println!("cargo:rustc-link-lib=static=kaldi-decoder-core");
println!("cargo:rustc-link-lib=static=sherpa-onnx-kaldifst-core");
println!("cargo:rustc-link-lib=static=sherpa-onnx-fst");
println!("cargo:rustc-link-lib=static=sherpa-onnx-fstfar");
println!("cargo:rustc-link-lib=static=ssentencepiece_core");
println!("cargo:rustc-link-lib=static=piper_phonemize");
println!("cargo:rustc-link-lib=static=ucd");
}
Expand Down

0 comments on commit ee5f9c0

Please sign in to comment.