Skip to content

Commit

Permalink
Separate python bindings, cli and download logic into features
Browse files Browse the repository at this point in the history
This will allow easier use of heliport as a crate compiled with less
dependencies. Maturin build tested and running, but crate usage has not
been tested.
  • Loading branch information
ZJaume committed Sep 12, 2024
1 parent 7c12ec4 commit 151db84
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 173 deletions.
19 changes: 11 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@ env_logger = "0.10"
strum = { version = "0.25", features = ["derive"] }
strum_macros = "0.25"
wyhash2 = "0.2.1"
pyo3 = { version = "0.22", features = ["gil-refs"] }
target = "2.1.0"
tempfile = "3"
reqwest = { version = "0.12", features = ["stream"] }
tokio = { version = "1", features = ["io-util", "rt-multi-thread", "signal", "macros"] }
futures-util = "0.3"
clap = { version = "4.5", features = ["derive"] }
pyo3 = { version = "0.22", features = ["gil-refs"], optional = true }
target = { version = "2.1.0", optional = true }
tempfile = { version = "3", optional = true }
reqwest = { version = "0.12", features = ["stream"], optional = true }
tokio = { version = "1", features = ["io-util", "rt-multi-thread", "signal", "macros"], optional = true }
futures-util = { version = "0.3", optional = true }
clap = { version = "4.5", features = ["derive"], optional = true}
anyhow = "1.0"

[dev-dependencies]
test-log = "0.2.15"

[features]
# Put log features in default, to allow crates using heli as a library, disable them
default = ["log/max_level_debug", "log/release_max_level_debug"]
default = ["cli", "log/max_level_debug", "log/release_max_level_debug"]
cli = ["download", "python", "dep:clap", "dep:target"]
download = ["dep:tokio", "dep:tempfile", "dep:reqwest", "dep:futures-util"]
python = ["dep:pyo3"]
6 changes: 3 additions & 3 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use target;
use crate::languagemodel::{Model, ModelType};
use crate::identifier::Identifier;
use crate::utils::Abort;
use crate::module_path;
use crate::utils;
use crate::python::module_path;
use crate::download;

#[derive(Parser, Clone)]
#[command(version, about, long_about = None)]
Expand Down Expand Up @@ -80,7 +80,7 @@ impl DownloadCmd {
target::os(),
target::arch());

utils::download_file_and_extract(&url, download_path.to_str().unwrap()).unwrap();
download::download_file_and_extract(&url, download_path.to_str().unwrap()).unwrap();
info!("Finished");

Ok(())
Expand Down
97 changes: 97 additions & 0 deletions src/download.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use std::process::{exit, Command};
use std::fs;

use log::{info, warn, debug, error};
use tokio::io::AsyncWriteExt;
use tokio::runtime::Runtime;
use tokio::signal::unix;
use futures_util::StreamExt;
use tempfile::NamedTempFile;
use anyhow::{bail, Context, Result};
use reqwest;

// Run a listener for cancel signals, if received terminate
// if a filename is provided, delete it
async fn run_cancel_handler(filename: Option<String>) {
tokio::spawn(async move {
let mut sigint = unix::signal(unix::SignalKind::interrupt()).unwrap();
let mut sigterm = unix::signal(unix::SignalKind::terminate()).unwrap();
let mut sigalrm = unix::signal(unix::SignalKind::alarm()).unwrap();
let mut sighup = unix::signal(unix::SignalKind::hangup()).unwrap();
loop {
let kind;
tokio::select! {
_ = sigint.recv() => { kind = "SIGINT" },
_ = sigterm.recv() => { kind = "SIGTERM" },
_ = sigalrm.recv() => { kind = "SIGALRM" },
_ = sighup.recv() => { kind = "SIGHUP" },
else => break,
}
error!("Received {}, exiting", kind);
if let Some(f) = filename {
// panic if cannot be deleted?
debug!("Cleaning temp: {}", f);
if fs::remove_file(&f).is_err(){
warn!("Could not remove temporary file: {f}");
}
}
exit(1);
}
});
}

// Download a file to a path
async fn download_file_async(url: &str, filepath: &str) -> Result<()> {
info!("Downloading file from '{url}'");
// Create a download stream
let response = reqwest::get(url).await?;
let status = response.status();
debug!("Response status: {}", status);
if !status.is_success() {
error!("Could not download file, HTTP status code: {status}");
exit(1);
}

let mut response_stream = response.bytes_stream();
let mut outfile = tokio::fs::File::create(filepath).await?;

debug!("Writing file to '{filepath}'");
// asyncronously write to the file every piece of bytes that come from the stream
while let Some(bytes) = response_stream.next().await {
outfile.write_all(&bytes?).await?;
}

Ok(())
}

// Download a .tgz file and extract it, async version
async fn download_file_and_extract_async(url: &str, extractpath: &str) -> Result<()> {
let binding = NamedTempFile::new()?.into_temp_path();
let temp_path = binding
.to_str()
.context("Error converting tempfile name to string")?;
run_cancel_handler(Some(String::from(temp_path))).await;
download_file_async(url, &temp_path).await?;

let mut command = Command::new("/bin/tar");
command.args(["xvfm", temp_path, "-C", extractpath, "--strip-components", "1"]);
debug!("Running command {:?}", command.get_args());
let comm_output = command.output()?;
debug!("Command status: {:?}", comm_output.status);
// If the command fails, return an error, containing command stderr output
if !comm_output.status.success() {
let stderr_out = String::from_utf8_lossy(&comm_output.stderr);
bail!("Command failed during execution: {stderr_out}");
}
debug!("Command stderr: {}", std::str::from_utf8(&comm_output.stderr)?);
debug!("Command stdout: {}", std::str::from_utf8(&comm_output.stdout)?);
Ok(())
}

// Download a .tgz file and extract it, call async version and block on it
pub fn download_file_and_extract(url: &str, extractpath: &str) -> Result<()> {
let runtime = Runtime::new()?;
runtime.block_on(download_file_and_extract_async(url, extractpath))
}


72 changes: 6 additions & 66 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,70 +1,10 @@
use std::path::PathBuf;
use std::env;

use pyo3::prelude::*;

use crate::identifier::Identifier;
use crate::cli::cli_run;
use crate::utils::Abort;

pub mod languagemodel;
pub mod identifier;
pub mod lang;
mod utils;
#[cfg(feature = "download")]
pub mod download;
pub mod utils;
#[cfg(feature = "cli")]
mod cli;

// Call python interpreter and obtain python path of our module
pub fn module_path() -> PyResult<PathBuf> {
let mut path = PathBuf::new();
Python::with_gil(|py| {
// Instead of hardcoding the module name, obtain it from the crate name at compile time
let module = PyModule::import_bound(py, env!("CARGO_PKG_NAME"))?;
let paths: Vec<&str> = module
.getattr("__path__")?
.extract()?;
// __path__ attribute returns a list of paths, return first
path.push(paths[0]);
Ok(path)
})
}


/// Bindings to Python
#[pyclass(name = "Identifier")]
pub struct PyIdentifier {
inner: Identifier,
}

#[pymethods]
impl PyIdentifier {
#[new]
fn new() -> PyResult<Self> {
let modulepath = module_path().expect("Error loading python module path");
let identifier = Identifier::load(&modulepath.to_str().unwrap())
.or_abort(1);

Ok(Self {
inner: identifier,
})
}

fn identify(&mut self, text: &str) -> String {
self.inner.identify(text).0.to_string()
}
}

// #[pyclass(name = "Lang")]
// pub struct PyLang {
// inner: Lang,
// }



#[pymodule]
fn heliport(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(cli_run))?;
m.add_class::<PyIdentifier>()?;
// m.add_class::<PyLang>()?;

Ok(())
}
#[cfg(feature = "python")]
mod python;
63 changes: 63 additions & 0 deletions src/python.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::path::PathBuf;
use std::env;

use pyo3::prelude::*;

#[cfg(feature = "cli")]
use crate::cli::cli_run;
use crate::utils::Abort;
use crate::identifier::Identifier;

// Call python interpreter and obtain python path of our module
pub fn module_path() -> PyResult<PathBuf> {
let mut path = PathBuf::new();
Python::with_gil(|py| {
// Instead of hardcoding the module name, obtain it from the crate name at compile time
let module = PyModule::import_bound(py, env!("CARGO_PKG_NAME"))?;
let paths: Vec<&str> = module
.getattr("__path__")?
.extract()?;
// __path__ attribute returns a list of paths, return first
path.push(paths[0]);
Ok(path)
})
}

/// Bindings to Python
#[pyclass(name = "Identifier")]
pub struct PyIdentifier {
inner: Identifier,
}

#[pymethods]
impl PyIdentifier {
#[new]
fn new() -> PyResult<Self> {
let modulepath = module_path().expect("Error loading python module path");
let identifier = Identifier::load(&modulepath.to_str().unwrap())
.or_abort(1);

Ok(Self {
inner: identifier,
})
}

fn identify(&mut self, text: &str) -> String {
self.inner.identify(text).0.to_string()
}
}

// #[pyclass(name = "Lang")]
// pub struct PyLang {
// inner: Lang,
// }

#[pymodule]
fn heliport(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
#[cfg(feature = "cli")]
m.add_wrapped(wrap_pyfunction!(cli_run))?;
m.add_class::<PyIdentifier>()?;
// m.add_class::<PyLang>()?;

Ok(())
}
98 changes: 2 additions & 96 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,106 +1,12 @@
use std::process::{exit, Command};
use std::fs;

use log::{info, warn, debug, error};
use tokio::io::AsyncWriteExt;
use tokio::runtime::Runtime;
use tokio::signal::unix;
use futures_util::StreamExt;
use tempfile::NamedTempFile;
use anyhow::{bail, Context, Result};
use reqwest;

// Run a listener for cancel signals, if received terminate
// if a filename is provided, delete it
async fn run_cancel_handler(filename: Option<String>) {
tokio::spawn(async move {
let mut sigint = unix::signal(unix::SignalKind::interrupt()).unwrap();
let mut sigterm = unix::signal(unix::SignalKind::terminate()).unwrap();
let mut sigalrm = unix::signal(unix::SignalKind::alarm()).unwrap();
let mut sighup = unix::signal(unix::SignalKind::hangup()).unwrap();
loop {
let kind;
tokio::select! {
_ = sigint.recv() => { kind = "SIGINT" },
_ = sigterm.recv() => { kind = "SIGTERM" },
_ = sigalrm.recv() => { kind = "SIGALRM" },
_ = sighup.recv() => { kind = "SIGHUP" },
else => break,
}
error!("Received {}, exiting", kind);
if let Some(f) = filename {
// panic if cannot be deleted?
debug!("Cleaning temp: {}", f);
if fs::remove_file(&f).is_err(){
warn!("Could not remove temporary file: {f}");
}
}
exit(1);
}
});
}

// Download a file to a path
async fn download_file_async(url: &str, filepath: &str) -> Result<()> {
info!("Downloading file from '{url}'");
// Create a download stream
let response = reqwest::get(url).await?;
let status = response.status();
debug!("Response status: {}", status);
if !status.is_success() {
error!("Could not download file, HTTP status code: {status}");
exit(1);
}

let mut response_stream = response.bytes_stream();
let mut outfile = tokio::fs::File::create(filepath).await?;

debug!("Writing file to '{filepath}'");
// asyncronously write to the file every piece of bytes that come from the stream
while let Some(bytes) = response_stream.next().await {
outfile.write_all(&bytes?).await?;
}

Ok(())
}

// Download a .tgz file and extract it, async version
async fn download_file_and_extract_async(url: &str, extractpath: &str) -> Result<()> {
let binding = NamedTempFile::new()?.into_temp_path();
let temp_path = binding
.to_str()
.context("Error converting tempfile name to string")?;
run_cancel_handler(Some(String::from(temp_path))).await;
download_file_async(url, &temp_path).await?;

let mut command = Command::new("/bin/tar");
command.args(["xvfm", temp_path, "-C", extractpath, "--strip-components", "1"]);
debug!("Running command {:?}", command.get_args());
let comm_output = command.output()?;
debug!("Command status: {:?}", comm_output.status);
// If the command fails, return an error, containing command stderr output
if !comm_output.status.success() {
let stderr_out = String::from_utf8_lossy(&comm_output.stderr);
bail!("Command failed during execution: {stderr_out}");
}
debug!("Command stderr: {}", std::str::from_utf8(&comm_output.stderr)?);
debug!("Command stdout: {}", std::str::from_utf8(&comm_output.stdout)?);
Ok(())
}

// Download a .tgz file and extract it, call async version and block on it
pub fn download_file_and_extract(url: &str, extractpath: &str) -> Result<()> {
let runtime = Runtime::new()?;
runtime.block_on(download_file_and_extract_async(url, extractpath))
}
use std::process::exit;
use log::error;

// Trait that extracts the contained ok value or aborts if error
// sending the error message to the log
pub trait Abort<T> {
fn or_abort(self, exit_code: i32) -> T;
}


impl<T, E: std::fmt::Display> Abort<T> for Result<T, E>
{
fn or_abort(self, exit_code: i32) -> T {
Expand Down

0 comments on commit 151db84

Please sign in to comment.