-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
46 lines (36 loc) · 1.11 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import pickle
from argparse import ArgumentParser
from pathlib import Path
import jax
import jax.numpy as jnp
import numpy as np
from scipy.io.wavfile import write
import config
from hifigan import Generator
parser = ArgumentParser()
parser.add_argument("--model", type=Path, required=True, help="Path to model file")
parser.add_argument("--mel", type=Path, required=True, help="Path to mel file")
parser.add_argument("--wav", type=Path, required=True, help="Path to output wav file")
args = parser.parse_args()
g = Generator(
config.num_mels,
config.resblock_kernel_sizes,
config.upsample_rates,
config.upsample_kernel_sizes,
config.upsample_initial_channel,
config.resblock_kind,
config.resblock_dilation_sizes,
)
with open(args.model, "rb") as f:
dic = pickle.load(f)
g.load_state_dict(dic["generator"])
g = jax.device_put(g, device=jax.devices("cpu")[0])
g = g.eval()
mel = np.load(args.mel)
wav = g(mel)
wav = wav * 2**15
wav = jnp.clip(wav, a_min=-(2**15), a_max=2**15 - 1)
wav = wav.astype(jnp.int16)
wav = jax.device_get(wav)
write(args.wav, config.sample_rate, wav)
print("Done!")