-
Notifications
You must be signed in to change notification settings - Fork 1
/
melodygenerator.py
114 lines (77 loc) · 3.57 KB
/
melodygenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
import music21 as m21
import tensorflow.keras as keras
import numpy as np
from preprocess import SEQUENCE_LENGTH, MAPPING_PATH
class MelodyGenerator:
def __init__(self, model_path="model.h5"):
self.model_path = model_path
self.model = keras.models.load_model(model_path)
with open(MAPPING_PATH, "r") as fp:
self._mappings = json.load(fp)
self._start_symbols = ["/"] * SEQUENCE_LENGTH
@staticmethod
def _sample_with_temperature(self, probabilities, temperature):
predictions = np.log(probabilities) / temperature
probabilities = np.exp(predictions) / np.sum(np.exp(predictions))
choices = range(len(probabilities))
index = np.random.choice(choices, p=probabilities)
def generate_melody(self, seed, num_steps, max_sequence_length, temperature):
# create seed with start symbols
seed = seed.split()
melody = seed
seed = self._start_symbols + seed
# map seed to integers
seed = [self._mappings[symbol] for symbol in seed]
for _ in range(num_steps):
# limit the seed to max_sequence_length
seed = seed[-max_sequence_length:]
# one-hot encode the seed
onehot_seed = keras.utils.to_categorical(seed, num_classes=len(self._mappings))
onehot_seed = onehot_seed[np.newaxis, ...]
# make a prediction
probabilities = self.model.predict(onehot_seed)[0]
output_int = self._sample_with_temperature(probabilities, temperature)
# update the seed
seed.append(output_int)
# map int to our encoding
output_symbol = [k for k, v in self._mappings.items() if v == output_int][0]
# check whether we're at the end of a melody
if output_symbol == "/":
break
# update the melody
melody.append(output_symbol)
return melody
def save_melody(self, melody, step_duration=0.25, format="midi", file_name = "mel.midi"):
# create a music21 stream
stream = m21.stream.Stream()
# parse all the symbols in the melody
start_symbol = None
step_counter = 1
for i, symbol in enumerate(melody):
# handle case in which we have a note/ rest
if symbol != "_" or i + 1 == len(melody):
# ensure we're dealing with note/ rest beyond the first one
if start_symbol is not None:
quarter_length_duration = step_duration * step_counter
# handle rest
if start_symbol == "r":
m21_event = m21.note.Rest(quarterLength=quarter_length_duration)
# handle note
else:
m21_event = m21.note.Note(int(start_symbol), quarterLength=quarter_length_duration)
stream.append(m21_event)
# reset the step counter
step_counter = 1
start_symbol = symbol
# handle case in which we have a prolongation sign "_"
else:
step_counter += 1
# write the m21 stream into midi file
stream.write(format, file_name)
if __name__ == "__main__":
mg = MelodyGenerator()
seed = "55 _ _ _ 60 _ _ _ 55 _ _ _ 55 _ "
melody = mg.generate_melody(seed, 500, SEQUENCE_LENGTH, 0.4)
print(melody)
mg.save_melody(melody)