Skip to content

Commit

Permalink
feat:added a multi-modal part for audio (#3)
Browse files Browse the repository at this point in the history
* feat(audio part): added a new format part. Audio part

* Add audio file support

---------

Co-authored-by: haruiz <henryruiz22@gmail.com>
  • Loading branch information
jggomez and haruiz committed May 13, 2024
1 parent fce835e commit e879cb5
Show file tree
Hide file tree
Showing 27 changed files with 433 additions and 88 deletions.
49 changes: 49 additions & 0 deletions examples/chat_with_your_audios.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from rich import print

from geminiplayground.core import GeminiClient
from geminiplayground.parts import AudioFile
from geminiplayground.schemas import GenerateRequestParts, TextPart, GenerateRequest
from dotenv import load_dotenv, find_dotenv
from rich import print

load_dotenv(find_dotenv())


def chat_wit_your_audios():
"""
Get the content parts of an audio file and generate a request.
:return:
"""
gemini_client = GeminiClient()
model = "models/gemini-1.5-pro-latest"

audio_file_path = "<your audio file>.mp3"
audio_file = AudioFile(audio_file_path, gemini_client=gemini_client)
# audio_file.delete()
audio_files = audio_file.files
print("Audio files: ", audio_files)

audio_parts = audio_file.content_parts()
request_parts = GenerateRequestParts(
parts=[
TextPart(text="Listen this audio:"),
*audio_parts,
TextPart(text="Describe what you heard"),
]
)
request = GenerateRequest(
contents=[request_parts],
)
tokens_count = gemini_client.get_tokens_count(model, request)
print("Tokens count: ", tokens_count)
response = gemini_client.generate_response(model, request)

# Print the response
for candidate in response.candidates:
for part in candidate.content.parts:
if part.text:
print(part.text)


if __name__ == "__main__":
chat_wit_your_audios()
9 changes: 6 additions & 3 deletions src/geminiplayground/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def check_api_key():

@cli.command()
def ui(
host: str = "0.0.0.0",
host: str = "localhost",
port: int = 8081,
workers: int = os.cpu_count() * 2 + 1,
reload: Annotated[bool, typer.Option("--reload")] = True,
Expand All @@ -33,6 +33,7 @@ def ui(
Launch the web app
"""
check_api_key()

import uvicorn

uvicorn.run(
Expand All @@ -46,17 +47,19 @@ def ui(

@cli.command()
def api(
host: str = "0.0.0.0",
host: str = "localhost",
port: int = 8081,
workers: int = os.cpu_count() * 2 + 1,
reload: Annotated[bool, typer.Option("--reload")] = True

):
"""
Launch the API
"""
check_api_key()

import uvicorn

check_api_key()
uvicorn.run(
"geminiplayground.web.api:api",
host=host,
Expand Down
1 change: 1 addition & 0 deletions src/geminiplayground/parts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .image_part import ImageFile
from .multimodal_part import MultimodalPart
from .video_part import VideoFile
from .audio_part import AudioFile
from .multimodal_part_factory import MultimodalPartFactory
129 changes: 129 additions & 0 deletions src/geminiplayground/parts/audio_part.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import logging
import typing
import validators
import urllib.request
from geminiplayground.catching import cache
from geminiplayground.core.gemini_client import GeminiClient
from geminiplayground.schemas import UploadFile
from .multimodal_part import MultimodalPart
from geminiplayground.utils import normalize_url
from geminiplayground.utils import get_file_name_from_path
from geminiplayground.utils import get_expire_time
from pathlib import Path
from urllib.error import HTTPError

logger = logging.getLogger("rich")


def get_audio_from_url(url: str) -> str:
"""
Create an audio from url and return it
"""
http_uri = normalize_url(url)
try:
assert validators.url(http_uri), "invalid url"
file_name, _ = urllib.request.urlretrieve(url)
logger.info(f"Temporary file was saved in {file_name}")
return file_name
except HTTPError as err:
if err.strerror == 404:
raise Exception("Audio not found")
elif err.code in [403, 406]:
raise Exception("Audio image, it can not be reached")
else:
raise


def get_audio_from_anywhere(uri_or_path: typing.Union[str, Path]) -> str:
"""
read an audio from an url or local file and return it
"""
uri_or_path = str(uri_or_path)
if validators.url(uri_or_path):
return get_audio_from_url(uri_or_path)
return uri_or_path


def upload_audio(
audio_path: typing.Union[str, Path],
gemini_client: GeminiClient = None):
"""
Upload an audio to Gemini
:param gemini_client: The Gemini client
:param audio_path: The path to the audio
:return:
"""
audio_path = get_audio_from_anywhere(audio_path)
audio_filename = get_file_name_from_path(audio_path)

if audio_path:
upload_file = UploadFile.from_path(audio_path,
body={"file": {"displayName": audio_filename}})
uploaded_file = gemini_client.upload_file(upload_file)
return uploaded_file


class AudioFile(MultimodalPart):
"""
Audio file part implementation
"""

def __init__(self, audio_path: typing.Union[str, Path], **kwargs):
self.audio_path = audio_path
self.audio_name = get_file_name_from_path(audio_path)
self.gemini_client = kwargs.get("gemini_client", GeminiClient())

def upload(self):
"""
Upload the audio to Gemini
:return:
"""
if cache.get(self.audio_name):
cached_file = cache.get(self.audio_name)
return [cached_file]

delta_t = get_expire_time()
uploaded_file = upload_audio(self.audio_path, self.gemini_client)
cache.set(self.audio_name, uploaded_file, expire=delta_t)
return [uploaded_file]

@property
def files(self):
"""
Get the files
:return:
"""
return self.upload()

def force_upload(self):
"""
Force the upload of the audio
:return:
"""
self.delete()
self.upload()

def delete(self):
"""
Delete the image from Gemini
:return:
"""
if cache.get(self.audio_name):
cached_file = cache.get(self.audio_name)
self.gemini_client.delete_file(cached_file.name)
# remove the cached file
cache.delete(self.audio_name)

def clear_cache(self):
"""
Clear the cache
:return:
"""
cache.delete(self.audio_name)

def content_parts(self):
"""
Get the content parts for the audio
:return:
"""
return list(map(lambda f: f.to_file_part(), self.files))
8 changes: 4 additions & 4 deletions src/geminiplayground/parts/image_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def upload_image(image_path: typing.Union[str, Path], gemini_client: GeminiClien


class ImageFile(MultimodalPart):
"""
Image File Part implementation
"""

def __init__(self, image_path: typing.Union[str, Path], **kwargs):
self.image_path = image_path
Expand All @@ -44,10 +47,7 @@ def upload(self):
cached_file = cache.get(self.image_name)
return [cached_file]

now = datetime.now()
future = now + timedelta(days=1)
delta_t = future - now
delta_t = delta_t.total_seconds()
delta_t = get_expire_time()
uploaded_file = upload_image(self.image_path, self.gemini_client)
cache.set(self.image_name, uploaded_file, expire=delta_t)
return [uploaded_file]
Expand Down
9 changes: 6 additions & 3 deletions src/geminiplayground/parts/multimodal_part_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .image_part import ImageFile
from .video_part import VideoFile
from .audio_part import AudioFile


class MultimodalPartFactory:
Expand All @@ -23,7 +24,9 @@ def from_path(path: typing.Union[str, Path], **kwargs):
if path.is_file():
mime_type = mimetypes.guess_type(path.as_posix())[0]
if mime_type.startswith("image"):
return ImageFile(image_path=path, **kwargs)
elif mime_type.startswith("video"):
return VideoFile(video_path=path, **kwargs)
return ImageFile(path, **kwargs)
if mime_type.startswith("video"):
return VideoFile(path, **kwargs)
if mime_type.startswith("audio"):
return AudioFile(path, **kwargs)
raise ValueError(f"Unsupported file type: {path}")
7 changes: 2 additions & 5 deletions src/geminiplayground/parts/video_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def upload_video(video_path: typing.Union[str, Path], gemini_client: GeminiClien

class VideoFile(MultimodalPart):
"""
Extract frames from a video and upload them to Gemini
Video part implementation
"""

def __init__(self, video_path: typing.Union[str, Path], **kwargs):
Expand All @@ -59,10 +59,7 @@ def upload(self):
if cache.get(self.video_name):
return cache.get(self.video_name)

now = datetime.now()
future = now + timedelta(days=1)
delta_t = future - now
delta_t = delta_t.total_seconds()
delta_t = get_expire_time()
uploaded_files = upload_video(self.video_path, self.gemini_client)
cache.set(self.video_name, uploaded_files, expire=delta_t)
return uploaded_files
Expand Down
12 changes: 12 additions & 0 deletions src/geminiplayground/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import math
from tqdm import tqdm
import validators
from datetime import datetime, timedelta


def rm_tree(pth: typing.Union[str, Path]):
Expand Down Expand Up @@ -466,3 +467,14 @@ def create_image_thumbnail(
background.paste(pil_image, mask=pil_image.split()[3])
pil_image = background
return pil_image


def get_expire_time():
"""
Get the expiration time for the cache
"""
now = datetime.now()
future = now + timedelta(days=1)
delta_t = future - now
delta_t = delta_t.total_seconds()
return delta_t
Loading

0 comments on commit e879cb5

Please sign in to comment.