-
Notifications
You must be signed in to change notification settings - Fork 0
/
apply_attack_semiinformed.py
55 lines (40 loc) · 1.67 KB
/
apply_attack_semiinformed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
DESC="""
This thing takes a reverser checkpoint, a folder full of wavs, and tries to reverse their drift.
Will plunge me into depression. This time for real.
"""
import os
from argparse import ArgumentParser
import torch
import numpy as np
from atk_tools import SemiInformedReverser as Reverser, SupervisedDataset
###########
parser = ArgumentParser(description=DESC)
parser.add_argument('--reverser_path', required=True, type=str, help='Path to a state dict of a Reverser.')
parser.add_argument('--in_fold', required=True, type=str, help='Where to find wavs to reverse.')
parser.add_argument('--out_fold', required=True, type=str, help='Where to output the files in numpy xvector format.')
parser.add_argument('--device', required=False, type=str, default='cuda')
args = parser.parse_args()
device = args.device
if not os.path.exists(args.out_fold):
os.makedirs(args.out_fold)
# Parse a list of ids from the files.
ids = [os.path.splitext(fn)[0].replace('_gen', '') for fn in os.listdir(args.in_fold) if '_gen.wav' in fn]
ds = SupervisedDataset(args.in_fold, ids, return_ids=True)
dl = torch.utils.data.DataLoader(
ds,
batch_size = 1,
shuffle = False
)
reverser = Reverser().eval().to(device)
reverser.load_state_dict(torch.load(args.reverser_path, map_location='cuda'))
tot = len(dl)
with torch.no_grad():
for i, (wav, y, uttids) in enumerate(dl):
wav = wav.to(device)
uttid = uttids[0] # batch size is hardcoded to 1
reco_xv = reverser(wav, embeddings=True)
fn = f'{uttid}_gen.xvector'
fp = os.path.join(args.out_fold, fn)
print(f'[{i+1}/{tot}] Saving to {fp}')
np.save(fp, reco_xv.cpu().numpy())
print('Done.')