From 30c9bb242268eee90c248aeca7cd70a4403c3e25 Mon Sep 17 00:00:00 2001 From: ZJaume Date: Mon, 16 Sep 2024 09:48:34 +0000 Subject: [PATCH] Support for i/o files other than stdin/out --- src/cli.rs | 53 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index 13ea294..62034ea 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,7 +1,9 @@ -use std::io::{self, BufRead}; +use std::io::{self, BufRead, BufReader, Write, BufWriter}; +use std::fs::File; use std::path::{Path, PathBuf}; use std::env; +use anyhow::{Context, Result}; use clap::{Parser, Subcommand, Args}; use itertools::Itertools; use pyo3::prelude::*; @@ -98,25 +100,55 @@ struct IdentifyCmd { default_value_t=100000, help="Number of text segments to pre-load for parallel processing")] batch_size: usize, + + #[arg(help="Input file, default: stdin", )] + input_file: Option, + #[arg(help="Output file, default: stdout", )] + output_file: Option, +} + +fn open_reader(p: &Path) -> Result> { + let file = File::open(&p) + .with_context(|| format!("Error opening input file {} for reading", p.display()))?; + Ok(Box::new(BufReader::new(file))) +} + +fn open_writer(p: &Path) -> Result> { + let file = File::create(&p) + .with_context(|| format!("Error opening input file {} for writing", p.display()))?; + Ok(Box::new(BufWriter::new(file))) } impl IdentifyCmd { fn cli(self) -> PyResult<()> { let identifier = Identifier::load(&module_path().unwrap().to_str().unwrap()) .or_abort(1); + let (input_file, output_file); + if let Some(p) = &self.input_file { + input_file = open_reader(&p).or_abort(1); + } else { + input_file = Box::new(io::stdin().lock()); + } + if let Some(p) = &self.output_file { + output_file = open_writer(&p).or_abort(1); + } else { + output_file = Box::new(io::stdout().lock()); + } + - let stdin = io::stdin().lock(); if self.threads == 0 { - return self.run_single(stdin, identifier) + self.run_single(identifier, input_file, output_file).or_abort(1); } else { - self.run_parallel(stdin, identifier) + self.run_parallel(identifier, input_file, output_file).or_abort(1); } + Ok(()) } // Run using the parallel identification method // read in batches - fn run_parallel(self, reader: F, identifier: Identifier) -> PyResult<()> - where F: BufRead + fn run_parallel(self, identifier: Identifier, reader: R, mut writer: W) -> Result<()> + where R: BufRead, + W: Write, { // Initialize global thread pool with the number of threads // provided by the user @@ -138,19 +170,20 @@ impl IdentifyCmd { }) .collect(); for b in identifier.par_identify(batch) { - println!("{}", b.0); + writeln!(writer, "{}", b.0)?; } } Ok(()) } // Run using the single-threaded indetification method - fn run_single(self, reader: F, mut identifier: Identifier) -> PyResult<()> - where F: BufRead + fn run_single(self, mut identifier: Identifier, reader: R, mut writer: W) -> Result<()> + where R: BufRead, + W: Write, { // Process line by line for line in reader.lines() { - println!("{}", identifier.identify(&line?).0); + writeln!(writer, "{}", identifier.identify(&line?).0)?; } Ok(()) }