Skip to content

Commit ae7e422

Browse files
committed
Updated model data fetching to support VR models, added method to list available models, solidified properties in common class vs. arch-specific
1 parent e7e8d45 commit ae7e422

File tree

6 files changed

+306
-214
lines changed

6 files changed

+306
-214
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,10 @@ separator = Separator()
168168
separator.load_model()
169169

170170
# Perform the separation on specific audio files without reloading the model
171-
primary_stem_path, secondary_stem_path = separator.separate('audio1.wav')
171+
primary_stem_output_path, secondary_stem_output_path = separator.separate('audio1.wav')
172172

173-
print(f'Primary stem saved at {primary_stem_path}')
174-
print(f'Secondary stem saved at {secondary_stem_path}')
173+
print(f'Primary stem saved at {primary_stem_output_path}')
174+
print(f'Secondary stem saved at {secondary_stem_output_path}')
175175
```
176176
177177
#### Batch processing, or processing with multiple models

audio_separator/separator/architectures/mdx_separator.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ def __init__(self, common_config, arch_config):
2626
self.overlap = arch_config.get("overlap")
2727
self.batch_size = arch_config.get("batch_size")
2828

29+
# Initializing model parameters
30+
self.compensate = self.model_data["compensate"]
31+
self.dim_f = self.model_data["mdx_dim_f_set"]
32+
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
33+
self.n_fft = self.model_data["mdx_n_fft_scale_set"]
34+
35+
self.config_yaml = self.model_data.get("config_yaml", None)
36+
2937
self.logger.debug(f"Model params: primary_stem={self.primary_stem_name}, secondary_stem={self.secondary_stem_name}")
3038
self.logger.debug(f"Model params: batch_size={self.batch_size}, compensate={self.compensate}, segment_size={self.segment_size}, dim_f={self.dim_f}, dim_t={self.dim_t}")
3139
self.logger.debug(f"Model params: n_fft={self.n_fft}, hop={self.hop_length}")
@@ -107,20 +115,20 @@ def separate(self, audio_file_path):
107115
# Save and process the secondary stem if needed
108116
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
109117
self.logger.info(f"Saving {self.secondary_stem_name} stem...")
110-
if not self.secondary_stem_path:
111-
self.secondary_stem_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
112-
self.secondary_source_map = self.final_process(self.secondary_stem_path, self.secondary_source, self.secondary_stem_name)
113-
output_files.append(self.secondary_stem_path)
118+
if not self.secondary_stem_output_path:
119+
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
120+
self.secondary_source_map = self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
121+
output_files.append(self.secondary_stem_output_path)
114122

115123
# Save and process the primary stem if needed
116124
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
117125
self.logger.info(f"Saving {self.primary_stem_name} stem...")
118-
if not self.primary_stem_path:
119-
self.primary_stem_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
126+
if not self.primary_stem_output_path:
127+
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
120128
if not isinstance(self.primary_source, np.ndarray):
121129
self.primary_source = source.T
122-
self.primary_source_map = self.final_process(self.primary_stem_path, self.primary_source, self.primary_stem_name)
123-
output_files.append(self.primary_stem_path)
130+
self.primary_source_map = self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
131+
output_files.append(self.primary_stem_output_path)
124132

125133
# TODO: In UVR, this is where the vocal split chain gets processed - see process_vocal_split_chain()
126134

audio_separator/separator/architectures/vr_separator.py

+46-63
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Module for separating audio sources using VR architecture models."""
22

33
import os
4+
import sys
5+
import math
6+
47
import torch
58
import librosa
6-
import onnxruntime as ort
79
import numpy as np
8-
import onnx2torch
10+
911
from audio_separator.separator import spec_utils
10-
from audio_separator.separator.stft import STFT
1112
from audio_separator.separator.common_separator import CommonSeparator
1213

1314

@@ -20,31 +21,10 @@ class VRSeparator(CommonSeparator):
2021
def __init__(self, common_config, arch_config):
2122
super().__init__(config=common_config)
2223

23-
self.hop_length = arch_config.get("hop_length")
24-
self.segment_size = arch_config.get("segment_size")
25-
self.overlap = arch_config.get("overlap")
26-
self.batch_size = arch_config.get("batch_size")
24+
self.logger.debug(f"Model data: ", self.model_data)
2725

2826
self.logger.debug(f"Model params: primary_stem={self.primary_stem_name}, secondary_stem={self.secondary_stem_name}")
29-
self.logger.debug(f"Model params: batch_size={self.batch_size}, compensate={self.compensate}, segment_size={self.segment_size}, dim_f={self.dim_f}, dim_t={self.dim_t}")
30-
self.logger.debug(f"Model params: n_fft={self.n_fft}, hop={self.hop_length}")
31-
32-
# Loading the model for inference
33-
self.logger.debug("Loading ONNX model for inference...")
34-
if self.segment_size == self.dim_t:
35-
ort_ = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider)
36-
self.model_run = lambda spek: ort_.run(None, {"input": spek.cpu().numpy()})[0]
37-
self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.")
38-
else:
39-
self.model_run = onnx2torch.convert(self.model_path)
40-
self.model_run.to(self.torch_device).eval()
41-
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")
4227

43-
self.n_bins = None
44-
self.trim = None
45-
self.chunk_size = None
46-
self.gen_size = None
47-
self.stft = None
4828

4929
self.primary_source = None
5030
self.secondary_source = None
@@ -53,49 +33,52 @@ def __init__(self, common_config, arch_config):
5333
self.secondary_source_map = None
5434
self.primary_source_map = None
5535

36+
self.is_vr_51_model = model_data.is_vr_51_model
37+
38+
def separate(self, audio_file_path):
39+
"""
40+
Separates the audio file into primary and secondary sources based on the model's configuration.
41+
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
5642
43+
Args:
44+
audio_file_path (str): The path to the audio file to be processed.
5745
46+
Returns:
47+
list: A list of paths to the output files generated by the separation process.
48+
"""
49+
self.primary_source = None
50+
self.secondary_source = None
5851

52+
self.audio_file_path = audio_file_path
53+
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
5954

55+
self.logger.debug("Starting inference...")
6056

61-
def seperate(self):
62-
self.logger.debug("Starting separation process in SeperateVR...")
63-
if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, tuple):
64-
self.logger.debug("Using cached primary sources...")
65-
y_spec, v_spec = self.primary_sources
66-
self.load_cached_sources()
57+
nn_arch_sizes = [31191, 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227] # default
58+
vr_5_1_models = [56817, 218409]
59+
model_size = math.ceil(os.stat(self.model_path).st_size / 1024)
60+
nn_arch_size = min(nn_arch_sizes, key=lambda x: abs(x - model_size))
61+
self.logger.debug(f"Model size determined: {model_size}, NN architecture size: {nn_arch_size}")
62+
63+
if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
64+
self.logger.debug("Using CascadedNet for VR 5.1 model...")
65+
self.model_run = nets_new.CascadedNet(self.mp.param["bins"] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1])
66+
self.is_vr_51_model = True
6767
else:
68-
self.logger.debug("Starting inference...")
69-
self.start_inference_console_write()
70-
71-
device = self.device
72-
self.logger.debug(f"Device set to: {device}")
73-
74-
nn_arch_sizes = [31191, 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227] # default
75-
vr_5_1_models = [56817, 218409]
76-
model_size = math.ceil(os.stat(self.model_path).st_size / 1024)
77-
nn_arch_size = min(nn_arch_sizes, key=lambda x: abs(x - model_size))
78-
self.logger.debug(f"Model size determined: {model_size}, NN architecture size: {nn_arch_size}")
79-
80-
if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
81-
self.logger.debug("Using CascadedNet for VR 5.1 model...")
82-
self.model_run = nets_new.CascadedNet(self.mp.param["bins"] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1])
83-
self.is_vr_51_model = True
84-
else:
85-
self.logger.debug("Determining model capacity...")
86-
self.model_run = nets.determine_model_capacity(self.mp.param["bins"] * 2, nn_arch_size)
68+
self.logger.debug("Determining model capacity...")
69+
self.model_run = nets.determine_model_capacity(self.mp.param["bins"] * 2, nn_arch_size)
8770

88-
self.model_run.load_state_dict(torch.load(self.model_path, map_location=cpu))
89-
self.model_run.to(device)
90-
self.logger.debug("Model loaded and moved to device.")
71+
self.model_run.load_state_dict(torch.load(self.model_path, map_location=cpu))
72+
self.model_run.to(device)
73+
self.logger.debug("Model loaded and moved to device.")
9174

92-
self.running_inference_console_write()
75+
self.running_inference_console_write()
9376

94-
y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness)
95-
self.logger.debug("Inference completed.")
96-
if not self.is_vocal_split_model:
97-
self.cache_source((y_spec, v_spec))
98-
self.write_to_console(DONE, base_text="")
77+
y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness)
78+
self.logger.debug("Inference completed.")
79+
if not self.is_vocal_split_model:
80+
self.cache_source((y_spec, v_spec))
81+
self.write_to_console(DONE, base_text="")
9982

10083
if self.is_secondary_model_activated and self.secondary_model:
10184
self.logger.debug("Processing secondary model...")
@@ -104,7 +87,7 @@ def seperate(self):
10487
)
10588

10689
if not self.is_secondary_stem_only:
107-
primary_stem_path = os.path.join(self.export_path, f"{self.audio_file_base}_({self.primary_stem}).wav")
90+
primary_stem_output_path = os.path.join(self.export_path, f"{self.audio_file_base}_({self.primary_stem}).wav")
10891
self.logger.debug(f"Processing primary stem: {self.primary_stem}")
10992
if not isinstance(self.primary_source, np.ndarray):
11093
self.primary_source = self.spec_to_wav(y_spec).T
@@ -113,11 +96,11 @@ def seperate(self):
11396
self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
11497
self.logger.debug("Resampling primary source to 44100Hz.")
11598

116-
self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, 44100)
99+
self.primary_source_map = self.final_process(primary_stem_output_path, self.primary_source, self.secondary_source_primary, self.primary_stem, 44100)
117100
self.logger.debug("Primary stem processed.")
118101

119102
if not self.is_primary_stem_only:
120-
secondary_stem_path = os.path.join(self.export_path, f"{self.audio_file_base}_({self.secondary_stem}).wav")
103+
secondary_stem_output_path = os.path.join(self.export_path, f"{self.audio_file_base}_({self.secondary_stem}).wav")
121104
self.logger.debug(f"Processing secondary stem: {self.secondary_stem}")
122105
if not isinstance(self.secondary_source, np.ndarray):
123106
self.secondary_source = self.spec_to_wav(v_spec).T
@@ -126,7 +109,7 @@ def seperate(self):
126109
self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
127110
self.logger.debug("Resampling secondary source to 44100Hz.")
128111

129-
self.secondary_source_map = self.final_process(secondary_stem_path, self.secondary_source, self.secondary_source_secondary, self.secondary_stem, 44100)
112+
self.secondary_source_map = self.final_process(secondary_stem_output_path, self.secondary_source, self.secondary_source_secondary, self.secondary_stem, 44100)
130113
self.logger.debug("Secondary stem processed.")
131114

132115
clear_gpu_cache()

audio_separator/separator/common_separator.py

+22-21
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,38 @@ class CommonSeparator:
1515
def __init__(self, config):
1616

1717
self.logger: Logger = config.get("logger")
18+
19+
# Inferencing device / acceleration config
1820
self.torch_device = config.get("torch_device")
1921
self.onnx_execution_provider = config.get("onnx_execution_provider")
22+
23+
# Model data
2024
self.model_name = config.get("model_name")
2125
self.model_path = config.get("model_path")
2226
self.model_data = config.get("model_data")
23-
self.primary_stem_path = config.get("primary_stem_path")
24-
self.secondary_stem_path = config.get("secondary_stem_path")
25-
self.output_format = config.get("output_format")
26-
self.output_subtype = config.get("output_subtype")
27+
28+
# Optional custom output paths for the primary and secondary stems
29+
# If left as None, the arch-specific class decides the output filename, e.g. something like:
30+
# f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}"
31+
self.primary_stem_output_path = config.get("primary_stem_output_path")
32+
self.secondary_stem_output_path = config.get("secondary_stem_output_path")
33+
34+
# Output directory and format
2735
self.output_dir = config.get("output_dir")
36+
self.output_format = config.get("output_format")
37+
38+
# Functional options which are applicable to all architectures and the user may tweak to affect the output
2839
self.normalization_threshold = config.get("normalization_threshold")
2940
self.denoise_enabled = config.get("denoise_enabled")
3041
self.output_single_stem = config.get("output_single_stem")
3142
self.invert_using_spec = config.get("invert_using_spec")
3243
self.sample_rate = config.get("sample_rate")
3344

34-
# Initializing model parameters
35-
self.compensate, self.dim_f, self.dim_t, self.n_fft, self.primary_stem_name = (
36-
self.model_data["compensate"],
37-
self.model_data["mdx_dim_f_set"],
38-
2 ** self.model_data["mdx_dim_t_set"],
39-
self.model_data["mdx_n_fft_scale_set"],
40-
self.model_data["primary_stem"],
41-
)
45+
# Model specific properties
46+
self.primary_stem_name = self.model_data["primary_stem"]
4247
self.secondary_stem_name = "Vocals" if self.primary_stem_name == "Instrumental" else "Instrumental"
48+
self.is_karaoke = self.model_data.get("is_karaoke", False)
49+
self.is_bv_model = self.model_data.get("is_bv_model", False)
4350

4451
# In UVR, these variables are set but either aren't useful or are better handled in audio-separator.
4552
# Leaving these comments explaining to help myself or future developers understand why these aren't in audio-separator.
@@ -62,12 +69,6 @@ def __init__(self, config):
6269

6370
self.cached_sources_map = {}
6471

65-
def prepare_mix(self, mix):
66-
"""
67-
Placeholder method for preparing the mix. Should be overridden by subclasses.
68-
"""
69-
raise NotImplementedError("This method should be overridden by subclasses.")
70-
7172
def separate(self, audio_file_path):
7273
"""
7374
Placeholder method for separating audio sources. Should be overridden by subclasses.
@@ -79,7 +80,7 @@ def final_process(self, stem_path, source, stem_name):
7980
Finalizes the processing of a stem by writing the audio to a file and returning the processed source.
8081
"""
8182
self.logger.debug(f"Finalizing {stem_name} stem processing and writing audio...")
82-
self.write_audio(stem_path, source, stem_name=stem_name)
83+
self.write_audio(stem_path, source)
8384

8485
return {stem_name: source}
8586

@@ -126,11 +127,11 @@ def cached_model_source_holder(self, model_architecture, sources, model_name=Non
126127
"""
127128
self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}}
128129

129-
def write_audio(self, stem_path: str, stem_source, stem_name=None):
130+
def write_audio(self, stem_path: str, stem_source):
130131
"""
131132
Writes the separated audio source to a file.
132133
"""
133-
self.logger.debug(f"Entering write_audio with stem_name: {stem_name} and stem_path: {stem_path}")
134+
self.logger.debug(f"Entering write_audio with stem_path: {stem_path}")
134135

135136
stem_source = spec_utils.normalize(self.logger, wave=stem_source, max_peak=self.normalization_threshold)
136137

0 commit comments

Comments
 (0)