Skip to content

Commit

Permalink
add search function for prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
gongda authored and gongda committed Apr 18, 2023
1 parent 5f00f00 commit 4b550bf
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 33 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ An Web UI with intelligent prompts of Stable Diffusion with Core ML on Apple Sil

![Main Screen](./images/main_screen.png)

![Search](./images/search_prompt.png)

![Main Screen 1](./images/main_screen1.png)

![History Screen](./images/history.png)
Expand All @@ -15,7 +17,7 @@ An Web UI with intelligent prompts of Stable Diffusion with Core ML on Apple Sil
2. One submit could generate multiple images. Improve your prompt writing speed.
3. Support preserve options of medium and style and artist and resolution.
4. Analysis your usage habits. Help you discover best prompt words.
5. Contains 1048 prompts.
5. Contains 1048 prompts. Support quick search by keyword.


We support [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and [stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) now.
Expand Down
Binary file added images/search_prompt.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions stable_diffusion_webui/stable_diffusion_webui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
]

MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
#'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
#'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
#'django.middleware.clickjacking.XFrameOptionsMiddleware',
]

ROOT_URLCONF = 'stable_diffusion_webui.urls'
Expand Down
90 changes: 88 additions & 2 deletions stable_diffusion_webui/stable_diffusion_webui/static/js/default.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,91 @@ let selectAllOptions = function(input_name) {


let openPromptSearchBox = function(input_id) {
let modal = "";
};
let modalId = input_id + "OptionsSearchBox";
let modal = bootstrap.Modal.getOrCreateInstance("#" + modalId);
modal.show();
};


let searchPrompts = function(inputId, q) {
console.debug("modal shown ", inputId);
let modalId = inputId + "OptionsSearchBox";
let table = $("#" + modalId).find("table");
let modalTips = $("#" + modalId).find("#modal-tips");
let checkboxes = $("input[name='" + inputId + "']:checkbox:checked");
console.debug("checkboxes ", checkboxes);
let promptChecked = {};
for(const cb of checkboxes) {
const prompt = $(cb).prop("value");
const checked = $(cb).prop("checked");
promptChecked[prompt] = checked;
}
console.debug("promptChecked ", promptChecked);

$.ajax({
url: "/search/",
type: "POST",
data: JSON.stringify({category: inputId, q: q}),
contentType: "application/json; charset=utf-8",
dataType: "json",
success: function(resp) {
console.debug("search done ", resp);
//append header
let header = [
"<tr>",
" <th></th>",
" <th>Prompt</th>",
" <th>Info</th>",
" <th>Category</th>",
"</tr>"
].join("");
table.html(header);

for(const element of resp.data) {
let checked = "checked=checked";
if (promptChecked[element.name] == null) {
checked = "";
}
const node = [
"<tr>",
" <td>",
//" <div class='form-check form-check-inline'>",
" <input class='form-check-input' type='checkbox' name='" + inputId + "' value='" + element.name + "' " + checked + " />",
//" </div>",
" </td>",
" <td>",
" <label class='form-check-label'>" + element.name + "</label>",
" </td>",
" <td>" + element.info + "</td>",
" <td>" + element.category + "</td>",
"</tr>"
].join("");
table.append(node);
}

modalTips.removeClass();
modalTips.addClass("text-success text-end");
modalTips.html("Found " + resp.n);
},
error: function(err) {
console.error(err);
modalTips.removeClass();
modalTips.addClass("text-danger text-end");
modalTips.html(err);
}
});
};


let showPromptResult = function(inputId) {
let checkboxes = $("input[name='" + inputId + "']:checkbox:checked");
let prompts = [];

for(const cb of checkboxes) {
const prompt = $(cb).prop("value");
prompts.push(prompt);
}

$("#" + inputId + "Result").html(prompts.join(","));
};

Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ <h4>Text 2 Image</h4>
{% multiple_check item.input_id item.name item.options %}
{% endfor %}


<p id="tips"></p>
<button class="btn btn-primary mb-5" id="generateBtn" type="submit">Generate</button>
</form>

</div>
<div class="col-8">
<div class="row row-cols-lg-auto g-3 align-items-center p-3">
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,51 @@
<div class="mb-2">
<button class="btn btn-outline-dark btn-sm" type="button" data-bs-toggle="collapse" data-bs-target="#{{input_id}}Panel" aria-expanded="false" aria-controls="{{input_id}}Panel">
{{ label }}
</button>

<button type="button" class="btn btn-light btn-sm" id="{{ input_id }}Btn" onclick="selectAllOptions('{{ input_id }}')">
Check All
</button>
<button type="button" class="btn btn-outline-dark btn-sm" data-bs-toggle="modal" data-bs-target="#{{input_id}}OptionsSearchBox">Select {{ label }} Prompt</button>
<span class="text-secondary" id="{{input_id}}Result"></span>
</div>

<div id="{{ input_id }}Panel" class="collapse">
{% for opt in options %}
<div class="form-check form-check-inline">
<input class="form-check-input" type="checkbox" name="{{ input_id }}" value="{{ opt.value }}" id="{{ input_id }}">
<label class="form-check-label" for="{{ input_id }}" title="{{ opt.info }}">
{{ opt.name }}
{% if opt.hit %}
<small class="text-secondary"> Use: {{opt.percentage}}%</small>
{% endif %}
</label>

<!-- Modal -->
<div class="modal fade" id="{{input_id}}OptionsSearchBox" data-input-id="{{ input_id }}" tabindex="-1" aria-labelledby="{{input_id}}OptionsSearchBoxModalLabel" aria-hidden="true">
<div class="modal-dialog">
<div class="modal-content">
<div class="modal-header">
<h2 class="modal-title fs-5" id="{{input_id}}OptionsSearchBoxModalLabel">Find {{input_id}} Prompts</h2>
<button type="button" class="btn-close btn-sm" data-bs-dismiss="modal" aria-label="Close"></button>
</div>
<div class="modal-body">
<div class="mb-2">
<input class="form-control" type="text" id="{{input_id}}QueryInput" placeholder="type your prompt" />
</div>
<div class="mb-1">
<button type="button" class="btn btn-outline-dark btn-sm mb-1 text-end" id="{{ input_id }}Btn" onclick="selectAllOptions('{{ input_id }}')">
Check All
</button>
<span id="modal-tips" class="text-end"></span>
</div>
<table class="table table-bordered table-sm"></table>
</div>
<div class="modal-footer">
<button type="button" class="btn btn-secondary btn-sm" data-bs-dismiss="modal">Close</button>
</div>
</div>
</div>
{% endfor %}
</div>
</div>
<!-- options modal end -->


<script>
$(document).ready(function(e) {
const inputId = "{{ input_id }}";

document.getElementById("{{ input_id }}OptionsSearchBox").addEventListener("shown.bs.modal", function(ev){
searchPrompts(inputId, "");
});

$("#{{ input_id }}OptionsSearchBox").find("#" + inputId + "QueryInput").bind("input", function(ev){
searchPrompts(inputId, ev.target.value);
});

document.getElementById("{{ input_id }}OptionsSearchBox").addEventListener("hidden.bs.modal", function(ev){
showPromptResult(inputId);
});
});
</script>
22 changes: 18 additions & 4 deletions stable_diffusion_webui/stable_diffusion_webui/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django.test import TestCase, Client
import json

from stable_diffusion_webui.views import generate_combinations
from stable_diffusion_webui.views import _generate_combinations


class TestViews(TestCase):
Expand All @@ -11,17 +11,18 @@ def setUp(self) -> None:
self.client = Client()

def test_generate_two_combinations(self):
combs = generate_combinations("hi", ['printmaking', "painting"])
combs = _generate_combinations("hi", ['printmaking', "painting"])
self.assertEqual(len(combs), 2)

def test_generate_three_combinations(self):
combs = generate_combinations("hi", ['printmaking', "painting"], ['cartoon', 'pop'])
combs = _generate_combinations("hi", ['printmaking', "painting"], ['cartoon', 'pop'])
print(combs)
self.assertEqual(len(combs), 4)

def test_generate_image(self):
params = {
"subject": "A girl",
"exclude": "",
"medium": [],
"style": [],
"artist": [],
Expand All @@ -30,12 +31,25 @@ def test_generate_image(self):
"color": [],
"lighting": []
}
"""
r = self.client.post("/generate_image/", data=params, content_type="application/json")
self.assertEqual(r.status_code, 200)
data = r.json()
print(data)
self.assertEqual(data['n'], 1)
self.assertIsNotNone(data['id'])
"""

def test_search(self):
body = {
"category": "style",
"q": "2d",
}
resp = self.client.post("/search/", data=body, content_type="application/json")
self.assertEqual(resp.status_code, 200)
resp_data = resp.json()
print(resp_data)
self.assertGreater(len(resp_data['data']), 0)



1 change: 1 addition & 0 deletions stable_diffusion_webui/stable_diffusion_webui/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
path("get_generate_request/", views.get_generate_request, name="get_generate_request"),
path("history/", views.history, name="history"),
path("history/<int:page>", views.history, name="history"),
path("search/", views.search, name="search"),

path('admin/', admin.site.urls),
] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
23 changes: 21 additions & 2 deletions stable_diffusion_webui/stable_diffusion_webui/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
logger = logging.getLogger(__name__)


prompt_data = os.path.join(os.path.dirname(__file__), "prompt_data")
prompt_data_dir = os.path.join(os.path.dirname(__file__), "prompt_data")


def get_prompt_options(file_name):
options = []
path = os.path.join(prompt_data, file_name)
path = os.path.join(prompt_data_dir, file_name)
df = pd.read_csv(path)
for idx, row in df.iterrows():
options.append({
Expand All @@ -39,6 +39,25 @@ def get_prompt_options(file_name):
print("There are {} prompts".format(count_of_options))


def load_all_prompts():
files = os.listdir(prompt_data_dir)
df_list = []
for filename in files:
if filename.find(".csv") < 0:
continue
path = os.path.join(prompt_data_dir, filename)
df = pd.read_csv(path)
if 'category' not in df:
df['category'] = filename.split(".")[0]
df_list.append(df)

return pd.concat(df_list)


all_prompts_df = load_all_prompts()



def list_to_matrix(items, col):
"""Transform list to matrix with specify column number and unlimited rows
Expand Down
40 changes: 38 additions & 2 deletions stable_diffusion_webui/stable_diffusion_webui/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from .sd_model import default_sd_model
from .models import Prompt, GenerateRequest, PromptWordStat
from .utils import medium_options, style_options, artist_options, resolution_options, light_options, \
color_options, website_options, list_to_matrix, do_paginator, translate_chinese_to_english
color_options, website_options, list_to_matrix, do_paginator, translate_chinese_to_english, \
all_prompts_df



Expand Down Expand Up @@ -224,5 +225,40 @@ def history(request, page=1):

return render(request, "history.html", {'pager': pager, 'prefix': '/history/'})



def search(request):
"""Search prompts.
Request is json format.
{
"category": str,
"q": str
}
Returns:
{
n: int,
data: [
{"prompt": str, "category": str}
]
}
"""
body = json.loads(request.body)
category = body.get("category")
q = body.get("q")

df = all_prompts_df
if category:
df = df[df['category'] == category]

if q:
df = df[df['Name'].str.contains(q, case=False)]

data = []
for idx, row in df.iterrows():
data.append({
"name": row['Name'],
"info": row['Info'] if not pd.isna(row['Info']) else "",
'category': row['category']
})

return JsonResponse({"n": len(data), "data": data})

0 comments on commit 4b550bf

Please sign in to comment.