Skip to content

Commit

Permalink
0.1.5
Browse files Browse the repository at this point in the history
1. Fixed cli mode
2. In API you can upload your file without base64
  • Loading branch information
daswer123 committed Oct 18, 2024
1 parent ff061a0 commit cff3ffb
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 45 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "rvc-python"
version = "0.1.4"
version = "0.1.5"
authors = [
{ name="daswer123", email="daswerq123@gmail.com" },
]
Expand Down
47 changes: 33 additions & 14 deletions rvc_python/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def main():
api_parser = subparsers.add_parser("api", help="Start API server")
api_parser.add_argument("-p", "--port", type=int, default=5050, help="Port number for the API server")
api_parser.add_argument("-l", "--listen", action="store_true", help="Listen to external connections")
api_parser.add_argument("-md", "--models_dir", type=str, default="rvc_models", help="Directory to store models")
api_parser.add_argument("-pm", "--preload-model", type=str, help="Preload model on startup (optional)")

# Common arguments for both CLI and API
for subparser in [cli_parser, api_parser]:
subparser.add_argument("-md", "--models_dir", type=str, default="rvc_models", help="Directory to store models")
subparser.add_argument("-ip", "--index", type=str, default="", help="Path to index file (optional)")
subparser.add_argument("-de", "--device", type=str, default="cpu:0", help="Device to use (e.g., cpu:0, cuda:0)")
subparser.add_argument("-me", "--method", type=str, default="rmvpe", choices=['harvest', "crepe", "rmvpe", 'pm'], help="Pitch extraction algorithm")
Expand All @@ -40,21 +40,25 @@ def main():

args = parser.parse_args()

# Initialize RVCInference
rvc = RVCInference(models_dir=args.models_dir, device=args.device)
rvc.set_params(
f0method=args.method,
f0up_key=args.pitch,
index_rate=args.index_rate,
filter_radius=args.filter_radius,
resample_sr=args.resample_sr,
rms_mix_rate=args.rms_mix_rate,
protect=args.protect
)

# Handle CLI command
if args.command == "cli":
rvc.load_model(args.model)
# Initialize RVCInference with model path
rvc = RVCInference(
device=args.device,
model_path=args.model,
index_path=args.index,
version=args.version
)
rvc.set_params(
f0method=args.method,
f0up_key=args.pitch,
index_rate=args.index_rate,
filter_radius=args.filter_radius,
resample_sr=args.resample_sr,
rms_mix_rate=args.rms_mix_rate,
protect=args.protect
)

if args.input:
# Process single file
rvc.infer_file(args.input, args.output)
Expand All @@ -69,6 +73,21 @@ def main():

# Handle API command
elif args.command == "api":
# Initialize RVCInference without a model (will be loaded on demand)
rvc = RVCInference(
models_dir=args.models_dir,
device=args.device
)
rvc.set_params(
f0method=args.method,
f0up_key=args.pitch,
index_rate=args.index_rate,
filter_radius=args.filter_radius,
resample_sr=args.resample_sr,
rms_mix_rate=args.rms_mix_rate,
protect=args.protect
)

# Create and configure FastAPI app
app = create_app()
app.state.rvc = rvc
Expand Down
101 changes: 88 additions & 13 deletions rvc_python/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# api.py

from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, JSONResponse
Expand All @@ -13,7 +15,7 @@ class SetDeviceRequest(BaseModel):
device: str

class ConvertAudioRequest(BaseModel):
audio_data: str
audio_data: str # Base64 encoded audio data

class SetParamsRequest(BaseModel):
params: dict
Expand All @@ -23,38 +25,91 @@ class SetModelsDirRequest(BaseModel):

def setup_routes(app: FastAPI):
@app.post("/convert")
def rvc_convert(request: ConvertAudioRequest):
async def rvc_convert(request: ConvertAudioRequest):
"""
Converts audio data using the currently loaded model.
Accepts a base64 encoded audio data in WAV format.
Returns the converted audio as WAV data.
"""
if not app.state.rvc.current_model:
raise HTTPException(status_code=400, detail="No model loaded. Please load a model first.")

tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
tmp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
with tempfile.NamedTemporaryFile(delete=False) as tmp_input:
input_path = tmp_input.name
try:
logger.info("Received request to convert audio")
audio_data = base64.b64decode(request.audio_data)
tmp_input.write(audio_data)
except Exception as e:
logger.error(f"Error decoding audio data: {e}")
raise HTTPException(status_code=400, detail="Invalid audio data")

with tempfile.NamedTemporaryFile(delete=False) as tmp_output:
output_path = tmp_output.name

try:
logger.info("Received request to convert audio")
audio_data = base64.b64decode(request.audio_data)
tmp_input.write(audio_data)
app.state.rvc.infer_file(input_path, output_path)

with open(output_path, "rb") as f:
output_data = f.read()
return Response(content=output_data, media_type="audio/wav")
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
finally:
os.unlink(input_path)
os.unlink(output_path)

@app.post("/convert_file")
async def rvc_convert_file(file: UploadFile = File(...)):
"""
Converts an uploaded audio file using the currently loaded model.
Accepts an audio file in WAV format.
Returns the converted audio as WAV data.
"""
if not app.state.rvc.current_model:
raise HTTPException(status_code=400, detail="No model loaded. Please load a model first.")

with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_input:
input_path = tmp_input.name
try:
logger.info("Received file to convert")
contents = await file.read()
tmp_input.write(contents)
except Exception as e:
logger.error(f"Error reading uploaded file: {e}")
raise HTTPException(status_code=400, detail="Invalid file")

with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_output:
output_path = tmp_output.name

try:
app.state.rvc.infer_file(input_path, output_path)

output_data = tmp_output.read()
with open(output_path, "rb") as f:
output_data = f.read()
return Response(content=output_data, media_type="audio/wav")
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
finally:
tmp_input.close()
tmp_output.close()
os.unlink(tmp_input.name)
os.unlink(tmp_output.name)
os.unlink(input_path)
os.unlink(output_path)

@app.get("/models")
def list_models():
"""
Lists available models.
Returns a JSON response with the list of model names.
"""
return JSONResponse(content={"models": app.state.rvc.list_models()})

@app.post("/models/{model_name}")
def load_model(model_name: str):
"""
Loads a model by name.
The model must be available in the models directory.
"""
try:
app.state.rvc.load_model(model_name)
return JSONResponse(content={"message": f"Model {model_name} loaded successfully"})
Expand All @@ -63,6 +118,10 @@ def load_model(model_name: str):

@app.get("/params")
def get_params():
"""
Retrieves current parameters used for inference.
Returns a JSON response with the parameters.
"""
return JSONResponse(content={
"f0method": app.state.rvc.f0method,
"f0up_key": app.state.rvc.f0up_key,
Expand All @@ -75,6 +134,10 @@ def get_params():

@app.post("/params")
def set_params(request: SetParamsRequest):
"""
Sets parameters for inference.
Accepts a JSON object with parameter names and values.
"""
try:
app.state.rvc.set_params(**request.params)
return JSONResponse(content={"message": "Parameters updated successfully"})
Expand All @@ -83,9 +146,14 @@ def set_params(request: SetParamsRequest):

@app.post("/upload_model")
async def upload_models(file: UploadFile = File(...)):
"""
Uploads and extracts a ZIP file containing models.
The models are extracted to the models directory.
"""
try:
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
shutil.copyfileobj(file.file, tmp_file)
contents = await file.read()
tmp_file.write(contents)

with zipfile.ZipFile(tmp_file.name, 'r') as zip_ref:
zip_ref.extractall(app.state.rvc.models_dir)
Expand All @@ -101,6 +169,9 @@ async def upload_models(file: UploadFile = File(...)):

@app.post("/set_device")
def set_device(request: SetDeviceRequest):
"""
Sets the device for inference (e.g., 'cpu:0' or 'cuda:0').
"""
try:
device = request.device
app.state.rvc.set_device(device)
Expand All @@ -110,6 +181,10 @@ def set_device(request: SetDeviceRequest):

@app.post("/set_models_dir")
def set_models_dir(request: SetModelsDirRequest):
"""
Sets a new directory for models.
The directory must exist and contain valid models.
"""
try:
new_models_dir = request.models_dir
app.state.rvc.set_models_dir(new_models_dir)
Expand Down
Loading

0 comments on commit cff3ffb

Please sign in to comment.