-
Notifications
You must be signed in to change notification settings - Fork 0
/
GraphGenerator.py
103 lines (88 loc) · 4.36 KB
/
GraphGenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import errno
import os
import random
import sys
import argparse
import numpy as np
from tqdm import tqdm
from AudioAugmentation import AudioAugmentation
from DataParser import DataParser
from common import Utils
class GraphGenerator:
def __init__(self, type_graph, folder_type, folders=None, augmentation=None,
skip_probability=0, save_raw=False, ):
if folders is None:
folders = ["ff1010bird"]
self.files = DataParser(type_folder=folder_type, folders=folders, graph_type=type_graph).get_audio_files_name()
self.type_graph = type_graph
self.folder_type = folder_type
self.aug = augmentation
self.save_raw = save_raw
self.skip_prob = skip_probability
def generateGraph(self):
for file in tqdm(self.files):
if self.trigger_with_prob():
pass
folder = os.path.basename((os.path.dirname(file)))
file_name = os.path.splitext(DataParser.path_leaf(file))[0]
if self.aug is not None:
file_name += "_" + self.aug.get_file_label()
path_output_graph = os.getcwd() + "/data/graphs/" + self.folder_type + "/" + folder + "/" + self.type_graph + "/"
self._makedirs(path_output_graph)
path_output_graph = os.path.join(path_output_graph, file_name)
data, sr = self._get_plot_data(file)
self._write_graph(data, sr, path_output_graph)
if self.save_raw:
path_output_raw = os.getcwd() + "/data/raw/" + self.folder_type + "/" + folder + "/" + self.type_graph + "/"
self._makedirs(path_output_raw)
path_output_raw = os.path.join(path_output_raw, file_name)
self._write_raw_file(data, path_output_raw)
def _get_plot_data(self, file):
data, sr = Utils.read_audio_file(file)
if self.aug is not None:
data = aug.augment_data(data, sr)
return Utils.get_plot_data(data, sr, self.type_graph), sr
def _write_graph(self, data, sr, save_path):
Utils.write_graph(data, sr, save_path, self.type_graph)
def _write_raw_file(self, data, save_path):
arr = np.array(data)
np.save(save_path, arr)
def _makedirs(self, path):
if not os.path.exists(os.path.dirname(path)):
try:
os.makedirs(os.path.dirname(path))
except OSError as exc: # Guard against race condition
if exc.errno != errno.EEXIST:
raise
def trigger_with_prob(self):
return random.random() > self.skip_prob
def main(type_graph, folder_type, folders, augmentation, skip_probability):
g = GraphGenerator(type_graph, folder_type, folders, augmentation, skip_probability)
g.generateGraph()
return
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Customization options for the graph generator")
parser.add_argument("graph_type", nargs='?', default="mfcc", help='Set type of graph (i.e: mfcc/melspectrogram/melspectrogram-energy/spectrogram)')
parser.add_argument("folder_type", nargs='?', default="testing", help='Set folder type (e.g: training or testing)')
parser.add_argument("folders", nargs='?', default="BirdVoxDCASE20k,ff1010bird,warblrb10k",
help='Folders from which the audio files will be taken (e.g: "folder1,folder2")')
parser.add_argument("additive_noise", nargs='?', type=int, default=0, help='Additive noise')
parser.add_argument("random_noise", nargs='?', type=bool, default=False, help='Random noise')
parser.add_argument("time_stretch_rate", nargs='?', type=int, default=1, help='Time stretch')
parser.add_argument("skip_probability", nargs='?', type=int, default=0, help='Probability of skipping an augmentation regarding a file')
args = parser.parse_args()
print(args)
aug = AudioAugmentation()
check = False
if args.random_noise:
aug.add_random_noise()
check = True
if args.additive_noise > 0:
aug.add_noise(args.additive_noise)
check = True
if args.time_stretch_rate != 1 and args.time_stretch_rate > 0:
aug.time_stretch(args.time_stretch_rate)
check = True
if not check:
aug = None
main(str(args.graph_type).lower(), str(args.folder_type).lower(), [item.strip() for item in args.folders.strip().split(',')], aug, args.skip_probability)