-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathrun_inference_for_challenge.py
38 lines (30 loc) · 1.59 KB
/
run_inference_for_challenge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# ==============================================================================
# Copyright (c) 2021, Yamagishi Laboratory, National Institute of Informatics
# Author: Erica Cooper
# All rights reserved.
# ==============================================================================
## Get a pretrained model and run inference; generate an answers.txt file that
## can be submitted to the challenge server.
import os
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--datadir', type=str, required=True, help='Path of your DATA/ directory')
args = parser.parse_args()
DATADIR = args.datadir
## 1. download the base model from fairseq
if not os.path.exists('fairseq/wav2vec_small.pt'):
os.system('mkdir -p fairseq')
os.system('wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt -P fairseq')
os.system('wget https://raw.githubusercontent.com/pytorch/fairseq/main/LICENSE -P fairseq/')
## 2. download the finetuned checkpoint
if not os.path.exists('pretrained/ckpt_w2vsmall'):
os.system('mkdir -p pretrained')
os.system('wget https://zenodo.org/record/6785056/files/ckpt_w2vsmall.tar.gz')
os.system('tar -zxvf ckpt_w2vsmall.tar.gz')
os.system('mv ckpt_w2vsmall pretrained/')
os.system('rm ckpt_w2vsmall.tar.gz')
os.system('cp fairseq/LICENSE pretrained/')
## 3. run inference
os.system('python predict.py --fairseq_base_model fairseq/wav2vec_small.pt --outfile answer_main.txt --finetuned_checkpoint pretrained/ckpt_w2vsmall --datadir ' + DATADIR)
main()