Skip to content

Commit a0bfc27

Browse files
committed
Refactored existing MDX separation into separate class with common class, began adding VR code
1 parent 64f1982 commit a0bfc27

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+4349
-251
lines changed
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from .mdx_separator import MDXSeparator
2+
from .vr_separator import VRSeparator
3+

audio_separator/separator/architectures/mdx_separator.py

+37-94
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,32 @@
1+
"""Module for separating audio sources using MDX architecture models."""
2+
13
import os
24
import torch
35
import librosa
46
import onnxruntime as ort
57
import numpy as np
6-
from onnx2torch import convert
8+
import onnx2torch
79
from audio_separator.separator import spec_utils
810
from audio_separator.separator.stft import STFT
11+
from audio_separator.separator.common_separator import CommonSeparator
912

1013

11-
class MDXSeparator:
14+
class MDXSeparator(CommonSeparator):
1215
"""
13-
MDXSeparator is responsible for separating audio sources using the MDX model.
16+
MDXSeparator is responsible for separating audio sources using MDX models.
1417
It initializes with configuration parameters and prepares the model for separation tasks.
1518
"""
1619

17-
def __init__(self, logger, write_audio, separator_params):
18-
self.logger = logger
19-
self.write_audio = write_audio
20-
self.separator_params = separator_params
21-
22-
self.model_name = separator_params["model_name"]
23-
self.model_data = separator_params["model_data"]
24-
self.model_path = separator_params["model_path"]
25-
26-
self.primary_stem_path = separator_params["primary_stem_path"]
27-
self.secondary_stem_path = separator_params["secondary_stem_path"]
28-
self.output_format = separator_params["output_format"]
29-
self.output_subtype = separator_params["output_subtype"]
30-
self.normalization_threshold = separator_params["normalization_threshold"]
31-
self.denoise_enabled = separator_params["denoise_enabled"]
32-
self.output_single_stem = separator_params["output_single_stem"]
33-
self.invert_using_spec = separator_params["invert_using_spec"]
34-
self.sample_rate = separator_params["sample_rate"]
35-
self.hop_length = separator_params["hop_length"]
36-
self.segment_size = separator_params["segment_size"]
37-
self.overlap = separator_params["overlap"]
38-
self.batch_size = separator_params["batch_size"]
39-
self.device = separator_params["device"]
40-
self.onnx_execution_provider = separator_params["onnx_execution_provider"]
41-
42-
# Initializing model parameters
43-
self.compensate, self.dim_f, self.dim_t, self.n_fft, self.model_primary_stem = (
44-
self.model_data["compensate"],
45-
self.model_data["mdx_dim_f_set"],
46-
2 ** self.model_data["mdx_dim_t_set"],
47-
self.model_data["mdx_n_fft_scale_set"],
48-
self.model_data["primary_stem"],
49-
)
50-
self.model_secondary_stem = "Vocals" if self.model_primary_stem == "Instrumental" else "Instrumental"
51-
52-
# In UVR, these variables are set but either aren't useful or are better handled in audio-separator.
53-
# Leaving these comments explaining to help myself or future developers understand why these aren't in audio-separator.
54-
55-
# "chunks" is not actually used for anything in UVR...
56-
# self.chunks = 0
57-
58-
# "adjust" is hard-coded to 1 in UVR, and only used as a multiplier in run_model, so it does nothing.
59-
# self.adjust = 1
60-
61-
# "hop" is hard-coded to 1024 in UVR. We have a "hop_length" parameter instead
62-
# self.hop = 1024
63-
64-
# "margin" maps to sample rate and is set from the GUI in UVR (default: 44100). We have a "sample_rate" parameter instead.
65-
# self.margin = 44100
66-
67-
# "dim_c" is hard-coded to 4 in UVR, seems to be a parameter for the number of channels, and is only used for checkpoint models.
68-
# We haven't implemented support for the checkpoint models here, so we're not using it.
69-
# self.dim_c = 4
70-
71-
self.logger.debug(f"Model params: primary_stem={self.model_primary_stem}, secondary_stem={self.model_secondary_stem}")
72-
self.logger.debug(
73-
f"Model params: batch_size={self.batch_size}, compensate={self.compensate}, segment_size={self.segment_size}, dim_f={self.dim_f}, dim_t={self.dim_t}"
74-
)
20+
def __init__(self, common_config, arch_config):
21+
super().__init__(config=common_config)
22+
23+
self.hop_length = arch_config.get("hop_length")
24+
self.segment_size = arch_config.get("segment_size")
25+
self.overlap = arch_config.get("overlap")
26+
self.batch_size = arch_config.get("batch_size")
27+
28+
self.logger.debug(f"Model params: primary_stem={self.primary_stem_name}, secondary_stem={self.secondary_stem_name}")
29+
self.logger.debug(f"Model params: batch_size={self.batch_size}, compensate={self.compensate}, segment_size={self.segment_size}, dim_f={self.dim_f}, dim_t={self.dim_t}")
7530
self.logger.debug(f"Model params: n_fft={self.n_fft}, hop={self.hop_length}")
7631

7732
# Loading the model for inference
@@ -81,8 +36,8 @@ def __init__(self, logger, write_audio, separator_params):
8136
self.model_run = lambda spek: ort_.run(None, {"input": spek.cpu().numpy()})[0]
8237
self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.")
8338
else:
84-
self.model_run = convert(self.model_path)
85-
self.model_run.to(self.device).eval()
39+
self.model_run = onnx2torch.convert(self.model_path)
40+
self.model_run.to(self.torch_device).eval()
8641
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")
8742

8843
self.n_bins = None
@@ -149,29 +104,21 @@ def separate(self, audio_file_path):
149104
self.secondary_source = mix.T - source.T
150105

151106
# Save and process the secondary stem if needed
152-
if not self.output_single_stem or self.output_single_stem.lower() == self.model_secondary_stem.lower():
153-
self.logger.info(f"Saving {self.model_secondary_stem} stem...")
107+
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
108+
self.logger.info(f"Saving {self.secondary_stem_name} stem...")
154109
if not self.secondary_stem_path:
155-
self.secondary_stem_path = os.path.join(
156-
f"{self.audio_file_base}_({self.model_secondary_stem})_{self.model_name}.{self.output_format.lower()}"
157-
)
158-
self.secondary_source_map = self.final_process(
159-
self.secondary_stem_path, self.secondary_source, self.model_secondary_stem, self.sample_rate
160-
)
110+
self.secondary_stem_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
111+
self.secondary_source_map = self.final_process(self.secondary_stem_path, self.secondary_source, self.secondary_stem_name)
161112
output_files.append(self.secondary_stem_path)
162113

163114
# Save and process the primary stem if needed
164-
if not self.output_single_stem or self.output_single_stem.lower() == self.model_primary_stem.lower():
165-
self.logger.info(f"Saving {self.model_primary_stem} stem...")
115+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
116+
self.logger.info(f"Saving {self.primary_stem_name} stem...")
166117
if not self.primary_stem_path:
167-
self.primary_stem_path = os.path.join(
168-
f"{self.audio_file_base}_({self.model_primary_stem})_{self.model_name}.{self.output_format.lower()}"
169-
)
118+
self.primary_stem_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
170119
if not isinstance(self.primary_source, np.ndarray):
171120
self.primary_source = source.T
172-
self.primary_source_map = self.final_process(
173-
self.primary_stem_path, self.primary_source, self.model_primary_stem, self.sample_rate
174-
)
121+
self.primary_source_map = self.final_process(self.primary_stem_path, self.primary_source, self.primary_stem_name)
175122
output_files.append(self.primary_stem_path)
176123

177124
# TODO: In UVR, this is where the vocal split chain gets processed - see process_vocal_split_chain()
@@ -198,7 +145,7 @@ def initialize_model_settings(self):
198145
# gen_size is the chunk size minus twice the trim size
199146
self.gen_size = self.chunk_size - 2 * self.trim
200147

201-
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.device)
148+
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
202149

203150
self.logger.debug(f"Model input params: n_fft={self.n_fft} hop_length={self.hop_length} dim_f={self.dim_f}")
204151
self.logger.debug(f"Model settings: n_bins={self.n_bins}, trim={self.trim}, chunk_size={self.chunk_size}, gen_size={self.gen_size}")
@@ -253,7 +200,7 @@ def initialize_mix(self, mix, is_ckpt=False):
253200
i += self.gen_size
254201

255202
# Convert the list of wave chunks into a tensor for processing on the specified device
256-
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
203+
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
257204
self.logger.debug(f"Converted mix_waves to tensor. Tensor shape: {mix_waves_tensor.shape}")
258205

259206
return mix_waves_tensor, pad
@@ -334,7 +281,7 @@ def demix(self, mix, is_match_mix=False):
334281
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
335282

336283
# Converts the chunk to a tensor for processing.
337-
mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.device)
284+
mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device)
338285
# Splits the chunk into smaller batches if necessary.
339286
mix_waves = mix_part.split(self.batch_size)
340287
total_batches = len(mix_waves)
@@ -376,6 +323,7 @@ def demix(self, mix, is_match_mix=False):
376323

377324
# Compensates the source if not matching the mix.
378325
if not is_match_mix:
326+
# TODO: Investigate whether fixing this bug actually does anything!
379327
source * self.compensate
380328
self.logger.debug("Match mix mode; compensate multiplier applied.")
381329

@@ -391,7 +339,7 @@ def run_model(self, mix, is_match_mix=False):
391339
"""
392340
# Applying the STFT to the mix. The mix is moved to the specified device (e.g., GPU) before processing.
393341
# self.logger.debug(f"Running STFT on the mix. Mix shape before STFT: {mix.shape}")
394-
spek = self.stft(mix.to(self.device))
342+
spek = self.stft(mix.to(self.torch_device))
395343
self.logger.debug(f"STFT applied on mix. Spectrum shape: {spek.shape}")
396344

397345
# Zeroing out the first 3 bins of the spectrum. This is often done to reduce low-frequency noise.
@@ -406,14 +354,18 @@ def run_model(self, mix, is_match_mix=False):
406354
else:
407355
# If denoising is enabled, the model is run on both the negative and positive spectrums.
408356
if self.denoise_enabled:
409-
spec_pred = -self.model_run(-spek) * 0.5 + self.model_run(spek) * 0.5
357+
# Assuming spek is a tensor and self.model_run can process it directly
358+
spec_pred_neg = self.model_run(-spek) # Ensure this line correctly negates spek and runs the model
359+
spec_pred_pos = self.model_run(spek)
360+
# Ensure both spec_pred_neg and spec_pred_pos are tensors before applying operations
361+
spec_pred = (-spec_pred_neg * 0.5) + (spec_pred_pos * 0.5) # [invalid-unary-operand-type]
410362
self.logger.debug("Model run on both negative and positive spectrums for denoising.")
411363
else:
412364
spec_pred = self.model_run(spek)
413365
self.logger.debug("Model run on the spectrum without denoising.")
414366

415367
# Applying the inverse STFT to convert the spectrum back to the time domain.
416-
result = self.stft.inverse(torch.tensor(spec_pred).to(self.device)).cpu().detach().numpy()
368+
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
417369
self.logger.debug(f"Inverse STFT applied. Returning result with shape: {result.shape}")
418370

419371
return result
@@ -455,12 +407,3 @@ def prepare_mix(self, mix):
455407
# Final log indicating successful preparation of the mix
456408
self.logger.debug("Mix preparation completed.")
457409
return mix
458-
459-
def final_process(self, stem_path, source, stem_name, sample_rate):
460-
"""
461-
Finalizes the processing of a stem by writing the audio to a file and returning the processed source.
462-
"""
463-
self.logger.debug(f"Finalizing {stem_name} stem processing and writing audio...")
464-
self.write_audio(stem_path, source, sample_rate, stem_name=stem_name)
465-
466-
return {stem_name: source}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Module for separating audio sources using VR architecture models."""
2+
3+
import os
4+
import torch
5+
import librosa
6+
import onnxruntime as ort
7+
import numpy as np
8+
import onnx2torch
9+
from audio_separator.separator import spec_utils
10+
from audio_separator.separator.stft import STFT
11+
from audio_separator.separator.common_separator import CommonSeparator
12+
13+
14+
class VRSeparator(CommonSeparator):
15+
"""
16+
VRSeparator is responsible for separating audio sources using VR models.
17+
It initializes with configuration parameters and prepares the model for separation tasks.
18+
"""
19+
20+
def __init__(self, common_config, arch_config):
21+
super().__init__(config=common_config)
22+
23+
self.hop_length = arch_config.get("hop_length")
24+
self.segment_size = arch_config.get("segment_size")
25+
self.overlap = arch_config.get("overlap")
26+
self.batch_size = arch_config.get("batch_size")
27+
28+
self.logger.debug(f"Model params: primary_stem={self.primary_stem_name}, secondary_stem={self.secondary_stem_name}")
29+
self.logger.debug(f"Model params: batch_size={self.batch_size}, compensate={self.compensate}, segment_size={self.segment_size}, dim_f={self.dim_f}, dim_t={self.dim_t}")
30+
self.logger.debug(f"Model params: n_fft={self.n_fft}, hop={self.hop_length}")
31+
32+
# Loading the model for inference
33+
self.logger.debug("Loading ONNX model for inference...")
34+
if self.segment_size == self.dim_t:
35+
ort_ = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider)
36+
self.model_run = lambda spek: ort_.run(None, {"input": spek.cpu().numpy()})[0]
37+
self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.")
38+
else:
39+
self.model_run = onnx2torch.convert(self.model_path)
40+
self.model_run.to(self.torch_device).eval()
41+
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")
42+
43+
self.n_bins = None
44+
self.trim = None
45+
self.chunk_size = None
46+
self.gen_size = None
47+
self.stft = None
48+
49+
self.primary_source = None
50+
self.secondary_source = None
51+
self.audio_file_path = None
52+
self.audio_file_base = None
53+
self.secondary_source_map = None
54+
self.primary_source_map = None

0 commit comments

Comments
 (0)