Skip to content

Commit

Permalink
Introduce possibility to pass class arguments to PythonFMU builder sc…
Browse files Browse the repository at this point in the history
…ript (#225)
  • Loading branch information
Jorgelmh authored Jan 10, 2025
1 parent 6e4ac1e commit 3c3b013
Showing 1 changed file with 144 additions and 15 deletions.
159 changes: 144 additions & 15 deletions pythonfmu/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import shutil
import sys
import tempfile
from types import FunctionType
import zipfile
import inspect
from pathlib import Path
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable, Literal, Optional, Tuple, Union
from xml.dom.minidom import parseString
from xml.etree.ElementTree import Element, SubElement, tostring
from .osutil import get_lib_extension, get_platform
Expand All @@ -20,24 +21,140 @@

logger = logging.getLogger(__name__)

def match_par(txt: str, left: str = "(", right: str = ")") -> tuple[int, Literal[-1]] | tuple[int, int]:
"""
Finds the position of the matching closing parenthesis for the first opening parenthesis in a given string.
Args:
txt (str): The input string to search within.
left (str, optional): The character representing the opening parenthesis. Defaults to "(".
right (str, optional): The character representing the closing parenthesis. Defaults to ")".
Returns:
tuple[int, Literal[-1]] | tuple[int, int]: A tuple containing the position of the first opening parenthesis
and the position of the matching closing parenthesis. If no matching closing parenthesis is found,
returns the position of the first opening parenthesis and -1.
Raises:
AssertionError: If the first opening parenthesis is not found in the input string.
"""

def get_class_name(interface) -> str:
"""Returns the name of the class derived from Fmi2Slave in the given interface module.
pos0 = txt.find(left, 0)
assert pos0 >= 0, f"First {left} not found"
stack = [pos0]
i = pos0
while True:
i += 1
if len(txt) <= i:
return (pos0, -1)
elif txt[i] == "#": # comment
i = txt.find("\n", i)
elif txt[i:].startswith(left):
stack.append(i)
elif txt[i:].startswith(right):
if len(stack) > 1:
stack.pop(-1)
else:
return (pos0, i)

def get_model_class(src: Path) -> Fmi2Slave:
"""
Given a source file path, dynamically import the module and find the class that
inherits from Fmi2Slave with the longest hierarchy.
Args:
interface: The module containing the classes to be inspected.
src (Path): The path to the source file containing the module.
Returns:
Fmi2Slave: The class that inherits from Fmi2Slave with the longest hierarchy.
Raises:
ValueError: If no class inheriting from Fmi2Slave is found in the module.
ValueError: If multiple classes with the same hierarchy length are found.
"""

modulename = src.stem
module = importlib.import_module(modulename)

assert inspect.ismodule(module)

modelclasses: dict[Fmi2Slave, int] = {}

# get all classes in the module and store them in a dict with their hierarchy length
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj):
mro = inspect.getmro(obj)
if Fmi2Slave in mro and not inspect.isabstract(obj):
# store the class and its hierarchy length
modelclasses.update({obj: len(mro)})

if not len(modelclasses):
raise ValueError(f"No child class of Fmi2Slave found in module {src}") from None
else:
# Return the class with the longest hierarchy, if no unique class is found, raise errors
maxlen = max(n for n in modelclasses.values())
classes = [c for c, n in modelclasses.items() if n == maxlen]
if not len(classes):
raise ValueError(f"No child class of Fmi2Slave found in module {src}") from None
elif len(classes) > 1:
raise ValueError(f"Non-unique Fmi2Slave-derived class in module {src}. Found {classes}.") from None
else:
return classes[0]

def update_model_parameters(src: Path, model: Fmi2Slave, newargs: dict) -> str:
"""
Update the model parameters in the __init__ function of a given module.
This function modifies the default values of the parameters in the __init__
function of the specified model with the new values provided in the newargs
dictionary. It returns the updated module code as a string.
Args:
src (Path): The path to the source file containing the module.
model (Fmi2Slave): The model object whose __init__ function parameters
need to be updated.
newargs (dict): A dictionary containing the new parameter values. The
keys should be the parameter names and the values should
be the new default values.
Returns:
str: The name of the class derived from Fmi2Slave, or None if no such class is found.
str: The updated module code as a string.
Raises:
AssertionError: If the __init__ function is not found in the module.
"""
candidate, mro = None, []
for cl in [x for x in dir(interface) if inspect.isclass(getattr(interface, x))]: # get all classes in module and go through them
if any(m.__name__ == 'Fmi2Slave' for m in inspect.getmro(getattr(interface, cl))): # inspect the class hierarchy and return if 'Fmi2Slave' found
if getattr(interface, cl) not in mro: # must be a sub-class of the already registered (or first)
candidate, mro = cl, inspect.getmro(getattr(interface, cl))

return candidate
init: FunctionType = None
modulename = src.stem
module = importlib.import_module(modulename)

# Find the __init__ function in the module
for name, obj in inspect.getmembers(model):
if inspect.isfunction(obj) and name == "__init__":
init = obj
break

module_lines = inspect.getsourcelines(module)

assert init is not None, f"__init__() function not found in module {src}, model {model}"
sig = inspect.signature(init)
pars = sig.parameters
newpars = []

# Replace the default values of the parameters with the new provided values
for p in pars:
if p in newargs:
par = pars[p].replace(default=newargs[p])
else:
par = pars[p]
newpars.append(par)
signew = inspect.Signature(parameters=newpars)

# Replace the signature of the __init__ function
init_line = inspect.getsourcelines(init)[1]
from_init = "".join(line for line in module_lines[0][init_line - 1 :])
init_pos = from_init.find("__init__")
start, end = (match_par(from_init[init_pos - 1 :])[i] + init_pos for i in range(2))

from_init = from_init.replace(from_init[start - 1 : end], str(signew), 1)
module_code = "".join(line for line in module_lines[0][: init_line - 1]) + from_init

return module_code

def get_model_description(filepath: Path, module_name: str) -> Tuple[str, Element]:
def get_model_description(filepath: Path, module_name: str, class_name: str) -> Tuple[str, Element]:
"""Extract the FMU model description as XML.
Args:
Expand All @@ -55,7 +172,6 @@ def get_model_description(filepath: Path, module_name: str) -> Tuple[str, Elemen
fmu_interface = importlib.util.module_from_spec(spec)
spec.loader.exec_module(fmu_interface)
# Instantiate the interface
class_name = get_class_name(fmu_interface)
instance = getattr(fmu_interface, class_name)(instance_name="dummyInstance", resources=str(filepath.parent))
finally:
sys.path.remove(str(filepath.parent)) # remove inserted temporary path
Expand All @@ -76,6 +192,7 @@ def build_FMU(
dest: FilePath = ".",
project_files: Iterable[FilePath] = set(),
documentation_folder: Optional[FilePath] = None,
newargs: dict | None = None,
**options,
) -> Path:
script_file = Path(script_file)
Expand All @@ -96,11 +213,23 @@ def build_FMU(
f"The documentation folder does not exists {documentation_folder!s}"
)

if script_file.parent not in sys.path:
sys.path.insert(0, str(script_file.parent))

module_name = script_file.stem
model_class = get_model_class(script_file)

with tempfile.TemporaryDirectory(prefix="pythonfmu_") as tempd:
temp_dir = Path(tempd)
shutil.copy2(script_file, temp_dir)

if newargs:
model_file = temp_dir / f"{module_name}.py"
updated_code = update_model_parameters(script_file, model_class, newargs)

# Write the updated code to a new file
model_file.write_text(updated_code)
else:
shutil.copy2(script_file, temp_dir)

# Embed pythonfmu in the FMU so it does not need to be included
dep_folder = temp_dir / "pythonfmu"
Expand Down Expand Up @@ -129,7 +258,7 @@ def build_FMU(
shutil.copy2(file_, temp_dir)

model_identifier, xml = get_model_description(
temp_dir.absolute() / script_file.name, module_name
temp_dir.absolute() / script_file.name, module_name, model_class.__name__
)

dest_file = dest / f"{model_identifier}.fmu"
Expand Down

0 comments on commit 3c3b013

Please sign in to comment.