1
1
"""Module for separating audio sources using VR architecture models."""
2
2
3
3
import os
4
+ import sys
5
+ import math
6
+
4
7
import torch
5
8
import librosa
6
- import onnxruntime as ort
7
9
import numpy as np
8
- import onnx2torch
10
+
9
11
from audio_separator .separator import spec_utils
10
- from audio_separator .separator .stft import STFT
11
12
from audio_separator .separator .common_separator import CommonSeparator
12
13
13
14
@@ -20,31 +21,10 @@ class VRSeparator(CommonSeparator):
20
21
def __init__ (self , common_config , arch_config ):
21
22
super ().__init__ (config = common_config )
22
23
23
- self .hop_length = arch_config .get ("hop_length" )
24
- self .segment_size = arch_config .get ("segment_size" )
25
- self .overlap = arch_config .get ("overlap" )
26
- self .batch_size = arch_config .get ("batch_size" )
24
+ self .logger .debug (f"Model data: " , self .model_data )
27
25
28
26
self .logger .debug (f"Model params: primary_stem={ self .primary_stem_name } , secondary_stem={ self .secondary_stem_name } " )
29
- self .logger .debug (f"Model params: batch_size={ self .batch_size } , compensate={ self .compensate } , segment_size={ self .segment_size } , dim_f={ self .dim_f } , dim_t={ self .dim_t } " )
30
- self .logger .debug (f"Model params: n_fft={ self .n_fft } , hop={ self .hop_length } " )
31
-
32
- # Loading the model for inference
33
- self .logger .debug ("Loading ONNX model for inference..." )
34
- if self .segment_size == self .dim_t :
35
- ort_ = ort .InferenceSession (self .model_path , providers = self .onnx_execution_provider )
36
- self .model_run = lambda spek : ort_ .run (None , {"input" : spek .cpu ().numpy ()})[0 ]
37
- self .logger .debug ("Model loaded successfully using ONNXruntime inferencing session." )
38
- else :
39
- self .model_run = onnx2torch .convert (self .model_path )
40
- self .model_run .to (self .torch_device ).eval ()
41
- self .logger .warning ("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower." )
42
27
43
- self .n_bins = None
44
- self .trim = None
45
- self .chunk_size = None
46
- self .gen_size = None
47
- self .stft = None
48
28
49
29
self .primary_source = None
50
30
self .secondary_source = None
@@ -53,49 +33,52 @@ def __init__(self, common_config, arch_config):
53
33
self .secondary_source_map = None
54
34
self .primary_source_map = None
55
35
36
+ self .is_vr_51_model = model_data .is_vr_51_model
37
+
38
+ def separate (self , audio_file_path ):
39
+ """
40
+ Separates the audio file into primary and secondary sources based on the model's configuration.
41
+ It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
56
42
43
+ Args:
44
+ audio_file_path (str): The path to the audio file to be processed.
57
45
46
+ Returns:
47
+ list: A list of paths to the output files generated by the separation process.
48
+ """
49
+ self .primary_source = None
50
+ self .secondary_source = None
58
51
52
+ self .audio_file_path = audio_file_path
53
+ self .audio_file_base = os .path .splitext (os .path .basename (audio_file_path ))[0 ]
59
54
55
+ self .logger .debug ("Starting inference..." )
60
56
61
- def seperate (self ):
62
- self .logger .debug ("Starting separation process in SeperateVR..." )
63
- if self .primary_model_name == self .model_basename and isinstance (self .primary_sources , tuple ):
64
- self .logger .debug ("Using cached primary sources..." )
65
- y_spec , v_spec = self .primary_sources
66
- self .load_cached_sources ()
57
+ nn_arch_sizes = [31191 , 33966 , 56817 , 123821 , 123812 , 129605 , 218409 , 537238 , 537227 ] # default
58
+ vr_5_1_models = [56817 , 218409 ]
59
+ model_size = math .ceil (os .stat (self .model_path ).st_size / 1024 )
60
+ nn_arch_size = min (nn_arch_sizes , key = lambda x : abs (x - model_size ))
61
+ self .logger .debug (f"Model size determined: { model_size } , NN architecture size: { nn_arch_size } " )
62
+
63
+ if nn_arch_size in vr_5_1_models or self .is_vr_51_model :
64
+ self .logger .debug ("Using CascadedNet for VR 5.1 model..." )
65
+ self .model_run = nets_new .CascadedNet (self .mp .param ["bins" ] * 2 , nn_arch_size , nout = self .model_capacity [0 ], nout_lstm = self .model_capacity [1 ])
66
+ self .is_vr_51_model = True
67
67
else :
68
- self .logger .debug ("Starting inference..." )
69
- self .start_inference_console_write ()
70
-
71
- device = self .device
72
- self .logger .debug (f"Device set to: { device } " )
73
-
74
- nn_arch_sizes = [31191 , 33966 , 56817 , 123821 , 123812 , 129605 , 218409 , 537238 , 537227 ] # default
75
- vr_5_1_models = [56817 , 218409 ]
76
- model_size = math .ceil (os .stat (self .model_path ).st_size / 1024 )
77
- nn_arch_size = min (nn_arch_sizes , key = lambda x : abs (x - model_size ))
78
- self .logger .debug (f"Model size determined: { model_size } , NN architecture size: { nn_arch_size } " )
79
-
80
- if nn_arch_size in vr_5_1_models or self .is_vr_51_model :
81
- self .logger .debug ("Using CascadedNet for VR 5.1 model..." )
82
- self .model_run = nets_new .CascadedNet (self .mp .param ["bins" ] * 2 , nn_arch_size , nout = self .model_capacity [0 ], nout_lstm = self .model_capacity [1 ])
83
- self .is_vr_51_model = True
84
- else :
85
- self .logger .debug ("Determining model capacity..." )
86
- self .model_run = nets .determine_model_capacity (self .mp .param ["bins" ] * 2 , nn_arch_size )
68
+ self .logger .debug ("Determining model capacity..." )
69
+ self .model_run = nets .determine_model_capacity (self .mp .param ["bins" ] * 2 , nn_arch_size )
87
70
88
- self .model_run .load_state_dict (torch .load (self .model_path , map_location = cpu ))
89
- self .model_run .to (device )
90
- self .logger .debug ("Model loaded and moved to device." )
71
+ self .model_run .load_state_dict (torch .load (self .model_path , map_location = cpu ))
72
+ self .model_run .to (device )
73
+ self .logger .debug ("Model loaded and moved to device." )
91
74
92
- self .running_inference_console_write ()
75
+ self .running_inference_console_write ()
93
76
94
- y_spec , v_spec = self .inference_vr (self .loading_mix (), device , self .aggressiveness )
95
- self .logger .debug ("Inference completed." )
96
- if not self .is_vocal_split_model :
97
- self .cache_source ((y_spec , v_spec ))
98
- self .write_to_console (DONE , base_text = "" )
77
+ y_spec , v_spec = self .inference_vr (self .loading_mix (), device , self .aggressiveness )
78
+ self .logger .debug ("Inference completed." )
79
+ if not self .is_vocal_split_model :
80
+ self .cache_source ((y_spec , v_spec ))
81
+ self .write_to_console (DONE , base_text = "" )
99
82
100
83
if self .is_secondary_model_activated and self .secondary_model :
101
84
self .logger .debug ("Processing secondary model..." )
@@ -104,7 +87,7 @@ def seperate(self):
104
87
)
105
88
106
89
if not self .is_secondary_stem_only :
107
- primary_stem_path = os .path .join (self .export_path , f"{ self .audio_file_base } _({ self .primary_stem } ).wav" )
90
+ primary_stem_output_path = os .path .join (self .export_path , f"{ self .audio_file_base } _({ self .primary_stem } ).wav" )
108
91
self .logger .debug (f"Processing primary stem: { self .primary_stem } " )
109
92
if not isinstance (self .primary_source , np .ndarray ):
110
93
self .primary_source = self .spec_to_wav (y_spec ).T
@@ -113,11 +96,11 @@ def seperate(self):
113
96
self .primary_source = librosa .resample (self .primary_source .T , orig_sr = self .model_samplerate , target_sr = 44100 ).T
114
97
self .logger .debug ("Resampling primary source to 44100Hz." )
115
98
116
- self .primary_source_map = self .final_process (primary_stem_path , self .primary_source , self .secondary_source_primary , self .primary_stem , 44100 )
99
+ self .primary_source_map = self .final_process (primary_stem_output_path , self .primary_source , self .secondary_source_primary , self .primary_stem , 44100 )
117
100
self .logger .debug ("Primary stem processed." )
118
101
119
102
if not self .is_primary_stem_only :
120
- secondary_stem_path = os .path .join (self .export_path , f"{ self .audio_file_base } _({ self .secondary_stem } ).wav" )
103
+ secondary_stem_output_path = os .path .join (self .export_path , f"{ self .audio_file_base } _({ self .secondary_stem } ).wav" )
121
104
self .logger .debug (f"Processing secondary stem: { self .secondary_stem } " )
122
105
if not isinstance (self .secondary_source , np .ndarray ):
123
106
self .secondary_source = self .spec_to_wav (v_spec ).T
@@ -126,7 +109,7 @@ def seperate(self):
126
109
self .secondary_source = librosa .resample (self .secondary_source .T , orig_sr = self .model_samplerate , target_sr = 44100 ).T
127
110
self .logger .debug ("Resampling secondary source to 44100Hz." )
128
111
129
- self .secondary_source_map = self .final_process (secondary_stem_path , self .secondary_source , self .secondary_source_secondary , self .secondary_stem , 44100 )
112
+ self .secondary_source_map = self .final_process (secondary_stem_output_path , self .secondary_source , self .secondary_source_secondary , self .secondary_stem , 44100 )
130
113
self .logger .debug ("Secondary stem processed." )
131
114
132
115
clear_gpu_cache ()
0 commit comments