Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/advanced/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ dependencies = [
"redis>=5.0.0",
"rq>=1.16.0",
"soundfile>=0.12.1",
"google-api-python-client>=2.0.0",
"google-auth-oauthlib>=1.0.0",
"google-auth-httplib2>=0.2.0",
"websockets>=12.0",
]

Expand Down
3 changes: 3 additions & 0 deletions backends/advanced/src/advanced_omi_backend/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __init__(self):
# Memory service configuration
self.memory_service_supports_threshold = self.memory_provider == "friend_lite"

self.gdrive_credentials_path = "data/gdrive_service_account.json"
self.gdrive_scopes = ["https://www.googleapis.com/auth/drive.readonly"]


# Global configuration instance
app_config = AppConfig()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
from google.oauth2.service_account import Credentials
from googleapiclient.discovery import build
from advanced_omi_backend.app_config import get_app_config

_drive_client_cache = None

def get_google_drive_client():
"""Singleton Google Drive client."""
global _drive_client_cache

if _drive_client_cache:
return _drive_client_cache

config = get_app_config()

if not os.path.exists(config.gdrive_credentials_path):
raise FileNotFoundError(
f"Missing Google Drive credentials at {config.gdrive_credentials_path}"
)

creds = Credentials.from_service_account_file(
config.gdrive_credentials_path,
scopes=config.gdrive_scopes
)

_drive_client_cache = build("drive", "v3", credentials=creds)

return _drive_client_cache
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ async def upload_and_process_audio_files(
device_name: str = "upload",
auto_generate_client: bool = True,
folder: str = None,
source: str = "upload"
) -> dict:
"""
Upload audio files and process them directly.
Expand Down Expand Up @@ -81,8 +82,15 @@ async def upload_and_process_audio_files(
# Read file content
content = await file.read()


# Generate audio UUID and timestamp
audio_uuid = str(uuid.uuid4())
if source == "gdrive":
audio_uuid = getattr(file, "audio_uuid", None)
if not audio_uuid:
audio_logger.error(f"Missing audio_uuid for gdrive file: {file.filename}")
audio_uuid = str(uuid.uuid4())
else:
audio_uuid = str(uuid.uuid4())
timestamp = int(time.time() * 1000)

# Determine output directory (with optional subfolder)
Expand All @@ -98,12 +106,13 @@ async def upload_and_process_audio_files(
relative_audio_path, file_path, duration = await write_audio_file(
raw_audio_data=content,
audio_uuid=audio_uuid,
source=source,
client_id=client_id,
user_id=user.user_id,
user_email=user.email,
timestamp=timestamp,
chunk_dir=chunk_dir,
validate=True # Validate WAV format, convert stereo→mono
validate=True, # Validate WAV format, convert stereo→mono
)
except AudioValidationError as e:
processed_files.append({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class AudioFile(Document):

# Core identifiers
audio_uuid: Indexed(str, unique=True) = Field(description="Unique audio identifier")
source: Indexed(str) = Field(
default="upload",
description="Source of the audio (upload, gdrive, etc.)"
)
audio_path: str = Field(description="Path to raw audio file")
client_id: Indexed(str) = Field(description="Client device identifier")
timestamp: Indexed(int) = Field(description="Unix timestamp in milliseconds")
Expand All @@ -51,11 +55,13 @@ class AudioFile(Document):
description="Speech detection results"
)



class Settings:
name = "audio_chunks"
indexes = [
"audio_uuid",
"client_id",
"user_id",
"timestamp"
"timestamp",
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,29 @@
from advanced_omi_backend.auth import current_superuser, current_active_user_optional, get_user_from_token_param
from advanced_omi_backend.controllers import audio_controller
from advanced_omi_backend.models.user import User
from advanced_omi_backend.app_config import get_audio_chunk_dir
from advanced_omi_backend.utils.gdrive_audio_utils import download_audio_files_from_drive, AudioValidationError

router = APIRouter(prefix="/audio", tags=["audio"])


@router.post("/upload_audio_from_gdrive")
async def upload_audio_from_drive_folder(
gdrive_folder_id: str = Query(..., description="Google Drive Folder ID containing audio files (e.g., the string after /folders/ in the URL)"),
current_user: User = Depends(current_superuser),
device_name: str = Query(default="upload"),
auto_generate_client: bool = Query(default=True),
):
try:
files = await download_audio_files_from_drive(gdrive_folder_id)
except AudioValidationError as e:
raise HTTPException(status_code=400, detail=str(e))

return await audio_controller.upload_and_process_audio_files(
current_user, files, device_name, auto_generate_client, source="gdrive"
)


@router.get("/get_audio/{conversation_id}")
async def get_conversation_audio(
conversation_id: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ async def validate_and_prepare_audio(
async def write_audio_file(
raw_audio_data: bytes,
audio_uuid: str,
source: str,
client_id: str,
user_id: str,
user_email: str,
timestamp: int,
chunk_dir: Optional[Path] = None,
validate: bool = True
validate: bool = True,
) -> tuple[str, str, float]:
"""
Validate, write audio data to WAV file, and create AudioSession database entry.
Expand Down Expand Up @@ -197,13 +198,14 @@ async def write_audio_file(
# Create AudioFile database entry using Beanie model
audio_file = AudioFile(
audio_uuid=audio_uuid,
source=source,
audio_path=wav_filename,
client_id=client_id,
timestamp=timestamp,
user_id=user_id,
user_email=user_email,
has_speech=False, # Will be updated by transcription
speech_analysis={}
speech_analysis={},
)
await audio_file.insert()

Expand Down
119 changes: 119 additions & 0 deletions backends/advanced/src/advanced_omi_backend/utils/gdrive_audio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import io
import tempfile
from typing import List
import logging
from starlette.datastructures import UploadFile as StarletteUploadFile
from googleapiclient.http import MediaIoBaseDownload
from advanced_omi_backend.clients.gdrive_audio_client import get_google_drive_client
from advanced_omi_backend.models.audio_file import AudioFile
from advanced_omi_backend.utils.audio_utils import AudioValidationError


logger = logging.getLogger(__name__)
audio_logger = logging.getLogger("audio_processing")

AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac", ".ogg", ".m4a")
FOLDER_MIMETYPE = "application/vnd.google-apps.folder"



async def download_and_wrap_drive_file(service, file_item):
file_id = file_item["id"]
name = file_item["name"]

request = service.files().get_media(fileId=file_id)

fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)

done = False
while not done:
_status, done = downloader.next_chunk()

content = fh.getvalue()

if not content:
raise AudioValidationError(f"Downloaded Google Drive file '{name}' was empty")

tmp_file = tempfile.SpooledTemporaryFile(max_size=10*1024*1024) # 10 MB
tmp_file.write(content)
tmp_file.seek(0)
upload_file = StarletteUploadFile(filename=name, file=tmp_file)

original_close = upload_file.close

def wrapped_close():
try:
original_close()
finally:
# SpooledTemporaryFile auto-cleans when closed; no unlink needed
pass

upload_file.close = wrapped_close

return upload_file

# -------------------------------------------------------------
# LIST + DOWNLOAD FILES IN FOLDER (OAUTH)
# -------------------------------------------------------------
async def download_audio_files_from_drive(folder_id: str) -> List[StarletteUploadFile]:
if not folder_id:
raise AudioValidationError("Google Drive folder ID is required.")

service = get_google_drive_client()

try:
escaped_folder_id = folder_id.replace("\\", "\\\\").replace("'", "\\'")
query = f"'{escaped_folder_id}' in parents and trashed = false"

response = service.files().list(
q=query,
fields="files(id, name, mimeType)",
includeItemsFromAllDrives=False,
supportsAllDrives=False,
).execute()

all_files = response.get("files", [])

audio_files_metadata = [
f for f in all_files
if f["name"].lower().endswith(AUDIO_EXTENSIONS)
]

if not audio_files_metadata:
raise AudioValidationError("No audio files found in folder.")

wrapped_files = []
skipped_count = 0

for item in audio_files_metadata:
file_id = item["id"] # Get the Google Drive File ID

# Check if the file is already processed
existing = await AudioFile.find_one({
"audio_uuid": file_id,
"source": "gdrive"
})

if existing:
audio_logger.info(f"Skipping already processed file: {item['name']}")
skipped_count += 1
continue

# synchronous call now (but make the parent function async)
wrapped_file = await download_and_wrap_drive_file(service, item)
# Attach the file_id to the UploadFile object for later use
wrapped_file.audio_uuid = file_id
wrapped_files.append(wrapped_file)

if not wrapped_files and skipped_count > 0:
raise AudioValidationError(f"All {skipped_count} files in the folder have already been processed.")

return wrapped_files

except Exception as e:
if isinstance(e, AudioValidationError):
raise
raise AudioValidationError(f"Google Drive API Error: {e}") from e


Loading
Loading