diff --git a/README.md b/README.md index 11fdf60..ba3b8b8 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,10 @@ This extension is an extension of the built-in Composable Diffusion. This allows you to determine the region of the latent space that reflects your subprompts. ## How to use +  + ### Enabled The effect of Latent Couple appears only when Enabled is checked. @@ -41,11 +43,34 @@ outputs - end at step=4 https://imgur.com/a1kyvhX - end at step=0 https://imgur.com/yhGF7g8 -## Old prerequisite -This extension need to apply cfg_denoised_callback-ea9bd9fc.patch (as of Feb 5, 2023 origin/HEAD commit ea9bd9fc). + +## ~~Prerequisite for prompt pasting~~ +## ~~Prerequisite for gradio Image and Sketch component bug fix~~ +This fix is no longer suitable for latest webui commit at 22bcc7be, with gradio dependency upgraded to 3.23. + +I'll keep the fix here for people still using older versions of webui. + +Activate your venv in webui root directory + +For Windows, in cmd +``` +venv\Scripts\activate.bat ``` -git apply --ignore-whitespace extensions/stable-diffusion-webui-two-shot/cfg_denoised_callback-ea9bd9fc.patch +For Linux ``` +source venv/bin/activate +``` +Then, install wheel distribution with bugfix applied +``` +pip install --force-reinstall --no-deps extensions/stable-diffusion-webui-two-shot/gradio-3.16.2-py3-none-any.whl +``` +For bugfix related modifications, see https://github.com/ashen-sensored/gradio/tree/3.16.2 + + +## Issues +- ~~The extension's mask color sketching function does not work well with chrome(extreme stuttering) due to gradio's Image component bug.~~ Please keep the browser scaling at 100% while creating blank canvas to avoid the bug. +See prerequisite above. The fix is no longer suitable for latest webui version at 22bcc7be, with gradio dependency upgraded to 3.23. + ## Credits - two shot diffusion.ipynb https://colab.research.google.com/drive/1UdElpQfKFjY5luch9v_LlmSdH7AmeiDe?usp=sharing diff --git a/cfg_denoised_callback-ea9bd9fc.patch b/cfg_denoised_callback-ea9bd9fc.patch deleted file mode 100644 index c62aa24..0000000 --- a/cfg_denoised_callback-ea9bd9fc.patch +++ /dev/null @@ -1,83 +0,0 @@ -diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py -index 4bb45ec7..edd0e2a7 100644 ---- a/modules/script_callbacks.py -+++ b/modules/script_callbacks.py -@@ -46,6 +46,18 @@ class CFGDenoiserParams: - """Total number of sampling steps planned""" - - -+class CFGDenoisedParams: -+ def __init__(self, x, sampling_step, total_sampling_steps): -+ self.x = x -+ """Latent image representation in the process of being denoised""" -+ -+ self.sampling_step = sampling_step -+ """Current Sampling step number""" -+ -+ self.total_sampling_steps = total_sampling_steps -+ """Total number of sampling steps planned""" -+ -+ - class UiTrainTabParams: - def __init__(self, txt2img_preview_params): - self.txt2img_preview_params = txt2img_preview_params -@@ -68,6 +80,7 @@ callback_map = dict( - callbacks_before_image_saved=[], - callbacks_image_saved=[], - callbacks_cfg_denoiser=[], -+ callbacks_cfg_denoised=[], - callbacks_before_component=[], - callbacks_after_component=[], - callbacks_image_grid=[], -@@ -150,6 +163,14 @@ def cfg_denoiser_callback(params: CFGDenoiserParams): - report_exception(c, 'cfg_denoiser_callback') - - -+def cfg_denoised_callback(params: CFGDenoisedParams): -+ for c in callback_map['callbacks_cfg_denoised']: -+ try: -+ c.callback(params) -+ except Exception: -+ report_exception(c, 'cfg_denoised_callback') -+ -+ - def before_component_callback(component, **kwargs): - for c in callback_map['callbacks_before_component']: - try: -@@ -283,6 +304,14 @@ def on_cfg_denoiser(callback): - add_callback(callback_map['callbacks_cfg_denoiser'], callback) - - -+def on_cfg_denoised(callback): -+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. -+ The callback is called with one argument: -+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. -+ """ -+ add_callback(callback_map['callbacks_cfg_denoised'], callback) -+ -+ - def on_before_component(callback): - """register a function to be called before a component is created. - The callback is called with arguments: -diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py -index f076fc55..28847397 100644 ---- a/modules/sd_samplers_kdiffusion.py -+++ b/modules/sd_samplers_kdiffusion.py -@@ -8,6 +8,7 @@ from modules import prompt_parser, devices, sd_samplers_common - from modules.shared import opts, state - import modules.shared as shared - from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback - - samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), -@@ -136,6 +137,9 @@ class CFGDenoiser(torch.nn.Module): - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - -+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) -+ cfg_denoised_callback(denoised_params) -+ - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": diff --git a/gradio-3.16.2-py3-none-any.whl b/gradio-3.16.2-py3-none-any.whl new file mode 100644 index 0000000..aa91f5c Binary files /dev/null and b/gradio-3.16.2-py3-none-any.whl differ diff --git a/screenshots/20230303.png b/screenshots/20230303.png new file mode 100644 index 0000000..4669e66 Binary files /dev/null and b/screenshots/20230303.png differ diff --git a/scripts/sketch_helper.py b/scripts/sketch_helper.py new file mode 100644 index 0000000..b58582d --- /dev/null +++ b/scripts/sketch_helper.py @@ -0,0 +1,99 @@ +import numpy as np +import cv2 +import base64 + + +def count_high_freq_colors(image): + im = image.getcolors(maxcolors=1024*1024) + sorted_colors = sorted(im, key=lambda x: x[0], reverse=True) + + freqs = [c[0] for c in sorted_colors] + mean_freq = sum(freqs) / len(freqs) + + high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq*1.25)] + return high_freq_colors + +def get_high_freq_colors(image, similarity_threshold=30): + image_copy = image.copy() + high_freq_colors = count_high_freq_colors(image) + # Check for similar colors and replace the lower frequency color with the higher frequency color in the image + for i, (freq1, color1) in enumerate(high_freq_colors): + for j, (freq2, color2) in enumerate(high_freq_colors): + if (color_distance(color1, color2) < similarity_threshold) or (color_distance(color1, opaque_color_on_white(color2, 0.5)) < 5): + if(freq2 > freq1): + replace_color(image_copy, color1, color2) + + high_freq_colors = count_high_freq_colors(image_copy) + print(high_freq_colors) + return [high_freq_colors, image_copy] + +def color_quantization(image, color_frequency_list): + # Convert the color frequency list to a set of unique colors + unique_colors = set([color for _, color in color_frequency_list]) + + # Create a mask for the image with True where the color is in the unique colors set + mask = np.any(np.all(image.reshape(-1, 1, 3) == np.array(list(unique_colors)), axis=2), axis=1).reshape(image.shape[:2]) + + # Create a new image with all pixels set to white + new_image = np.full_like(image, 255) + + # Copy the pixels from the original image that have a color in the color frequency list + new_image[mask] = image[mask] + return new_image + + +def create_binary_mask(img_arr, target_color): + # Create mask of pixels with target color + mask = np.all(img_arr == target_color, axis=-1) + + # Convert mask to binary matrix + binary_matrix = mask.astype(int) + + return binary_matrix + + +def create_binary_matrix_base64(img_arr, target_color): + # Create mask of pixels with target color + mask = np.all(img_arr == target_color, axis=-1) + + # Convert mask to binary matrix + binary_matrix = mask.astype(int) + from datetime import datetime + binary_file_name = f'mask-{datetime.now().timestamp()}.png' + _, im_arr = cv2.imencode('.png', binary_matrix * 255) + + im_bytes = im_arr.tobytes() + im_b64 = base64.b64encode(im_bytes).decode('ascii') + # binary_matrix = torch.from_numpy(binary_matrix).unsqueeze(0).unsqueeze(0) + return mask, im_b64 + +def create_binary_matrix_img(img_arr, target_color): + # Create mask of pixels with target color + mask = np.all(img_arr == target_color, axis=-1) + + # Convert mask to binary matrix + binary_matrix = mask.astype(int) + from datetime import datetime + binary_file_name = f'mask-{datetime.now().timestamp()}.png' + cv2.imwrite(binary_file_name, binary_matrix * 255) + + #binary_matrix = torch.from_numpy(binary_matrix).unsqueeze(0).unsqueeze(0) + return binary_file_name + +def color_distance(color1, color2): + return sum((a - b) ** 2 for a, b in zip(color1, color2)) ** 0.5 + +def replace_color(image, old_color, new_color): + pixels = image.load() + width, height = image.size + for x in range(width): + for y in range(height): + if pixels[x, y] == old_color: + pixels[x, y] = new_color + +def opaque_color_on_white(color, a): + r, g, b = color + opaque_red = int((1 - a) * 255 + a * r) + opaque_green = int((1 - a) * 255 + a * g) + opaque_blue = int((1 - a) * 255 + a * b) + return (opaque_red, opaque_green, opaque_blue) \ No newline at end of file diff --git a/scripts/two_shot.py b/scripts/two_shot.py index 0727bd6..f15d814 100644 --- a/scripts/two_shot.py +++ b/scripts/two_shot.py @@ -1,17 +1,49 @@ +import base64 from typing import List, Dict, Optional, Tuple from dataclasses import dataclass import torch -from modules import devices +from scripts.sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix_base64, create_binary_mask +import numpy as np +import cv2 + +from modules import devices, script_callbacks import modules.scripts as scripts import gradio as gr -# todo: + from modules.script_callbacks import CFGDenoisedParams, on_cfg_denoised from modules.processing import StableDiffusionProcessing +MAX_COLORS = 12 +switch_values_symbol = '\U000021C5' # ⇅ + + +class ToolButton(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool", **kwargs) + + def get_block_name(self): + return "button" + + +# abstract base class for filters +from abc import ABC, abstractmethod + + + + +class Filter(ABC): + + @abstractmethod + def create_tensor(self): + pass + + @dataclass class Division: @@ -27,8 +59,8 @@ class Position: ex: float -class Filter: +class RectFilter(Filter): def __init__(self, division: Division, position: Position, weight: float): self.division = division self.position = position @@ -50,13 +82,86 @@ def create_tensor(self, num_channels: int, height_b: int, width_b: int) -> torch return x +class MaskFilter: + def __init__(self, binary_mask: np.array = None, weight: float = None, float_mask: np.array = None): + if float_mask is None: + self.mask = binary_mask.astype(np.float32) * weight + elif binary_mask is None and weight is None: + self.mask = float_mask + else: + raise ValueError('Either float_mask or binary_mask and weight must be provided') + self.tensor_mask = torch.tensor(self.mask).to(devices.device) + + def create_tensor(self, num_channels: int, height_b: int, width_b: int) -> torch.Tensor: + + + # x = torch.zeros(num_channels, height_b, width_b).to(devices.device) + # mask = torch.tensor(self.mask).to(devices.device) + # downsample mask to x size + # mask_bicubic = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(height_b, width_b), mode='bicubic').squeeze(0).squeeze(0).cpu().numpy() + # + # mask_nearest_exact = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(height_b, width_b), mode='nearest-exact').squeeze(0).squeeze(0).cpu().numpy() + # + # mask_nearest = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(height_b, width_b), mode='nearest').squeeze(0).squeeze(0).cpu().numpy() + # + # mask_area = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(height_b, width_b), mode='area').squeeze(0).squeeze(0).cpu().numpy() + + mask = torch.nn.functional.interpolate(self.tensor_mask.unsqueeze(0).unsqueeze(0), size=(height_b, width_b), mode='nearest-exact').squeeze(0).squeeze(0) + mask = mask.unsqueeze(0).repeat(num_channels, 1, 1) + + return mask + + +class PastePromptTextboxTracker: + def __init__(self): + self.scripts = [] + return + + def set_script(self, script): + self.scripts.append(script) + + def on_after_component_callback(self, component, **_kwargs): + + if not self.scripts: + return + if type(component) is gr.State: + return + + script = None + if type(component) is gr.Textbox and component.elem_id == 'txt2img_prompt': + # select corresponding script + script = next(x for x in self.scripts if x.is_txt2img) + self.scripts.remove(script) + + if type(component) is gr.Textbox and component.elem_id == 'img2img_prompt': + # select corresponding script + script = next(x for x in self.scripts if x.is_img2img) + self.scripts.remove(script) + + if script is None: + return + + script.target_paste_prompt = component + + +prompt_textbox_tracker = PastePromptTextboxTracker() + + class Script(scripts.Script): def __init__(self): + self.ui_root = None self.num_batches: int = 0 self.end_at_step: int = 20 self.filters: List[Filter] = [] self.debug: bool = False + self.selected_twoshot_tab = 0 + self.ndmasks = [] + self.area_colors = [] + self.mask_denoise = False + prompt_textbox_tracker.set_script(self) + self.target_paste_prompt = None + def title(self): return "Latent Couple extension" @@ -64,7 +169,36 @@ def title(self): def show(self, is_img2img): return scripts.AlwaysVisible - def create_filters_from_ui_params(self, raw_divisions: str, raw_positions: str, raw_weights: str): + def create_rect_filters_from_ui_params(self, raw_divisions: str, raw_positions: str, raw_weights: str): + + divisions = [] + for division in raw_divisions.split(','): + y, x = division.split(':') + divisions.append(Division(float(y), float(x))) + + def start_and_end_position(raw: str): + nums = [float(num) for num in raw.split('-')] + if len(nums) == 1: + return nums[0], nums[0] + 1.0 + else: + return nums[0], nums[1] + + positions = [] + for position in raw_positions.split(','): + y, x = position.split(':') + y1, y2 = start_and_end_position(y) + x1, x2 = start_and_end_position(x) + positions.append(Position(y1, x1, y2, x2)) + + weights = [] + for w in raw_weights.split(','): + weights.append(float(w)) + + # todo: assert len + + return [RectFilter(division, position, weight) for division, position, weight in zip(divisions, positions, weights)] + + def create_mask_filters_from_ui_params(self, raw_divisions: str, raw_positions: str, raw_weights: str): divisions = [] for division in raw_divisions.split(','): @@ -95,7 +229,7 @@ def start_and_end_position(raw: str): def do_visualize(self, raw_divisions: str, raw_positions: str, raw_weights: str): - self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights) + self.filters = self.create_rect_filters_from_ui_params(raw_divisions, raw_positions, raw_weights) return [f.create_tensor(1, 128, 128).squeeze(dim=0).cpu().numpy() for f in self.filters] @@ -114,32 +248,258 @@ def do_apply(self, extra_generation_params: str): return raw_params.get('divisions', '1:1,1:2,1:2'), raw_params.get('positions', '0:0,0:0,0:1'), raw_params.get('weights', '0.2,0.8,0.8'), int(raw_params.get('step', '20')) def ui(self, is_img2img): + process_script_params = [] id_part = "img2img" if is_img2img else "txt2img" - - with gr.Group(): + canvas_html = "
" + # get_js_colors = """ + # async (canvasData) => { + # const canvasEl = document.getElementById("canvas-root"); + # return [canvasEl._data] + # } + # """ + + def create_canvas(h, w): + return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 + + def process_sketch(img_arr, input_binary_matrixes): + input_binary_matrixes.clear() + # base64_img = canvas_data['image'] + # image_data = base64.b64decode(base64_img.split(',')[1]) + # image = Image.open(BytesIO(image_data)).convert("RGB") + im2arr = img_arr + # colors = [tuple(map(int, rgb[4:-1].split(','))) for rgb in + # ['colors']] + sketch_colors, color_counts = np.unique(im2arr.reshape(-1, im2arr.shape[2]), axis=0, return_counts=True) + colors_fixed = [] + # if color count is less than 0.001 of total pixel count, collect it for edge color correction + edge_color_correction_arr = [] + for sketch_color_idx, color in enumerate(sketch_colors[:-1]): # exclude white + if color_counts[sketch_color_idx] < im2arr.shape[0] * im2arr.shape[1] * 0.002: + edge_color_correction_arr.append(sketch_color_idx) + + edge_fix_dict = {} + # TODO:for every non area color pixel in img_arr, find the nearest area color pixel and replace it with that color + + area_colors = np.delete(sketch_colors, edge_color_correction_arr, axis=0) + if self.mask_denoise: + for edge_color_idx in edge_color_correction_arr: + edge_color = sketch_colors[edge_color_idx] + # find the nearest area_color + + color_distances = np.linalg.norm(area_colors - edge_color, axis=1) + nearest_index = np.argmin(color_distances) + nearest_color = area_colors[nearest_index] + edge_fix_dict[edge_color_idx] = nearest_color + # replace edge color with the nearest area_color + cur_color_mask = np.all(im2arr == edge_color, axis=2) + im2arr[cur_color_mask] = nearest_color + + # recalculate area colors + sketch_colors, color_counts = np.unique(im2arr.reshape(-1, im2arr.shape[2]), axis=0, return_counts=True) + area_colors = sketch_colors + + # create binary matrix for each area_color + area_color_maps = [] + self.ndmasks = [] + self.area_colors = area_colors + for color in area_colors: + r, g, b = color + mask, binary_matrix = create_binary_matrix_base64(im2arr, color) + self.ndmasks.append(mask) + input_binary_matrixes.append(binary_matrix) + colors_fixed.append(gr.update( + value=f'