diff --git a/fireredasr2s/fireredasr2s_cli.py b/fireredasr2s/fireredasr2s_cli.py index 5912290..1d97813 100755 --- a/fireredasr2s/fireredasr2s_cli.py +++ b/fireredasr2s/fireredasr2s_cli.py @@ -9,16 +9,17 @@ import os import soundfile as sf -from textgrid import TextGrid, IntervalTier +from textgrid import IntervalTier, TextGrid from fireredasr2s.fireredasr2 import FireRedAsr2Config +from fireredasr2s.fireredasr2system import (FireRedAsr2System, + FireRedAsr2SystemConfig) from fireredasr2s.fireredlid import FireRedLidConfig from fireredasr2s.fireredpunc import FireRedPuncConfig from fireredasr2s.fireredvad import FireRedVadConfig -from fireredasr2s.fireredasr2system import FireRedAsr2System, FireRedAsr2SystemConfig logging.basicConfig(level=logging.INFO, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logger = logging.getLogger("fireredasr2s.asr_system") @@ -86,7 +87,8 @@ def main(args): wavs = get_wav_info(args) - if args.outdir: os.makedirs(args.outdir, exist_ok=True) + if args.outdir: + os.makedirs(args.outdir, exist_ok=True) fout = open(args.outdir + "/result.jsonl", "w") if args.outdir else None # Build Models @@ -155,7 +157,8 @@ def main(args): save_segment_dir = os.path.join(args.outdir, "vad_segment") split_and_save_segment(wav_path, result["vad_segments_ms"], save_segment_dir) - if fout: fout.close() + if fout: + fout.close() logger.info("All Done") @@ -164,7 +167,7 @@ def get_wav_info(args): Returns: wavs: list of (uttid, wav_path) """ - base = lambda p: os.path.basename(p).replace(".wav", "") + def base(p): return os.path.basename(p).replace(".wav", "") if args.wav_path: wavs = [(base(args.wav_path), args.wav_path)] elif args.wav_paths and len(args.wav_paths) >= 1: @@ -258,10 +261,13 @@ def split_and_save_segment(wav_path, timestamps_ms, save_segment_dir): start = int(start_ms / 1000 * sample_rate) end = int(end_ms / 1000 * sample_rate) sf.write(seg_path, wav_np[start:end], samplerate=sample_rate) - -if __name__ == "__main__": +def cli_main(): args = parser.parse_args() logger.info(args) main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e0c9128 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,29 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "fireredasr2s" +version = "0.0.1" +description = "Industrial-grade ASR system with VAD, LID, and punctuation prediction" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.9" +dependencies = [ + "torch", + "torchaudio", + "transformers", + "numpy", + "cn2an", + "kaldiio", + "kaldi_native_fbank", + "sentencepiece", + "soundfile", + "textgrid", +] + +[project.scripts] +fireredasr2s-cli = "fireredasr2s.fireredasr2s_cli:cli_main" + +[tool.setuptools.packages.find] +include = ["fireredasr2s*"]