Skip to content

Commit

Permalink
Merge pull request #10 from Zingzy/9-feature-request-separate-configu…
Browse files Browse the repository at this point in the history
…rations-into-a-toml-file

Separate Configurations into a TOML File
  • Loading branch information
Zingzy authored Feb 2, 2025
2 parents c97fb3f + c644dee commit 8da002b
Show file tree
Hide file tree
Showing 15 changed files with 440 additions and 179 deletions.
3 changes: 1 addition & 2 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
TOKEN=
MONGODB_URI=
TOKEN=
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,7 @@ cython_debug/
# PyPI configuration file
.pypirc

# Configurations
config.toml

todo.md
64 changes: 37 additions & 27 deletions cogs/imagine_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import discord
from discord import app_commands
from discord.ext import commands
import traceback

from constants import MODELS
from config import config
from utils.image_gen_utils import generate_image, validate_dimensions, validate_prompt
from utils.embed_utils import generate_pollinate_embed, generate_error_message
from utils.pollinate_utils import parse_url
from utils.error_handler import send_error_embed
from exceptions import DimensionTooSmallError, PromptTooLongError, APIError
import traceback


class ImagineButtonView(discord.ui.View):
Expand All @@ -20,7 +20,7 @@ def __init__(self) -> None:
label="Regenerate",
style=discord.ButtonStyle.secondary,
custom_id="regenerate-button",
emoji="<:redo:1187101382101180456>",
emoji=f"<:redo:{config.bot.emojis['redo_emoji_id']}>",
)
async def regenerate(
self, interaction: discord.Interaction, button: discord.ui.Button
Expand All @@ -29,7 +29,7 @@ async def regenerate(
embed=discord.Embed(
title="Regenerating Your Image",
description="Please wait while we generate your image",
color=discord.Color.blurple(),
color=int(config.ui.colors.success, 16),
),
ephemeral=True,
)
Expand All @@ -49,7 +49,7 @@ async def regenerate(
embed=discord.Embed(
title="Couldn't Generate the Requested Image 😔",
description=f"```\n{e.message}\n```",
color=discord.Color.red(),
color=int(config.ui.colors.error, 16),
),
ephemeral=True,
)
Expand All @@ -60,7 +60,7 @@ async def regenerate(
embed=discord.Embed(
title="Error",
description=f"Error generating image : {e}",
color=discord.Color.red(),
color=int(config.ui.colors.error, 16),
),
ephemeral=True,
)
Expand All @@ -86,7 +86,7 @@ async def regenerate(
@discord.ui.button(
style=discord.ButtonStyle.red,
custom_id="delete-button",
emoji="<:delete:1187102382312652800>",
emoji=f"<:delete:{config.bot.emojis['delete_emoji_id']}>",
)
async def delete(self, interaction: discord.Interaction, button: discord.ui.Button):
try:
Expand All @@ -99,8 +99,8 @@ async def delete(self, interaction: discord.Interaction, button: discord.ui.Butt
await interaction.response.send_message(
embed=discord.Embed(
title="Error",
description="You can only delete the images prompted by you",
color=discord.Color.red(),
description=config.ui.error_messages["delete_unauthorized"],
color=int(config.ui.colors.error, 16),
),
ephemeral=True,
)
Expand All @@ -114,7 +114,7 @@ async def delete(self, interaction: discord.Interaction, button: discord.ui.Butt
embed=discord.Embed(
title="Error Deleting the Image",
description=f"{e}",
color=discord.Color.red(),
color=int(config.ui.colors.error, 16),
),
ephemeral=True,
)
Expand All @@ -124,7 +124,7 @@ async def delete(self, interaction: discord.Interaction, button: discord.ui.Butt
label="Bookmark",
style=discord.ButtonStyle.secondary,
custom_id="bookmark-button",
emoji="<:save:1187101389822902344>",
emoji=f"<:save:{config.bot.emojis['save_emoji_id']}>",
)
async def bookmark(
self, interaction: discord.Interaction, button: discord.ui.Button
Expand All @@ -137,7 +137,7 @@ async def bookmark(

embed: discord.Embed = discord.Embed(
description=f"**Prompt : {prompt}**",
color=discord.Color.og_blurple(),
color=int(config.ui.colors.success, 16),
)
embed.add_field(
name="",
Expand All @@ -152,7 +152,7 @@ async def bookmark(
embed=discord.Embed(
title="Image Bookmarked",
description="The image has been bookmarked and sent to your DMs",
color=discord.Color.blurple(),
color=int(config.ui.colors.success, 16),
),
ephemeral=True,
)
Expand All @@ -164,7 +164,7 @@ async def bookmark(
embed=discord.Embed(
title="Error Bookmarking the Image",
description=f"{e}",
color=discord.Color.red(),
color=int(config.ui.colors.error, 16),
),
ephemeral=True,
)
Expand All @@ -174,17 +174,23 @@ async def bookmark(
class Imagine(commands.Cog):
def __init__(self, bot) -> None:
self.bot = bot
self.command_config = config.commands["pollinate"]

async def cog_load(self) -> None:
await self.bot.wait_until_ready()
self.bot.add_view(ImagineButtonView())

@app_commands.command(name="pollinate", description="Generate AI Images")
@app_commands.choices(
model=[app_commands.Choice(name=choice, value=choice) for choice in MODELS],
model=[
app_commands.Choice(name=choice, value=choice) for choice in config.MODELS
],
)
@app_commands.guild_only()
@app_commands.checks.cooldown(1, 10)
@app_commands.checks.cooldown(
config.commands["pollinate"].cooldown.rate,
config.commands["pollinate"].cooldown.seconds,
)
@app_commands.describe(
prompt="Prompt of the Image you want want to generate",
height="Height of the Image",
Expand All @@ -200,22 +206,22 @@ async def imagine_command(
self,
interaction: discord.Interaction,
prompt: str,
width: int = 1000,
height: int = 1000,
model: app_commands.Choice[str] = MODELS[0],
enhance: bool | None = None,
safe: bool = False,
cached: bool = False,
nologo: bool = False,
private: bool = False,
width: int = config.commands["pollinate"].default_width,
height: int = config.commands["pollinate"].default_height,
model: app_commands.Choice[str] = config.MODELS[0],
enhance: bool | None = config.image_generation.defaults.enhance,
safe: bool = config.image_generation.defaults.safe,
cached: bool = config.image_generation.defaults.cached,
nologo: bool = config.image_generation.defaults.nologo,
private: bool = config.image_generation.defaults.private,
) -> None:
validate_dimensions(width, height)
validate_prompt(prompt)

await interaction.response.defer(thinking=True, ephemeral=private)

try:
model = model.value
model = model.value if model else None
except Exception:
pass

Expand Down Expand Up @@ -250,7 +256,9 @@ async def imagine_command_error(
embed: discord.Embed = await generate_error_message(
interaction,
error,
cooldown_configuration=["- 1 time every 10 seconds"],
cooldown_configuration=[
f"- {self.command_config.cooldown.rate} time every {self.command_config.cooldown.seconds} seconds",
],
)
return await interaction.response.send_message(embed=embed, ephemeral=True)

Expand All @@ -277,7 +285,9 @@ async def imagine_command_error(

else:
await send_error_embed(
interaction, "An unexprected error occurred", f"```\n{str(error)}\n```"
interaction,
"An unexpected error occurred",
f"```\n{str(error)}\n```",
)


Expand Down
52 changes: 31 additions & 21 deletions cogs/multi_pollinate_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import traceback
import asyncio

from config import config
from utils.embed_utils import generate_error_message
from utils.image_gen_utils import generate_image, validate_dimensions, validate_prompt
from utils.error_handler import send_error_embed
Expand All @@ -15,7 +16,6 @@
DimensionTooSmallError,
APIError,
)
from constants import MODELS


class multiImagineButtonView(discord.ui.View):
Expand All @@ -39,7 +39,7 @@ def create_buttons(self) -> None:
label="",
style=discord.ButtonStyle.danger,
custom_id="multiimagine_delete",
emoji="<:delete:1187102382312652800>",
emoji=f"<:delete:{config.bot.emojis['delete_emoji_id']}>",
)
)

Expand Down Expand Up @@ -76,8 +76,8 @@ async def delete_image(self, interaction: discord.Interaction):
await interaction.response.send_message(
embed=discord.Embed(
title="Error",
description="You can only delete your own images",
color=discord.Color.red(),
description=config.ui.error_messages["delete_unauthorized"],
color=int(config.ui.colors.error, 16),
),
ephemeral=True,
)
Expand All @@ -91,7 +91,7 @@ async def delete_image(self, interaction: discord.Interaction):
embed=discord.Embed(
title="Error Deleting the Image",
description=f"{e}",
color=discord.Color.red(),
color=int(config.ui.colors.error, 16),
),
ephemeral=True,
)
Expand All @@ -101,6 +101,7 @@ async def delete_image(self, interaction: discord.Interaction):
class Multi_pollinate(commands.Cog):
def __init__(self, bot) -> None:
self.bot = bot
self.command_config = config.commands["multi_pollinate"]

async def cog_load(self) -> None:
await self.bot.wait_until_ready()
Expand All @@ -112,7 +113,10 @@ async def get_info(interaction: discord.Interaction, index: int) -> None:
@app_commands.command(
name="multi-pollinate", description="Imagine multiple prompts"
)
@app_commands.checks.cooldown(1, 20)
@app_commands.checks.cooldown(
config.commands["multi_pollinate"].cooldown.rate,
config.commands["multi_pollinate"].cooldown.seconds,
)
@app_commands.guild_only()
@app_commands.describe(
prompt="Prompt of the Image you want want to generate",
Expand All @@ -128,25 +132,25 @@ async def multiimagine_command(
self,
interaction: discord.Interaction,
prompt: str,
width: int = 1000,
height: int = 1000,
enhance: bool | None = None,
width: int = config.commands["multi_pollinate"].default_width,
height: int = config.commands["multi_pollinate"].default_height,
enhance: bool | None = config.image_generation.defaults.enhance,
negative: str | None = None,
cached: bool = False,
nologo: bool = False,
private: bool = False,
cached: bool = config.image_generation.defaults.cached,
nologo: bool = config.image_generation.defaults.nologo,
private: bool = config.image_generation.defaults.private,
) -> None:
validate_dimensions(width, height)
validate_prompt(prompt)

total_models: int = len(MODELS)
total_models: int = len(config.MODELS)

await interaction.response.send_message(
embed=discord.Embed(
title="Generating Image",
description=f"Generating images across {total_models} models...\n"
f"Completed: 0/{total_models} 0%",
color=discord.Color.blurple(),
color=int(config.ui.colors.success, 16),
),
ephemeral=private,
)
Expand Down Expand Up @@ -178,12 +182,11 @@ async def update_progress() -> None:
description=f"Generating images across {total_models} models...\n"
f"Completed: {completed_count}/{total_models} "
f"({(completed_count / total_models * 100):.2f}%)",
color=discord.Color.blurple(),
color=int(config.ui.colors.success, 16),
)
)

async def generate_for_model(i, model):
"""Asynchronous function to generate an image for a specific model."""
try:
sub_start_time: datetime.datetime = datetime.datetime.now()
dic, image = await generate_image(model=model, **command_args)
Expand All @@ -206,10 +209,13 @@ async def generate_for_model(i, model):
try:
results = await asyncio.wait_for(
asyncio.gather(
*[generate_for_model(i, model) for i, model in enumerate(MODELS)],
*[
generate_for_model(i, model)
for i, model in enumerate(config.MODELS)
],
return_exceptions=True,
),
timeout=180,
timeout=self.command_config.timeout_seconds,
)
except asyncio.TimeoutError:
raise asyncio.TimeoutError
Expand Down Expand Up @@ -265,15 +271,17 @@ async def multiimagine_command_error(
embed: discord.Embed = await generate_error_message(
interaction,
error,
cooldown_configuration=["- 1 time every 20 seconds"],
cooldown_configuration=[
f"- {self.command_config.cooldown.rate} time every {self.command_config.cooldown.seconds} seconds",
],
)
await interaction.response.send_message(embed=embed, ephemeral=True)

elif isinstance(error, asyncio.TimeoutError):
await send_error_embed(
interaction,
"Timeout Error",
"Image generation took too long and timed out. Please try again.",
config.ui.error_messages["timeout"],
)

elif isinstance(error, NoImagesGeneratedError):
Expand Down Expand Up @@ -306,7 +314,9 @@ async def multiimagine_command_error(

else:
await send_error_embed(
interaction, "An unexprected error occurred", f"```\n{str(error)}\n```"
interaction,
config.ui.error_messages["unknown"],
f"```\n{str(error)}\n```",
)


Expand Down
Loading

0 comments on commit 8da002b

Please sign in to comment.