Skip to content

ahclab/Wav2VecSegmenter

Repository files navigation

Wav2VecSegmenter

This repository contains the code for the paper: Improving Speech Translation Accuracy and Time Efficiency with Fine-tuned wav2vec 2.0-based Speech Segmentation

Part of the code is based on SHAS.

Abstract

Speech translation (ST) automatically converts utterances in a source language into text in another language. Splitting continuous speech into shorter segments, known as speech segmentation, plays an important role in ST. Recent segmentation methods trained to mimic the segmentation of ST corpora have surpassed traditional approaches. Tsiamas et al. [1] proposed a segmentation frame classifier (SFC) based on a pre-trained speech encoder called wav2vec 2.0. Their method, named SHAS, retains 95-98% of the BLEU score for ST corpus segmentation. However, the segments generated by SHAS are very different from ST corpus segmentation and tend to be longer with multiple combined utterances. This is due to SHAS's reliance on length heuristics, i.e., it splits speech into segments of easily translatable length without fully considering the potential for ST improvement by splitting them into even shorter segments. Longer segments often degrade translation quality and ST's time efficiency. In this study, we extended SHAS to improve ST translation accuracy and efficiency by splitting speech into shorter segments that correspond to sentences. We introduced a simple segmentation avlgorithm using the moving average of SFC predictions without relying on length heuristics and explored wav2vec 2.0 fine-tuning for improved speech segmentation prediction. Our experimental results reveal that our speech segmentation method significantly improved the quality and the time efficiency of speech translation compared to SHAS.

Setup

git clone git@github.com:ahclab/Wav2VecSegmenter.git
cd Wav2VecSegmenter

# install requirements
pip install -r  requirements.txt

# install SHAS, fairseq, and mwerSegmenter
bash runs/setup_tools.sh

Example: Segmentation of MuST-C En-De tst-COMMON

Download data

Download MuST-C v2 en-de to $MUSTC_ROOT from here. Environment variables are listed in runs/path.sh.

Download a segmentation model

Download a pretrained model (large+all in the paper) and a config file to $SEG_MODEL_PATH from here.

Segmentation

Segment the MuST-C En-De tst-COMMON set with the segmentation model:

SEG_MODEL_PATH=${PWD}/models/segmentation

ckpt_path=${SEG_MODEL_PATH}/large+all/lna_l24_ft24/ckpts/epoch-15_best_eval_f1.pt
config_path=${SEG_MODEL_PATH}/large+all/.hydra/config.yaml
python segment.py \
  ckpt_path=${ckpt_path} \
  config_path=${config_path} \
  output_dir=results/mustc_ende_tst-COMMON

A custom segmentation yaml file is saved to results/mustc_ende_tst-COMMON/custom_segments.yaml.

We used Hydra to manage configurations. See conf/segment.yaml for detail.

Set the data to be segmented in infer_data; to run segment.py, the following information needs to be written in the infer_data/*.yaml file:

  • wav_dir: path to the directory containing wav files
  • orig_seg_yaml: path to the original segmentation yaml file (MuST-C format)

Reproduce results in the paper

Prepare training data

bash runs/prep_mustc.sh

Train a segmentation model

Following commands are examples to train segmentation models with different configurations. See the paper and conf/train.yaml for detail.

  • middle (0/16)
    python train.py \
      batch_size=4 save_ckpts=True \
      exp_name=lna_l16_ft0 \
      data=mustc_ende \
      task.model.finetune_wav2vec=False \
      task.model.wav2vec_keep_layers=16
    
  • large+all (24/24)
    python train.py \
      batch_size=4 save_ckpts=True \
      exp_name=lna_l24_ft24 \
      data=mustc_ende \
      task.model.finetune_wav2vec=True \
      task.model.wav2vec_keep_layers=24 \
      task.model.wav2vec_ft_layers=24
    

When running train.py, a directory outputs/yyyy-mm-dd/hh-mm-ss/ is created by Hydra, and the configuration files and models are stored there.

Models used in the paper can be downloaded from the following links.

wandb is used for logging. To disable wandb, set log_wandb=False.

Download a pre-trained speech translation model

bash runs/prep_s2t_mustc.sh

Segment, translate, and evaluate

The following commands reproduce the results of large+all (24/24). Extract the downloaded model to ${EXP_PATH}. Results are stored in ${EXP_PATH}/infer_outputs.

  • Parameter search in dev set

    export PYTHONPATH=${PWD}/tools/fairseq
    EXP_PATH=${PWD}/outputs/large+all
    ckpt_name=epoch-15_best_eval_f1.pt
    
    # pDAC
    python inference_st_pipe.py -m \
      outputs=${EXP_PATH} ckpt=${ckpt_name} \
      log_wandb=False \
      infer_data=mustc_ende_dev batch_size=14 \
      algorithm=dac \
      algorithm.max_segment_length=10,12,14,16,18,20,22,24,26,28 \
      algorithm.threshold=0.5
    
    # pSTRM
    python inference_st_pipe.py -m \
      outputs=${EXP_PATH} ckpt=${ckpt_name} \
      log_wandb=False \
      infer_data=mustc_ende_dev batch_size=14 \
      algorithm=strm \
      algorithm.max_segment_length=10,12,14,16,18,20,22,24,26,28 \
      algorithm.threshold=0.5
    
    # pTHR(+MA)
    python inference_st_pipe.py -m \
      outputs=${EXP_PATH} ckpt=${ckpt_name} \
      log_wandb=False \
      infer_data=mustc_ende_dev batch_size=14 \
      algorithm=pthr \
      algorithm.max_segment_length=28 \
      algorithm.max_lerp_range=4 \
      algorithm.min_lerp_range=0.4 \
      algorithm.threshold=0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9 \
      algorithm.moving_average_window=0,0.1,0.2,0.4,0.6,0.8,1
    
    
    large+all BLEU # seg best config
    pDAC 26.3 1391 max_segment_length=10
    pSTRM 26.2 796 max_segment_length=20
    pTHR 26.8 1335 threshold=0.1, moving_average_window=0
    pTHR+MA 26.9 1264 threshold=0.1, moving_average_window=0.1
  • Evaluation in tst-COMMON

    export PYTHONPATH=${PWD}/tools/fairseq
    EXP_PATH=${PWD}/outputs/large+all
    ckpt_name=epoch-15_best_eval_f1.pt
    
    # pDAC
    python inference_st_pipe.py -m \
      outputs=${EXP_PATH} ckpt=${ckpt_name} \
      log_wandb=False \
      batch_size=14 \
      algorithm=dac \
      algorithm.max_segment_length=10 \
      algorithm.threshold=0.5
    
    # pSTRM
    python inference_st_pipe.py -m \
      outputs=${EXP_PATH} ckpt=${ckpt_name} \
      log_wandb=False \
      batch_size=14 \
      algorithm=strm \
      algorithm.max_segment_length=20 \
      algorithm.threshold=0.5
    
    # pTHR(+MA)
    python inference_st_pipe.py -m \
      outputs=${EXP_PATH} ckpt=${ckpt_name} \
      log_wandb=False \
      batch_size=14 \
      algorithm=pthr \
      algorithm.max_segment_length=28 \
      algorithm.max_lerp_range=4 \
      algorithm.min_lerp_range=0.4 \
      algorithm.threshold=0.1 \
      algorithm.moving_average_window=0,0.1
      
    
    large+all BLEU # seg config
    pDAC 25.9 2279 max_segment_length=10
    pSTRM 25.7 1292 max_segment_length=20
    pTHR 26.3 2149 threshold=0.1, moving_average_window=0
    pTHR+MA 26.3 2044 threshold=0.1, moving_average_window=0.1

Citation

Ryo Fukuda, Katuhito Sudoh and Satoshi Nakamura, "Improving Speech Translation Accuracy and Time Efficiency with Fine-tuned wav2vec 2.0-based Speech Segmentation," in IEEE/ACM Transactions on Audio, Speech, and Language Processing, doi: 10.1109/TASLP.2023.3343614.

@ARTICLE{10361556,
  author={Fukuda, Ryo and Sudoh, Katsuhito and Nakamura, Satoshi},
  journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 
  title={Improving Speech Translation Accuracy and Time Efficiency with Fine-tuned wav2vec 2.0-based Speech Segmentation}, 
  year={2023},
  volume={},
  number={},
  pages={1-12},
  doi={10.1109/TASLP.2023.3343614}}

Contact

If you have any questions about codes in this repository, please contact Ryo Fukuda via email: fukuda.ryo.fo3[at]is.naist.jp

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published