Skip to content

Commit b6e0547

Browse files
committed
Fixed minor lint issues
1 parent a0bfc27 commit b6e0547

File tree

2 files changed

+214
-9
lines changed

2 files changed

+214
-9
lines changed

audio_separator/separator/architectures/mdx_separator.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@ def __init__(self, common_config, arch_config):
3232
# Loading the model for inference
3333
self.logger.debug("Loading ONNX model for inference...")
3434
if self.segment_size == self.dim_t:
35-
ort_ = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider)
36-
self.model_run = lambda spek: ort_.run(None, {"input": spek.cpu().numpy()})[0]
35+
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider)
36+
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
3737
self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.")
3838
else:
3939
self.model_run = onnx2torch.convert(self.model_path)
4040
self.model_run.to(self.torch_device).eval()
4141
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")
4242

43-
self.n_bins = None
44-
self.trim = None
45-
self.chunk_size = None
46-
self.gen_size = None
43+
self.n_bins = 0
44+
self.trim = 0
45+
self.chunk_size = 0
46+
self.gen_size = 0
4747
self.stft = None
4848

4949
self.primary_source = None
@@ -323,8 +323,7 @@ def demix(self, mix, is_match_mix=False):
323323

324324
# Compensates the source if not matching the mix.
325325
if not is_match_mix:
326-
# TODO: Investigate whether fixing this bug actually does anything!
327-
source * self.compensate
326+
source *= self.compensate
328327
self.logger.debug("Match mix mode; compensate multiplier applied.")
329328

330329
# TODO: In UVR, VR denoise model gets applied here. Consider implementing this as a feature.
@@ -358,7 +357,7 @@ def run_model(self, mix, is_match_mix=False):
358357
spec_pred_neg = self.model_run(-spek) # Ensure this line correctly negates spek and runs the model
359358
spec_pred_pos = self.model_run(spek)
360359
# Ensure both spec_pred_neg and spec_pred_pos are tensors before applying operations
361-
spec_pred = (-spec_pred_neg * 0.5) + (spec_pred_pos * 0.5) # [invalid-unary-operand-type]
360+
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5) # [invalid-unary-operand-type]
362361
self.logger.debug("Model run on both negative and positive spectrums for denoising.")
363362
else:
364363
spec_pred = self.model_run(spek)

audio_separator/separator/architectures/vr_separator.py

+206
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,209 @@ def __init__(self, common_config, arch_config):
5252
self.audio_file_base = None
5353
self.secondary_source_map = None
5454
self.primary_source_map = None
55+
56+
57+
58+
59+
60+
61+
def seperate(self):
62+
self.logger.debug("Starting separation process in SeperateVR...")
63+
if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, tuple):
64+
self.logger.debug("Using cached primary sources...")
65+
y_spec, v_spec = self.primary_sources
66+
self.load_cached_sources()
67+
else:
68+
self.logger.debug("Starting inference...")
69+
self.start_inference_console_write()
70+
71+
device = self.device
72+
self.logger.debug(f"Device set to: {device}")
73+
74+
nn_arch_sizes = [31191, 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227] # default
75+
vr_5_1_models = [56817, 218409]
76+
model_size = math.ceil(os.stat(self.model_path).st_size / 1024)
77+
nn_arch_size = min(nn_arch_sizes, key=lambda x: abs(x - model_size))
78+
self.logger.debug(f"Model size determined: {model_size}, NN architecture size: {nn_arch_size}")
79+
80+
if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
81+
self.logger.debug("Using CascadedNet for VR 5.1 model...")
82+
self.model_run = nets_new.CascadedNet(self.mp.param["bins"] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1])
83+
self.is_vr_51_model = True
84+
else:
85+
self.logger.debug("Determining model capacity...")
86+
self.model_run = nets.determine_model_capacity(self.mp.param["bins"] * 2, nn_arch_size)
87+
88+
self.model_run.load_state_dict(torch.load(self.model_path, map_location=cpu))
89+
self.model_run.to(device)
90+
self.logger.debug("Model loaded and moved to device.")
91+
92+
self.running_inference_console_write()
93+
94+
y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness)
95+
self.logger.debug("Inference completed.")
96+
if not self.is_vocal_split_model:
97+
self.cache_source((y_spec, v_spec))
98+
self.write_to_console(DONE, base_text="")
99+
100+
if self.is_secondary_model_activated and self.secondary_model:
101+
self.logger.debug("Processing secondary model...")
102+
self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(
103+
self.secondary_model, self.process_data, main_process_method=self.process_method, main_model_primary=self.primary_stem
104+
)
105+
106+
if not self.is_secondary_stem_only:
107+
primary_stem_path = os.path.join(self.export_path, f"{self.audio_file_base}_({self.primary_stem}).wav")
108+
self.logger.debug(f"Processing primary stem: {self.primary_stem}")
109+
if not isinstance(self.primary_source, np.ndarray):
110+
self.primary_source = self.spec_to_wav(y_spec).T
111+
self.logger.debug("Converting primary source spectrogram to waveform.")
112+
if not self.model_samplerate == 44100:
113+
self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
114+
self.logger.debug("Resampling primary source to 44100Hz.")
115+
116+
self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, 44100)
117+
self.logger.debug("Primary stem processed.")
118+
119+
if not self.is_primary_stem_only:
120+
secondary_stem_path = os.path.join(self.export_path, f"{self.audio_file_base}_({self.secondary_stem}).wav")
121+
self.logger.debug(f"Processing secondary stem: {self.secondary_stem}")
122+
if not isinstance(self.secondary_source, np.ndarray):
123+
self.secondary_source = self.spec_to_wav(v_spec).T
124+
self.logger.debug("Converting secondary source spectrogram to waveform.")
125+
if not self.model_samplerate == 44100:
126+
self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
127+
self.logger.debug("Resampling secondary source to 44100Hz.")
128+
129+
self.secondary_source_map = self.final_process(secondary_stem_path, self.secondary_source, self.secondary_source_secondary, self.secondary_stem, 44100)
130+
self.logger.debug("Secondary stem processed.")
131+
132+
clear_gpu_cache()
133+
self.logger.debug("GPU cache cleared.")
134+
secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
135+
136+
self.process_vocal_split_chain(secondary_sources)
137+
self.logger.debug("Vocal split chain processed.")
138+
139+
if self.is_secondary_model:
140+
self.logger.debug("Returning secondary sources...")
141+
return secondary_sources
142+
143+
def loading_mix(self):
144+
X_wave, X_spec_s = {}, {}
145+
146+
bands_n = len(self.mp.param["band"])
147+
148+
audio_file = spec_utils.write_array_to_mem(self.audio_file, subtype=self.wav_type_set)
149+
is_mp3 = audio_file.endswith(".mp3") if isinstance(audio_file, str) else False
150+
151+
for d in range(bands_n, 0, -1):
152+
bp = self.mp.param["band"][d]
153+
154+
if OPERATING_SYSTEM == "Darwin":
155+
wav_resolution = "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else bp["res_type"]
156+
else:
157+
wav_resolution = bp["res_type"]
158+
159+
if d == bands_n: # high-end band
160+
X_wave[d], _ = librosa.load(audio_file, bp["sr"], False, dtype=np.float32, res_type=wav_resolution)
161+
X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp["hl"], bp["n_fft"], self.mp, band=d, is_v51_model=self.is_vr_51_model)
162+
163+
if not np.any(X_wave[d]) and is_mp3:
164+
X_wave[d] = rerun_mp3(audio_file, bp["sr"])
165+
166+
if X_wave[d].ndim == 1:
167+
X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
168+
else: # lower bands
169+
X_wave[d] = librosa.resample(X_wave[d + 1], self.mp.param["band"][d + 1]["sr"], bp["sr"], res_type=wav_resolution)
170+
X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp["hl"], bp["n_fft"], self.mp, band=d, is_v51_model=self.is_vr_51_model)
171+
172+
if d == bands_n and self.high_end_process != "none":
173+
self.input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (self.mp.param["pre_filter_stop"] - self.mp.param["pre_filter_start"])
174+
self.input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - self.input_high_end_h : bp["n_fft"] // 2, :]
175+
176+
X_spec = spec_utils.combine_spectrograms(X_spec_s, self.mp, is_v51_model=self.is_vr_51_model)
177+
178+
del X_wave, X_spec_s, audio_file
179+
180+
return X_spec
181+
182+
def inference_vr(self, X_spec, device, aggressiveness):
183+
def _execute(X_mag_pad, roi_size):
184+
X_dataset = []
185+
patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size
186+
total_iterations = patches // self.batch_size if not self.is_tta else (patches // self.batch_size) * 2
187+
for i in range(patches):
188+
start = i * roi_size
189+
X_mag_window = X_mag_pad[:, :, start : start + self.window_size]
190+
X_dataset.append(X_mag_window)
191+
192+
X_dataset = np.asarray(X_dataset)
193+
self.model_run.eval()
194+
with torch.no_grad():
195+
mask = []
196+
for i in range(0, patches, self.batch_size):
197+
self.progress_value += 1
198+
if self.progress_value >= total_iterations:
199+
self.progress_value = total_iterations
200+
self.set_progress_bar(0.1, 0.8 / total_iterations * self.progress_value)
201+
X_batch = X_dataset[i : i + self.batch_size]
202+
X_batch = torch.from_numpy(X_batch).to(device)
203+
pred = self.model_run.predict_mask(X_batch)
204+
if not pred.size()[3] > 0:
205+
raise Exception(ERROR_MAPPER[WINDOW_SIZE_ERROR])
206+
pred = pred.detach().cpu().numpy()
207+
pred = np.concatenate(pred, axis=2)
208+
mask.append(pred)
209+
if len(mask) == 0:
210+
raise Exception(ERROR_MAPPER[WINDOW_SIZE_ERROR])
211+
212+
mask = np.concatenate(mask, axis=2)
213+
return mask
214+
215+
def postprocess(mask, X_mag, X_phase):
216+
is_non_accom_stem = False
217+
for stem in NON_ACCOM_STEMS:
218+
if stem == self.primary_stem:
219+
is_non_accom_stem = True
220+
221+
mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness)
222+
223+
if self.is_post_process:
224+
mask = spec_utils.merge_artifacts(mask, thres=self.post_process_threshold)
225+
226+
y_spec = mask * X_mag * np.exp(1.0j * X_phase)
227+
v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase)
228+
229+
return y_spec, v_spec
230+
231+
X_mag, X_phase = spec_utils.preprocess(X_spec)
232+
n_frame = X_mag.shape[2]
233+
pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset)
234+
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
235+
X_mag_pad /= X_mag_pad.max()
236+
mask = _execute(X_mag_pad, roi_size)
237+
238+
if self.is_tta:
239+
pad_l += roi_size // 2
240+
pad_r += roi_size // 2
241+
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
242+
X_mag_pad /= X_mag_pad.max()
243+
mask_tta = _execute(X_mag_pad, roi_size)
244+
mask_tta = mask_tta[:, :, roi_size // 2 :]
245+
mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5
246+
else:
247+
mask = mask[:, :, :n_frame]
248+
249+
y_spec, v_spec = postprocess(mask, X_mag, X_phase)
250+
251+
return y_spec, v_spec
252+
253+
def spec_to_wav(self, spec):
254+
if self.high_end_process.startswith("mirroring") and isinstance(self.input_high_end, np.ndarray) and self.input_high_end_h:
255+
input_high_end_ = spec_utils.mirroring(self.high_end_process, spec, self.input_high_end, self.mp)
256+
wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, self.input_high_end_h, input_high_end_, is_v51_model=self.is_vr_51_model)
257+
else:
258+
wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, is_v51_model=self.is_vr_51_model)
259+
260+
return wav

0 commit comments

Comments
 (0)