@@ -52,3 +52,209 @@ def __init__(self, common_config, arch_config):
52
52
self .audio_file_base = None
53
53
self .secondary_source_map = None
54
54
self .primary_source_map = None
55
+
56
+
57
+
58
+
59
+
60
+
61
+ def seperate (self ):
62
+ self .logger .debug ("Starting separation process in SeperateVR..." )
63
+ if self .primary_model_name == self .model_basename and isinstance (self .primary_sources , tuple ):
64
+ self .logger .debug ("Using cached primary sources..." )
65
+ y_spec , v_spec = self .primary_sources
66
+ self .load_cached_sources ()
67
+ else :
68
+ self .logger .debug ("Starting inference..." )
69
+ self .start_inference_console_write ()
70
+
71
+ device = self .device
72
+ self .logger .debug (f"Device set to: { device } " )
73
+
74
+ nn_arch_sizes = [31191 , 33966 , 56817 , 123821 , 123812 , 129605 , 218409 , 537238 , 537227 ] # default
75
+ vr_5_1_models = [56817 , 218409 ]
76
+ model_size = math .ceil (os .stat (self .model_path ).st_size / 1024 )
77
+ nn_arch_size = min (nn_arch_sizes , key = lambda x : abs (x - model_size ))
78
+ self .logger .debug (f"Model size determined: { model_size } , NN architecture size: { nn_arch_size } " )
79
+
80
+ if nn_arch_size in vr_5_1_models or self .is_vr_51_model :
81
+ self .logger .debug ("Using CascadedNet for VR 5.1 model..." )
82
+ self .model_run = nets_new .CascadedNet (self .mp .param ["bins" ] * 2 , nn_arch_size , nout = self .model_capacity [0 ], nout_lstm = self .model_capacity [1 ])
83
+ self .is_vr_51_model = True
84
+ else :
85
+ self .logger .debug ("Determining model capacity..." )
86
+ self .model_run = nets .determine_model_capacity (self .mp .param ["bins" ] * 2 , nn_arch_size )
87
+
88
+ self .model_run .load_state_dict (torch .load (self .model_path , map_location = cpu ))
89
+ self .model_run .to (device )
90
+ self .logger .debug ("Model loaded and moved to device." )
91
+
92
+ self .running_inference_console_write ()
93
+
94
+ y_spec , v_spec = self .inference_vr (self .loading_mix (), device , self .aggressiveness )
95
+ self .logger .debug ("Inference completed." )
96
+ if not self .is_vocal_split_model :
97
+ self .cache_source ((y_spec , v_spec ))
98
+ self .write_to_console (DONE , base_text = "" )
99
+
100
+ if self .is_secondary_model_activated and self .secondary_model :
101
+ self .logger .debug ("Processing secondary model..." )
102
+ self .secondary_source_primary , self .secondary_source_secondary = process_secondary_model (
103
+ self .secondary_model , self .process_data , main_process_method = self .process_method , main_model_primary = self .primary_stem
104
+ )
105
+
106
+ if not self .is_secondary_stem_only :
107
+ primary_stem_path = os .path .join (self .export_path , f"{ self .audio_file_base } _({ self .primary_stem } ).wav" )
108
+ self .logger .debug (f"Processing primary stem: { self .primary_stem } " )
109
+ if not isinstance (self .primary_source , np .ndarray ):
110
+ self .primary_source = self .spec_to_wav (y_spec ).T
111
+ self .logger .debug ("Converting primary source spectrogram to waveform." )
112
+ if not self .model_samplerate == 44100 :
113
+ self .primary_source = librosa .resample (self .primary_source .T , orig_sr = self .model_samplerate , target_sr = 44100 ).T
114
+ self .logger .debug ("Resampling primary source to 44100Hz." )
115
+
116
+ self .primary_source_map = self .final_process (primary_stem_path , self .primary_source , self .secondary_source_primary , self .primary_stem , 44100 )
117
+ self .logger .debug ("Primary stem processed." )
118
+
119
+ if not self .is_primary_stem_only :
120
+ secondary_stem_path = os .path .join (self .export_path , f"{ self .audio_file_base } _({ self .secondary_stem } ).wav" )
121
+ self .logger .debug (f"Processing secondary stem: { self .secondary_stem } " )
122
+ if not isinstance (self .secondary_source , np .ndarray ):
123
+ self .secondary_source = self .spec_to_wav (v_spec ).T
124
+ self .logger .debug ("Converting secondary source spectrogram to waveform." )
125
+ if not self .model_samplerate == 44100 :
126
+ self .secondary_source = librosa .resample (self .secondary_source .T , orig_sr = self .model_samplerate , target_sr = 44100 ).T
127
+ self .logger .debug ("Resampling secondary source to 44100Hz." )
128
+
129
+ self .secondary_source_map = self .final_process (secondary_stem_path , self .secondary_source , self .secondary_source_secondary , self .secondary_stem , 44100 )
130
+ self .logger .debug ("Secondary stem processed." )
131
+
132
+ clear_gpu_cache ()
133
+ self .logger .debug ("GPU cache cleared." )
134
+ secondary_sources = {** self .primary_source_map , ** self .secondary_source_map }
135
+
136
+ self .process_vocal_split_chain (secondary_sources )
137
+ self .logger .debug ("Vocal split chain processed." )
138
+
139
+ if self .is_secondary_model :
140
+ self .logger .debug ("Returning secondary sources..." )
141
+ return secondary_sources
142
+
143
+ def loading_mix (self ):
144
+ X_wave , X_spec_s = {}, {}
145
+
146
+ bands_n = len (self .mp .param ["band" ])
147
+
148
+ audio_file = spec_utils .write_array_to_mem (self .audio_file , subtype = self .wav_type_set )
149
+ is_mp3 = audio_file .endswith (".mp3" ) if isinstance (audio_file , str ) else False
150
+
151
+ for d in range (bands_n , 0 , - 1 ):
152
+ bp = self .mp .param ["band" ][d ]
153
+
154
+ if OPERATING_SYSTEM == "Darwin" :
155
+ wav_resolution = "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else bp ["res_type" ]
156
+ else :
157
+ wav_resolution = bp ["res_type" ]
158
+
159
+ if d == bands_n : # high-end band
160
+ X_wave [d ], _ = librosa .load (audio_file , bp ["sr" ], False , dtype = np .float32 , res_type = wav_resolution )
161
+ X_spec_s [d ] = spec_utils .wave_to_spectrogram (X_wave [d ], bp ["hl" ], bp ["n_fft" ], self .mp , band = d , is_v51_model = self .is_vr_51_model )
162
+
163
+ if not np .any (X_wave [d ]) and is_mp3 :
164
+ X_wave [d ] = rerun_mp3 (audio_file , bp ["sr" ])
165
+
166
+ if X_wave [d ].ndim == 1 :
167
+ X_wave [d ] = np .asarray ([X_wave [d ], X_wave [d ]])
168
+ else : # lower bands
169
+ X_wave [d ] = librosa .resample (X_wave [d + 1 ], self .mp .param ["band" ][d + 1 ]["sr" ], bp ["sr" ], res_type = wav_resolution )
170
+ X_spec_s [d ] = spec_utils .wave_to_spectrogram (X_wave [d ], bp ["hl" ], bp ["n_fft" ], self .mp , band = d , is_v51_model = self .is_vr_51_model )
171
+
172
+ if d == bands_n and self .high_end_process != "none" :
173
+ self .input_high_end_h = (bp ["n_fft" ] // 2 - bp ["crop_stop" ]) + (self .mp .param ["pre_filter_stop" ] - self .mp .param ["pre_filter_start" ])
174
+ self .input_high_end = X_spec_s [d ][:, bp ["n_fft" ] // 2 - self .input_high_end_h : bp ["n_fft" ] // 2 , :]
175
+
176
+ X_spec = spec_utils .combine_spectrograms (X_spec_s , self .mp , is_v51_model = self .is_vr_51_model )
177
+
178
+ del X_wave , X_spec_s , audio_file
179
+
180
+ return X_spec
181
+
182
+ def inference_vr (self , X_spec , device , aggressiveness ):
183
+ def _execute (X_mag_pad , roi_size ):
184
+ X_dataset = []
185
+ patches = (X_mag_pad .shape [2 ] - 2 * self .model_run .offset ) // roi_size
186
+ total_iterations = patches // self .batch_size if not self .is_tta else (patches // self .batch_size ) * 2
187
+ for i in range (patches ):
188
+ start = i * roi_size
189
+ X_mag_window = X_mag_pad [:, :, start : start + self .window_size ]
190
+ X_dataset .append (X_mag_window )
191
+
192
+ X_dataset = np .asarray (X_dataset )
193
+ self .model_run .eval ()
194
+ with torch .no_grad ():
195
+ mask = []
196
+ for i in range (0 , patches , self .batch_size ):
197
+ self .progress_value += 1
198
+ if self .progress_value >= total_iterations :
199
+ self .progress_value = total_iterations
200
+ self .set_progress_bar (0.1 , 0.8 / total_iterations * self .progress_value )
201
+ X_batch = X_dataset [i : i + self .batch_size ]
202
+ X_batch = torch .from_numpy (X_batch ).to (device )
203
+ pred = self .model_run .predict_mask (X_batch )
204
+ if not pred .size ()[3 ] > 0 :
205
+ raise Exception (ERROR_MAPPER [WINDOW_SIZE_ERROR ])
206
+ pred = pred .detach ().cpu ().numpy ()
207
+ pred = np .concatenate (pred , axis = 2 )
208
+ mask .append (pred )
209
+ if len (mask ) == 0 :
210
+ raise Exception (ERROR_MAPPER [WINDOW_SIZE_ERROR ])
211
+
212
+ mask = np .concatenate (mask , axis = 2 )
213
+ return mask
214
+
215
+ def postprocess (mask , X_mag , X_phase ):
216
+ is_non_accom_stem = False
217
+ for stem in NON_ACCOM_STEMS :
218
+ if stem == self .primary_stem :
219
+ is_non_accom_stem = True
220
+
221
+ mask = spec_utils .adjust_aggr (mask , is_non_accom_stem , aggressiveness )
222
+
223
+ if self .is_post_process :
224
+ mask = spec_utils .merge_artifacts (mask , thres = self .post_process_threshold )
225
+
226
+ y_spec = mask * X_mag * np .exp (1.0j * X_phase )
227
+ v_spec = (1 - mask ) * X_mag * np .exp (1.0j * X_phase )
228
+
229
+ return y_spec , v_spec
230
+
231
+ X_mag , X_phase = spec_utils .preprocess (X_spec )
232
+ n_frame = X_mag .shape [2 ]
233
+ pad_l , pad_r , roi_size = spec_utils .make_padding (n_frame , self .window_size , self .model_run .offset )
234
+ X_mag_pad = np .pad (X_mag , ((0 , 0 ), (0 , 0 ), (pad_l , pad_r )), mode = "constant" )
235
+ X_mag_pad /= X_mag_pad .max ()
236
+ mask = _execute (X_mag_pad , roi_size )
237
+
238
+ if self .is_tta :
239
+ pad_l += roi_size // 2
240
+ pad_r += roi_size // 2
241
+ X_mag_pad = np .pad (X_mag , ((0 , 0 ), (0 , 0 ), (pad_l , pad_r )), mode = "constant" )
242
+ X_mag_pad /= X_mag_pad .max ()
243
+ mask_tta = _execute (X_mag_pad , roi_size )
244
+ mask_tta = mask_tta [:, :, roi_size // 2 :]
245
+ mask = (mask [:, :, :n_frame ] + mask_tta [:, :, :n_frame ]) * 0.5
246
+ else :
247
+ mask = mask [:, :, :n_frame ]
248
+
249
+ y_spec , v_spec = postprocess (mask , X_mag , X_phase )
250
+
251
+ return y_spec , v_spec
252
+
253
+ def spec_to_wav (self , spec ):
254
+ if self .high_end_process .startswith ("mirroring" ) and isinstance (self .input_high_end , np .ndarray ) and self .input_high_end_h :
255
+ input_high_end_ = spec_utils .mirroring (self .high_end_process , spec , self .input_high_end , self .mp )
256
+ wav = spec_utils .cmb_spectrogram_to_wave (spec , self .mp , self .input_high_end_h , input_high_end_ , is_v51_model = self .is_vr_51_model )
257
+ else :
258
+ wav = spec_utils .cmb_spectrogram_to_wave (spec , self .mp , is_v51_model = self .is_vr_51_model )
259
+
260
+ return wav
0 commit comments