Skip to content

Commit

Permalink
Merge branch 'extension-device' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Nov 6, 2023
2 parents b53d76b + 0ae1c22 commit d93cab1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
8 changes: 5 additions & 3 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ def get_torch_device():
global cpu_state
global extensions_devices

extension = comfy.utils.get_extension_calling()
if extension is not None and extension in extensions_devices:
return torch.device(extensions_devices[extension])
extension_stack = comfy.utils.get_extension_calling()
if extension_stack is not None:
for extension in extension_stack:
if extension in extensions_devices:
return torch.device(extensions_devices[extension])

if directml_enabled:
global directml_device
Expand Down
10 changes: 9 additions & 1 deletion comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ def get_extension_calling():
if os.sep + "custom_nodes" + os.sep in frame.filename:
stack_module = inspect.getmodule(frame[0])
if stack_module:
return re.sub(r".*\.?custom_nodes\.([^\.]+).*", r"\1", stack_module.__name__.replace("\\", ".").replace("/", ".")).split(".")[0]
stack = []

parts = re.sub(r".*\.?custom_nodes\.([^\.]+).*", r"\1", stack_module.__name__.replace(os.sep, ".")).split(".")

while len(parts) > 0:
stack.append(".".join(parts))
parts.pop()

return stack

return None

Expand Down

0 comments on commit d93cab1

Please sign in to comment.