1
+ """Module for separating audio sources using MDX architecture models."""
2
+
1
3
import os
2
4
import torch
3
5
import librosa
4
6
import onnxruntime as ort
5
7
import numpy as np
6
- from onnx2torch import convert
8
+ import onnx2torch
7
9
from audio_separator .separator import spec_utils
8
10
from audio_separator .separator .stft import STFT
11
+ from audio_separator .separator .common_separator import CommonSeparator
9
12
10
13
11
- class MDXSeparator :
14
+ class MDXSeparator ( CommonSeparator ) :
12
15
"""
13
- MDXSeparator is responsible for separating audio sources using the MDX model .
16
+ MDXSeparator is responsible for separating audio sources using MDX models .
14
17
It initializes with configuration parameters and prepares the model for separation tasks.
15
18
"""
16
19
17
- def __init__ (self , logger , write_audio , separator_params ):
18
- self .logger = logger
19
- self .write_audio = write_audio
20
- self .separator_params = separator_params
21
-
22
- self .model_name = separator_params ["model_name" ]
23
- self .model_data = separator_params ["model_data" ]
24
- self .model_path = separator_params ["model_path" ]
25
-
26
- self .primary_stem_path = separator_params ["primary_stem_path" ]
27
- self .secondary_stem_path = separator_params ["secondary_stem_path" ]
28
- self .output_format = separator_params ["output_format" ]
29
- self .output_subtype = separator_params ["output_subtype" ]
30
- self .normalization_threshold = separator_params ["normalization_threshold" ]
31
- self .denoise_enabled = separator_params ["denoise_enabled" ]
32
- self .output_single_stem = separator_params ["output_single_stem" ]
33
- self .invert_using_spec = separator_params ["invert_using_spec" ]
34
- self .sample_rate = separator_params ["sample_rate" ]
35
- self .hop_length = separator_params ["hop_length" ]
36
- self .segment_size = separator_params ["segment_size" ]
37
- self .overlap = separator_params ["overlap" ]
38
- self .batch_size = separator_params ["batch_size" ]
39
- self .device = separator_params ["device" ]
40
- self .onnx_execution_provider = separator_params ["onnx_execution_provider" ]
41
-
42
- # Initializing model parameters
43
- self .compensate , self .dim_f , self .dim_t , self .n_fft , self .model_primary_stem = (
44
- self .model_data ["compensate" ],
45
- self .model_data ["mdx_dim_f_set" ],
46
- 2 ** self .model_data ["mdx_dim_t_set" ],
47
- self .model_data ["mdx_n_fft_scale_set" ],
48
- self .model_data ["primary_stem" ],
49
- )
50
- self .model_secondary_stem = "Vocals" if self .model_primary_stem == "Instrumental" else "Instrumental"
51
-
52
- # In UVR, these variables are set but either aren't useful or are better handled in audio-separator.
53
- # Leaving these comments explaining to help myself or future developers understand why these aren't in audio-separator.
54
-
55
- # "chunks" is not actually used for anything in UVR...
56
- # self.chunks = 0
57
-
58
- # "adjust" is hard-coded to 1 in UVR, and only used as a multiplier in run_model, so it does nothing.
59
- # self.adjust = 1
60
-
61
- # "hop" is hard-coded to 1024 in UVR. We have a "hop_length" parameter instead
62
- # self.hop = 1024
63
-
64
- # "margin" maps to sample rate and is set from the GUI in UVR (default: 44100). We have a "sample_rate" parameter instead.
65
- # self.margin = 44100
66
-
67
- # "dim_c" is hard-coded to 4 in UVR, seems to be a parameter for the number of channels, and is only used for checkpoint models.
68
- # We haven't implemented support for the checkpoint models here, so we're not using it.
69
- # self.dim_c = 4
70
-
71
- self .logger .debug (f"Model params: primary_stem={ self .model_primary_stem } , secondary_stem={ self .model_secondary_stem } " )
72
- self .logger .debug (
73
- f"Model params: batch_size={ self .batch_size } , compensate={ self .compensate } , segment_size={ self .segment_size } , dim_f={ self .dim_f } , dim_t={ self .dim_t } "
74
- )
20
+ def __init__ (self , common_config , arch_config ):
21
+ super ().__init__ (config = common_config )
22
+
23
+ self .hop_length = arch_config .get ("hop_length" )
24
+ self .segment_size = arch_config .get ("segment_size" )
25
+ self .overlap = arch_config .get ("overlap" )
26
+ self .batch_size = arch_config .get ("batch_size" )
27
+
28
+ self .logger .debug (f"Model params: primary_stem={ self .primary_stem_name } , secondary_stem={ self .secondary_stem_name } " )
29
+ self .logger .debug (f"Model params: batch_size={ self .batch_size } , compensate={ self .compensate } , segment_size={ self .segment_size } , dim_f={ self .dim_f } , dim_t={ self .dim_t } " )
75
30
self .logger .debug (f"Model params: n_fft={ self .n_fft } , hop={ self .hop_length } " )
76
31
77
32
# Loading the model for inference
@@ -81,8 +36,8 @@ def __init__(self, logger, write_audio, separator_params):
81
36
self .model_run = lambda spek : ort_ .run (None , {"input" : spek .cpu ().numpy ()})[0 ]
82
37
self .logger .debug ("Model loaded successfully using ONNXruntime inferencing session." )
83
38
else :
84
- self .model_run = convert (self .model_path )
85
- self .model_run .to (self .device ).eval ()
39
+ self .model_run = onnx2torch . convert (self .model_path )
40
+ self .model_run .to (self .torch_device ).eval ()
86
41
self .logger .warning ("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower." )
87
42
88
43
self .n_bins = None
@@ -149,29 +104,21 @@ def separate(self, audio_file_path):
149
104
self .secondary_source = mix .T - source .T
150
105
151
106
# Save and process the secondary stem if needed
152
- if not self .output_single_stem or self .output_single_stem .lower () == self .model_secondary_stem .lower ():
153
- self .logger .info (f"Saving { self .model_secondary_stem } stem..." )
107
+ if not self .output_single_stem or self .output_single_stem .lower () == self .secondary_stem_name .lower ():
108
+ self .logger .info (f"Saving { self .secondary_stem_name } stem..." )
154
109
if not self .secondary_stem_path :
155
- self .secondary_stem_path = os .path .join (
156
- f"{ self .audio_file_base } _({ self .model_secondary_stem } )_{ self .model_name } .{ self .output_format .lower ()} "
157
- )
158
- self .secondary_source_map = self .final_process (
159
- self .secondary_stem_path , self .secondary_source , self .model_secondary_stem , self .sample_rate
160
- )
110
+ self .secondary_stem_path = os .path .join (f"{ self .audio_file_base } _({ self .secondary_stem_name } )_{ self .model_name } .{ self .output_format .lower ()} " )
111
+ self .secondary_source_map = self .final_process (self .secondary_stem_path , self .secondary_source , self .secondary_stem_name )
161
112
output_files .append (self .secondary_stem_path )
162
113
163
114
# Save and process the primary stem if needed
164
- if not self .output_single_stem or self .output_single_stem .lower () == self .model_primary_stem .lower ():
165
- self .logger .info (f"Saving { self .model_primary_stem } stem..." )
115
+ if not self .output_single_stem or self .output_single_stem .lower () == self .primary_stem_name .lower ():
116
+ self .logger .info (f"Saving { self .primary_stem_name } stem..." )
166
117
if not self .primary_stem_path :
167
- self .primary_stem_path = os .path .join (
168
- f"{ self .audio_file_base } _({ self .model_primary_stem } )_{ self .model_name } .{ self .output_format .lower ()} "
169
- )
118
+ self .primary_stem_path = os .path .join (f"{ self .audio_file_base } _({ self .primary_stem_name } )_{ self .model_name } .{ self .output_format .lower ()} " )
170
119
if not isinstance (self .primary_source , np .ndarray ):
171
120
self .primary_source = source .T
172
- self .primary_source_map = self .final_process (
173
- self .primary_stem_path , self .primary_source , self .model_primary_stem , self .sample_rate
174
- )
121
+ self .primary_source_map = self .final_process (self .primary_stem_path , self .primary_source , self .primary_stem_name )
175
122
output_files .append (self .primary_stem_path )
176
123
177
124
# TODO: In UVR, this is where the vocal split chain gets processed - see process_vocal_split_chain()
@@ -198,7 +145,7 @@ def initialize_model_settings(self):
198
145
# gen_size is the chunk size minus twice the trim size
199
146
self .gen_size = self .chunk_size - 2 * self .trim
200
147
201
- self .stft = STFT (self .logger , self .n_fft , self .hop_length , self .dim_f , self .device )
148
+ self .stft = STFT (self .logger , self .n_fft , self .hop_length , self .dim_f , self .torch_device )
202
149
203
150
self .logger .debug (f"Model input params: n_fft={ self .n_fft } hop_length={ self .hop_length } dim_f={ self .dim_f } " )
204
151
self .logger .debug (f"Model settings: n_bins={ self .n_bins } , trim={ self .trim } , chunk_size={ self .chunk_size } , gen_size={ self .gen_size } " )
@@ -253,7 +200,7 @@ def initialize_mix(self, mix, is_ckpt=False):
253
200
i += self .gen_size
254
201
255
202
# Convert the list of wave chunks into a tensor for processing on the specified device
256
- mix_waves_tensor = torch .tensor (mix_waves , dtype = torch .float32 ).to (self .device )
203
+ mix_waves_tensor = torch .tensor (mix_waves , dtype = torch .float32 ).to (self .torch_device )
257
204
self .logger .debug (f"Converted mix_waves to tensor. Tensor shape: { mix_waves_tensor .shape } " )
258
205
259
206
return mix_waves_tensor , pad
@@ -334,7 +281,7 @@ def demix(self, mix, is_match_mix=False):
334
281
mix_part_ = np .concatenate ((mix_part_ , np .zeros ((2 , pad_size ), dtype = "float32" )), axis = - 1 )
335
282
336
283
# Converts the chunk to a tensor for processing.
337
- mix_part = torch .tensor ([mix_part_ ], dtype = torch .float32 ).to (self .device )
284
+ mix_part = torch .tensor ([mix_part_ ], dtype = torch .float32 ).to (self .torch_device )
338
285
# Splits the chunk into smaller batches if necessary.
339
286
mix_waves = mix_part .split (self .batch_size )
340
287
total_batches = len (mix_waves )
@@ -376,6 +323,7 @@ def demix(self, mix, is_match_mix=False):
376
323
377
324
# Compensates the source if not matching the mix.
378
325
if not is_match_mix :
326
+ # TODO: Investigate whether fixing this bug actually does anything!
379
327
source * self .compensate
380
328
self .logger .debug ("Match mix mode; compensate multiplier applied." )
381
329
@@ -391,7 +339,7 @@ def run_model(self, mix, is_match_mix=False):
391
339
"""
392
340
# Applying the STFT to the mix. The mix is moved to the specified device (e.g., GPU) before processing.
393
341
# self.logger.debug(f"Running STFT on the mix. Mix shape before STFT: {mix.shape}")
394
- spek = self .stft (mix .to (self .device ))
342
+ spek = self .stft (mix .to (self .torch_device ))
395
343
self .logger .debug (f"STFT applied on mix. Spectrum shape: { spek .shape } " )
396
344
397
345
# Zeroing out the first 3 bins of the spectrum. This is often done to reduce low-frequency noise.
@@ -406,14 +354,18 @@ def run_model(self, mix, is_match_mix=False):
406
354
else :
407
355
# If denoising is enabled, the model is run on both the negative and positive spectrums.
408
356
if self .denoise_enabled :
409
- spec_pred = - self .model_run (- spek ) * 0.5 + self .model_run (spek ) * 0.5
357
+ # Assuming spek is a tensor and self.model_run can process it directly
358
+ spec_pred_neg = self .model_run (- spek ) # Ensure this line correctly negates spek and runs the model
359
+ spec_pred_pos = self .model_run (spek )
360
+ # Ensure both spec_pred_neg and spec_pred_pos are tensors before applying operations
361
+ spec_pred = (- spec_pred_neg * 0.5 ) + (spec_pred_pos * 0.5 ) # [invalid-unary-operand-type]
410
362
self .logger .debug ("Model run on both negative and positive spectrums for denoising." )
411
363
else :
412
364
spec_pred = self .model_run (spek )
413
365
self .logger .debug ("Model run on the spectrum without denoising." )
414
366
415
367
# Applying the inverse STFT to convert the spectrum back to the time domain.
416
- result = self .stft .inverse (torch .tensor (spec_pred ).to (self .device )).cpu ().detach ().numpy ()
368
+ result = self .stft .inverse (torch .tensor (spec_pred ).to (self .torch_device )).cpu ().detach ().numpy ()
417
369
self .logger .debug (f"Inverse STFT applied. Returning result with shape: { result .shape } " )
418
370
419
371
return result
@@ -455,12 +407,3 @@ def prepare_mix(self, mix):
455
407
# Final log indicating successful preparation of the mix
456
408
self .logger .debug ("Mix preparation completed." )
457
409
return mix
458
-
459
- def final_process (self , stem_path , source , stem_name , sample_rate ):
460
- """
461
- Finalizes the processing of a stem by writing the audio to a file and returning the processed source.
462
- """
463
- self .logger .debug (f"Finalizing { stem_name } stem processing and writing audio..." )
464
- self .write_audio (stem_path , source , sample_rate , stem_name = stem_name )
465
-
466
- return {stem_name : source }
0 commit comments