Skip to content

Commit

Permalink
Merge pull request #22 from Sunbird-ALL/all-1.1-dev-fastapi
Browse files Browse the repository at this point in the history
Pushing Error handling changes to staging from dev
  • Loading branch information
sudeeppr1998 authored May 31, 2024
2 parents 3607d01 + c245659 commit 9370ec2
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 72 deletions.
182 changes: 129 additions & 53 deletions routes.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,148 @@
import base64
import io
import logging
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from utils import denoise_with_rnnoise, get_error_arrays, get_pause_count, split_into_phonemes, processLP
from schemas import TextData,audioData,PhonemesRequest, PhonemesResponse, ErrorArraysResponse
from utils import denoise_with_rnnoise, get_error_arrays, get_pause_count, split_into_phonemes, processLP
from schemas import TextData, audioData, PhonemesRequest, PhonemesResponse, ErrorArraysResponse
from typing import List
import jiwer
import eng_to_ipa as p

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

router = APIRouter()

@router.post('/getTextMatrices')
async def compute_errors(data: TextData):
reference = data.reference
hypothesis = data.hypothesis
language = data.language

charOut = jiwer.process_characters(reference, hypothesis)
wer = jiwer.wer(reference, hypothesis)

confidence_char_list =[]
missing_char_list =[]
construct_text=""

if language == "en":
confidence_char_list, missing_char_list,construct_text = processLP(reference,hypothesis)

# Extract error arrays
error_arrays = get_error_arrays(
charOut.alignments, reference, hypothesis)

return {
"wer": wer,
"cer": charOut.cer,
"insertion": error_arrays['insertion'],
"insertion_count": len(error_arrays['insertion']),
"deletion": error_arrays['deletion'],
"deletion_count": len(error_arrays['deletion']),
"substitution": error_arrays['substitution'],
"substitution_count": len(error_arrays['substitution']),
"confidence_char_list":confidence_char_list,
"missing_char_list":missing_char_list,
"construct_text":construct_text
}
try:
# Validate input data
if not data.reference or not data.hypothesis:
raise HTTPException(status_code=400, detail="Reference and hypothesis texts must be provided.")

reference = data.reference
hypothesis = data.hypothesis
language = data.language

# Validate language
allowed_languages = {"en", "ta", "te", "kn", "hi"}
if language not in allowed_languages:
raise HTTPException(status_code=400, detail=f"Unsupported language: {language}. Supported languages are: {', '.join(allowed_languages)}")

# Process character-level differences
try:
charOut = jiwer.process_characters(reference, hypothesis)
except Exception as e:
logger.error(f"Error processing characters: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing characters: {str(e)}")

# Compute WER
try:
wer = jiwer.wer(reference, hypothesis)
except Exception as e:
logger.error(f"Error computing WER: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error computing WER: {str(e)}")

confidence_char_list = []
missing_char_list = []
construct_text = ""

if language == "en":
try:
confidence_char_list, missing_char_list, construct_text = processLP(reference, hypothesis)
except Exception as e:
logger.error(f"Error processing LP: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing LP: {str(e)}")

# Extract error arrays
try:
error_arrays = get_error_arrays(charOut.alignments, reference, hypothesis)
except Exception as e:
logger.error(f"Error extracting error arrays: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error extracting error arrays: {str(e)}")

return {
"wer": wer,
"cer": charOut.cer,
"insertion": error_arrays['insertion'],
"insertion_count": len(error_arrays['insertion']),
"deletion": error_arrays['deletion'],
"deletion_count": len(error_arrays['deletion']),
"substitution": error_arrays['substitution'],
"substitution_count": len(error_arrays['substitution']),
"confidence_char_list": confidence_char_list,
"missing_char_list": missing_char_list,
"construct_text": construct_text
}
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")

@router.post("/getPhonemes", response_model=dict)
async def get_phonemes(data: PhonemesRequest):
phonemesList = split_into_phonemes(p.convert(data.text))
return {"phonemes": phonemesList}

try:
phonemesList = split_into_phonemes(p.convert(data.text))
return {"phonemes": phonemesList}
except Exception as e:
logger.error(f"Error getting phonemes: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error getting phonemes: {str(e)}")

@router.post('/audio_processing')
async def audio_processing(data: audioData):
audio_data = data.base64_string
audio_bytes = base64.b64decode(audio_data)
audio_io = io.BytesIO(audio_bytes)

pause_count = 0
denoised_audio_base64 = ""

if data.enablePauseCount:
pause_count = get_pause_count(audio_io)
if data.enableDenoiser:
denoised_audio_base64 = denoise_with_rnnoise(audio_data, data.contentType)
if denoised_audio_base64 is None:
raise HTTPException(status_code=500, detail="Error during audio denoising")
return {
"denoised_audio_base64": denoised_audio_base64,
"pause_count": pause_count
}
try:
# Validate input data
if not data.base64_string:
raise HTTPException(status_code=400, detail="Base64 string of audio must be provided.")
if not data.contentType:
raise HTTPException(status_code=400, detail="Content type must be specified.")

try:
audio_data = data.base64_string
audio_bytes = base64.b64decode(audio_data)
audio_io = io.BytesIO(audio_bytes)
except Exception as e:
logger.error(f"Invalid base64 string: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid base64 string: {str(e)}")

pause_count = 0
denoised_audio_base64 = ""

if data.enablePauseCount:
try:
pause_count = get_pause_count(audio_io)
if pause_count is None:
logger.error("Error during pause count detection")
raise HTTPException(status_code=500, detail="Error during pause count detection")
except Exception as e:
logger.error(f"Error during pause count detection: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error during pause count detection: {str(e)}")

if data.enableDenoiser:
try:
denoised_audio_base64 = denoise_with_rnnoise(audio_data, data.contentType)
if denoised_audio_base64 is None:
logger.error("Error during audio denoising")
raise HTTPException(status_code=500, detail="Error during audio denoising")
except ValueError as e:
logger.error(f"Value error in denoise_with_rnnoise: {str(e)}")
raise HTTPException(status_code=400, detail=f"Value error in denoise_with_rnnoise: {str(e)}")
except RuntimeError as e:
logger.error(f"Runtime error in denoise_with_rnnoise: {str(e)}")
raise HTTPException(status_code=500, detail=f"Runtime error in denoise_with_rnnoise: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in denoise_with_rnnoise: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unexpected error in denoise_with_rnnoise: {str(e)}")

return {
"denoised_audio_base64": denoised_audio_base64,
"pause_count": pause_count
}
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
66 changes: 47 additions & 19 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
def denoise_with_rnnoise(audio_base64, content_type, padding_duration=0.1, time_stretch_factor=0.75):
try:
# Decode base64 to get the audio data
audio_data = base64.b64decode(audio_base64)
try:
audio_data = base64.b64decode(audio_base64)
except base64.binascii.Error as e:
raise ValueError(f"Invalid base64 string: {str(e)}")

audio_io = io.BytesIO(audio_data)
input_audio = audio_io.read()

Expand All @@ -28,40 +32,64 @@ def denoise_with_rnnoise(audio_base64, content_type, padding_duration=0.1, time_

# Create the ffmpeg filter chain
filter_chain = []
if content_type == 'Word' or content_type == 'word':
if content_type.lower() == 'word':
filter_chain.append(f'apad=pad_dur={padding_duration}')
filter_chain.append(f'apad=pad_dur={padding_duration}')
filter_chain.append(f'atempo={time_stretch_factor}')
filter_chain_str = ','.join(filter_chain)

# Apply the filters and denoise
output, _ = (
ffmpeg
.input('pipe:')
.output('pipe:', format='wav', af=f'{filter_chain_str},arnndn=m={model_path}')
.run(input=input_audio, capture_stdout=True, capture_stderr=True)
)
try:
output, _ = (
ffmpeg
.input('pipe:', format='wav')
.output('pipe:', format='wav', af=f'{filter_chain_str},arnndn=m={model_path}')
.run(input=input_audio, capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Error during noise reduction with FFmpeg: {e.stderr.decode()}")

# Convert the processed output back to base64
denoised_audio_base64 = base64.b64encode(output).decode('utf-8')
try:
denoised_audio_base64 = base64.b64encode(output).decode('utf-8')
except Exception as e:
raise RuntimeError(f"Error encoding output to base64: {str(e)}")

# Clear cache to free memory
del audio_data
del audio_io

return denoised_audio_base64

except ffmpeg.Error as e:
print(f"Error during noise reduction: {e.stderr.decode()}")
return None


except ValueError as e:
print(f"Value error in denoise_with_rnnoise: {str(e)}")
raise
except RuntimeError as e:
print(f"Runtime error in denoise_with_rnnoise: {str(e)}")
raise
except Exception as e:
print(f"Unexpected error in denoise_with_rnnoise: {str(e)}")
raise

def convert_to_base64(audio_data, sample_rate):
buffer = io.BytesIO()
sf.write(buffer, audio_data, sample_rate, format='wav')
buffer.seek(0)
base64_audio = base64.b64encode(buffer.read()).decode('utf-8')
return base64_audio
try:
buffer = io.BytesIO()
try:
sf.write(buffer, audio_data, sample_rate, format='wav')
except Exception as e:
raise RuntimeError(f"Error writing audio data to buffer: {str(e)}")

buffer.seek(0)
try:
base64_audio = base64.b64encode(buffer.read()).decode('utf-8')
except Exception as e:
raise RuntimeError(f"Error encoding buffer to base64: {str(e)}")

return base64_audio
except Exception as e:
print(f"Error in convert_to_base64: {str(e)}")
return {"error": str(e)}

def get_error_arrays(alignments, reference, hypothesis):
insertion = []
deletion = []
Expand Down

0 comments on commit 9370ec2

Please sign in to comment.