Skip to content

Commit

Permalink
use helper for image gen tests (#7343)
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaan-jaff authored Dec 21, 2024
1 parent b90b98b commit 11e5960
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 41 deletions.
107 changes: 71 additions & 36 deletions litellm/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ def completion_cost( # noqa: PLR0915
region_name=None, # used for bedrock pricing
### IMAGE GEN ###
size: Optional[str] = None,
quality=None,
n=None, # number of images
quality: Optional[str] = None,
n: Optional[int] = None, # number of images
### CUSTOM PRICING ###
custom_cost_per_token: Optional[CostPerToken] = None,
custom_cost_per_second: Optional[float] = None,
Expand Down Expand Up @@ -640,41 +640,14 @@ def completion_cost( # noqa: PLR0915
raise TypeError(
"completion_response must be of type ImageResponse for bedrock image cost calculation"
)
if size is None:
size = "1024-x-1024" # openai default
# fix size to match naming convention
if "x" in size and "-x-" not in size:
size = size.replace("x", "-x-")
image_gen_model_name = f"{size}/{model}"
image_gen_model_name_with_quality = image_gen_model_name
if quality is not None:
image_gen_model_name_with_quality = f"{quality}/{image_gen_model_name}"
size_parts = size.split("-x-")
height = int(size_parts[0]) # if it's 1024-x-1024 vs. 1024x1024
width = int(size_parts[1])
verbose_logger.debug(f"image_gen_model_name: {image_gen_model_name}")
verbose_logger.debug(
f"image_gen_model_name_with_quality: {image_gen_model_name_with_quality}"
)
if image_gen_model_name in litellm.model_cost:
return (
litellm.model_cost[image_gen_model_name]["input_cost_per_pixel"]
* height
* width
* n
)
elif image_gen_model_name_with_quality in litellm.model_cost:
return (
litellm.model_cost[image_gen_model_name_with_quality][
"input_cost_per_pixel"
]
* height
* width
* n
)
else:
raise Exception(
f"Model={image_gen_model_name} not found in completion cost model map"
return default_image_cost_calculator(
model=model,
quality=quality,
custom_llm_provider=custom_llm_provider,
n=n,
size=size,
optional_params=optional_params,
)
elif (
call_type == CallTypes.speech.value or call_type == CallTypes.aspeech.value
Expand Down Expand Up @@ -869,3 +842,65 @@ def transcription_cost(
return openai_cost_per_second(
model=model, custom_llm_provider=custom_llm_provider, duration=duration
)


def default_image_cost_calculator(
model: str,
custom_llm_provider: Optional[str] = None,
quality: Optional[str] = None,
n: Optional[int] = 1, # Default to 1 image
size: Optional[str] = "1024-x-1024", # OpenAI default
optional_params: Optional[dict] = None,
) -> float:
"""
Default image cost calculator for image generation
Args:
model (str): Model name
image_response (ImageResponse): Response from image generation
quality (Optional[str]): Image quality setting
n (Optional[int]): Number of images generated
size (Optional[str]): Image size (e.g. "1024x1024" or "1024-x-1024")
Returns:
float: Cost in USD for the image generation
Raises:
Exception: If model pricing not found in cost map
"""
# Standardize size format to use "-x-"
size_str: str = size or "1024-x-1024"
size_str = (
size_str.replace("x", "-x-")
if "x" in size_str and "-x-" not in size_str
else size_str
)

# Parse dimensions
height, width = map(int, size_str.split("-x-"))

# Build model names for cost lookup
base_model_name = f"{size_str}/{model}"
if custom_llm_provider and model.startswith(custom_llm_provider):
base_model_name = (
f"{custom_llm_provider}/{size_str}/{model.replace(custom_llm_provider, '')}"
)
model_name_with_quality = (
f"{quality}/{base_model_name}" if quality else base_model_name
)

verbose_logger.debug(
f"Looking up cost for models: {model_name_with_quality}, {base_model_name}"
)

# Try model with quality first, fall back to base model name
if model_name_with_quality in litellm.model_cost:
cost_info = litellm.model_cost[model_name_with_quality]
elif base_model_name in litellm.model_cost:
cost_info = litellm.model_cost[base_model_name]
else:
raise Exception(
f"Model not found in cost map. Tried {model_name_with_quality} and {base_model_name}"
)

return cost_info["input_cost_per_pixel"] * height * width * n
18 changes: 13 additions & 5 deletions tests/image_gen_tests/test_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import sys
import traceback


sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path

from dotenv import load_dotenv
from openai.types.image import Image
from litellm.caching import InMemoryCache
Expand All @@ -14,10 +19,6 @@
load_dotenv()
import asyncio
import os

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest

import litellm
Expand Down Expand Up @@ -142,7 +143,7 @@ def get_base_image_generation_call_args(self) -> dict:
"api_version": "2023-09-01-preview",
"metadata": {
"model_info": {
"base_model": "dall-e-3",
"base_model": "azure/dall-e-3",
}
},
}
Expand All @@ -158,8 +159,15 @@ def test_image_generation_azure_dall_e_3():
api_version="2023-12-01-preview",
api_base=os.getenv("AZURE_SWEDEN_API_BASE"),
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
metadata={
"model_info": {
"base_model": "azure/dall-e-3",
}
},
)
print(f"response: {response}")

print("response", response._hidden_params)
assert len(response.data) > 0
except litellm.InternalServerError as e:
pass
Expand Down

0 comments on commit 11e5960

Please sign in to comment.