From cff3ffbe99ce2e062831e4377f8cef587d0f38f9 Mon Sep 17 00:00:00 2001 From: Danil Boldyrev Date: Sat, 19 Oct 2024 02:40:58 +0300 Subject: [PATCH] 0.1.5 1. Fixed cli mode 2. In API you can upload your file without base64 --- pyproject.toml | 2 +- rvc_python/__main__.py | 47 +++++++++++++------ rvc_python/api.py | 101 +++++++++++++++++++++++++++++++++++------ rvc_python/infer.py | 83 ++++++++++++++++++++++++++------- 4 files changed, 188 insertions(+), 45 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 86798d6..c0d89d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "rvc-python" -version = "0.1.4" +version = "0.1.5" authors = [ { name="daswer123", email="daswerq123@gmail.com" }, ] diff --git a/rvc_python/__main__.py b/rvc_python/__main__.py index e5dabd0..dc3c431 100644 --- a/rvc_python/__main__.py +++ b/rvc_python/__main__.py @@ -22,11 +22,11 @@ def main(): api_parser = subparsers.add_parser("api", help="Start API server") api_parser.add_argument("-p", "--port", type=int, default=5050, help="Port number for the API server") api_parser.add_argument("-l", "--listen", action="store_true", help="Listen to external connections") + api_parser.add_argument("-md", "--models_dir", type=str, default="rvc_models", help="Directory to store models") api_parser.add_argument("-pm", "--preload-model", type=str, help="Preload model on startup (optional)") # Common arguments for both CLI and API for subparser in [cli_parser, api_parser]: - subparser.add_argument("-md", "--models_dir", type=str, default="rvc_models", help="Directory to store models") subparser.add_argument("-ip", "--index", type=str, default="", help="Path to index file (optional)") subparser.add_argument("-de", "--device", type=str, default="cpu:0", help="Device to use (e.g., cpu:0, cuda:0)") subparser.add_argument("-me", "--method", type=str, default="rmvpe", choices=['harvest', "crepe", "rmvpe", 'pm'], help="Pitch extraction algorithm") @@ -40,21 +40,25 @@ def main(): args = parser.parse_args() - # Initialize RVCInference - rvc = RVCInference(models_dir=args.models_dir, device=args.device) - rvc.set_params( - f0method=args.method, - f0up_key=args.pitch, - index_rate=args.index_rate, - filter_radius=args.filter_radius, - resample_sr=args.resample_sr, - rms_mix_rate=args.rms_mix_rate, - protect=args.protect - ) - # Handle CLI command if args.command == "cli": - rvc.load_model(args.model) + # Initialize RVCInference with model path + rvc = RVCInference( + device=args.device, + model_path=args.model, + index_path=args.index, + version=args.version + ) + rvc.set_params( + f0method=args.method, + f0up_key=args.pitch, + index_rate=args.index_rate, + filter_radius=args.filter_radius, + resample_sr=args.resample_sr, + rms_mix_rate=args.rms_mix_rate, + protect=args.protect + ) + if args.input: # Process single file rvc.infer_file(args.input, args.output) @@ -69,6 +73,21 @@ def main(): # Handle API command elif args.command == "api": + # Initialize RVCInference without a model (will be loaded on demand) + rvc = RVCInference( + models_dir=args.models_dir, + device=args.device + ) + rvc.set_params( + f0method=args.method, + f0up_key=args.pitch, + index_rate=args.index_rate, + filter_radius=args.filter_radius, + resample_sr=args.resample_sr, + rms_mix_rate=args.rms_mix_rate, + protect=args.protect + ) + # Create and configure FastAPI app app = create_app() app.state.rvc = rvc diff --git a/rvc_python/api.py b/rvc_python/api.py index 60bd4c0..15e6ad9 100644 --- a/rvc_python/api.py +++ b/rvc_python/api.py @@ -1,3 +1,5 @@ +# api.py + from fastapi import FastAPI, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response, JSONResponse @@ -13,7 +15,7 @@ class SetDeviceRequest(BaseModel): device: str class ConvertAudioRequest(BaseModel): - audio_data: str + audio_data: str # Base64 encoded audio data class SetParamsRequest(BaseModel): params: dict @@ -23,38 +25,91 @@ class SetModelsDirRequest(BaseModel): def setup_routes(app: FastAPI): @app.post("/convert") - def rvc_convert(request: ConvertAudioRequest): + async def rvc_convert(request: ConvertAudioRequest): + """ + Converts audio data using the currently loaded model. + Accepts a base64 encoded audio data in WAV format. + Returns the converted audio as WAV data. + """ if not app.state.rvc.current_model: raise HTTPException(status_code=400, detail="No model loaded. Please load a model first.") - tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") - tmp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + with tempfile.NamedTemporaryFile(delete=False) as tmp_input: + input_path = tmp_input.name + try: + logger.info("Received request to convert audio") + audio_data = base64.b64decode(request.audio_data) + tmp_input.write(audio_data) + except Exception as e: + logger.error(f"Error decoding audio data: {e}") + raise HTTPException(status_code=400, detail="Invalid audio data") + + with tempfile.NamedTemporaryFile(delete=False) as tmp_output: + output_path = tmp_output.name + try: - logger.info("Received request to convert audio") - audio_data = base64.b64decode(request.audio_data) - tmp_input.write(audio_data) + app.state.rvc.infer_file(input_path, output_path) + + with open(output_path, "rb") as f: + output_data = f.read() + return Response(content=output_data, media_type="audio/wav") + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") + finally: + os.unlink(input_path) + os.unlink(output_path) + + @app.post("/convert_file") + async def rvc_convert_file(file: UploadFile = File(...)): + """ + Converts an uploaded audio file using the currently loaded model. + Accepts an audio file in WAV format. + Returns the converted audio as WAV data. + """ + if not app.state.rvc.current_model: + raise HTTPException(status_code=400, detail="No model loaded. Please load a model first.") + + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_input: input_path = tmp_input.name + try: + logger.info("Received file to convert") + contents = await file.read() + tmp_input.write(contents) + except Exception as e: + logger.error(f"Error reading uploaded file: {e}") + raise HTTPException(status_code=400, detail="Invalid file") + + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_output: output_path = tmp_output.name + try: app.state.rvc.infer_file(input_path, output_path) - output_data = tmp_output.read() + with open(output_path, "rb") as f: + output_data = f.read() return Response(content=output_data, media_type="audio/wav") except Exception as e: logger.error(e) raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") finally: - tmp_input.close() - tmp_output.close() - os.unlink(tmp_input.name) - os.unlink(tmp_output.name) + os.unlink(input_path) + os.unlink(output_path) @app.get("/models") def list_models(): + """ + Lists available models. + Returns a JSON response with the list of model names. + """ return JSONResponse(content={"models": app.state.rvc.list_models()}) @app.post("/models/{model_name}") def load_model(model_name: str): + """ + Loads a model by name. + The model must be available in the models directory. + """ try: app.state.rvc.load_model(model_name) return JSONResponse(content={"message": f"Model {model_name} loaded successfully"}) @@ -63,6 +118,10 @@ def load_model(model_name: str): @app.get("/params") def get_params(): + """ + Retrieves current parameters used for inference. + Returns a JSON response with the parameters. + """ return JSONResponse(content={ "f0method": app.state.rvc.f0method, "f0up_key": app.state.rvc.f0up_key, @@ -75,6 +134,10 @@ def get_params(): @app.post("/params") def set_params(request: SetParamsRequest): + """ + Sets parameters for inference. + Accepts a JSON object with parameter names and values. + """ try: app.state.rvc.set_params(**request.params) return JSONResponse(content={"message": "Parameters updated successfully"}) @@ -83,9 +146,14 @@ def set_params(request: SetParamsRequest): @app.post("/upload_model") async def upload_models(file: UploadFile = File(...)): + """ + Uploads and extracts a ZIP file containing models. + The models are extracted to the models directory. + """ try: with tempfile.NamedTemporaryFile(delete=False) as tmp_file: - shutil.copyfileobj(file.file, tmp_file) + contents = await file.read() + tmp_file.write(contents) with zipfile.ZipFile(tmp_file.name, 'r') as zip_ref: zip_ref.extractall(app.state.rvc.models_dir) @@ -101,6 +169,9 @@ async def upload_models(file: UploadFile = File(...)): @app.post("/set_device") def set_device(request: SetDeviceRequest): + """ + Sets the device for inference (e.g., 'cpu:0' or 'cuda:0'). + """ try: device = request.device app.state.rvc.set_device(device) @@ -110,6 +181,10 @@ def set_device(request: SetDeviceRequest): @app.post("/set_models_dir") def set_models_dir(request: SetModelsDirRequest): + """ + Sets a new directory for models. + The directory must exist and contain valid models. + """ try: new_models_dir = request.models_dir app.state.rvc.set_models_dir(new_models_dir) diff --git a/rvc_python/infer.py b/rvc_python/infer.py index 3491303..78a3aa0 100644 --- a/rvc_python/infer.py +++ b/rvc_python/infer.py @@ -1,3 +1,5 @@ +# infer.py + import os from glob import glob import soundfile as sf @@ -7,14 +9,14 @@ from rvc_python.download_model import download_rvc_models class RVCInference: - def __init__(self, models_dir="rvc_models", device="cpu:0"): + def __init__(self, models_dir="rvc_models", device="cpu:0", model_path=None, index_path="", version="v2"): self.models_dir = models_dir self.device = device self.lib_dir = os.path.dirname(os.path.abspath(__file__)) self.config = Config(self.lib_dir, self.device) self.vc = VC(self.lib_dir, self.config) self.current_model = None - self.models = self._load_available_models() + self.models = {} # Default parameters self.f0method = "harvest" @@ -25,9 +27,16 @@ def __init__(self, models_dir="rvc_models", device="cpu:0"): self.rms_mix_rate = 1 self.protect = 0.33 - # Download Models + # Download Models (if necessary) download_rvc_models(self.lib_dir) + # Load available models + self.models = self._load_available_models() + + # Load model if model_path is provided + if model_path: + self.load_model(model_path, version=version, index_path=index_path) + def _load_available_models(self): """Loads a list of available models from the directory.""" models = {} @@ -44,21 +53,42 @@ def _load_available_models(self): return models def set_models_dir(self, new_models_dir): + """Sets a new directory for models and reloads available models.""" if not os.path.isdir(new_models_dir): raise ValueError(f"Directory {new_models_dir} does not exist") self.models_dir = new_models_dir self.models = self._load_available_models() - + def list_models(self): """Returns a list of available models.""" return list(self.models.keys()) - def load_model(self, model_name, version="v2"): - """Loads a model into memory.""" - if model_name not in self.models: - raise ValueError(f"Model {model_name} not found.") + def load_model(self, model_path_or_name, version="v2", index_path=""): + """Loads a model into memory. + + Args: + model_path_or_name (str): Path to the model file or model name if in models_dir. + version (str): Version of the model ('v1' or 'v2'). + index_path (str): Path to the index file (optional). + """ + # If model_path_or_name is a name in self.models, load from models_dir + if model_path_or_name in self.models: + model_info = self.models[model_path_or_name] + model_path = model_info["pth"] + index_path = model_info.get("index", "") + model_name = model_path_or_name + else: + # Else, assume it's a direct path + model_path = model_path_or_name + model_name = os.path.basename(model_path) + if index_path and not os.path.isfile(index_path): + raise ValueError(f"Index file {index_path} not found.") + # Update models dict + self.models[model_name] = {"pth": model_path, "index": index_path} + + if not os.path.isfile(model_path): + raise ValueError(f"Model file {model_path} not found.") - model_path = self.models[model_name]["pth"] self.vc.get_vc(model_path, version) self.current_model = model_name print(f"Model {model_name} loaded.") @@ -85,16 +115,24 @@ def set_params(self, **kwargs): print(f"Warning: parameter {key} not recognized and will be ignored.") def infer_file(self, input_path, output_path): - """Processes a single file.""" + """Processes a single file. + + Args: + input_path (str): Path to the input audio file. + output_path (str): Path to save the output audio file. + """ if not self.current_model: raise ValueError("Please load a model first.") + model_info = self.models[self.current_model] + file_index = model_info.get("index", "") + wav_opt = self.vc.vc_single( sid=0, input_audio_path=input_path, f0_up_key=self.f0up_key, f0_method=self.f0method, - file_index=self.models[self.current_model].get("index", ""), + file_index=file_index, index_rate=self.index_rate, filter_radius=self.filter_radius, resample_sr=self.resample_sr, @@ -108,7 +146,12 @@ def infer_file(self, input_path, output_path): return output_path def infer_dir(self, input_dir, output_dir): - """Processes all files in a directory.""" + """Processes all files in a directory. + + Args: + input_dir (str): Path to the input directory containing audio files. + output_dir (str): Path to the output directory to save processed files. + """ if not self.current_model: raise ValueError("Please load a model first.") @@ -125,17 +168,23 @@ def infer_dir(self, input_dir, output_dir): return processed_files def set_device(self, device): - """Sets the device for computations.""" + """Sets the device for computations. + + Args: + device (str): Device identifier (e.g., 'cpu:0', 'cuda:0'). + """ self.device = device self.config.device = device self.vc.device = device # Usage example: if __name__ == "__main__": - rvc = RVCInference(device="cuda:0") - print("Available models:", rvc.list_models()) - - rvc.load_model("example_model") + rvc = RVCInference( + device="cuda:0", + model_path="path/to/model.pth", + index_path="path/to/index.index", + version="v2" + ) rvc.set_params(f0up_key=2, protect=0.5) rvc.infer_file("input.wav", "output.wav")