-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy path__init__.py
41 lines (27 loc) · 873 Bytes
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from .normalization import Normalization
from .recenter import Recenter, RecenterXL, disable_recenter
from functools import wraps
import execution
NODE_CLASS_MAPPINGS = {
"Normalization": Normalization,
"Recenter": Recenter,
"Recenter XL": RecenterXL,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Normalization": "Normalization",
"Recenter": "Recenter",
"Recenter XL": "RecenterXL",
}
def find_node(prompt: dict) -> bool:
"""Find any ReCenter Node"""
for node in prompt.values():
if node.get("class_type", None) in ("Recenter", "Recenter XL"):
return True
return False
original_validate = execution.validate_prompt
@wraps(original_validate)
def hijack_validate(prompt: dict):
if not find_node(prompt):
disable_recenter()
return original_validate(prompt)
execution.validate_prompt = hijack_validate