-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
61 lines (48 loc) · 2.05 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, BackgroundTasks
from fastapi.responses import JSONResponse
import shutil
import os
from google.oauth2 import service_account
from constants import KEY_PATH
from speech_to_text import SpeechToText
from text_to_text import TextToText
from text_to_speech import GoogleCloudTTS, ElevenLabsTTS
from logger import logger
from pydantic import BaseModel
import asyncio
app = FastAPI()
class AudioInput(BaseModel):
want_sound: bool
use_eleven_labs: bool
debug: bool
@app.post("/process-audio")
async def process_audio(input_data: AudioInput, audio_file: UploadFile = File(...), background_tasks: BackgroundTasks):
try:
# Save the audio file temporarily
temp_file = f"temp_{audio_file.filename}"
with open(temp_file, "wb") as buffer:
shutil.copyfileobj(audio_file.file, buffer)
# Process the audio file and get the response
text_response = await asyncio.to_thread(run_processing, temp_file, input_data)
# Schedule the removal of the temporary file
background_tasks.add_task(os.remove, temp_file)
# Return the response
return JSONResponse(content={"response": text_response})
except Exception as e:
logger.error(f"Error processing audio: {e}")
raise HTTPException(status_code=500, detail=str(e))
def run_processing(filename: str, input_data: AudioInput):
credentials = service_account.Credentials.from_service_account_file(KEY_PATH)
s2t = SpeechToText(credentials)
t2t = TextToText(messages=[{"role": "system", "content": "You are Vivy, the AI songstress..."}])
t2s = ElevenLabsTTS() if input_data.use_eleven_labs else GoogleCloudTTS(credentials)
if input_data.debug:
return "Checking if debug mode works or not."
print("Transcribing...")
prompt = s2t.transcribe_audio(filename)
print(prompt)
text_response = t2t.generate_response(prompt)
print(text_response)
if input_data.want_sound:
t2s.synthesize(text_response)
return text_response