Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions examples/audio_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Example: Multimodal RLM - Audio Support (TTS and Transcription)

This demonstrates the audio capabilities of RLM:
- speak() for text-to-speech
- audio_query() for audio transcription/analysis
"""

import os

from dotenv import load_dotenv

from rlm import RLM
from rlm.logger import RLMLogger

load_dotenv()

# Get the directory where this script is located
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

logger = RLMLogger(log_dir="./logs")

# Use Gemini which supports audio
rlm = RLM(
backend="gemini",
backend_kwargs={
"model_name": "gemini-2.5-flash",
"api_key": os.getenv("GEMINI_API_KEY"),
},
environment="local",
environment_kwargs={},
max_depth=1,
logger=logger,
verbose=True,
enable_multimodal=True, # Enable multimodal functions (vision_query, audio_query, speak)
)

# Example 1: Text-to-Speech
# Ask RLM to generate speech
context = {
"task": "Generate a spoken greeting",
"message": "Hello! This is a test of the RLM text-to-speech capability.",
"output_path": os.path.join(SCRIPT_DIR, "generated_speech.aiff"),
}

result = rlm.completion(
prompt=context,
root_prompt="Use speak(text, output_path) to convert context['message'] to audio and save it to context['output_path']. Return the path.",
)

print("\n" + "=" * 50)
print("TTS RESULT:")
print("=" * 50)
print(result.response)


# Example 2: Audio Analysis (if you have an audio file)
# Uncomment this section if you have an audio file to analyze:
"""
audio_context = {
"task": "Transcribe the audio",
"audio_file": "/path/to/your/audio.mp3",
}

result = rlm.completion(
prompt=audio_context,
root_prompt="Use audio_query to transcribe the audio file.",
)
print(result.response)
"""
52 changes: 52 additions & 0 deletions examples/multimodal_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Example: Multimodal RLM - Analyzing Images with Vision

This demonstrates the multimodal capabilities of RLM using the
vision_query() function to analyze images.
"""

import os

from dotenv import load_dotenv

from rlm import RLM
from rlm.logger import RLMLogger

load_dotenv()

# Get the directory where this script is located
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
TEST_IMAGE = os.path.join(SCRIPT_DIR, "test_image.png")

logger = RLMLogger(log_dir="./logs")

# Use Gemini which supports vision
rlm = RLM(
backend="gemini",
backend_kwargs={
"model_name": "gemini-2.5-flash",
"api_key": os.getenv("GEMINI_API_KEY"),
},
environment="local",
environment_kwargs={},
max_depth=1,
logger=logger,
verbose=True,
enable_multimodal=True, # Enable multimodal functions (vision_query, audio_query, speak)
)

# Create a context that includes references to images
context = {
"query": "Analyze the image and tell me what fruits are visible.",
"images": [TEST_IMAGE],
}

result = rlm.completion(
prompt=context,
root_prompt="What fruits are in the image? Use vision_query to analyze the image.",
)

print("\n" + "=" * 50)
print("FINAL RESULT:")
print("=" * 50)
print(result.response)
202 changes: 194 additions & 8 deletions rlm/clients/gemini.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import os
from collections import defaultdict
from pathlib import Path
from typing import Any

from dotenv import load_dotenv
Expand All @@ -11,6 +13,119 @@

load_dotenv()


def _load_image_as_part(image_source: str | dict) -> types.Part:
"""Load an image and return a Gemini Part object.

Args:
image_source: Either a file path (str), URL (str starting with http),
or a dict with 'type' and 'data' keys for base64 images.

Returns:
A Gemini Part object containing the image.
"""
if isinstance(image_source, dict):
# Base64 encoded image: {"type": "base64", "media_type": "image/png", "data": "..."}
if image_source.get("type") == "base64":
image_bytes = base64.b64decode(image_source["data"])
mime_type = image_source.get("media_type", "image/png")
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
# URL format from OpenAI-style: {"type": "image_url", "image_url": {"url": "..."}}
elif image_source.get("type") == "image_url":
url = image_source["image_url"]["url"]
if url.startswith("data:"):
# Data URL: data:image/png;base64,...
header, data = url.split(",", 1)
mime_type = header.split(":")[1].split(";")[0]
image_bytes = base64.b64decode(data)
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
else:
return types.Part.from_uri(file_uri=url, mime_type="image/jpeg")
elif isinstance(image_source, str):
if image_source.startswith(("http://", "https://")):
# URL
return types.Part.from_uri(file_uri=image_source, mime_type="image/jpeg")
else:
# Local file path
path = Path(image_source)
if path.exists():
mime_type = _get_mime_type(path)
with open(path, "rb") as f:
return types.Part.from_bytes(data=f.read(), mime_type=mime_type)
else:
raise FileNotFoundError(f"Image file not found: {image_source}")
raise ValueError(f"Unsupported image source type: {type(image_source)}")


def _get_mime_type(path: Path) -> str:
"""Get MIME type from file extension."""
suffix = path.suffix.lower()
mime_types = {
# Images
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
".bmp": "image/bmp",
# Audio
".mp3": "audio/mpeg",
".wav": "audio/wav",
".ogg": "audio/ogg",
".flac": "audio/flac",
".m4a": "audio/mp4",
".aac": "audio/aac",
".webm": "audio/webm",
# Video
".mp4": "video/mp4",
".mpeg": "video/mpeg",
".mov": "video/quicktime",
".avi": "video/x-msvideo",
".mkv": "video/x-matroska",
}
return mime_types.get(suffix, "application/octet-stream")


def _load_audio_as_part(audio_source: str | dict) -> types.Part:
"""Load an audio file and return a Gemini Part object.

Args:
audio_source: Either a file path (str), URL (str starting with http),
or a dict with 'type' and 'data' keys for base64 audio.

Returns:
A Gemini Part object containing the audio.
"""
if isinstance(audio_source, dict):
# Base64 encoded audio
if audio_source.get("type") == "base64":
audio_bytes = base64.b64decode(audio_source["data"])
mime_type = audio_source.get("media_type", "audio/mpeg")
return types.Part.from_bytes(data=audio_bytes, mime_type=mime_type)
# Path format
elif audio_source.get("type") == "audio_path":
path = Path(audio_source.get("path", ""))
if path.exists():
mime_type = _get_mime_type(path)
with open(path, "rb") as f:
return types.Part.from_bytes(data=f.read(), mime_type=mime_type)
else:
raise FileNotFoundError(f"Audio file not found: {audio_source.get('path')}")
elif isinstance(audio_source, str):
if audio_source.startswith(("http://", "https://")):
# URL - let Gemini fetch it
return types.Part.from_uri(file_uri=audio_source, mime_type="audio/mpeg")
else:
# Local file path
path = Path(audio_source)
if path.exists():
mime_type = _get_mime_type(path)
with open(path, "rb") as f:
return types.Part.from_bytes(data=f.read(), mime_type=mime_type)
else:
raise FileNotFoundError(f"Audio file not found: {audio_source}")
raise ValueError(f"Unsupported audio source type: {type(audio_source)}")

DEFAULT_GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")


Expand Down Expand Up @@ -95,7 +210,18 @@ async def acompletion(
def _prepare_contents(
self, prompt: str | list[dict[str, Any]]
) -> tuple[list[types.Content] | str, str | None]:
"""Prepare contents and extract system instruction for Gemini API."""
"""Prepare contents and extract system instruction for Gemini API.

Supports multimodal content where message content can be:
- A string (text only)
- A list of content items (text and images mixed)

Image items can be:
- {"type": "text", "text": "..."}
- {"type": "image_url", "image_url": {"url": "..."}}
- {"type": "image_path", "path": "/path/to/image.png"}
- {"type": "base64", "media_type": "image/png", "data": "..."}
"""
system_instruction = None

if isinstance(prompt, str):
Expand All @@ -110,20 +236,80 @@ def _prepare_contents(

if role == "system":
# Gemini handles system instruction separately
system_instruction = content
elif role == "user":
contents.append(types.Content(role="user", parts=[types.Part(text=content)]))
elif role == "assistant":
# Gemini uses "model" instead of "assistant"
contents.append(types.Content(role="model", parts=[types.Part(text=content)]))
if isinstance(content, str):
system_instruction = content
elif isinstance(content, list):
# Extract text from system message list
system_parts = []
for item in content:
if isinstance(item, str):
system_parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
system_parts.append(item.get("text", ""))
system_instruction = "\n".join(system_parts)
elif role in ("user", "assistant"):
gemini_role = "user" if role == "user" else "model"
parts = self._content_to_parts(content)
if parts:
contents.append(types.Content(role=gemini_role, parts=parts))
else:
# Default to user role for unknown roles
contents.append(types.Content(role="user", parts=[types.Part(text=content)]))
parts = self._content_to_parts(content)
if parts:
contents.append(types.Content(role="user", parts=parts))

return contents, system_instruction

raise ValueError(f"Invalid prompt type: {type(prompt)}")

def _content_to_parts(self, content: str | list) -> list[types.Part]:
"""Convert message content to Gemini Parts.

Args:
content: Either a string or a list of content items.

Returns:
List of Gemini Part objects.
"""
if isinstance(content, str):
return [types.Part(text=content)]

if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, str):
parts.append(types.Part(text=item))
elif isinstance(item, dict):
item_type = item.get("type", "text")
if item_type == "text":
parts.append(types.Part(text=item.get("text", "")))
elif item_type in ("image_url", "image_path", "base64"):
try:
# Use image_path for local files
if item_type == "image_path":
image_part = _load_image_as_part(item.get("path", ""))
else:
image_part = _load_image_as_part(item)
parts.append(image_part)
except Exception as e:
# If image loading fails, add error as text
parts.append(types.Part(text=f"[Image load error: {e}]"))
elif item_type == "audio_path":
try:
audio_part = _load_audio_as_part(item.get("path", ""))
parts.append(audio_part)
except Exception as e:
parts.append(types.Part(text=f"[Audio load error: {e}]"))
elif item_type == "audio_url":
try:
audio_part = _load_audio_as_part(item.get("url", ""))
parts.append(audio_part)
except Exception as e:
parts.append(types.Part(text=f"[Audio load error: {e}]"))
return parts

return [types.Part(text=str(content))]

def _track_cost(self, response: types.GenerateContentResponse, model: str):
self.model_call_counts[model] += 1

Expand Down
Loading