diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index f090c05a73..5f09651671 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -63,7 +63,7 @@ def _process_remote_workflow(self, id: str): def _process_file(self, file: Path): with file.open("r") as f: - self._process(CustomWorkflow(file.stem, WorkflowSource.local, json.load(f))) + self._process(CustomWorkflow(file.stem, WorkflowSource.local, json.load(f), file)) def _process(self, workflow: CustomWorkflow): idx = self.find_index(workflow.id) @@ -94,6 +94,16 @@ def append(self, item: CustomWorkflow): self._workflows.append(item) self.endInsertRows() + def remove(self, id: str): + idx = self.find_index(id) + if idx.isValid(): + wf = self._workflows[idx.row()] + if wf.source is WorkflowSource.local and wf.path is not None: + wf.path.unlink() + self.beginRemoveRows(QModelIndex(), idx.row(), idx.row()) + self._workflows.pop(idx.row()) + self.endRemoveRows() + def set_graph(self, index: QModelIndex, graph: dict): self._workflows[index.row()].graph = graph self.dataChanged.emit(index, index) @@ -253,6 +263,14 @@ def save_as(self, id: str): assert self._graph, "Save as: no workflow selected" self.workflow_id = self._workflows.save_as(id, self._graph.root) + def remove_workflow(self): + if id := self.workflow_id: + self._workflow_id = "" + self._workflow = None + self._graph = None + self._metadata = [] + self._workflows.remove(id) + @property def workflow(self): return self._workflow diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 4ecec0a6a7..d0206f9774 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -370,7 +370,7 @@ def generate_custom(self): sampling=SamplingInput("custom", "custom", 1, 1000, seed=seed), custom_workflow=CustomWorkflowInput(wf.root, {}), ) - job_params = JobParams(bounds, self.custom.graph_id) + job_params = JobParams(bounds, self.custom.workflow_id) except Exception as e: self.report_error(util.log_error(e)) return diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index a54baf5f3a..f6a92c0174 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -25,7 +25,7 @@ from ..image import Bounds, Extent, Image from ..jobs import Job, JobQueue, JobState, JobKind, JobParams from ..model import Model, InpaintContext, RootRegion, ProgressKind -from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows +from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows, WorkflowSource from ..style import Styles from ..root import root from ..workflow import InpaintMode, FillMode @@ -1005,6 +1005,12 @@ def __init__(self): _("Save workflow to file"), self._save_workflow, ) + self._delete_workflow_button = _create_tool_button( + self._workflow_select_widgets, + theme.icon("discard"), + _("Delete the currently selected workflow"), + self._delete_workflow, + ) self._open_webui_button = _create_tool_button( self._workflow_select_widgets, theme.icon("comfyui"), @@ -1044,6 +1050,7 @@ def __init__(self): select_layout.addWidget(self._workflow_select) select_layout.addWidget(self._import_workflow_button) select_layout.addWidget(self._save_workflow_button) + select_layout.addWidget(self._delete_workflow_button) select_layout.addWidget(self._open_webui_button) self._workflow_select_widgets.setLayout(select_layout) edit_layout = QHBoxLayout() @@ -1071,8 +1078,12 @@ def __init__(self): def _update_current_workflow(self): if not self.model.custom.workflow: self._save_workflow_button.setEnabled(False) + self._delete_workflow_button.setEnabled(False) return self._save_workflow_button.setEnabled(True) + self._delete_workflow_button.setEnabled( + self.model.custom.workflow.source is WorkflowSource.local + ) self._params_widget.deleteLater() self._params_widget = WorkflowParamsWidget(self.model.custom.metadata, self) @@ -1127,6 +1138,18 @@ def _import_workflow(self, *args): def _save_workflow(self): self.is_edit_mode = True + def _delete_workflow(self): + filepath = ensure(self.model.custom.workflow).path + q = QMessageBox.question( + self, + _("Delete Workflow"), + _("Are you sure you want to delete the current workflow?") + f"\n{filepath}", + QMessageBox.Yes | QMessageBox.No, + QMessageBox.StandardButton.No, + ) + if q == QMessageBox.StandardButton.Yes: + self.model.custom.remove_workflow() + def _open_webui(self): if client := root.connection.client_if_connected: QDesktopServices.openUrl(QUrl(client.url)) diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 689aaa8386..844e3176b4 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -26,8 +26,12 @@ def test_collection(tmp_path: Path): collection = WorkflowCollection(connection, tmp_path) assert len(collection) == 3 - assert collection.find("file1") == CustomWorkflow("file1", WorkflowSource.local, {"file": 1}) - assert collection.find("file2") == CustomWorkflow("file2", WorkflowSource.local, {"file": 2}) + assert collection.find("file1") == CustomWorkflow( + "file1", WorkflowSource.local, {"file": 1}, file1 + ) + assert collection.find("file2") == CustomWorkflow( + "file2", WorkflowSource.local, {"file": 2}, file2 + ) assert collection.find("connection1") == CustomWorkflow( "connection1", WorkflowSource.remote, {"connection": 1} ) @@ -56,8 +60,9 @@ def on_data_changed(start, end): ) collection.set_graph(collection.index(0), {"file": 3}) - assert collection.find("file1") == CustomWorkflow("file1", WorkflowSource.local, {"file": 3}) - + assert collection.find("file1") == CustomWorkflow( + "file1", WorkflowSource.local, {"file": 3}, file1 + ) assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)] collection.append(CustomWorkflow("doc1", WorkflowSource.document, {"doc": 1})) @@ -104,6 +109,10 @@ def test_files(tmp_path: Path): ] assert all(f.exists() for f in files) + collection.remove("file1 (1)") + assert collection.find("file1 (1)") is None + assert not (collection_folder / "file1 (1).json").exists() + bad_file = tmp_path / "bad.json" bad_file.write_text("bad json") with pytest.raises(RuntimeError):