Skip to content

Commit

Permalink
Button to delete local (file based) custom workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Oct 2, 2024
1 parent 7df76e0 commit a7e3408
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 7 deletions.
20 changes: 19 additions & 1 deletion ai_diffusion/custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _process_remote_workflow(self, id: str):

def _process_file(self, file: Path):
with file.open("r") as f:
self._process(CustomWorkflow(file.stem, WorkflowSource.local, json.load(f)))
self._process(CustomWorkflow(file.stem, WorkflowSource.local, json.load(f), file))

def _process(self, workflow: CustomWorkflow):
idx = self.find_index(workflow.id)
Expand Down Expand Up @@ -94,6 +94,16 @@ def append(self, item: CustomWorkflow):
self._workflows.append(item)
self.endInsertRows()

def remove(self, id: str):
idx = self.find_index(id)
if idx.isValid():
wf = self._workflows[idx.row()]
if wf.source is WorkflowSource.local and wf.path is not None:
wf.path.unlink()
self.beginRemoveRows(QModelIndex(), idx.row(), idx.row())
self._workflows.pop(idx.row())
self.endRemoveRows()

def set_graph(self, index: QModelIndex, graph: dict):
self._workflows[index.row()].graph = graph
self.dataChanged.emit(index, index)
Expand Down Expand Up @@ -253,6 +263,14 @@ def save_as(self, id: str):
assert self._graph, "Save as: no workflow selected"
self.workflow_id = self._workflows.save_as(id, self._graph.root)

def remove_workflow(self):
if id := self.workflow_id:
self._workflow_id = ""
self._workflow = None
self._graph = None
self._metadata = []
self._workflows.remove(id)

@property
def workflow(self):
return self._workflow
Expand Down
2 changes: 1 addition & 1 deletion ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def generate_custom(self):
sampling=SamplingInput("custom", "custom", 1, 1000, seed=seed),
custom_workflow=CustomWorkflowInput(wf.root, {}),
)
job_params = JobParams(bounds, self.custom.graph_id)
job_params = JobParams(bounds, self.custom.workflow_id)
except Exception as e:
self.report_error(util.log_error(e))
return
Expand Down
25 changes: 24 additions & 1 deletion ai_diffusion/ui/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..image import Bounds, Extent, Image
from ..jobs import Job, JobQueue, JobState, JobKind, JobParams
from ..model import Model, InpaintContext, RootRegion, ProgressKind
from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows
from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows, WorkflowSource
from ..style import Styles
from ..root import root
from ..workflow import InpaintMode, FillMode
Expand Down Expand Up @@ -1005,6 +1005,12 @@ def __init__(self):
_("Save workflow to file"),
self._save_workflow,
)
self._delete_workflow_button = _create_tool_button(
self._workflow_select_widgets,
theme.icon("discard"),
_("Delete the currently selected workflow"),
self._delete_workflow,
)
self._open_webui_button = _create_tool_button(
self._workflow_select_widgets,
theme.icon("comfyui"),
Expand Down Expand Up @@ -1044,6 +1050,7 @@ def __init__(self):
select_layout.addWidget(self._workflow_select)
select_layout.addWidget(self._import_workflow_button)
select_layout.addWidget(self._save_workflow_button)
select_layout.addWidget(self._delete_workflow_button)
select_layout.addWidget(self._open_webui_button)
self._workflow_select_widgets.setLayout(select_layout)
edit_layout = QHBoxLayout()
Expand Down Expand Up @@ -1071,8 +1078,12 @@ def __init__(self):
def _update_current_workflow(self):
if not self.model.custom.workflow:
self._save_workflow_button.setEnabled(False)
self._delete_workflow_button.setEnabled(False)
return
self._save_workflow_button.setEnabled(True)
self._delete_workflow_button.setEnabled(
self.model.custom.workflow.source is WorkflowSource.local
)

self._params_widget.deleteLater()
self._params_widget = WorkflowParamsWidget(self.model.custom.metadata, self)
Expand Down Expand Up @@ -1127,6 +1138,18 @@ def _import_workflow(self, *args):
def _save_workflow(self):
self.is_edit_mode = True

def _delete_workflow(self):
filepath = ensure(self.model.custom.workflow).path
q = QMessageBox.question(
self,
_("Delete Workflow"),
_("Are you sure you want to delete the current workflow?") + f"\n{filepath}",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.StandardButton.No,
)
if q == QMessageBox.StandardButton.Yes:
self.model.custom.remove_workflow()

def _open_webui(self):
if client := root.connection.client_if_connected:
QDesktopServices.openUrl(QUrl(client.url))
Expand Down
17 changes: 13 additions & 4 deletions tests/test_custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ def test_collection(tmp_path: Path):

collection = WorkflowCollection(connection, tmp_path)
assert len(collection) == 3
assert collection.find("file1") == CustomWorkflow("file1", WorkflowSource.local, {"file": 1})
assert collection.find("file2") == CustomWorkflow("file2", WorkflowSource.local, {"file": 2})
assert collection.find("file1") == CustomWorkflow(
"file1", WorkflowSource.local, {"file": 1}, file1
)
assert collection.find("file2") == CustomWorkflow(
"file2", WorkflowSource.local, {"file": 2}, file2
)
assert collection.find("connection1") == CustomWorkflow(
"connection1", WorkflowSource.remote, {"connection": 1}
)
Expand Down Expand Up @@ -56,8 +60,9 @@ def on_data_changed(start, end):
)

collection.set_graph(collection.index(0), {"file": 3})
assert collection.find("file1") == CustomWorkflow("file1", WorkflowSource.local, {"file": 3})

assert collection.find("file1") == CustomWorkflow(
"file1", WorkflowSource.local, {"file": 3}, file1
)
assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)]

collection.append(CustomWorkflow("doc1", WorkflowSource.document, {"doc": 1}))
Expand Down Expand Up @@ -104,6 +109,10 @@ def test_files(tmp_path: Path):
]
assert all(f.exists() for f in files)

collection.remove("file1 (1)")
assert collection.find("file1 (1)") is None
assert not (collection_folder / "file1 (1).json").exists()

bad_file = tmp_path / "bad.json"
bad_file.write_text("bad json")
with pytest.raises(RuntimeError):
Expand Down

0 comments on commit a7e3408

Please sign in to comment.