From edcb415306ea826e1737e77bb64dbec18ca66892 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Zahradn=C3=ADk?= Date: Sat, 11 Nov 2023 16:19:15 +0100 Subject: [PATCH] Make global setter for graphviz --- neuralogic/__init__.py | 26 +++++++++++++++++++++++++- neuralogic/utils/visualize/__init__.py | 15 +++++++++++++-- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/neuralogic/__init__.py b/neuralogic/__init__.py index d70c45c1..48f1ad57 100644 --- a/neuralogic/__init__.py +++ b/neuralogic/__init__.py @@ -11,6 +11,7 @@ _initial_seed = _seed _rnd_generator = None _max_memory_size = None +_graphviz_path = None jvm_params = { "classpath": os.path.join(os.path.abspath(os.path.dirname(__file__)), "jar", "NeuraLogic.jar"), @@ -106,8 +107,31 @@ def is_initialized() -> bool: return _is_initialized +def set_graphviz_path(path: Optional[str]): + """ + Set the default path to Graphviz + + Parameters + ---------- + path : Optional[str] + The Graphviz path + """ + global _graphviz_path + _graphviz_path = path + + +def get_default_graphviz_path() -> Optional[str]: + """ + Get the default path to Graphviz + """ + return _graphviz_path + + def initialize( - debug_mode: bool = False, debug_port: int = 12999, is_debug_server: bool = True, debug_suspend: bool = True + debug_mode: bool = False, + debug_port: int = 12999, + is_debug_server: bool = True, + debug_suspend: bool = True, ): """ Initialize the NeuraLogic backend. This function is called implicitly when needed and should be called diff --git a/neuralogic/utils/visualize/__init__.py b/neuralogic/utils/visualize/__init__.py index 01bcef35..bf6aa97a 100644 --- a/neuralogic/utils/visualize/__init__.py +++ b/neuralogic/utils/visualize/__init__.py @@ -4,9 +4,19 @@ import jpype +from neuralogic import get_default_graphviz_path from neuralogic.core.settings import Settings, SettingsProxy +def get_graphviz_path(path: Optional[str] = None) -> str: + """ + Get the path to the Graphviz executable + """ + if path is not None: + return path + return get_default_graphviz_path() + + def get_drawing_settings( img_type: str = "png", value_detail: int = 0, graphviz_path: Optional[str] = None ) -> SettingsProxy: @@ -19,8 +29,9 @@ def get_drawing_settings( """ settings = Settings().create_proxy() - if graphviz_path is not None: - settings.settings.graphvizPath = graphviz_path + graphviz = get_graphviz_path(graphviz_path) + if graphviz is not None: + settings.settings.graphvizPath = graphviz settings.settings.drawing = False settings.settings.storeNotShow = True