Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add-parameters #6

Merged
merged 2 commits into from
Nov 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,40 @@
from griptape.drivers import BaseImageGenerationDriver


def steps_validator(instance, attribute, value):
if value and (value < 1 or value > 50):
raise ValueError("steps must be between 1 and 50")


def size_validator(instance, attribute, value):
if value and value % 32 != 0:
raise ValueError(f"{attribute} must be a multiple of 32")
if value and value < 256 or value > 1440:
raise ValueError(f"{attribute} must be between 256 and 1440")


def safety_validator(instance, attribute, value):
if value and (value < 0 or value > 6):
raise ValueError("safety_tolerance must be between 0 and 6")


def aspect_ratio_validator(instance, attribute, value):
if value:
width, height = value.split(":")
if width < 9 or width > 21 or height < 9 or height > 21:
raise ValueError("aspect_ratio must be between 9:21 and 21:9")


def guidance_validator(instance, attribute, value):
if value and (value < 1.5 or value > 5):
raise ValueError("guidance must be between 1.5 and 5")


def interval_validator(instance, attribute, value):
if value and (value < 1 or value > 4):
raise ValueError("interval must be between 1 and 4")


@define
class BlackForestImageGenerationDriver(BaseImageGenerationDriver):
"""Driver for the Black Forest Labs image generation API.
Expand All @@ -22,11 +56,12 @@ class BlackForestImageGenerationDriver(BaseImageGenerationDriver):
height: Height of the generated image. Valid for 'flux-pro-1.1', 'flux-pro', 'flux-dev' models only. Integer range from 256 to 1440. Must be a multiple of 32. Default is 1024.
aspect_ratio: Aspect ratio of the generated image between 21:9 and 9:21. Valid for 'flux-pro-1.1-ultra' model only. Default is 16:9.
prompt_upsampling: Optional flag to perform upsampling on the prompt. Valid for `flux-pro-1.1', 'flux-pro', 'flux-dev' models only. If active, automatically modifies the prompt for more creative generation.
safety_tolerance: Optional tolerance level for input and output moderation. Valid for 'flux-pro-1.1', 'flux-pro', 'flux-dev' models only. Between 0 and 6, 0 being most strict, 6 being least strict.
safety_tolerance: Optional tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.
seed: Optional seed for reproducing results. Default is None.
steps: Optional number of steps for the image generation process. Valid for 'flux-dev' model only. Default is None.


steps: Optional number of steps for the image generation process. Valid for 'flux-dev' and `flux-pro` models only. Can be a value between 1 and 50. Default is None.
guidance: Optional guidance scale for image generation. High guidance scales improve prompt adherence at the cost of reduced realism. Min: 1.5, max: 5. Valid for 'flux-dev' and 'flux-pro' models only.
interval: Optional interval parameter for guidance control. Valid for 'flux-pro' model only. Value is an integer between 1 and 4. Default is None.
raw: Optional flag to generate less processed, more natural-looking images. Valid for 'flux-pro-1.1-ultra' model only. Default is False.
"""

base_url: str = field(
Expand All @@ -39,45 +74,58 @@ class BlackForestImageGenerationDriver(BaseImageGenerationDriver):
kw_only=True,
metadata={"serializable": False},
)
width: int = field(default=1024, kw_only=True)
height: int = field(default=768, kw_only=True)
width: int = field(default=1024, kw_only=True, validator=size_validator)
height: int = field(default=768, kw_only=True, validator=size_validator)
sleep_interval: float = field(default=0.5, kw_only=True)
safety_tolerance: int | None = field(default=None, kw_only=True)
aspect_ratio: str = field(default=None, kw_only=True)
safety_tolerance: int | None = field(
default=None, kw_only=True, validator=safety_validator
)
aspect_ratio: str = field(
default=None, kw_only=True, validator=aspect_ratio_validator
)
seed: int | None = field(default=None, kw_only=True)
prompt_upsampling: bool = field(default=False, kw_only=True)
steps: int | None = field(default=None, kw_only=True)
steps: int | None = field(default=None, kw_only=True, validator=steps_validator)
guidance: float | None = field(
default=None, kw_only=True, validator=guidance_validator
)
interval: int | None = field(
default=None, kw_only=True, validator=interval_validator
)
raw: bool = field(default=False, kw_only=True)

def try_text_to_image(
self, prompts: list[str], negative_prompts: list[str] | None = None
) -> ImageArtifact:
prompt = " ".join(prompts)

if self.width % 32 != 0 or self.height % 32 != 0:
msg = "width and height must be multiples of 32"
raise ValueError(msg)
if self.width < 256 or self.width > 1440:
raise ValueError("width must be between 256 and 1440")
if self.safety_tolerance and (
self.safety_tolerance < 0 or self.safety_tolerance > 6
):
raise ValueError("safety_tolerance must be between 0 and 6")

data: dict[str, Any] = {
"prompt": prompt,
}

if self.seed:
data["seed"] = self.seed
if self.safety_tolerance:
data["safety_tolerance"] = self.safety_tolerance

if self.model == "flux-pro-1.1-ultra" and self.aspect_ratio:
if self.aspect_ratio and self.model == "flux-pro-1.1-ultra":
data["aspect_ratio"] = self.aspect_ratio

if self.raw and self.model == "flux-pro-1.1-ultra":
data["raw"] = self.raw

if self.guidance and self.model in ["flux-dev", "flux-pro"]:
data["guidance"] = float(self.guidance)

if self.steps and self.model in ["flux-dev", "flux-pro"]:
data["steps"] = int(self.steps)

if self.interval and self.model == "flux-pro":
data["interval"] = int(self.interval)

if self.model in ["flux-pro-1.1", "flux-pro", "flux-dev"]:
data["width"] = self.width
data["height"] = self.height
if self.safety_tolerance:
data["safety_tolerance"] = self.safety_tolerance
if self.prompt_upsampling:
data["prompt_upsampling"] = self.prompt_upsampling

Expand Down