-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_multi_lyric.py
52 lines (43 loc) · 1.59 KB
/
generate_multi_lyric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import sys, getopt, os
import re
import warnings
warnings.filterwarnings("ignore")
## Use cuda if available
from config import device
from dataset import LyricsDatasetBPE, LyricsDatasetRegex
from train import predict
## Define model
from model import LSTM
dataset = LyricsDatasetRegex([[]])
dataset = torch.load('pretrained_models/train_set_rt3.pt')
model = LSTM(len(dataset.token_set), padding_idx=dataset.padding_idx, embedding_size=100, hidden_size=256, num_layers=2)
model.load_state_dict(torch.load('pretrained_models/model_rt3.pt', map_location=torch.device(device)))
model.to(device)
model.eval()
def main(argv):
## Read args from user
try:
opts, args = getopt.getopt(argv,"ht:",["top-words="])
except getopt.GetoptError:
print('generate_multi_lyric.py -t <top_words>')
sys.exit(2)
top_words = 5
for opt, arg in opts:
if opt == '-h':
print('generate_multi_lyric.py -t <top_words>')
sys.exit()
elif opt in ("-t", "--top-words"):
top_words = int(arg)
## Print instructions
print("Multi-Lyric Generation in greek!")
print("Type 'quit' to exit the program")
## While loop
regex = re.compile("\w+|\\.")
val = input("Type a word or phrase and press Enter, to get an AI-generated lyric: ")
while val != 'quit':
seq = regex.findall(val)
print(val + " " + predict(model, dataset, seq, 100, top_only=top_words, no_unk=True))
val = input("Type a word or phrase and press Enter, to get an AI-generated lyric: ")
if __name__ == "__main__":
main(sys.argv[1:])