diff --git a/README.md b/README.md index 11fdf60..ba3b8b8 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,10 @@ This extension is an extension of the built-in Composable Diffusion. This allows you to determine the region of the latent space that reflects your subprompts. ## How to use +![20230303.png](./screenshots/20230303.png) ![20230213.png](./screenshots/20230213.png) + ### Enabled The effect of Latent Couple appears only when Enabled is checked. @@ -41,11 +43,34 @@ outputs - end at step=4 https://imgur.com/a1kyvhX - end at step=0 https://imgur.com/yhGF7g8 -## Old prerequisite -This extension need to apply cfg_denoised_callback-ea9bd9fc.patch (as of Feb 5, 2023 origin/HEAD commit ea9bd9fc). + +## ~~Prerequisite for prompt pasting~~ +## ~~Prerequisite for gradio Image and Sketch component bug fix~~ +This fix is no longer suitable for latest webui commit at 22bcc7be, with gradio dependency upgraded to 3.23. + +I'll keep the fix here for people still using older versions of webui. + +Activate your venv in webui root directory + +For Windows, in cmd +``` +venv\Scripts\activate.bat ``` -git apply --ignore-whitespace extensions/stable-diffusion-webui-two-shot/cfg_denoised_callback-ea9bd9fc.patch +For Linux ``` +source venv/bin/activate +``` +Then, install wheel distribution with bugfix applied +``` +pip install --force-reinstall --no-deps extensions/stable-diffusion-webui-two-shot/gradio-3.16.2-py3-none-any.whl +``` +For bugfix related modifications, see https://github.com/ashen-sensored/gradio/tree/3.16.2 + + +## Issues +- ~~The extension's mask color sketching function does not work well with chrome(extreme stuttering) due to gradio's Image component bug.~~ Please keep the browser scaling at 100% while creating blank canvas to avoid the bug. +See prerequisite above. The fix is no longer suitable for latest webui version at 22bcc7be, with gradio dependency upgraded to 3.23. + ## Credits - two shot diffusion.ipynb https://colab.research.google.com/drive/1UdElpQfKFjY5luch9v_LlmSdH7AmeiDe?usp=sharing diff --git a/cfg_denoised_callback-ea9bd9fc.patch b/cfg_denoised_callback-ea9bd9fc.patch deleted file mode 100644 index c62aa24..0000000 --- a/cfg_denoised_callback-ea9bd9fc.patch +++ /dev/null @@ -1,83 +0,0 @@ -diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py -index 4bb45ec7..edd0e2a7 100644 ---- a/modules/script_callbacks.py -+++ b/modules/script_callbacks.py -@@ -46,6 +46,18 @@ class CFGDenoiserParams: - """Total number of sampling steps planned""" - - -+class CFGDenoisedParams: -+ def __init__(self, x, sampling_step, total_sampling_steps): -+ self.x = x -+ """Latent image representation in the process of being denoised""" -+ -+ self.sampling_step = sampling_step -+ """Current Sampling step number""" -+ -+ self.total_sampling_steps = total_sampling_steps -+ """Total number of sampling steps planned""" -+ -+ - class UiTrainTabParams: - def __init__(self, txt2img_preview_params): - self.txt2img_preview_params = txt2img_preview_params -@@ -68,6 +80,7 @@ callback_map = dict( - callbacks_before_image_saved=[], - callbacks_image_saved=[], - callbacks_cfg_denoiser=[], -+ callbacks_cfg_denoised=[], - callbacks_before_component=[], - callbacks_after_component=[], - callbacks_image_grid=[], -@@ -150,6 +163,14 @@ def cfg_denoiser_callback(params: CFGDenoiserParams): - report_exception(c, 'cfg_denoiser_callback') - - -+def cfg_denoised_callback(params: CFGDenoisedParams): -+ for c in callback_map['callbacks_cfg_denoised']: -+ try: -+ c.callback(params) -+ except Exception: -+ report_exception(c, 'cfg_denoised_callback') -+ -+ - def before_component_callback(component, **kwargs): - for c in callback_map['callbacks_before_component']: - try: -@@ -283,6 +304,14 @@ def on_cfg_denoiser(callback): - add_callback(callback_map['callbacks_cfg_denoiser'], callback) - - -+def on_cfg_denoised(callback): -+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. -+ The callback is called with one argument: -+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. -+ """ -+ add_callback(callback_map['callbacks_cfg_denoised'], callback) -+ -+ - def on_before_component(callback): - """register a function to be called before a component is created. - The callback is called with arguments: -diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py -index f076fc55..28847397 100644 ---- a/modules/sd_samplers_kdiffusion.py -+++ b/modules/sd_samplers_kdiffusion.py -@@ -8,6 +8,7 @@ from modules import prompt_parser, devices, sd_samplers_common - from modules.shared import opts, state - import modules.shared as shared - from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback - - samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), -@@ -136,6 +137,9 @@ class CFGDenoiser(torch.nn.Module): - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - -+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) -+ cfg_denoised_callback(denoised_params) -+ - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": diff --git a/gradio-3.16.2-py3-none-any.whl b/gradio-3.16.2-py3-none-any.whl new file mode 100644 index 0000000..aa91f5c Binary files /dev/null and b/gradio-3.16.2-py3-none-any.whl differ diff --git a/screenshots/20230303.png b/screenshots/20230303.png new file mode 100644 index 0000000..4669e66 Binary files /dev/null and b/screenshots/20230303.png differ diff --git a/scripts/sketch_helper.py b/scripts/sketch_helper.py new file mode 100644 index 0000000..b58582d --- /dev/null +++ b/scripts/sketch_helper.py @@ -0,0 +1,99 @@ +import numpy as np +import cv2 +import base64 + + +def count_high_freq_colors(image): + im = image.getcolors(maxcolors=1024*1024) + sorted_colors = sorted(im, key=lambda x: x[0], reverse=True) + + freqs = [c[0] for c in sorted_colors] + mean_freq = sum(freqs) / len(freqs) + + high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq*1.25)] + return high_freq_colors + +def get_high_freq_colors(image, similarity_threshold=30): + image_copy = image.copy() + high_freq_colors = count_high_freq_colors(image) + # Check for similar colors and replace the lower frequency color with the higher frequency color in the image + for i, (freq1, color1) in enumerate(high_freq_colors): + for j, (freq2, color2) in enumerate(high_freq_colors): + if (color_distance(color1, color2) < similarity_threshold) or (color_distance(color1, opaque_color_on_white(color2, 0.5)) < 5): + if(freq2 > freq1): + replace_color(image_copy, color1, color2) + + high_freq_colors = count_high_freq_colors(image_copy) + print(high_freq_colors) + return [high_freq_colors, image_copy] + +def color_quantization(image, color_frequency_list): + # Convert the color frequency list to a set of unique colors + unique_colors = set([color for _, color in color_frequency_list]) + + # Create a mask for the image with True where the color is in the unique colors set + mask = np.any(np.all(image.reshape(-1, 1, 3) == np.array(list(unique_colors)), axis=2), axis=1).reshape(image.shape[:2]) + + # Create a new image with all pixels set to white + new_image = np.full_like(image, 255) + + # Copy the pixels from the original image that have a color in the color frequency list + new_image[mask] = image[mask] + return new_image + + +def create_binary_mask(img_arr, target_color): + # Create mask of pixels with target color + mask = np.all(img_arr == target_color, axis=-1) + + # Convert mask to binary matrix + binary_matrix = mask.astype(int) + + return binary_matrix + + +def create_binary_matrix_base64(img_arr, target_color): + # Create mask of pixels with target color + mask = np.all(img_arr == target_color, axis=-1) + + # Convert mask to binary matrix + binary_matrix = mask.astype(int) + from datetime import datetime + binary_file_name = f'mask-{datetime.now().timestamp()}.png' + _, im_arr = cv2.imencode('.png', binary_matrix * 255) + + im_bytes = im_arr.tobytes() + im_b64 = base64.b64encode(im_bytes).decode('ascii') + # binary_matrix = torch.from_numpy(binary_matrix).unsqueeze(0).unsqueeze(0) + return mask, im_b64 + +def create_binary_matrix_img(img_arr, target_color): + # Create mask of pixels with target color + mask = np.all(img_arr == target_color, axis=-1) + + # Convert mask to binary matrix + binary_matrix = mask.astype(int) + from datetime import datetime + binary_file_name = f'mask-{datetime.now().timestamp()}.png' + cv2.imwrite(binary_file_name, binary_matrix * 255) + + #binary_matrix = torch.from_numpy(binary_matrix).unsqueeze(0).unsqueeze(0) + return binary_file_name + +def color_distance(color1, color2): + return sum((a - b) ** 2 for a, b in zip(color1, color2)) ** 0.5 + +def replace_color(image, old_color, new_color): + pixels = image.load() + width, height = image.size + for x in range(width): + for y in range(height): + if pixels[x, y] == old_color: + pixels[x, y] = new_color + +def opaque_color_on_white(color, a): + r, g, b = color + opaque_red = int((1 - a) * 255 + a * r) + opaque_green = int((1 - a) * 255 + a * g) + opaque_blue = int((1 - a) * 255 + a * b) + return (opaque_red, opaque_green, opaque_blue) \ No newline at end of file diff --git a/scripts/two_shot.py b/scripts/two_shot.py index 0727bd6..f15d814 100644 --- a/scripts/two_shot.py +++ b/scripts/two_shot.py @@ -1,17 +1,49 @@ +import base64 from typing import List, Dict, Optional, Tuple from dataclasses import dataclass import torch -from modules import devices +from scripts.sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix_base64, create_binary_mask +import numpy as np +import cv2 + +from modules import devices, script_callbacks import modules.scripts as scripts import gradio as gr -# todo: + from modules.script_callbacks import CFGDenoisedParams, on_cfg_denoised from modules.processing import StableDiffusionProcessing +MAX_COLORS = 12 +switch_values_symbol = '\U000021C5' # ⇅ + + +class ToolButton(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool", **kwargs) + + def get_block_name(self): + return "button" + + +# abstract base class for filters +from abc import ABC, abstractmethod + + + + +class Filter(ABC): + + @abstractmethod + def create_tensor(self): + pass + + @dataclass class Division: @@ -27,8 +59,8 @@ class Position: ex: float -class Filter: +class RectFilter(Filter): def __init__(self, division: Division, position: Position, weight: float): self.division = division self.position = position @@ -50,13 +82,86 @@ def create_tensor(self, num_channels: int, height_b: int, width_b: int) -> torch return x +class MaskFilter: + def __init__(self, binary_mask: np.array = None, weight: float = None, float_mask: np.array = None): + if float_mask is None: + self.mask = binary_mask.astype(np.float32) * weight + elif binary_mask is None and weight is None: + self.mask = float_mask + else: + raise ValueError('Either float_mask or binary_mask and weight must be provided') + self.tensor_mask = torch.tensor(self.mask).to(devices.device) + + def create_tensor(self, num_channels: int, height_b: int, width_b: int) -> torch.Tensor: + + + # x = torch.zeros(num_channels, height_b, width_b).to(devices.device) + # mask = torch.tensor(self.mask).to(devices.device) + # downsample mask to x size + # mask_bicubic = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(height_b, width_b), mode='bicubic').squeeze(0).squeeze(0).cpu().numpy() + # + # mask_nearest_exact = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(height_b, width_b), mode='nearest-exact').squeeze(0).squeeze(0).cpu().numpy() + # + # mask_nearest = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(height_b, width_b), mode='nearest').squeeze(0).squeeze(0).cpu().numpy() + # + # mask_area = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(height_b, width_b), mode='area').squeeze(0).squeeze(0).cpu().numpy() + + mask = torch.nn.functional.interpolate(self.tensor_mask.unsqueeze(0).unsqueeze(0), size=(height_b, width_b), mode='nearest-exact').squeeze(0).squeeze(0) + mask = mask.unsqueeze(0).repeat(num_channels, 1, 1) + + return mask + + +class PastePromptTextboxTracker: + def __init__(self): + self.scripts = [] + return + + def set_script(self, script): + self.scripts.append(script) + + def on_after_component_callback(self, component, **_kwargs): + + if not self.scripts: + return + if type(component) is gr.State: + return + + script = None + if type(component) is gr.Textbox and component.elem_id == 'txt2img_prompt': + # select corresponding script + script = next(x for x in self.scripts if x.is_txt2img) + self.scripts.remove(script) + + if type(component) is gr.Textbox and component.elem_id == 'img2img_prompt': + # select corresponding script + script = next(x for x in self.scripts if x.is_img2img) + self.scripts.remove(script) + + if script is None: + return + + script.target_paste_prompt = component + + +prompt_textbox_tracker = PastePromptTextboxTracker() + + class Script(scripts.Script): def __init__(self): + self.ui_root = None self.num_batches: int = 0 self.end_at_step: int = 20 self.filters: List[Filter] = [] self.debug: bool = False + self.selected_twoshot_tab = 0 + self.ndmasks = [] + self.area_colors = [] + self.mask_denoise = False + prompt_textbox_tracker.set_script(self) + self.target_paste_prompt = None + def title(self): return "Latent Couple extension" @@ -64,7 +169,36 @@ def title(self): def show(self, is_img2img): return scripts.AlwaysVisible - def create_filters_from_ui_params(self, raw_divisions: str, raw_positions: str, raw_weights: str): + def create_rect_filters_from_ui_params(self, raw_divisions: str, raw_positions: str, raw_weights: str): + + divisions = [] + for division in raw_divisions.split(','): + y, x = division.split(':') + divisions.append(Division(float(y), float(x))) + + def start_and_end_position(raw: str): + nums = [float(num) for num in raw.split('-')] + if len(nums) == 1: + return nums[0], nums[0] + 1.0 + else: + return nums[0], nums[1] + + positions = [] + for position in raw_positions.split(','): + y, x = position.split(':') + y1, y2 = start_and_end_position(y) + x1, x2 = start_and_end_position(x) + positions.append(Position(y1, x1, y2, x2)) + + weights = [] + for w in raw_weights.split(','): + weights.append(float(w)) + + # todo: assert len + + return [RectFilter(division, position, weight) for division, position, weight in zip(divisions, positions, weights)] + + def create_mask_filters_from_ui_params(self, raw_divisions: str, raw_positions: str, raw_weights: str): divisions = [] for division in raw_divisions.split(','): @@ -95,7 +229,7 @@ def start_and_end_position(raw: str): def do_visualize(self, raw_divisions: str, raw_positions: str, raw_weights: str): - self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights) + self.filters = self.create_rect_filters_from_ui_params(raw_divisions, raw_positions, raw_weights) return [f.create_tensor(1, 128, 128).squeeze(dim=0).cpu().numpy() for f in self.filters] @@ -114,32 +248,258 @@ def do_apply(self, extra_generation_params: str): return raw_params.get('divisions', '1:1,1:2,1:2'), raw_params.get('positions', '0:0,0:0,0:1'), raw_params.get('weights', '0.2,0.8,0.8'), int(raw_params.get('step', '20')) def ui(self, is_img2img): + process_script_params = [] id_part = "img2img" if is_img2img else "txt2img" - - with gr.Group(): + canvas_html = "
" + # get_js_colors = """ + # async (canvasData) => { + # const canvasEl = document.getElementById("canvas-root"); + # return [canvasEl._data] + # } + # """ + + def create_canvas(h, w): + return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 + + def process_sketch(img_arr, input_binary_matrixes): + input_binary_matrixes.clear() + # base64_img = canvas_data['image'] + # image_data = base64.b64decode(base64_img.split(',')[1]) + # image = Image.open(BytesIO(image_data)).convert("RGB") + im2arr = img_arr + # colors = [tuple(map(int, rgb[4:-1].split(','))) for rgb in + # ['colors']] + sketch_colors, color_counts = np.unique(im2arr.reshape(-1, im2arr.shape[2]), axis=0, return_counts=True) + colors_fixed = [] + # if color count is less than 0.001 of total pixel count, collect it for edge color correction + edge_color_correction_arr = [] + for sketch_color_idx, color in enumerate(sketch_colors[:-1]): # exclude white + if color_counts[sketch_color_idx] < im2arr.shape[0] * im2arr.shape[1] * 0.002: + edge_color_correction_arr.append(sketch_color_idx) + + edge_fix_dict = {} + # TODO:for every non area color pixel in img_arr, find the nearest area color pixel and replace it with that color + + area_colors = np.delete(sketch_colors, edge_color_correction_arr, axis=0) + if self.mask_denoise: + for edge_color_idx in edge_color_correction_arr: + edge_color = sketch_colors[edge_color_idx] + # find the nearest area_color + + color_distances = np.linalg.norm(area_colors - edge_color, axis=1) + nearest_index = np.argmin(color_distances) + nearest_color = area_colors[nearest_index] + edge_fix_dict[edge_color_idx] = nearest_color + # replace edge color with the nearest area_color + cur_color_mask = np.all(im2arr == edge_color, axis=2) + im2arr[cur_color_mask] = nearest_color + + # recalculate area colors + sketch_colors, color_counts = np.unique(im2arr.reshape(-1, im2arr.shape[2]), axis=0, return_counts=True) + area_colors = sketch_colors + + # create binary matrix for each area_color + area_color_maps = [] + self.ndmasks = [] + self.area_colors = area_colors + for color in area_colors: + r, g, b = color + mask, binary_matrix = create_binary_matrix_base64(im2arr, color) + self.ndmasks.append(mask) + input_binary_matrixes.append(binary_matrix) + colors_fixed.append(gr.update( + value=f'
')) + + + + visibilities = [] + sketch_colors = [] + + for sketch_color_idx in range(MAX_COLORS): + visibilities.append(gr.update(visible=False)) + sketch_colors.append(gr.update(value=f'
')) + for j in range(len(colors_fixed)-1): + visibilities[j] = gr.update(visible=True) + sketch_colors[j] = colors_fixed[j] + + alpha_mask_visibility = gr.update(visible=True) + alpha_mask_html = colors_fixed[-1] + return [gr.update(visible=True), input_binary_matrixes, alpha_mask_visibility, alpha_mask_html, *visibilities, *sketch_colors] + + def update_mask_filters(alpha_blend_val, general_prompt_str, *cur_weights_and_prompts): + cur_weight_slider_vals = cur_weights_and_prompts[:MAX_COLORS] + cur_prompts = cur_weights_and_prompts[MAX_COLORS:] + general_mask = self.ndmasks[-1] + final_filter_list = [] + for m in range(len(self.ndmasks) - 1): + cur_float_mask = self.ndmasks[m].astype(np.float32) * float(cur_weight_slider_vals[m]) * float(1.0-alpha_blend_val) + mask_filter = MaskFilter(float_mask=cur_float_mask) + final_filter_list.append(mask_filter) + # subtract the sum of all masks from the general mask to get the alpha blend mask + initial_general_mask = np.ones(shape=general_mask.shape, dtype=np.float32) + alpha_blend_mask = initial_general_mask.astype(np.float32) - np.sum([f.mask for f in final_filter_list], axis=0) + alpha_blend_filter = MaskFilter(float_mask=alpha_blend_mask) + final_filter_list.insert(0, alpha_blend_filter) + self.filters = final_filter_list + + + sketch_colors = [] + colors_fixed = [] + for area_idx, color in enumerate(self.area_colors): + r, g, b = color + final_list_idx = area_idx + 1 + if final_list_idx == len(final_filter_list): + final_list_idx = 0 + # get shape of current mask + height_b, width_b = final_filter_list[final_list_idx].mask.shape + current_mask = torch.nn.functional.interpolate(final_filter_list[final_list_idx].tensor_mask.unsqueeze(0).unsqueeze(0), + size=(int(height_b/8), int(width_b/8)), mode='nearest-exact').squeeze(0).squeeze(0).cpu().numpy() + adjusted_mask = current_mask * 255 + _, adjusted_mask_arr = cv2.imencode('.png', adjusted_mask) + + adjusted_mask_b64 = base64.b64encode(adjusted_mask_arr.tobytes()).decode('ascii') + colors_fixed.append(gr.update( + value=f'
')) + for sketch_color_idx in range(MAX_COLORS): + + sketch_colors.append( + gr.update(value=f'
')) + for j in range(len(colors_fixed)-1): + + sketch_colors[j] = colors_fixed[j] + alpha_mask_visibility = gr.update(visible=True) + alpha_mask_html = colors_fixed[-1] + final_prompt_update = gr.update(value='\nAND '.join([general_prompt_str, *cur_prompts[:len(colors_fixed)-1]])) + return [final_prompt_update, alpha_mask_visibility, alpha_mask_html, *sketch_colors] + + + + cur_weight_sliders = [] + + with gr.Group() as group_two_shot_root: + binary_matrixes = gr.State([]) with gr.Accordion("Latent Couple", open=False): enabled = gr.Checkbox(value=False, label="Enabled") - with gr.Row(): - divisions = gr.Textbox(label="Divisions", elem_id=f"cd_{id_part}_divisions", value="1:1,1:2,1:2") - positions = gr.Textbox(label="Positions", elem_id=f"cd_{id_part}_positions", value="0:0,0:0,0:1") - with gr.Row(): - weights = gr.Textbox(label="Weights", elem_id=f"cd_{id_part}_weights", value="0.2,0.8,0.8") - end_at_step = gr.Slider(minimum=0, maximum=150, step=1, label="end at this step", elem_id=f"cd_{id_part}_end_at_this_step", value=20) - - visualize_button = gr.Button(value="Visualize") - visual_regions = gr.Gallery(label="Regions").style(grid=(4, 4, 4, 8), height="auto") - - visualize_button.click(fn=self.do_visualize, inputs=[divisions, positions, weights], outputs=[visual_regions]) - - extra_generation_params = gr.Textbox(label="Extra generation params") - apply_button = gr.Button(value="Apply") - - apply_button.click(fn=self.do_apply, inputs=[extra_generation_params], outputs=[divisions, positions, weights, end_at_step]) + with gr.Tabs(elem_id="script_twoshot_tabs") as twoshot_tabs: + + with gr.TabItem("Mask", elem_id="tab_twoshot_mask") as twoshot_tab_mask: + + canvas_data = gr.JSON(value={}, visible=False) + # model = gr.Textbox(label="The id of any Hugging Face model in the diffusers format", + # value="stabilityai/stable-diffusion-2-1-base", + # visible=False if is_shared_ui else True) + mask_denoise_checkbox = gr.Checkbox(value=False, label="Denoise Mask") + + def update_mask_denoise_flag(flag): + self.mask_denoise = flag + + mask_denoise_checkbox.change(fn=update_mask_denoise_flag, inputs=[mask_denoise_checkbox], outputs=None) + canvas_image = gr.Image(source='upload', mirror_webcam=False, type='numpy', tool='color-sketch', + elem_id='twoshot_canvas_sketch', interactive=True).style(height=480) + # aspect = gr.Radio(["square", "horizontal", "vertical"], value="square", label="Aspect Ratio", + # visible=False if is_shared_ui else True) + button_run = gr.Button("I've finished my sketch", elem_id="main_button", interactive=True) + + prompts = [] + colors = [] + color_row = [None] * MAX_COLORS + with gr.Column(visible=False) as post_sketch: + with gr.Row(visible=False) as alpha_mask_row: + # general_mask_label_span = gr.HTML( + # 'General Mask', + # elem_id='general_mask_label_span') + with gr.Box(elem_id="alpha_mask"): + alpha_color = gr.HTML( + '
') + general_prompt = gr.Textbox(label="General Prompt") + alpha_blend = gr.Slider(label="Alpha Blend", minimum=0.0, maximum=1.0, value=0.2, step=0.01, interactive=True) + + for n in range(MAX_COLORS): + with gr.Row(visible=False) as color_row[n]: + + with gr.Box(elem_id="color-bg"): + colors.append(gr.HTML( + '
')) + with gr.Column(): + with gr.Row(): + prompts.append(gr.Textbox(label="Prompt for this mask")) + + with gr.Row(): + weight_slider = gr.Slider(label=f"Area {n+1} Weight", minimum=0.0, maximum=1.0, + value=1.0, step=0.01, interactive=True, elem_id=f"weight_{n+1}_slider") + cur_weight_sliders.append(weight_slider) + + button_update = gr.Button("Prompt Info Update", elem_id="update_button", interactive=True) + final_prompt = gr.Textbox(label="Final Prompt", interactive=False) + + button_run.click(process_sketch, inputs=[canvas_image, binary_matrixes], + outputs=[post_sketch, binary_matrixes, alpha_mask_row, alpha_color, *color_row, *colors], + queue=False) + + button_update.click(fn=update_mask_filters, inputs=[alpha_blend, general_prompt, *cur_weight_sliders, *prompts], outputs=[final_prompt, alpha_mask_row, alpha_color, *colors]) + + def paste_prompt(*input_prompts): + final_prompts = input_prompts[:len(self.area_colors)] + final_prompt_str = '\nAND '.join(final_prompts) + return final_prompt_str + source_prompts = [general_prompt, *prompts] + button_update.click(fn=paste_prompt, inputs=source_prompts, + outputs=self.target_paste_prompt) + + + + with gr.Column(): + canvas_width = gr.Slider(label="Canvas Width", minimum=256, maximum=1024, value=512, step=64) + canvas_height = gr.Slider(label="Canvas Height", minimum=256, maximum=1024, value=512, step=64) + + + canvas_swap_res = ToolButton(value=switch_values_symbol) + canvas_swap_res.click(lambda w, h: (h, w), inputs=[canvas_width, canvas_height], + outputs=[canvas_width, canvas_height]) + create_button = gr.Button(value="Create blank canvas") + create_button.click(fn=create_canvas, inputs=[canvas_height, canvas_width], outputs=[canvas_image]) + + with gr.TabItem("Rectangular", elem_id="tab_twoshot_rect") as twoshot_tab_rect: + with gr.Row(): + divisions = gr.Textbox(label="Divisions", elem_id=f"cd_{id_part}_divisions", value="1:1,1:2,1:2") + positions = gr.Textbox(label="Positions", elem_id=f"cd_{id_part}_positions", value="0:0,0:0,0:1") + with gr.Row(): + weights = gr.Textbox(label="Weights", elem_id=f"cd_{id_part}_weights", value="0.2,0.8,0.8") + end_at_step = gr.Slider(minimum=0, maximum=150, step=1, label="end at this step", elem_id=f"cd_{id_part}_end_at_this_step", value=150) + + visualize_button = gr.Button(value="Visualize") + visual_regions = gr.Gallery(label="Regions").style(grid=(4, 4, 4, 8), height="auto") + + visualize_button.click(fn=self.do_visualize, inputs=[divisions, positions, weights], outputs=[visual_regions]) + + extra_generation_params = gr.Textbox(label="Extra generation params") + apply_button = gr.Button(value="Apply") + + apply_button.click(fn=self.do_apply, inputs=[extra_generation_params], outputs=[divisions, positions, weights, end_at_step]) + + def select_twosoht_tab(tab_id): + self.selected_twoshot_tab = tab_id + for i, elem in enumerate( + [twoshot_tab_mask, twoshot_tab_rect]): + elem.select( + fn=lambda tab=i: select_twosoht_tab(tab), + inputs=[], + outputs=[], + ) + + self.ui_root = group_two_shot_root self.infotext_fields = [ (extra_generation_params, "Latent Couple") ] - return enabled, divisions, positions, weights, end_at_step + process_script_params.append(enabled) + process_script_params.append(divisions) + process_script_params.append(positions) + process_script_params.append(weights) + process_script_params.append(end_at_step) + process_script_params.append(alpha_blend) + process_script_params.extend(cur_weight_sliders) + return process_script_params def denoised_callback(self, params: CFGDenoisedParams): @@ -192,7 +552,9 @@ def denoised_callback(self, params: CFGDenoisedParams): uncond_off += 1 - def process(self, p: StableDiffusionProcessing, enabled: bool, raw_divisions: str, raw_positions: str, raw_weights: str, raw_end_at_step: int): + def process(self, p: StableDiffusionProcessing, *args, **kwargs): + + enabled, raw_divisions, raw_positions, raw_weights, raw_end_at_step, alpha_blend, *cur_weight_sliders = args self.enabled = enabled @@ -201,15 +563,19 @@ def process(self, p: StableDiffusionProcessing, enabled: bool, raw_divisions: st self.num_batches = p.batch_size - self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights) + if self.selected_twoshot_tab == 0: + pass + elif self.selected_twoshot_tab == 1: + self.filters = self.create_rect_filters_from_ui_params(raw_divisions, raw_positions, raw_weights) + else: + raise ValueError(f"Unknown filter mode") self.end_at_step = raw_end_at_step - # + # TODO: handle different cases for generation info: 'mask' and 'rect' + # if self.end_at_step != 0: + # p.extra_generation_params["Latent Couple"] = f"divisions={raw_divisions} positions={raw_positions} weights={raw_weights} end at step={raw_end_at_step}" - if self.end_at_step != 0: - p.extra_generation_params["Latent Couple"] = f"divisions={raw_divisions} positions={raw_positions} weights={raw_weights} end at step={raw_end_at_step}" - # save params into the output file as PNG textual data. if self.debug: print(f"### Latent couple ###") @@ -225,3 +591,4 @@ def postprocess(self, *args): return +script_callbacks.on_after_component(prompt_textbox_tracker.on_after_component_callback) \ No newline at end of file diff --git a/style.css b/style.css new file mode 100644 index 0000000..dedeb79 --- /dev/null +++ b/style.css @@ -0,0 +1,36 @@ + +#twoshot_canvas_sketch, #twoshot_canvas_sketch > .h-60, #twoshot_canvas_sketch > .h-60 > div, #twoshot_canvas_sketch > .h-60 > div > img +{ + height: 480px !important; + max-height: 480px !important; + /*min-height: 480px !important;*/ +} + +#color-bg > .gr-box, #color-bg > .gr-box > .transition, #color-bg > .gr-box > .transition > .output-html +{ + height: 100%; +} + +#color-bg{ + max-height: 110px; +} + +#alpha_mask{ + max-height: 110px; +} + +#general_mask_label_span{ + position: absolute; +} + +#script_twoshot_tabs .output-html{ + display: flex; + justify-content: center; + align-items: center; +} + +#script_twoshot_tabs .gradio-image > div.fixed-height, #script_twoshot_tabs .gradio-image > div.fixed-height img { + height: 480px !important; + max-height: 480px !important; + min-height: 480px !important; +}