From 3c3b0131817f0dce0b3b465d299f6f1d8bc3b759 Mon Sep 17 00:00:00 2001 From: Jorge Mendez <42736565+Jorgelmh@users.noreply.github.com> Date: Fri, 10 Jan 2025 19:13:54 +0100 Subject: [PATCH] Introduce possibility to pass class arguments to PythonFMU builder script (#225) --- pythonfmu/builder.py | 159 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 144 insertions(+), 15 deletions(-) diff --git a/pythonfmu/builder.py b/pythonfmu/builder.py index da9275d..5a3288c 100644 --- a/pythonfmu/builder.py +++ b/pythonfmu/builder.py @@ -6,10 +6,11 @@ import shutil import sys import tempfile +from types import FunctionType import zipfile import inspect from pathlib import Path -from typing import Iterable, Optional, Tuple, Union +from typing import Iterable, Literal, Optional, Tuple, Union from xml.dom.minidom import parseString from xml.etree.ElementTree import Element, SubElement, tostring from .osutil import get_lib_extension, get_platform @@ -20,24 +21,140 @@ logger = logging.getLogger(__name__) +def match_par(txt: str, left: str = "(", right: str = ")") -> tuple[int, Literal[-1]] | tuple[int, int]: + """ + Finds the position of the matching closing parenthesis for the first opening parenthesis in a given string. + + Args: + txt (str): The input string to search within. + left (str, optional): The character representing the opening parenthesis. Defaults to "(". + right (str, optional): The character representing the closing parenthesis. Defaults to ")". + Returns: + tuple[int, Literal[-1]] | tuple[int, int]: A tuple containing the position of the first opening parenthesis + and the position of the matching closing parenthesis. If no matching closing parenthesis is found, + returns the position of the first opening parenthesis and -1. + Raises: + AssertionError: If the first opening parenthesis is not found in the input string. + """ -def get_class_name(interface) -> str: - """Returns the name of the class derived from Fmi2Slave in the given interface module. + pos0 = txt.find(left, 0) + assert pos0 >= 0, f"First {left} not found" + stack = [pos0] + i = pos0 + while True: + i += 1 + if len(txt) <= i: + return (pos0, -1) + elif txt[i] == "#": # comment + i = txt.find("\n", i) + elif txt[i:].startswith(left): + stack.append(i) + elif txt[i:].startswith(right): + if len(stack) > 1: + stack.pop(-1) + else: + return (pos0, i) + +def get_model_class(src: Path) -> Fmi2Slave: + """ + Given a source file path, dynamically import the module and find the class that + inherits from Fmi2Slave with the longest hierarchy. Args: - interface: The module containing the classes to be inspected. + src (Path): The path to the source file containing the module. + Returns: + Fmi2Slave: The class that inherits from Fmi2Slave with the longest hierarchy. + Raises: + ValueError: If no class inheriting from Fmi2Slave is found in the module. + ValueError: If multiple classes with the same hierarchy length are found. + """ + + modulename = src.stem + module = importlib.import_module(modulename) + + assert inspect.ismodule(module) + + modelclasses: dict[Fmi2Slave, int] = {} + + # get all classes in the module and store them in a dict with their hierarchy length + for _, obj in inspect.getmembers(module): + if inspect.isclass(obj): + mro = inspect.getmro(obj) + if Fmi2Slave in mro and not inspect.isabstract(obj): + # store the class and its hierarchy length + modelclasses.update({obj: len(mro)}) + + if not len(modelclasses): + raise ValueError(f"No child class of Fmi2Slave found in module {src}") from None + else: + # Return the class with the longest hierarchy, if no unique class is found, raise errors + maxlen = max(n for n in modelclasses.values()) + classes = [c for c, n in modelclasses.items() if n == maxlen] + if not len(classes): + raise ValueError(f"No child class of Fmi2Slave found in module {src}") from None + elif len(classes) > 1: + raise ValueError(f"Non-unique Fmi2Slave-derived class in module {src}. Found {classes}.") from None + else: + return classes[0] + +def update_model_parameters(src: Path, model: Fmi2Slave, newargs: dict) -> str: + """ + Update the model parameters in the __init__ function of a given module. + This function modifies the default values of the parameters in the __init__ + function of the specified model with the new values provided in the newargs + dictionary. It returns the updated module code as a string. + + Args: + src (Path): The path to the source file containing the module. + model (Fmi2Slave): The model object whose __init__ function parameters + need to be updated. + newargs (dict): A dictionary containing the new parameter values. The + keys should be the parameter names and the values should + be the new default values. Returns: - str: The name of the class derived from Fmi2Slave, or None if no such class is found. + str: The updated module code as a string. + Raises: + AssertionError: If the __init__ function is not found in the module. """ - candidate, mro = None, [] - for cl in [x for x in dir(interface) if inspect.isclass(getattr(interface, x))]: # get all classes in module and go through them - if any(m.__name__ == 'Fmi2Slave' for m in inspect.getmro(getattr(interface, cl))): # inspect the class hierarchy and return if 'Fmi2Slave' found - if getattr(interface, cl) not in mro: # must be a sub-class of the already registered (or first) - candidate, mro = cl, inspect.getmro(getattr(interface, cl)) - return candidate + init: FunctionType = None + modulename = src.stem + module = importlib.import_module(modulename) + + # Find the __init__ function in the module + for name, obj in inspect.getmembers(model): + if inspect.isfunction(obj) and name == "__init__": + init = obj + break + + module_lines = inspect.getsourcelines(module) + + assert init is not None, f"__init__() function not found in module {src}, model {model}" + sig = inspect.signature(init) + pars = sig.parameters + newpars = [] + + # Replace the default values of the parameters with the new provided values + for p in pars: + if p in newargs: + par = pars[p].replace(default=newargs[p]) + else: + par = pars[p] + newpars.append(par) + signew = inspect.Signature(parameters=newpars) + + # Replace the signature of the __init__ function + init_line = inspect.getsourcelines(init)[1] + from_init = "".join(line for line in module_lines[0][init_line - 1 :]) + init_pos = from_init.find("__init__") + start, end = (match_par(from_init[init_pos - 1 :])[i] + init_pos for i in range(2)) + + from_init = from_init.replace(from_init[start - 1 : end], str(signew), 1) + module_code = "".join(line for line in module_lines[0][: init_line - 1]) + from_init + + return module_code -def get_model_description(filepath: Path, module_name: str) -> Tuple[str, Element]: +def get_model_description(filepath: Path, module_name: str, class_name: str) -> Tuple[str, Element]: """Extract the FMU model description as XML. Args: @@ -55,7 +172,6 @@ def get_model_description(filepath: Path, module_name: str) -> Tuple[str, Elemen fmu_interface = importlib.util.module_from_spec(spec) spec.loader.exec_module(fmu_interface) # Instantiate the interface - class_name = get_class_name(fmu_interface) instance = getattr(fmu_interface, class_name)(instance_name="dummyInstance", resources=str(filepath.parent)) finally: sys.path.remove(str(filepath.parent)) # remove inserted temporary path @@ -76,6 +192,7 @@ def build_FMU( dest: FilePath = ".", project_files: Iterable[FilePath] = set(), documentation_folder: Optional[FilePath] = None, + newargs: dict | None = None, **options, ) -> Path: script_file = Path(script_file) @@ -96,11 +213,23 @@ def build_FMU( f"The documentation folder does not exists {documentation_folder!s}" ) + if script_file.parent not in sys.path: + sys.path.insert(0, str(script_file.parent)) + module_name = script_file.stem + model_class = get_model_class(script_file) with tempfile.TemporaryDirectory(prefix="pythonfmu_") as tempd: temp_dir = Path(tempd) - shutil.copy2(script_file, temp_dir) + + if newargs: + model_file = temp_dir / f"{module_name}.py" + updated_code = update_model_parameters(script_file, model_class, newargs) + + # Write the updated code to a new file + model_file.write_text(updated_code) + else: + shutil.copy2(script_file, temp_dir) # Embed pythonfmu in the FMU so it does not need to be included dep_folder = temp_dir / "pythonfmu" @@ -129,7 +258,7 @@ def build_FMU( shutil.copy2(file_, temp_dir) model_identifier, xml = get_model_description( - temp_dir.absolute() / script_file.name, module_name + temp_dir.absolute() / script_file.name, module_name, model_class.__name__ ) dest_file = dest / f"{model_identifier}.fmu"