Skip to content

Commit 165780a

Browse files
committed
Merge branch 'master' into beta
2 parents aa694f5 + 2995a24 commit 165780a

File tree

7 files changed

+227
-24
lines changed

7 files changed

+227
-24
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
4545
|---------------------------|--------------------------------------------------------------------------------------------------------------------|
4646
| Ctrl + Enter | Queue up current graph for generation |
4747
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
48+
| Ctrl + Z/Ctrl + Y | Undo/Redo |
4849
| Ctrl + S | Save workflow |
4950
| Ctrl + O | Load workflow |
5051
| Ctrl + A | Select all nodes |
@@ -100,6 +101,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
100101
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6```
101102

102103
This is the command to install the nightly with ROCm 5.7 that might have some performance improvements:
104+
103105
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7```
104106

105107
### NVIDIA
@@ -192,7 +194,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
192194

193195
Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU.
194196

195-
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers this option does not do anything.
197+
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
196198

197199
```--dont-upcast-attention```
198200

comfy/supported_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def model_type(self, state_dict, prefix=""):
7171
return model_base.ModelType.EPS
7272

7373
def process_clip_state_dict(self, state_dict):
74+
replace_prefix = {}
75+
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
76+
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
77+
7478
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
7579
return state_dict
7680

tests-ui/tests/extensions.test.js

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// @ts-check
2+
/// <reference path="../node_modules/@types/jest/index.d.ts" />
3+
const { start } = require("../utils");
4+
const lg = require("../utils/litegraph");
5+
6+
describe("extensions", () => {
7+
beforeEach(() => {
8+
lg.setup(global);
9+
});
10+
11+
afterEach(() => {
12+
lg.teardown(global);
13+
});
14+
15+
it("calls each extension hook", async () => {
16+
const mockExtension = {
17+
name: "TestExtension",
18+
init: jest.fn(),
19+
setup: jest.fn(),
20+
addCustomNodeDefs: jest.fn(),
21+
getCustomWidgets: jest.fn(),
22+
beforeRegisterNodeDef: jest.fn(),
23+
registerCustomNodes: jest.fn(),
24+
loadedGraphNode: jest.fn(),
25+
nodeCreated: jest.fn(),
26+
beforeConfigureGraph: jest.fn(),
27+
afterConfigureGraph: jest.fn(),
28+
};
29+
30+
const { app, ez, graph } = await start({
31+
async preSetup(app) {
32+
app.registerExtension(mockExtension);
33+
},
34+
});
35+
36+
// Basic initialisation hooks should be called once, with app
37+
expect(mockExtension.init).toHaveBeenCalledTimes(1);
38+
expect(mockExtension.init).toHaveBeenCalledWith(app);
39+
40+
// Adding custom node defs should be passed the full list of nodes
41+
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
42+
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
43+
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
44+
expect(defs).toHaveProperty("KSampler");
45+
expect(defs).toHaveProperty("LoadImage");
46+
47+
// Get custom widgets is called once and should return new widget types
48+
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
49+
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
50+
51+
// Before register node def will be called once per node type
52+
const nodeNames = Object.keys(defs);
53+
const nodeCount = nodeNames.length;
54+
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
55+
for (let i = 0; i < nodeCount; i++) {
56+
// It should be send the JS class and the original JSON definition
57+
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
58+
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
59+
60+
expect(nodeClass.name).toBe("ComfyNode");
61+
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
62+
expect(nodeDef.name).toBe(nodeNames[i]);
63+
expect(nodeDef).toHaveProperty("input");
64+
expect(nodeDef).toHaveProperty("output");
65+
}
66+
67+
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
68+
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
69+
70+
// Before configure graph will be called here as the default graph is being loaded
71+
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
72+
// it gets sent the graph data that is going to be loaded
73+
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
74+
75+
// A node created is fired for each node constructor that is called
76+
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
77+
for (let i = 0; i < graphData.nodes.length; i++) {
78+
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
79+
}
80+
81+
// Each node then calls loadedGraphNode to allow them to be updated
82+
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
83+
for (let i = 0; i < graphData.nodes.length; i++) {
84+
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
85+
}
86+
87+
// After configure is then called once all the setup is done
88+
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
89+
90+
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
91+
expect(mockExtension.setup).toHaveBeenCalledWith(app);
92+
93+
// Ensure hooks are called in the correct order
94+
const callOrder = [
95+
"init",
96+
"addCustomNodeDefs",
97+
"getCustomWidgets",
98+
"beforeRegisterNodeDef",
99+
"registerCustomNodes",
100+
"beforeConfigureGraph",
101+
"nodeCreated",
102+
"loadedGraphNode",
103+
"afterConfigureGraph",
104+
"setup",
105+
];
106+
for (let i = 1; i < callOrder.length; i++) {
107+
const fn1 = mockExtension[callOrder[i - 1]];
108+
const fn2 = mockExtension[callOrder[i]];
109+
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
110+
}
111+
112+
graph.clear();
113+
114+
// Ensure adding a new node calls the correct callback
115+
ez.LoadImage();
116+
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
117+
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
118+
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
119+
120+
// Reload the graph to ensure correct hooks are fired
121+
await graph.reload();
122+
123+
// These hooks should not be fired again
124+
expect(mockExtension.init).toHaveBeenCalledTimes(1);
125+
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
126+
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
127+
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
128+
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
129+
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
130+
131+
// These should be called again
132+
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
133+
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
134+
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
135+
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
136+
});
137+
138+
it("allows custom nodeDefs and widgets to be registered", async () => {
139+
const widgetMock = jest.fn((node, inputName, inputData, app) => {
140+
expect(node.constructor.comfyClass).toBe("TestNode");
141+
expect(inputName).toBe("test_input");
142+
expect(inputData[0]).toBe("CUSTOMWIDGET");
143+
expect(inputData[1]?.hello).toBe("world");
144+
expect(app).toStrictEqual(app);
145+
146+
return {
147+
widget: node.addWidget("button", inputName, "hello", () => {}),
148+
};
149+
});
150+
151+
// Register our extension that adds a custom node + widget type
152+
const mockExtension = {
153+
name: "TestExtension",
154+
addCustomNodeDefs: (nodeDefs) => {
155+
nodeDefs["TestNode"] = {
156+
output: [],
157+
output_name: [],
158+
output_is_list: [],
159+
name: "TestNode",
160+
display_name: "TestNode",
161+
category: "Test",
162+
input: {
163+
required: {
164+
test_input: ["CUSTOMWIDGET", { hello: "world" }],
165+
},
166+
},
167+
};
168+
},
169+
getCustomWidgets: jest.fn(() => {
170+
return {
171+
CUSTOMWIDGET: widgetMock,
172+
};
173+
}),
174+
};
175+
176+
const { graph, ez } = await start({
177+
async preSetup(app) {
178+
app.registerExtension(mockExtension);
179+
},
180+
});
181+
182+
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
183+
184+
graph.clear();
185+
expect(widgetMock).toBeCalledTimes(0);
186+
const node = ez.TestNode();
187+
expect(widgetMock).toBeCalledTimes(1);
188+
189+
// Ensure our custom widget is created
190+
expect(node.inputs.length).toBe(0);
191+
expect(node.widgets.length).toBe(1);
192+
const w = node.widgets[0].widget;
193+
expect(w.name).toBe("test_input");
194+
expect(w.type).toBe("button");
195+
});
196+
});

tests-ui/utils/index.js

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ const lg = require("./litegraph");
44

55
/**
66
*
7-
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean } } config
7+
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean, preSetup?(app): Promise<void> } } config
88
* @returns
99
*/
10-
export async function start(config = undefined) {
11-
if(config?.resetEnv) {
10+
export async function start(config = {}) {
11+
if(config.resetEnv) {
1212
jest.resetModules();
1313
jest.resetAllMocks();
1414
lg.setup(global);
1515
}
1616

1717
mockApi(config);
1818
const { app } = require("../../web/scripts/app");
19+
config.preSetup?.(app);
1920
await app.setup();
2021
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
2122
}

web/extensions/core/groupNode.js

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import { app } from "../../scripts/app.js";
22
import { api } from "../../scripts/api.js";
3-
import { getWidgetType } from "../../scripts/widgets.js";
43
import { mergeIfValid } from "./widgetInputs.js";
54

65
const GROUP = Symbol();
@@ -332,7 +331,7 @@ export class GroupNodeConfig {
332331
const converted = new Map();
333332
const widgetMap = (this.oldToNewWidgetMap[node.index] = {});
334333
for (const inputName of inputNames) {
335-
let widgetType = getWidgetType(inputs[inputName], inputName);
334+
let widgetType = app.getWidgetType(inputs[inputName], inputName);
336335
if (widgetType) {
337336
const convertedIndex = node.inputs?.findIndex(
338337
(inp) => inp.name === inputName && inp.widget?.name === inputName
@@ -1010,7 +1009,7 @@ function addConvertToGroupOptions() {
10101009
const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions;
10111010
LGraphCanvas.prototype.getCanvasMenuOptions = function () {
10121011
const options = getCanvasMenuOptions.apply(this, arguments);
1013-
const index = options.findIndex((o) => o?.content === "Add Group") + 1 || opts.length;
1012+
const index = options.findIndex((o) => o?.content === "Add Group") + 1 || options.length;
10141013
addOption(options, index);
10151014
return options;
10161015
};
@@ -1020,7 +1019,7 @@ function addConvertToGroupOptions() {
10201019
LGraphCanvas.prototype.getNodeMenuOptions = function (node) {
10211020
const options = getNodeMenuOptions.apply(this, arguments);
10221021
if (!GroupNodeHandler.isGroupNode(node)) {
1023-
const index = options.findIndex((o) => o?.content === "Outputs") + 1 || opts.length - 1;
1022+
const index = options.findIndex((o) => o?.content === "Outputs") + 1 || options.length - 1;
10241023
addOption(options, index);
10251024
}
10261025
return options;

web/scripts/app.js

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { ComfyLogging } from "./logging.js";
2-
import { ComfyWidgets, getWidgetType } from "./widgets.js";
2+
import { ComfyWidgets } from "./widgets.js";
33
import { ComfyUI, $el } from "./ui.js";
44
import { api } from "./api.js";
55
import { defaultGraph } from "./defaultGraph.js";
@@ -1381,6 +1381,20 @@ export class ComfyApp {
13811381
await this.#invokeExtensionsAsync("registerCustomNodes");
13821382
}
13831383

1384+
getWidgetType(inputData, inputName) {
1385+
const type = inputData[0];
1386+
1387+
if (Array.isArray(type)) {
1388+
return "COMBO";
1389+
} else if (`${type}:${inputName}` in this.widgets) {
1390+
return `${type}:${inputName}`;
1391+
} else if (type in this.widgets) {
1392+
return type;
1393+
} else {
1394+
return null;
1395+
}
1396+
}
1397+
13841398
async registerNodeDef(nodeId, nodeData) {
13851399
const self = this;
13861400
const node = Object.assign(
@@ -1396,7 +1410,7 @@ export class ComfyApp {
13961410
const extraInfo = {};
13971411

13981412
let widgetCreated = true;
1399-
const widgetType = getWidgetType(inputData, inputName);
1413+
const widgetType = self.getWidgetType(inputData, inputName);
14001414
if(widgetType) {
14011415
if(widgetType === "COMBO") {
14021416
Object.assign(config, self.widgets.COMBO(this, inputName, inputData, app) || {});
@@ -1649,6 +1663,7 @@ export class ComfyApp {
16491663
if (missingNodeTypes.length) {
16501664
this.showMissingNodesError(missingNodeTypes);
16511665
}
1666+
await this.#invokeExtensionsAsync("afterConfigureGraph", missingNodeTypes);
16521667
}
16531668

16541669
/**

web/scripts/widgets.js

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,6 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
2323
return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } };
2424
}
2525

26-
export function getWidgetType(inputData, inputName) {
27-
const type = inputData[0];
28-
29-
if (Array.isArray(type)) {
30-
return "COMBO";
31-
} else if (`${type}:${inputName}` in ComfyWidgets) {
32-
return `${type}:${inputName}`;
33-
} else if (type in ComfyWidgets) {
34-
return type;
35-
} else {
36-
return null;
37-
}
38-
}
39-
4026
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName, inputData) {
4127
let name = inputData[1]?.control_after_generate;
4228
if(typeof name !== "string") {

0 commit comments

Comments
 (0)