Skip to content

Commit

Permalink
Merge pull request #6 from jahangir091/feature/textToImageAPI
Browse files Browse the repository at this point in the history
text to image api implemented
  • Loading branch information
jahangir091 authored Nov 28, 2023
2 parents bbb9808 + e1f34ae commit 196cf66
Show file tree
Hide file tree
Showing 7 changed files with 693 additions and 1 deletion.
108 changes: 108 additions & 0 deletions modules/StyleSelectorXL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import contextlib

import gradio as gr
# from modules import scripts, shared, script_callbacks
# from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton
import json
import os
import random
stylespath = "sdxl_styles.json"


def get_json_content(file_path):
try:
with open(file_path, 'rt', encoding="utf-8") as file:
json_data = json.load(file)
return json_data
except Exception as e:
print(f"A Problem occurred: {str(e)}")


def read_sdxl_styles(json_data):
# Check that data is a list
if not isinstance(json_data, list):
print("Error: input data must be a list")
return None

names = []

# Iterate over each item in the data list
for item in json_data:
# Check that the item is a dictionary
if isinstance(item, dict):
# Check that 'name' is a key in the dictionary
if 'name' in item:
# Append the value of 'name' to the names list
names.append(item['name'])
names.sort()
return names


def getStyles():
global stylespath
json_path = os.path.join(scripts.basedir(), 'sdxl_styles.json')
stylespath = json_path
json_data = get_json_content(json_path)
styles = read_sdxl_styles(json_data)
return styles


def createPositive(style, positive):
json_data = get_json_content(stylespath)
try:
# Check if json_data is a list
if not isinstance(json_data, list):
raise ValueError(
"Invalid JSON data. Expected a list of templates.")

for template in json_data:
# Check if template contains 'name' and 'prompt' fields
if 'name' not in template or 'prompt' not in template:
raise ValueError(
"Invalid template. Missing 'name' or 'prompt' field.")

# Replace {prompt} in the matching template
if template['name'] == style:
positive = template['prompt'].replace(
'{prompt}', positive)

return positive

# If function hasn't returned yet, no matching template was found
raise ValueError(f"No template found with name '{style}'.")

except Exception as e:
print(f"An error occurred: {str(e)}")


def createNegative(style, negative):
json_data = get_json_content(stylespath)
try:
# Check if json_data is a list
if not isinstance(json_data, list):
raise ValueError(
"Invalid JSON data. Expected a list of templates.")

for template in json_data:
# Check if template contains 'name' and 'prompt' fields
if 'name' not in template or 'prompt' not in template:
raise ValueError(
"Invalid template. Missing 'name' or 'prompt' field.")

# Replace {prompt} in the matching template
if template['name'] == style:
json_negative_prompt = template.get('negative_prompt', "")
if negative:
negative = f"{json_negative_prompt}, {negative}" if json_negative_prompt else negative
else:
negative = json_negative_prompt

return negative

# If function hasn't returned yet, no matching template was found
raise ValueError(f"No template found with name '{style}'.")

except Exception as e:
print(f"An error occurred: {str(e)}")


56 changes: 55 additions & 1 deletion modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
import piexif.helper
from contextlib import closing

# Needed for txt2img enhance api
from modules.txt2img import txt2img, txt2img_process
from modules.json_helper import get_text2img_data
from modules.api.models import TextToImageJsonModel
from modules import StyleSelectorXL


def script_name_to_index(name, scripts):
try:
Expand Down Expand Up @@ -209,6 +215,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
self.add_api_route("/sdapi/ai/v1/txt2img/generate", self.text2imggenerateapi, methods=["POST"], response_model=models.TextToImageResponseAPI)
self.add_api_route("/sdapi/ai/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
self.add_api_route("/sdapi/ai/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
self.add_api_route("/sdapi/ai/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
Expand Down Expand Up @@ -252,6 +259,54 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []

def text2imggenerateapi(self, prompt: str, model_id: str, style: str):
# load model
shared.opts.sd_model_checkpoint = model_id
# reload_model_weights()

data = get_text2img_data(model_id=model_id)

if data == None:
data = TextToImageJsonModel(model_id="stabilityai/stable-diffusion-xl-refiner-1.0", sampeller_method="Euler", step=40, cfg=9, prompt="", negative_prompt="")

global_pos_prompt = "high res, 4k render, uhd, high quality, best quality, (highest quality, award winning, masterpiece:1.3)"
global_neg_prompt = "EasyNegative, FastNegativeV2, ugly, tiling, poorly drawn face, out of frame, extra limbs, disfigured, deformed, cut off, low contrast, distorted face, jpeg artifacts"

positive_prompt = prompt + data.prompt + global_pos_prompt
negative_prompt = data.negative_prompt + global_neg_prompt

if style != "base":
positive_prompt = StyleSelectorXL.createPositive(style=style, prompt = prompt + global_pos_prompt)
negative_prompt = StyleSelectorXL.createPositive(style=style, prompt = global_neg_prompt)

txt2img_process_result = txt2img_process(id_task="",
prompt = positive_prompt,
negative_prompt = negative_prompt,
seed = 1000,
prompt_styles=[],
steps = data.step,
sampler_name = data.sampeller_method,
n_iter=1,
batch_size=1,
cfg_scale = data.cfg,
height=512, width=512,
enable_hr=False,
denoising_strength=0.7,
hr_scale=2.0,
hr_upscaler="Latent",
hr_second_pass_steps=0,
hr_resize_x=0,
hr_resize_y=0,
hr_checkpoint_name="",
hr_sampler_name="",
hr_prompt="",
hr_negative_prompt="",
override_settings_texts="")

# unload_model_weights()
b64images = list(map(encode_pil_to_base64, txt2img_process_result))
return models.TextToImageResponseAPI(images=b64images)

def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
Expand Down Expand Up @@ -665,7 +720,6 @@ def create_embedding(self, args: dict):
finally:
shared.state.end()


def create_hypernetwork(self, args: dict):
try:
shared.state.begin(job="create_hypernetwork")
Expand Down
12 changes: 12 additions & 0 deletions modules/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,23 @@ class TextToImageResponse(BaseModel):
parameters: dict
info: str

class TextToImageResponseAPI(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")

class ImageToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str

class TextToImageJsonModel(BaseModel):
model_id: str
sampeller_method: str
step: int
cfg: float
prompt: str
negative_prompt: str


class ExtrasBaseRequest(BaseModel):
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
Expand Down
43 changes: 43 additions & 0 deletions modules/json_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json
import os
from modules import scripts
from modules.api.models import TextToImageJsonModel

def get_json_content(file_path):
try:
with open(file_path, 'rt', encoding="utf-8") as file:
json_data = json.load(file)
return json_data
except Exception as e:
print(f"A Problem occurred: {str(e)}")
return []

def get_text2img_data(model_id: str):
json_path = os.path.join(scripts.basedir(), 'text2img.json')
json_data = get_json_content(json_path)
if not isinstance(json_data, list):
print("Error: input data must be a list")
return None

model_id_key = 'model_id'
sampeller_method = 'sampeller_method'
step = 'step'
cfg = 'cfg'
prompt = 'prompt'
negative_prompt = 'negative_prompt'

# Iterate over each item in the data list
for item in json_data:
# Check that the item is a dictionary
if isinstance(item, dict):
# Check that all required key is in the dictionary
if model_id_key in item and item[model_id_key] == model_id:
if sampeller_method in item and step in item and cfg in item and prompt in item and negative_prompt in item:
return TextToImageJsonModel(model_id = model_id,
sampeller_method = item[sampeller_method],
step = item[step],
cfg = item[cfg],
prompt = item[prompt],
negative_prompt = item[negative_prompt])
print("Error: model not found")
return None
59 changes: 59 additions & 0 deletions modules/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,62 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
processed.images = []

return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")


def txt2img_process(id_task: str, prompt: str, negative_prompt: str, seed: int, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)

p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=prompt,
styles=prompt_styles,
negative_prompt=negative_prompt,
sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None,
hr_scale=hr_scale,
hr_upscaler=hr_upscaler,
hr_second_pass_steps=hr_second_pass_steps,
hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings,
)

p.scripts = modules.scripts.scripts_txt2img
p.script_args = args

p.user = id_task
p.seed = seed

if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)

with closing(p):
processed = None

if processed is None:
processed = processing.process_images(p)

shared.total_tqdm.clear()

generation_info_js = processed.js()
if opts.samples_log_stdout:
print(generation_info_js)

if opts.do_not_show_images:
processed.images = []

return processed.images

Loading

0 comments on commit 196cf66

Please sign in to comment.