Skip to content

Commit

Permalink
Add Extension tests (comfyanonymous#2125)
Browse files Browse the repository at this point in the history
* Add test for extension hooks
Add afterConfigureGraph callback

* fix comment
  • Loading branch information
pythongosssss authored Dec 1, 2023
1 parent ec7a00a commit 8491280
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 3 deletions.
196 changes: 196 additions & 0 deletions tests-ui/tests/extensions.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start } = require("../utils");
const lg = require("../utils/litegraph");

describe("extensions", () => {
beforeEach(() => {
lg.setup(global);
});

afterEach(() => {
lg.teardown(global);
});

it("calls each extension hook", async () => {
const mockExtension = {
name: "TestExtension",
init: jest.fn(),
setup: jest.fn(),
addCustomNodeDefs: jest.fn(),
getCustomWidgets: jest.fn(),
beforeRegisterNodeDef: jest.fn(),
registerCustomNodes: jest.fn(),
loadedGraphNode: jest.fn(),
nodeCreated: jest.fn(),
beforeConfigureGraph: jest.fn(),
afterConfigureGraph: jest.fn(),
};

const { app, ez, graph } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});

// Basic initialisation hooks should be called once, with app
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.init).toHaveBeenCalledWith(app);

// Adding custom node defs should be passed the full list of nodes
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
expect(defs).toHaveProperty("KSampler");
expect(defs).toHaveProperty("LoadImage");

// Get custom widgets is called once and should return new widget types
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);

// Before register node def will be called once per node type
const nodeNames = Object.keys(defs);
const nodeCount = nodeNames.length;
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
for (let i = 0; i < nodeCount; i++) {
// It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];

expect(nodeClass.name).toBe("ComfyNode");
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
expect(nodeDef.name).toBe(nodeNames[i]);
expect(nodeDef).toHaveProperty("input");
expect(nodeDef).toHaveProperty("output");
}

// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);

// Before configure graph will be called here as the default graph is being loaded
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
// it gets sent the graph data that is going to be loaded
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];

// A node created is fired for each node constructor that is called
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}

// Each node then calls loadedGraphNode to allow them to be updated
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}

// After configure is then called once all the setup is done
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);

expect(mockExtension.setup).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledWith(app);

// Ensure hooks are called in the correct order
const callOrder = [
"init",
"addCustomNodeDefs",
"getCustomWidgets",
"beforeRegisterNodeDef",
"registerCustomNodes",
"beforeConfigureGraph",
"nodeCreated",
"loadedGraphNode",
"afterConfigureGraph",
"setup",
];
for (let i = 1; i < callOrder.length; i++) {
const fn1 = mockExtension[callOrder[i - 1]];
const fn2 = mockExtension[callOrder[i]];
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
}

graph.clear();

// Ensure adding a new node calls the correct callback
ez.LoadImage();
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");

// Reload the graph to ensure correct hooks are fired
await graph.reload();

// These hooks should not be fired again
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);

// These should be called again
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
});

it("allows custom nodeDefs and widgets to be registered", async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => {
expect(node.constructor.comfyClass).toBe("TestNode");
expect(inputName).toBe("test_input");
expect(inputData[0]).toBe("CUSTOMWIDGET");
expect(inputData[1]?.hello).toBe("world");
expect(app).toStrictEqual(app);

return {
widget: node.addWidget("button", inputName, "hello", () => {}),
};
});

// Register our extension that adds a custom node + widget type
const mockExtension = {
name: "TestExtension",
addCustomNodeDefs: (nodeDefs) => {
nodeDefs["TestNode"] = {
output: [],
output_name: [],
output_is_list: [],
name: "TestNode",
display_name: "TestNode",
category: "Test",
input: {
required: {
test_input: ["CUSTOMWIDGET", { hello: "world" }],
},
},
};
},
getCustomWidgets: jest.fn(() => {
return {
CUSTOMWIDGET: widgetMock,
};
}),
};

const { graph, ez } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});

expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);

graph.clear();
expect(widgetMock).toBeCalledTimes(0);
const node = ez.TestNode();
expect(widgetMock).toBeCalledTimes(1);

// Ensure our custom widget is created
expect(node.inputs.length).toBe(0);
expect(node.widgets.length).toBe(1);
const w = node.widgets[0].widget;
expect(w.name).toBe("test_input");
expect(w.type).toBe("button");
});
});
7 changes: 4 additions & 3 deletions tests-ui/utils/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@ const lg = require("./litegraph");

/**
*
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean } } config
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean, preSetup?(app): Promise<void> } } config
* @returns
*/
export async function start(config = undefined) {
if(config?.resetEnv) {
export async function start(config = {}) {
if(config.resetEnv) {
jest.resetModules();
jest.resetAllMocks();
lg.setup(global);
}

mockApi(config);
const { app } = require("../../web/scripts/app");
config.preSetup?.(app);
await app.setup();
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
}
Expand Down
1 change: 1 addition & 0 deletions web/scripts/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,7 @@ export class ComfyApp {
if (missingNodeTypes.length) {
this.showMissingNodesError(missingNodeTypes);
}
await this.#invokeExtensionsAsync("afterConfigureGraph", missingNodeTypes);
}

/**
Expand Down

0 comments on commit 8491280

Please sign in to comment.