Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:added a multi-modal part for audio #3

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions examples/chat_with_your_audios.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from rich import print

from geminiplayground.core import GeminiClient
from geminiplayground.parts import AudioFile
from geminiplayground.schemas import GenerateRequestParts, TextPart, GenerateRequest
from dotenv import load_dotenv, find_dotenv
from rich import print

load_dotenv(find_dotenv())


def chat_wit_your_audios():
"""
Get the content parts of an audio file and generate a request.
:return:
"""
gemini_client = GeminiClient()
model = "models/gemini-1.5-pro-latest"

audio_file_path = "<your audio file>.mp3"
audio_file = AudioFile(audio_file_path, gemini_client=gemini_client)
# audio_file.delete()
audio_files = audio_file.files
print("Audio files: ", audio_files)

audio_parts = audio_file.content_parts()
request_parts = GenerateRequestParts(
parts=[
TextPart(text="Listen this audio:"),
*audio_parts,
TextPart(text="Describe what you heard"),
]
)
request = GenerateRequest(
contents=[request_parts],
)
tokens_count = gemini_client.get_tokens_count(model, request)
print("Tokens count: ", tokens_count)
response = gemini_client.generate_response(model, request)

# Print the response
for candidate in response.candidates:
for part in candidate.content.parts:
if part.text:
print(part.text)


if __name__ == "__main__":
chat_wit_your_audios()
9 changes: 6 additions & 3 deletions src/geminiplayground/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def check_api_key():

@cli.command()
def ui(
host: str = "0.0.0.0",
host: str = "localhost",
port: int = 8081,
workers: int = os.cpu_count() * 2 + 1,
reload: Annotated[bool, typer.Option("--reload")] = True,
Expand All @@ -33,6 +33,7 @@ def ui(
Launch the web app
"""
check_api_key()

import uvicorn

uvicorn.run(
Expand All @@ -46,17 +47,19 @@ def ui(

@cli.command()
def api(
host: str = "0.0.0.0",
host: str = "localhost",
port: int = 8081,
workers: int = os.cpu_count() * 2 + 1,
reload: Annotated[bool, typer.Option("--reload")] = True

):
"""
Launch the API
"""
check_api_key()

import uvicorn

check_api_key()
uvicorn.run(
"geminiplayground.web.api:api",
host=host,
Expand Down
1 change: 1 addition & 0 deletions src/geminiplayground/parts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .image_part import ImageFile
from .multimodal_part import MultimodalPart
from .video_part import VideoFile
from .audio_part import AudioFile
from .multimodal_part_factory import MultimodalPartFactory
129 changes: 129 additions & 0 deletions src/geminiplayground/parts/audio_part.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import logging
import typing
import validators
import urllib.request
from geminiplayground.catching import cache
from geminiplayground.core.gemini_client import GeminiClient
from geminiplayground.schemas import UploadFile
from .multimodal_part import MultimodalPart
from geminiplayground.utils import normalize_url
from geminiplayground.utils import get_file_name_from_path
from geminiplayground.utils import get_expire_time
from pathlib import Path
from urllib.error import HTTPError

logger = logging.getLogger("rich")


def get_audio_from_url(url: str) -> str:
"""
Create an audio from url and return it
"""
http_uri = normalize_url(url)
try:
assert validators.url(http_uri), "invalid url"
file_name, _ = urllib.request.urlretrieve(url)
logger.info(f"Temporary file was saved in {file_name}")
return file_name
except HTTPError as err:
if err.strerror == 404:
raise Exception("Audio not found")
elif err.code in [403, 406]:
raise Exception("Audio image, it can not be reached")
else:
raise


def get_audio_from_anywhere(uri_or_path: typing.Union[str, Path]) -> str:
"""
read an audio from an url or local file and return it
"""
uri_or_path = str(uri_or_path)
if validators.url(uri_or_path):
return get_audio_from_url(uri_or_path)
return uri_or_path


def upload_audio(
audio_path: typing.Union[str, Path],
gemini_client: GeminiClient = None):
"""
Upload an audio to Gemini
:param gemini_client: The Gemini client
:param audio_path: The path to the audio
:return:
"""
audio_path = get_audio_from_anywhere(audio_path)
audio_filename = get_file_name_from_path(audio_path)

if audio_path:
upload_file = UploadFile.from_path(audio_path,
body={"file": {"displayName": audio_filename}})
uploaded_file = gemini_client.upload_file(upload_file)
return uploaded_file


class AudioFile(MultimodalPart):
"""
Audio file part implementation
"""

def __init__(self, audio_path: typing.Union[str, Path], **kwargs):
self.audio_path = audio_path
self.audio_name = get_file_name_from_path(audio_path)
self.gemini_client = kwargs.get("gemini_client", GeminiClient())

def upload(self):
"""
Upload the audio to Gemini
:return:
"""
if cache.get(self.audio_name):
cached_file = cache.get(self.audio_name)
return [cached_file]

delta_t = get_expire_time()
uploaded_file = upload_audio(self.audio_path, self.gemini_client)
cache.set(self.audio_name, uploaded_file, expire=delta_t)
return [uploaded_file]

@property
def files(self):
"""
Get the files
:return:
"""
return self.upload()

def force_upload(self):
"""
Force the upload of the audio
:return:
"""
self.delete()
self.upload()

def delete(self):
"""
Delete the image from Gemini
:return:
"""
if cache.get(self.audio_name):
cached_file = cache.get(self.audio_name)
self.gemini_client.delete_file(cached_file.name)
# remove the cached file
cache.delete(self.audio_name)

def clear_cache(self):
"""
Clear the cache
:return:
"""
cache.delete(self.audio_name)

def content_parts(self):
"""
Get the content parts for the audio
:return:
"""
return list(map(lambda f: f.to_file_part(), self.files))
8 changes: 4 additions & 4 deletions src/geminiplayground/parts/image_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def upload_image(image_path: typing.Union[str, Path], gemini_client: GeminiClien


class ImageFile(MultimodalPart):
"""
Image File Part implementation
"""

def __init__(self, image_path: typing.Union[str, Path], **kwargs):
self.image_path = image_path
Expand All @@ -44,10 +47,7 @@ def upload(self):
cached_file = cache.get(self.image_name)
return [cached_file]

now = datetime.now()
future = now + timedelta(days=1)
delta_t = future - now
delta_t = delta_t.total_seconds()
delta_t = get_expire_time()
uploaded_file = upload_image(self.image_path, self.gemini_client)
cache.set(self.image_name, uploaded_file, expire=delta_t)
return [uploaded_file]
Expand Down
9 changes: 6 additions & 3 deletions src/geminiplayground/parts/multimodal_part_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .image_part import ImageFile
from .video_part import VideoFile
from .audio_part import AudioFile


class MultimodalPartFactory:
Expand All @@ -23,7 +24,9 @@ def from_path(path: typing.Union[str, Path], **kwargs):
if path.is_file():
mime_type = mimetypes.guess_type(path.as_posix())[0]
if mime_type.startswith("image"):
return ImageFile(image_path=path, **kwargs)
elif mime_type.startswith("video"):
return VideoFile(video_path=path, **kwargs)
return ImageFile(path, **kwargs)
if mime_type.startswith("video"):
return VideoFile(path, **kwargs)
if mime_type.startswith("audio"):
return AudioFile(path, **kwargs)
raise ValueError(f"Unsupported file type: {path}")
7 changes: 2 additions & 5 deletions src/geminiplayground/parts/video_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def upload_video(video_path: typing.Union[str, Path], gemini_client: GeminiClien

class VideoFile(MultimodalPart):
"""
Extract frames from a video and upload them to Gemini
Video part implementation
"""

def __init__(self, video_path: typing.Union[str, Path], **kwargs):
Expand All @@ -59,10 +59,7 @@ def upload(self):
if cache.get(self.video_name):
return cache.get(self.video_name)

now = datetime.now()
future = now + timedelta(days=1)
delta_t = future - now
delta_t = delta_t.total_seconds()
delta_t = get_expire_time()
uploaded_files = upload_video(self.video_path, self.gemini_client)
cache.set(self.video_name, uploaded_files, expire=delta_t)
return uploaded_files
Expand Down
12 changes: 12 additions & 0 deletions src/geminiplayground/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import math
from tqdm import tqdm
import validators
from datetime import datetime, timedelta


def rm_tree(pth: typing.Union[str, Path]):
Expand Down Expand Up @@ -466,3 +467,14 @@ def create_image_thumbnail(
background.paste(pil_image, mask=pil_image.split()[3])
pil_image = background
return pil_image


def get_expire_time():
"""
Get the expiration time for the cache
"""
now = datetime.now()
future = now + timedelta(days=1)
delta_t = future - now
delta_t = delta_t.total_seconds()
return delta_t
Loading
Loading