Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions PyFlow/Core/PackageBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import importlib.util
import inspect
from typing import Optional

from PyFlow.Core import PinBase
from PyFlow.Core import NodeBase
Expand All @@ -41,12 +42,13 @@ def __init__(self):
self._TOOLS = OrderedDict()
self._PREFS_WIDGETS = OrderedDict()
self._EXPORTERS = OrderedDict()
self._CUSTOM_PLUGIN_CLASSES = {}

self._PinsInputWidgetFactory = None
self._UINodesFactory = None
self._UIPinsFactory = None

def analyzePackage(self, packagePath):
def analyzePackage(self, packagePath, custom_types: Optional[list[tuple[str, type]]] = None):
def import_subclasses(directory, base_class):
subclasses = []
for filename in os.listdir(directory):
Expand Down Expand Up @@ -75,13 +77,21 @@ def loadPackageElements(packagePath, element, elementDict,classType):
else:
elementDict[subclass.__name__] = subclass

# initiate custom element store
if custom_types is None:
custom_types = []
for typ in custom_types:
if typ[0] not in self._CUSTOM_PLUGIN_CLASSES:
self._CUSTOM_PLUGIN_CLASSES[typ[0]] = {}
custom_elements = [(typ[0], self._CUSTOM_PLUGIN_CLASSES[typ[0]], typ[1]) for typ in custom_types]
# Load all elements from the package
for element in [("FunctionLibraries", self._FOO_LIBS, FunctionLibraryBase),
("Nodes", self._NODES, NodeBase),
("Pins", self._PINS, PinBase),
("Tools", self._TOOLS, ToolBase),
("Exporters", self._EXPORTERS, IDataExporter),
("PrefsWidgets", self._PREFS_WIDGETS, CategoryWidgetBase)]:
("PrefsWidgets", self._PREFS_WIDGETS, CategoryWidgetBase)] + \
custom_elements:
loadPackageElements(packagePath, element[0], element[1], element[2])
if os.path.exists(os.path.join(packagePath, "Factories")):
modPrefix = "PyFlow.Packages."+self.__class__.__name__+".Factories."
Expand Down Expand Up @@ -164,3 +174,12 @@ def PinsInputWidgetFactory(self):
:rtype: function
"""
return self._PinsInputWidgetFactory

def GetCustomClasses(self, custom_type):
"""Registered custom plugins by their type name.

:rtype: dict[str, class]
"""
if custom_type not in self._CUSTOM_PLUGIN_CLASSES:
return {}
return self._CUSTOM_PLUGIN_CLASSES[custom_type]