-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
66 lines (43 loc) · 2.05 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import print_function, division
from keras.layers import Input
from keras.models import Model
from keras.optimizers import Adam
from optparse import OptionParser
from model import train, build_audio_generator, build_audio_discriminator
import os
#turn off debug information
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def main():
parser = OptionParser()
# Only required for labeling - Defines train or generate mode
parser.add_option('-m', '--mode', help='train or gen', dest='mode', default = 'label')
# Only required for labeling - Enter Model id here
parser.add_option('-u', '--uid', help='enter model id here')
epochs = 1000
(options, args) = parser.parse_args()
training_data_path = 'data/cv-valid-train/*.wav'
if options.mode == 'train':
frame_size = 500
frame_shift = 128
audio_shape_disc = (frame_size,256)
audio_shape_gen = (frame_size, 256)
optimizer = Adam(0.0002, 0.5)
# Build and compile the discriminator
audio_discriminator = build_audio_discriminator(audio_shape_disc)
audio_discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
# Build the generator
audio_generator = build_audio_generator(audio_shape_gen, frame_size)
# The generator takes noise
noise = Input(shape=audio_shape_gen)
audio = audio_generator(noise)
# For the combined model we will only train the generator
# audio_discriminator.trainable = False
# The discriminator takes generated audio as input and determines validity
audio_valid = audio_discriminator(audio)
# The combined model (stacked generator and discriminator) takes
# noise as input => generates audio => determines validity
audio_combined = Model(noise, audio_valid)
audio_combined.compile(loss='binary_crossentropy', optimizer=optimizer)
train(training_data_path, audio_generator, audio_discriminator, audio_combined, epochs, frame_size, frame_shift)
if __name__ == '__main__':
main()