diff --git a/pyproject.toml b/pyproject.toml index 2f83530..8c025df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ docs = [ [project.scripts] bout-squashoutput = "boutdata.scripts.bout_squashoutput:main" +bout-upgrader = "boutupgrader:main" [tool.setuptools.dynamic] version = { attr = "setuptools_scm.get_version" } diff --git a/src/boutupgrader/__init__.py b/src/boutupgrader/__init__.py new file mode 100644 index 0000000..5741adf --- /dev/null +++ b/src/boutupgrader/__init__.py @@ -0,0 +1,70 @@ +import argparse +from importlib.metadata import PackageNotFoundError, version + +from .bout_3to4 import add_parser as add_3to4_parser +from .bout_v5_factory_upgrader import add_parser as add_factory_parser +from .bout_v5_format_upgrader import add_parser as add_format_parser +from .bout_v5_header_upgrader import add_parser as add_header_parser +from .bout_v5_input_file_upgrader import add_parser as add_input_parser +from .bout_v5_macro_upgrader import add_parser as add_macro_parser +from .bout_v5_physics_model_upgrader import add_parser as add_model_parser +from .bout_v5_xzinterpolation_upgrader import add_parser as add_xzinterp_parser + +try: + # This gives the version if the boututils package was installed + __version__ = version("boutdata") +except PackageNotFoundError: + # This branch handles the case when boututils is used from the git repo + try: + from setuptools_scm import get_version + + __version__ = get_version(root="..", relative_to=__file__) + except (ModuleNotFoundError, LookupError): + __version__ = "dev" + + +def main(): + # Parent parser that has arguments common to all subcommands + common_args = argparse.ArgumentParser(add_help=False) + common_args.add_argument( + "--quiet", "-q", action="store_true", help="Don't print patches" + ) + force_or_patch_group = common_args.add_mutually_exclusive_group() + force_or_patch_group.add_argument( + "--force", "-f", action="store_true", help="Make changes without asking" + ) + force_or_patch_group.add_argument( + "--patch-only", "-p", action="store_true", help="Print the patches and exit" + ) + + # Parent parser for commands that always take a list of files + files_args = argparse.ArgumentParser(add_help=False) + files_args.add_argument("files", action="store", nargs="+", help="Input files") + + parser = argparse.ArgumentParser( + description="Upgrade BOUT++ source and input files to newer versions" + ) + parser.add_argument( + "--version", action="version", version=f"%(prog)s {__version__}" + ) + subcommand = parser.add_subparsers(title="subcommands", required=True) + + v4_subcommand = subcommand.add_parser( + "v4", help="BOUT++ v4 upgrades" + ).add_subparsers(title="v4 subcommands", required=True) + add_3to4_parser(v4_subcommand, common_args, files_args) + + v5_subcommand = subcommand.add_parser( + "v5", help="BOUT++ v5 upgrades" + ).add_subparsers(title="v5 subcommands", required=True) + + add_factory_parser(v5_subcommand, common_args, files_args) + add_format_parser(v5_subcommand, common_args, files_args) + add_header_parser(v5_subcommand, common_args) + add_input_parser(v5_subcommand, common_args, files_args) + add_macro_parser(v5_subcommand, common_args, files_args) + add_model_parser(v5_subcommand, common_args, files_args) + add_xzinterp_parser(v5_subcommand, common_args, files_args) + + args = parser.parse_args() + args.func(args) diff --git a/src/boutupgrader/bout_3to4.py b/src/boutupgrader/bout_3to4.py new file mode 100755 index 0000000..f208634 --- /dev/null +++ b/src/boutupgrader/bout_3to4.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 + +import fileinput +import re +import sys + +nonmembers = { + "DC": ["DC", 1], + "slice": ["sliceXZ", 2], +} + +coordinates = [ + "outputVars", + "dx", + "dy", + "dz", + "non_uniform", + "d1_dx", + "d1_dy", + "J", + "Bxy", + "g11", + "g22", + "g33", + "g12", + "g13", + "g23", + "g_11", + "g_22", + "g_33", + "g_12", + "g_13", + "g_23", + "G1_11", + "G1_22", + "G1_33", + "G1_12", + "G1_13", + "G2_11", + "G2_22", + "G2_33", + "G2_12", + "G2_23", + "G3_11", + "G3_22", + "G3_33", + "G3_13", + "G3_23", + "G1", + "G2", + "G3", + "ShiftTorsion", + "IntShiftTorsion", + "geometry", + "calcCovariant", + "calcContravariant", + "jacobian", +] + +local_mesh = [ + ("ngx", "LocalNx"), + ("ngy", "LocalNy"), +] + +warnings = [ + (r"\^", "Use pow(a,b) instead of a^b"), + (r"\.max\(", "Use max(a) instead of a.max()"), + ( + "ngz", + ( + "ngz is changed to LocalNz in v4." + " The extra point in z has been removed." + " Change ngz -> LocalNz, and ensure that" + " the number of points are correct" + ), + ), +] + + +def fix_nonmembers(line_text, filename, line_num, replace=False): + """Replace member functions with nonmembers""" + + old_line_text = line_text + + for old, (new, num_args) in nonmembers.items(): + pattern = re.compile(rf"(\w*)\.{old}\(") + matches = re.findall(pattern, line_text) + for match in matches: + replacement = f"{new}({match}" + if num_args > 1: + replacement += ", " + line_text = re.sub(pattern, replacement, line_text) + if not replace: + name_num = f"{filename}:{line_num}:" + print(f"{name_num}{old_line_text}", end="") + print(" " * len(name_num) + line_text) + if replace: + return line_text + + +def fix_subscripts(line_text, filename, line_num, replace=False): + """Replace triple square brackets with round brackets + + Should also check that the variable is a Field3D/Field2D - but doesn't + """ + + old_line_text = line_text + # Catch both 2D and 3D arrays + pattern = re.compile(r"\[([^[]*)\]\[([^[]*)\](?:\[([^[]*)\])?") + matches = re.findall(pattern, line_text) + for match in matches: + # If the last group is non-empty, then it was a 3D array + if len(match[2]): + replacement = r"(\1, \2, \3)" + else: + replacement = r"(\1, \2)" + line_text = re.sub(pattern, replacement, line_text) + if not replace: + name_num = f"{filename}:{line_num}:" + print(f"{name_num}{old_line_text}", end="") + print(" " * len(name_num) + line_text) + if replace: + return line_text + + +def fix_coordinates(line_text, filename, line_num, replace=False): + """Fix variables that have moved from mesh to coordinates""" + + old_line_text = line_text + + for var in coordinates: + pattern = re.compile(f"mesh->{var}") + matches = re.findall(pattern, line_text) + for match in matches: + line_text = re.sub(pattern, f"mesh->coordinates()->{var}", line_text) + if not replace: + name_num = f"{filename}:{line_num}:" + print( + f"{name_num}{old_line_text}", + end="", + ) + print(" " * len(name_num) + line_text) + if replace: + return line_text + + +def fix_local_mesh_size(line_text, filename, line_num, replace=False): + """Replaces ng@ with LocalNg@, where @ is in {x,y,z}""" + + old_line_text = line_text + + for lm in local_mesh: + pattern = re.compile(lm[0]) + matches = re.findall(pattern, line_text) + for match in matches: + line_text = re.sub(pattern, lm[1], line_text) + if not replace: + name_num = f"{filename}:{line_num}:" + print( + f"{name_num}{old_line_text}", + end="", + ) + print(" " * len(name_num) + line_text) + if replace: + return line_text + + +def throw_warnings(line_text, filename, line_num): + """Throws a warning for ^, .max() and ngz""" + + for warn in warnings: + pattern = re.compile(warn[0]) + matches = re.findall(pattern, line_text) + for match in matches: + name_num = f"{filename}:{line_num}:" + # stdout is redirected to the file if --replace is given, + # therefore use stderr + sys.stderr.write(f"{name_num}{line_text}") + # Coloring with \033[91m, end coloring with \033[0m\n + sys.stderr.write( + " " * len(name_num) + f"\033[91m!!!WARNING: {warn[1]}\033[0m\n\n" + ) + + +def add_parser(subcommand, default_args, files_args): + epilog = """ + Currently bout_3to4 can detect the following transformations are needed: + - Triple square brackets instead of round brackets for subscripts + - Field member functions that are now non-members + - Variables/functions that have moved from Mesh to Coordinates + + Note that in the latter case of transformations, you will still need to manually add + Coordinates *coords = mesh->coordinates(); + to the correct scopes + """ + + parser = subcommand.add_parser( + "3to4", + help="A little helper for upgrading from BOUT++ version 3 to version 4", + description="A little helper for upgrading from BOUT++ version 3 to version 4", + parents=[default_args, files_args], + epilog=epilog, + ) + parser.set_defaults(func=run) + + +def run(args): + # Loops over all lines across all files + for line in fileinput.input(files=args.files, inplace=args.replace): + filename = fileinput.filename() + line_num = fileinput.filelineno() + + # Apply the transformations and then update the line if we're doing a replacement + new_line = fix_nonmembers(line, filename, line_num, args.replace) + line = new_line if args.replace else line + + new_line = fix_subscripts(line, filename, line_num, args.replace) + line = new_line if args.replace else line + + new_line = fix_coordinates(line, filename, line_num, args.replace) + line = new_line if args.replace else line + + new_line = fix_local_mesh_size(line, filename, line_num, args.replace) + line = new_line if args.replace else line + + new_line = throw_warnings(line, filename, line_num) + + # If we're doing a replacement, then we need to print all lines, without a newline + if args.replace: + print(line, end="") diff --git a/src/boutupgrader/bout_v5_factory_upgrader.py b/src/boutupgrader/bout_v5_factory_upgrader.py new file mode 100644 index 0000000..1558fae --- /dev/null +++ b/src/boutupgrader/bout_v5_factory_upgrader.py @@ -0,0 +1,250 @@ +import copy +import re + +from .common import apply_or_display_patch + +# Dictionary of factory methods that may need updating +factories = { + "Interpolation": { + "factory_name": "InterpolationFactory", + "type_name": "Interpolation", + "create_method": "create", + }, + "InvertPar": { + "factory_name": "InvertPar", + "type_name": "InvertPar", + "create_method": "create", + "old_create_method": "Create", + "arguments_changed": True, + }, + "Mesh": {"factory_name": "Mesh", "type_name": "Mesh", "create_method": "Create"}, + "Laplacian": { + "factory_name": "Laplacian", + "type_name": "Laplacian", + "create_method": "create", + }, + "LaplaceXZ": { + "factory_name": "LaplaceXZ", + "type_name": "LaplaceXZ", + "create_method": "create", + }, + "SolverFactory": { + "factory_name": "SolverFactory", + "type_name": "Solver", + "create_method": "createSolver", + }, + "Solver": { + "factory_name": "Solver", + "type_name": "Solver", + "create_method": "create", + }, +} + + +def find_factory_calls(factory, source): + """Find all the places where the factory creation method is called, + and return a list of the variable names + + Parameters + ---------- + factory + Dictionary containing 'factory_name' and 'create_method' + source + Text to search + + """ + return re.findall( + r""" + \s*([\w_]+) # variable + \s*=\s* + {factory_name}:: + .*{create_method}.* + """.format( + **factory + ), + source, + re.VERBOSE, + ) + + +def find_type_pointers(factory, source): + return re.findall( + r""" + \b{type_name}\s*\*\s* # Type name and pointer + ([\w_]+)\s*; # Variable name + """.format( + **factory + ), + source, + re.VERBOSE, + ) + + +def fix_declarations(factory, variables, source): + """Fix the declaration of varables in source. Returns modified source + + Replaces `Type*` with either `std::unique_ptr` for + declarations, or with `auto` for initialisations. + + Parameters + ---------- + factory + Dictionary of factory information + variables + List of variable names + source + Text to update + + """ + + for variable in variables: + # Declarations + source = re.sub( + r""" + (.*?)(class\s*)? # optional "class" keyword + \b({type_name})\s*\*\s* # Type-pointer + ({variable_name})\s*; # Variable + """.format( + type_name=factory["type_name"], variable_name=variable + ), + r"\1std::unique_ptr<\3> \4{nullptr};", + source, + flags=re.VERBOSE, + ) + + # Declarations with initialisation from factory + source = re.sub( + r""" + (.*?)(class\s*)? # optional "class" keyword + ({type_name})\s*\*\s* # Type-pointer + ({variable_name})\s* # Variable + =\s* # Assignment from factory + ({factory_name}::.*{create_method}.*); + """.format( + variable_name=variable, **factory + ), + r"\1auto \4 = \5;", + source, + flags=re.VERBOSE, + ) + + # Declarations with zero initialisation + source = re.sub( + r""" + (.*?)(?:class\s*)? # optional "class" keyword + ({type_name})\s*\*\s* # Type-pointer + ({variable_name})\s* # Variable + =\s* # Assignment + (0|nullptr|NULL); + """.format( + variable_name=variable, **factory + ), + r"\1std::unique_ptr<\2> \3{nullptr};", + source, + flags=re.VERBOSE, + ) + + return source + + +def fix_deletions(variables, source): + """Remove `delete` statements of variables. Returns modified source + + Parameters + ---------- + variables + List of variable names + source + Text to update + + """ + + for variable in variables: + source = re.sub( + rf"(.*;?)\s*(delete\s*{variable})\s*;", + r"\1", + source, + ) + + return source + + +def fix_create_method(factory, source): + """Fix change of name of factory `create` method""" + + if "old_create_method" not in factory: + return source + old_create_pattern = re.compile( + r"({factory_name})\s*::\s*{old_create_method}\b".format(**factory) + ) + if not old_create_pattern.findall(source): + return source + + if factory.get("arguments_changed", False): + print( + "**WARNING** Arguments of {factory_name}::{create_method} have changed, and your current arguments may not work." + " Please consult the documentation for the new arguments.".format(**factory) + ) + + return re.sub( + r"({factory_name})\s*::\s*{old_create_method}\b".format(**factory), + r"\1::{create_method}".format(**factory), + source, + ) + + +def apply_fixes(factories, source, all_declarations=False): + """Apply the various fixes for each factory to source. Returns + modified source + + Parameters + ---------- + factories + Dictionary of factory properties + source + Text to update + """ + + modified = source + + for factory in factories.values(): + modified = fix_create_method(factory, modified) + variables = find_factory_calls(factory, modified) + if all_declarations: + variables = variables + find_type_pointers(factory, modified) + modified = fix_declarations(factory, variables, modified) + modified = fix_deletions(variables, modified) + + return modified + + +def add_parser(subcommand, default_args, files_args): + factory_help = "Fix types of factory-created objects" + parser = subcommand.add_parser( + "factory", + help=factory_help, + description=factory_help, + parents=[default_args, files_args], + ) + parser.add_argument( + "--all-declarations", + "-a", + action="store_true", + help="Fix all declarations of factory types, not just variables created from factories", + ) + parser.set_defaults(func=run) + + +def run(args): + for filename in args.files: + with open(filename) as f: + contents = f.read() + original = copy.deepcopy(contents) + + modified = apply_fixes( + factories, contents, all_declarations=args.all_declarations + ) + + apply_or_display_patch( + filename, original, modified, args.patch_only, args.quiet, args.force + ) diff --git a/src/boutupgrader/bout_v5_format_upgrader.py b/src/boutupgrader/bout_v5_format_upgrader.py new file mode 100644 index 0000000..5c0b9b0 --- /dev/null +++ b/src/boutupgrader/bout_v5_format_upgrader.py @@ -0,0 +1,139 @@ +import copy +import re + +from .common import apply_or_display_patch + +format_replacements = { + "c": "c", + "d": "d", + "e": "e", + "f": "f", + "g": "g", + "i": "d", + "ld": "d", + "le": "e", + "lu": "d", + "p": "p", + "s": "s", + "zu": "d", +} + + +def fix_format_replacement(format_replacement, source): + """Replace printf format with fmt format""" + return re.sub( + rf"%([0-9]*\.?[0-9]*){format_replacement[0]}", + rf"{{:\1{format_replacement[1]}}}", + source, + ) + + +def fix_trivial_format(source): + """Reduce trivial formatting of strings to just the string""" + + def trivial_replace(match): + if match.group(2): + return f"{match.group(1)}{match.group(2)}{match.group(4)}" + if match.group(3): + return f"{match.group(1)}{match.group(3)}{match.group(4)}" + raise ValueError(f"Found an unexpected match: {match}") + + return re.sub( + r""" + (.*)? + "{:s}",\s* # Entire format is just a string + (?:([\w_]+)\.c_str\(\) # And replacement is std::string::c_str + |(".*?")) + (.*)? + """, + trivial_replace, + source, + flags=re.VERBOSE, + ) + + +def fix_string_c_str(source): + """Fix formats that use {:s} where the replacement is using std::string::c_str""" + return re.sub( + r""" + (".*{:s}[^;]*?",) # A format string containing {:s} + \s*([^;]+?)\.c_str\(\) # Replacement of std::string::c_str + """, + r"\1 \2", + source, + flags=re.VERBOSE, + ) + + +def fix_trace(source): + """Fix TRACE macros where fix_string_c_str has failed for some reason""" + return re.sub( + r""" + (TRACE\(".*{:s}.*",) + \s*([\w_]+)\.c_str\(\)\); # Replacement of std::string::c_str + """, + r"\1 \2);", + source, + flags=re.VERBOSE, + ) + + +def fix_toString_c_str(source): + """Fix formats that call toString where the replacement is using std::string::c_str""" + return re.sub( + r""" + (".*{:s}[^;]*?",.*?) # A format string containing {:s} + (toString\(.*?\))\.c_str\(\) # Replacement of std::string::c_str + """, + r"\1\2", + source, + flags=re.VERBOSE, + ) + + +def apply_fixes(format_replacements, source): + """Apply the various fixes for each factory to source. Returns + modified source + + Parameters + ---------- + factories + Dictionary of factory properties + source + Text to update + """ + + modified = source + + for format_replacement in format_replacements.items(): + modified = fix_format_replacement(format_replacement, modified) + + modified = fix_trivial_format(modified) + modified = fix_string_c_str(modified) + modified = fix_trace(modified) + + return modified + + +def add_parser(subcommand, default_args, files_args): + format_help = "Fix format specifiers" + parser = subcommand.add_parser( + "format", + description=format_help, + help=format_help, + parents=[default_args, files_args], + ) + parser.set_defaults(func=run) + + +def run(args): + for filename in args.files: + with open(filename) as f: + contents = f.read() + original = copy.deepcopy(contents) + + modified = apply_fixes(format_replacements, contents) + + apply_or_display_patch( + filename, original, modified, args.patch_only, args.quiet, args.force + ) diff --git a/src/boutupgrader/bout_v5_header_upgrader.py b/src/boutupgrader/bout_v5_header_upgrader.py new file mode 100644 index 0000000..1c7b26f --- /dev/null +++ b/src/boutupgrader/bout_v5_header_upgrader.py @@ -0,0 +1,160 @@ +import argparse +import copy +import re +import subprocess +import textwrap +from pathlib import Path +from typing import List + +from .common import apply_or_display_patch + +header_shim_sentinel = "// BOUT++ header shim" + +header_warning = f"""\ +#pragma once +{header_shim_sentinel} +#warning Header "{{0}}" has moved to "bout/{{0}}". Run `bout-upgrader header` to fix +#include "bout/{{0}}" +""" + + +def header_needs_moving(header: Path) -> bool: + """Check if `header` has not yet been moved""" + with open(header) as f: + return header_shim_sentinel not in f.read() + + +def deprecated_header_list(include_path: Path = Path("./include")): + """List of deprecated header paths (that is, those in bare + ``include`` directory) + + """ + return include_path.glob("*.hxx") + + +def write_header_shim(header: Path): + """Write 'shim' for header, that ``include``s new location""" + with open(header, "w") as f: + f.write(header_warning.format(header.name)) + + +def fix_library_header_locations( + include_path: Path = Path("./include"), quiet: bool = False +): + unmoved_headers = list( + filter(header_needs_moving, deprecated_header_list(include_path)) + ) + include_bout_path = include_path / "bout" + + if unmoved_headers == []: + print("No headers to move!") + return + + out = subprocess.run("git diff-index --cached HEAD --quiet", shell=True) + if out.returncode: + raise RuntimeError( + "git index not clean! Please commit or discard any changes before continuing" + ) + + # First we move any existing headers and commit this change, so + # that history is preserved + for header in unmoved_headers: + new_path = include_bout_path / header.name + if not quiet: + print(f"Moving '{header}' to '{new_path}'") + run(f"git mv {header} {new_path}", shell=True, check=True) + + run(r"git commit -m 'Move headers to `include/bout`'", shell=True, check=True) + + # Now we can write the compatibility shim + for header in unmoved_headers: + write_header_shim(header) + + run(f"git add {' '.join(map(str, unmoved_headers))}", shell=True, check=True) + run( + r"git commit -m 'Add compatibility shims for old header locations'", + shell=True, + check=True, + ) + + +def make_header_regex(deprecated_headers: List[str]) -> re.Pattern: + """Create a regular expression to match deprecated header locations""" + deprecated_header_joined = "|".join(header.name for header in deprecated_headers) + return re.compile(rf'(#include\s+<|")(?:\.\./)?({deprecated_header_joined})(>|")') + + +def apply_fixes(header_regex, source): + """Apply all fixes in this module""" + modified = copy.deepcopy(source) + + return header_regex.sub(r"\1bout/\2\3", modified) + + +def add_parser(subcommand, default_args): + parser = subcommand.add_parser( + "header", + help="Fix deprecated header locations", + formatter_class=argparse.RawDescriptionHelpFormatter, + description=textwrap.dedent( + """\ + Fix deprecated header locations for BOUT++ v4 -> v5 + + All BOUT++ headers are now under ``include/bout`` and + should be included as ``#include ``. This + tool will fix such includes. + + For developers: the option ``--move-deprecated-headers`` + will move the headers from ``include`` to + ``include/bout``, and add a compatibility shim in the old + location. This option is mutually exclusive with + ``--files``, and should be used after running this tool + over the library files. + + WARNING: If any files do need moving, this will create a + git commit in order to preserve history across file moves. + If you have staged changes, this tool will not work, so to + avoid committing undesired or unrelated changes. + + """ + ), + parents=[default_args], + ) + + parser.add_argument( + "--include-path", + "-i", + help="Path to `include` directory", + default="./include", + type=Path, + ) + + header_group = parser.add_mutually_exclusive_group(required=True) + header_group.add_argument("--files", nargs="*", action="store", help="Input files") + header_group.add_argument( + "--move-deprecated-headers", + action="store_true", + help="Move the deprecated headers", + ) + parser.set_defaults(func=run) + + +def run(args): + deprecated_headers = deprecated_header_list(args.include_path) + + if args.move_deprecated_headers: + fix_library_header_locations(args.include_path, args.quiet) + exit(0) + + header_regex = make_header_regex(deprecated_headers) + + for filename in args.files: + with open(filename) as f: + contents = f.read() + original = copy.deepcopy(contents) + + modified = apply_fixes(header_regex, contents) + + apply_or_display_patch( + filename, original, modified, args.patch_only, args.quiet, args.force + ) diff --git a/src/boutupgrader/bout_v5_input_file_upgrader.py b/src/boutupgrader/bout_v5_input_file_upgrader.py new file mode 100644 index 0000000..fb4aeb7 --- /dev/null +++ b/src/boutupgrader/bout_v5_input_file_upgrader.py @@ -0,0 +1,342 @@ +import argparse +import copy +import itertools +import textwrap +import warnings + +from boututils.boutwarnings import AlwaysWarning + +from .common import create_patch, yes_or_no + + +def case_sensitive_init(self, name="root", parent=None): + self._sections = dict() + self._keys = dict() + self._name = name + self._parent = parent + self.comments = dict() + self.inline_comments = dict() + self._comment_whitespace = dict() + + +# This should be a list of dicts, each containing "old", "new" and optionally "new_values". +# The values of "old"/"new" keys should be the old/new names of input file values or +# sections. The value of "new_values" is a dict containing replacements for values of the +# option. "old_type" optionally specifies the type of the old value of the option; for +# example this is needed for special handling of boolean values. +REPLACEMENTS = [ + {"old": "mesh:paralleltransform", "new": "mesh:paralleltransform:type"}, + {"old": "fci", "new": "mesh:paralleltransform"}, + {"old": "interpolation", "new": "mesh:paralleltransform:xzinterpolation"}, + { + "old": "fft:fft_measure", + "new": "fft:fft_measurement_flag", + "old_type": bool, + "new_values": {False: "estimate", True: "measure"}, + }, + {"old": "TIMESTEP", "new": "timestep"}, + {"old": "NOUT", "new": "nout"}, + {"old": "ddx", "new": "mesh:ddx"}, + {"old": "ddy", "new": "mesh:ddy"}, + {"old": "ddz", "new": "mesh:ddz"}, + {"old": "laplace:laplace_nonuniform", "new": "laplace:nonuniform"}, + {"old": "mesh:dump_format", "new": "dump_format"}, + {"old": "solver:ATOL", "new": "solver:atol"}, + {"old": "solver:RTOL", "new": "solver:rtol"}, + # This was inconsistent in the library + {"old": "All", "new": "all"}, + # The following haven't been changed, but are frequently spelt with the wrong case + {"old": "mxg", "new": "MXG"}, + {"old": "myg", "new": "MYG"}, + {"old": "nxpe", "new": "NXPE"}, + {"old": "nype", "new": "NYPE"}, + {"old": "mesh:NX", "new": "mesh:nx"}, + {"old": "mesh:NY", "new": "mesh:ny"}, + {"old": "mesh:shiftangle", "new": "mesh:ShiftAngle"}, + {"old": "mesh:shiftAngle", "new": "mesh:ShiftAngle"}, + {"old": "mesh:zshift", "new": "mesh:zShift"}, + {"old": "mesh:StaggerGrids", "new": "mesh:staggergrids"}, + {"old": "output:shiftOutput", "new": "output:shiftoutput"}, + {"old": "output:ShiftOutput", "new": "output:shiftoutput"}, + {"old": "output:shiftInput", "new": "output:shiftinput"}, + {"old": "output:ShiftInput", "new": "output:shiftinput"}, + {"old": "output:flushFrequency", "new": "output:flushfrequency"}, + {"old": "output:FlushFrequency", "new": "output:flushfrequency"}, + {"old": "TwistShift", "new": "twistshift"}, + {"old": "zmin", "new": "ZMIN"}, + {"old": "zmax", "new": "ZMAX"}, + {"old": "ZPERIOD", "new": "zperiod"}, + # 'restart' can be either a section or a value, so move all the + # section:values instead + {"old": "restart:parallel", "new": "restart_files:parallel"}, + {"old": "restart:flush", "new": "restart_files:flush"}, + {"old": "restart:guards", "new": "restart_files:guards"}, + {"old": "restart:floats", "new": "restart_files:floats"}, + {"old": "restart:openclose", "new": "restart_files:openclose"}, + {"old": "restart:enabled", "new": "restart_files:enabled"}, + {"old": "restart:init_missing", "new": "restart_files:init_missing"}, + {"old": "restart:shiftOutput", "new": "restart_files:shiftOutput"}, + {"old": "restart:shiftInput", "new": "restart_files:shiftInput"}, + {"old": "restart:flushFrequency", "new": "restart_files:flushFrequency"}, +] + +for section, derivative in itertools.product( + ["ddx", "ddy", "ddz", "diff"], ["First", "Second", "Fourth", "Flux", "Upwind"] +): + REPLACEMENTS.append( + { + "old": f"mesh:{section}:{derivative}", + "new": f"mesh:{section}:{derivative.lower()}", + } + ) + +DELETED = ["dump_format"] + +for section, value in itertools.product( + ["output", "restart"], + [ + "floats", + # Following are not yet implemented in OptionsNetCDF. Not yet + # clear if they need to be, or can be safely removed + # "shiftoutput", + # "shiftinput", + # "flushfrequency", + # "parallel", + # "guards", + # "openclose", + # "init_missing", + ], +): + DELETED.append(f"{section}:{value}") + + +def parse_bool(bool_expression): + try: + bool_expression_lower = bool_expression.lower() + except AttributeError: + # bool_expression was not a str: no need to lower + bool_expression_lower = bool_expression + + if bool_expression_lower in ["true", "y", "t", 1, True]: + return True + elif bool_expression_lower in ["false", "n", "f", 0, False]: + return False + else: + raise RuntimeError( + f"Expected boolean option. Could not parse {bool_expression}" + ) + + +def already_fixed(replacement, options_file): + """Check if the options_file already has already had this particular fix applied""" + # The old key is there and the new one isn't, then it's definitely not fixed + if replacement["old"] in options_file and replacement["new"] not in options_file: + return False + # If the new isn't there, there's nothing to fix + if replacement["new"] not in options_file: + return True + # If we don't need to fix values, we're done + if "new_values" not in replacement: + return True + # Check if the current value is acceptable + return options_file[replacement["new"]] in replacement["new_values"].values() + + +def fix_replacements(replacements, options_file): + """Change the names of options in options_file according to the list + of dicts replacements + + """ + for replacement in replacements: + try: + if already_fixed(replacement, options_file): + continue + options_file.rename(replacement["old"], replacement["new"]) + except KeyError: + pass + except TypeError as e: + raise RuntimeError( + "Could not apply transformation: '{old}' -> '{new}' to file '{0}', due to error:" + "\n\t{1}".format(options_file.filename, e.args[0], **replacement) + ) from e + else: + if "old_type" in replacement: + # Special handling for certain types, replicating what BOUT++ does + if replacement["old_type"] is bool: + # The original value must be something that BOUT++ recognises as a + # bool. + # replacement["new_values"] must contain both True and False keys. + old_value = parse_bool(options_file[replacement["new"]]) + options_file[replacement["new"]] = replacement["new_values"][ + old_value + ] + else: + raise ValueError( + f"Error in REPLACEMENTS: type {replacement['type']} is not handled" + ) + else: + # Option values are just a string + if "new_values" in replacement: + old_value = options_file[replacement["new"]] + try: + old_value = old_value.lower() + except AttributeError: + # old_value was not a str, so no need to convert to lower case + pass + + try: + options_file[replacement["new"]] = replacement["new_values"][ + old_value + ] + except KeyError: + # No replacement given for this value: keep the old one + pass + + +def remove_deleted(deleted, options_file): + """Remove each key that appears in 'deleted' from 'options_file'""" + + for key in deleted: + # Better would be options_file.pop(key, None), but there's a + # bug in current implementation + if key in options_file: + del options_file[key] + + +def apply_fixes(replacements, deleted, options_file): + """Apply all fixes in this module""" + + modified = copy.deepcopy(options_file) + + fix_replacements(replacements, modified) + + remove_deleted(deleted, modified) + + return modified + + +def possibly_apply_patch(patch, options_file, quiet=False, force=False): + """Possibly apply patch to options_file. If force is True, applies the + patch without asking, overwriting any existing file. Otherwise, + ask for confirmation from stdin + + """ + if not quiet: + print("\n******************************************") + print(f"Changes to {options_file.filename}\n") + print(patch) + print("\n******************************************") + + if force: + make_change = True + else: + make_change = yes_or_no(f"Make changes to {options_file.filename}?") + if make_change: + options_file.write(overwrite=True) + return make_change + + +def add_parser(subcommand, default_args, files_args): + parser = subcommand.add_parser( + "input", + formatter_class=argparse.RawDescriptionHelpFormatter, + help="Fix input files", + description=textwrap.dedent( + """\ + Fix input files for BOUT++ v5+ + + Please note that this will only fix input options in sections with + standard or default names. You may also need to fix options in custom + sections. + + Warning! Even with no fixes, there may still be changes as this script + will "canonicalise" the input files: + + * nested sections are moved to be under their parent section, while + preserving relative order + + * empty sections are removed + + * floating point numbers may have their format changed, although the + value will not change + + * consecutive blank lines will be reduced to a single blank line + + * whitespace around equals signs will be changed to exactly one space + + * trailing whitespace will be removed + + * comments will always use '#' + + Files that change in this way will have the "canonicalisation" patch + presented first. If you choose not to apply this patch, the "upgrade + fixer" patch will still include it.""" + ), + parents=[default_args, files_args], + ) + + parser.add_argument( + "--accept-canonical", + "-c", + action="store_true", + help="Automatically accept the canonical patch", + ) + parser.add_argument( + "--canonical-only", + "-k", + action="store_true", + help="Only check/fix canonicalisation", + ) + parser.set_defaults(func=run) + + +def run(args): + from boutdata.data import BoutOptions, BoutOptionsFile + + # Monkey-patch BoutOptions to make sure it's case sensitive + BoutOptions.__init__ = case_sensitive_init + + warnings.simplefilter("ignore", AlwaysWarning) + + for filename in args.files: + with open(filename) as f: + original_source = f.read() + + try: + original = BoutOptionsFile(filename) + except ValueError: + pass + + canonicalised_patch = create_patch(filename, original_source, str(original)) + if canonicalised_patch and not args.patch_only: + print(f"WARNING: original input file '{filename}' not in canonical form!") + applied_patch = possibly_apply_patch( + canonicalised_patch, + original, + args.quiet, + args.force or args.accept_canonical, + ) + # Re-read input file + if applied_patch: + original_source = str(original) + + if args.canonical_only: + continue + + try: + modified = apply_fixes(REPLACEMENTS, DELETED, original) + except RuntimeError as e: + print(e) + continue + patch = create_patch(filename, original_source, str(modified)) + + if args.patch_only: + print(patch) + continue + + if not patch: + if not args.quiet: + print(f"No changes to make to {filename}") + continue + + possibly_apply_patch(patch, modified, args.quiet, args.force) diff --git a/src/boutupgrader/bout_v5_macro_upgrader.py b/src/boutupgrader/bout_v5_macro_upgrader.py new file mode 100644 index 0000000..c271a6e --- /dev/null +++ b/src/boutupgrader/bout_v5_macro_upgrader.py @@ -0,0 +1,352 @@ +import argparse +import copy +import re +import textwrap + +from .common import apply_or_display_patch + +# List of macros, their replacements and what header to find them +# in. Each element should be a dict with "old", "new" and "headers" +# keys, with "old" and "new" values being strings, and "headers" being a +# list of strings. "new" can also be None if the macro has been removed, which +# will cause an error to be printed if the macro is found. +MACRO_REPLACEMENTS = [ + { + "old": "REVISION", + "new": "bout::version::revision", + "headers": ["bout/revision.hxx"], + "macro": False, + "always_defined": True, + }, + { + "old": "BOUT_VERSION_DOUBLE", + "new": "bout::version::as_double", + "headers": ["bout/version.hxx", "bout.hxx"], + "macro": False, + "always_defined": True, + }, + { + "old": "BOUT_VERSION_STRING", + "new": "bout::version::full", + "headers": ["bout/version.hxx", "bout.hxx"], + "macro": False, + "always_defined": True, + }, + # Next one is not technically a macro, but near enough + { + "old": "BOUT_VERSION", + "new": "bout::version::full", + "headers": ["bout/version.hxx", "bout.hxx"], + "macro": False, + "always_defined": True, + }, + { + "old": "BACKTRACE", + "new": "BOUT_USE_BACKTRACE", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "HAS_ARKODE", + "new": "BOUT_HAS_ARKODE", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "HAS_CVODE", + "new": "BOUT_HAS_CVODE", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "HAS_HDF5", + "new": None, + "headers": [], + "macro": True, + "always_defined": True, + }, + { + "old": "HAS_IDA", + "new": "BOUT_HAS_IDA", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "HAS_LAPACK", + "new": "BOUT_HAS_LAPACK", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "LAPACK", + "new": "BOUT_HAS_LAPACK", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "HAS_NETCDF", + "new": "BOUT_HAS_NETCDF", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "HAS_PETSC", + "new": "BOUT_HAS_PETSC", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "HAS_PRETTY_FUNCTION", + "new": "BOUT_HAS_PRETTY_FUNCTION", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "HAS_PVODE", + "new": "BOUT_HAS_PVODE", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "TRACK", + "new": "BOUT_USE_TRACK", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "NCDF4", + "new": "BOUT_HAS_NETCDF", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "NCDF", + "new": "BOUT_HAS_LEGACY_NETCDF", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "HDF5", + "new": None, + "headers": [], + "macro": True, + "always_defined": True, + }, + { + "old": "DEBUG_ENABLED", + "new": "BOUT_USE_OUTPUT_DEBUG", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "BOUT_FPE", + "new": "BOUT_USE_SIGFPE", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "LOGCOLOR", + "new": "BOUT_USE_COLOR", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, + { + "old": "OPENMP_SCHEDULE", + "new": "BOUT_OPENMP_SCHEDULE", + "headers": "bout/build_config.hxx", + "macro": True, + "always_defined": True, + }, +] + + +def fix_include_version_header(old, headers, source): + """Make sure version.hxx header is included""" + + if not isinstance(headers, list): + headers = [headers] + + # If header is already included, we can skip this fix + for header in headers: + if ( + re.search(rf'^#\s*include.*(<|"){header}(>|")', source, flags=re.MULTILINE) + is not None + ): + return source + + # If the old macro isn't in the file, we can skip this fix + if re.search(rf"\b{old}\b", source) is None: + return source + + # Now we want to find a suitable place to stick the new include + # Good candidates are includes of BOUT++ headers + includes = [] + source_lines = source.splitlines() + for linenumber, line in enumerate(source_lines): + if re.match(r"^#\s*include.*bout/", line): + includes.append(linenumber) + if re.match(r"^#\s*include.*physicsmodel", line): + includes.append(linenumber) + + if includes: + last_include = includes[-1] + 1 + else: + # No suitable includes, so just stick at the top of the file + last_include = 0 + source_lines.insert(last_include, f'#include "{headers[0]}"') + + return "\n".join(source_lines) + + +def fix_ifdefs(old, source): + """Remove any code inside #ifdef/#ifndef blocks that would now not be compiled""" + source_lines = source.splitlines() + + # Something to keep track of nested sections + in_ifdef = None + # List of (#ifdef or #ifndef, dict of start/else/end lines) + macro_blocks = [] + for linenumber, line in enumerate(source_lines): + if_def = re.match(r"#\s*(ifn?def)\s*(.*)", line) + else_block = re.match(r"#\s*else", line) + endif = re.match(r"#\s*endif", line) + if not (if_def or else_block or endif): + continue + # Now we need to keep track of whether we're inside an + # interesting #ifdef/ifndef, as they might be nested, and we + # want to find the matching #endif and #else + if endif: + if in_ifdef is not None: + in_ifdef -= 1 + if in_ifdef == 0: + in_ifdef = None + macro_blocks[-1]["end"] = linenumber + continue + if else_block: + if in_ifdef == 1: + macro_blocks[-1]["else"] = linenumber + continue + if if_def.group(2) == old: + in_ifdef = 1 + macro_blocks.append({"start": linenumber, "if_def_type": if_def.group(1)}) + elif in_ifdef is not None: + in_ifdef += 1 + + if macro_blocks == []: + return source + + # Get all of the lines to be removed + lines_to_remove = set() + for block in macro_blocks: + lines_to_remove |= set(block.values()) + if block["if_def_type"] == "ifdef": + if "else" in block: + # Delete the #else block for #ifdef + lines_to_remove |= set(range(block["else"], block["end"])) + else: + # Keep the #else block for #ifndef if there is one, otherwise remove the + # whole block + lines_to_remove |= set( + range(block["start"], block.get("else", block["end"])) + ) + + # Apparently this is actually the best way of removing a bunch of (possibly) + # non-contiguous indices + modified_lines = [ + line for num, line in enumerate(source_lines) if num not in lines_to_remove + ] + + return "\n".join(modified_lines) + + +def fix_always_defined_macros(old, new, source): + """Fix '#ifdef's that should become plain '#if'""" + new_source = re.sub(rf"#ifdef\s+{old}\b", rf"#if {new}", source) + return re.sub(rf"#ifndef\s+{old}\b", rf"#if !{new}", new_source) + + +def fix_replacement(old, new, source): + """Straight replacements""" + return re.sub(rf'([^"_])\b{old}\b([^"_])', rf"\1{new}\2", source) + + +def apply_fixes(replacements, source): + """Apply all fixes in this module""" + modified = copy.deepcopy(source) + + for replacement in replacements: + if replacement["new"] is None: + print( + f"{replacement['old']} has been removed, please delete from your code" + ) + continue + + modified = fix_include_version_header( + replacement["old"], replacement["headers"], modified + ) + if replacement["macro"] and replacement["always_defined"]: + modified = fix_always_defined_macros( + replacement["old"], replacement["new"], modified + ) + elif replacement["always_defined"]: + modified = fix_ifdefs(replacement["old"], modified) + modified = fix_replacement(replacement["old"], replacement["new"], modified) + + return modified + + +def add_parser(subcommand, default_args, files_args): + parser = subcommand.add_parser( + "macro", + help="Fix macro defines", + formatter_class=argparse.RawDescriptionHelpFormatter, + description=textwrap.dedent( + """\ + Fix macro defines for BOUT++ v4 -> v5 + + Please note that this is only slightly better than dumb text replacement. It + will fix the following: + + * replacement of macros with variables or new names + * inclusion of correct headers for new variables + * removal of #if(n)def/#endif blocks that do simple checks for the old + macro, keeping the appriopriate part, if replaced by a variable + * change '#if(n)def' for '#if (!)' if the replacment is always defined + + It will try not to replace quoted macro names, but may + still replace them in strings or comments. + + Please check the diff output carefully! + """ + ), + parents=[default_args, files_args], + ) + parser.set_defaults(func=run) + + +def run(args): + for filename in args.files: + with open(filename) as f: + contents = f.read() + original = copy.deepcopy(contents) + + modified = apply_fixes(MACRO_REPLACEMENTS, contents) + apply_or_display_patch( + filename, original, modified, args.patch_only, args.quiet, args.force + ) diff --git a/src/boutupgrader/bout_v5_physics_model_upgrader.py b/src/boutupgrader/bout_v5_physics_model_upgrader.py new file mode 100644 index 0000000..5763f4a --- /dev/null +++ b/src/boutupgrader/bout_v5_physics_model_upgrader.py @@ -0,0 +1,430 @@ +import argparse +import copy +import pathlib +import re +import textwrap +import warnings + +from .common import apply_or_display_patch + +PHYSICS_MODEL_INCLUDE = '#include "bout/physicsmodel.hxx"' + +PHYSICS_MODEL_SKELETON = """ +class {name} : public PhysicsModel {{ +protected: + {methods} +}}; +""" + +PHYSICS_MODEL_RHS_SKELETON = "int {function}({arguments}){override};" + +BOUTMAIN = "\n\nBOUTMAIN({})\n" + +# Regular expression for a PhysicsModel +PHYSICS_MODEL_RE = re.compile( + r"""class\s+(?P[a-zA-Z0-9_]+)\s*: # Class name + \s*(?:public)?\s*PhysicsModel[\n\s]*{ # Inherits from PhysicsModel + """, + re.VERBOSE | re.MULTILINE, +) + +FUNCTION_SIGNATURE_ARGUMENT_RE = r"""({arg_type} + \s+ # Require spaces only if the argument is named + (?PUNUSED\()? # Possible UNUSED macro + [a-zA-Z_0-9]* # Argument name + (?(unused{arg_num})\)) # If UNUSED macro was present, we need an extra closing bracket + )? +""" + + +def create_function_signature_re(function_name, argument_types): + """Create a regular expression for a legacy physics model function""" + + if not isinstance(argument_types, list): + argument_types = [argument_types] + + arguments = r",\s*".join( + [ + FUNCTION_SIGNATURE_ARGUMENT_RE.format(arg_type=argument, arg_num=num) + for num, argument in enumerate(argument_types) + ] + ) + + return rf"int\s+{function_name}\s*\({arguments}\)" + + +LEGACY_MODEL_INCLUDE_RE = re.compile( + r'^#\s*include.*(<|")boutmain.hxx(>|")', re.MULTILINE +) + +BOUT_SOLVE_RE = re.compile( + r"bout_solve\(([^,)]+,\s*[^,)]+(,\s*[^,)]+)?)\)", re.MULTILINE +) + +RHS_RE = re.compile(r"solver\s*->\s*setRHS\(\s*([a-zA-Z0-9_]+)\s*\)") + +PRECON_RE = re.compile(r"solver\s*->\s*setPrecon\(\s*([a-zA-Z0-9_]+)\s*\)") + +JACOBIAN_RE = re.compile(r"solver\s*->\s*setJacobian\(\s*([a-zA-Z0-9_]+)\s*\)") + +SPLIT_OPERATOR_RE = re.compile( + r"solver\s*->\s*setSplitOperator\(\s*([a-zA-Z0-9_]+),\s*([a-zA-Z0-9_]+)\s*\)" +) + + +def has_split_operator(source): + """Return the names of the split operator functions if set, otherwise False""" + + match = SPLIT_OPERATOR_RE.search(source) + if not match: + return False + + return match.group(1), match.group(2) + + +def is_legacy_model(source): + """Return true if the source is a legacy physics model""" + return LEGACY_MODEL_INCLUDE_RE.search(source) is not None + + +def find_last_include(source_lines): + """Return the line number after the last #include (or 0 if no includes)""" + for number, line in enumerate(reversed(source_lines)): + if line.startswith("#include"): + return len(source_lines) - number + return 0 + + +def fix_model_operator( + source, model_name, operator_name, operator_type, new_name, override +): + """Fix any definitions of the operator, and return the new declaration + + May modify source + + Parameters + ---------- + source: str + Source code to fix + model_name: str + Name of the PhysicsModel class to create or add methods to + operator_name: str + Name of the free function to fix + operator_type: str, [str] + Function argument types + new_name: str + Name of the PhysicsModel method + override: bool + Is `new_name` overriding a virtual? + """ + + # Make sure we have a list of types + if not isinstance(operator_type, list): + operator_type = [operator_type] + + # Get a regex for the function signature + operator_re = re.compile( + create_function_signature_re(operator_name, operator_type), + re.VERBOSE | re.MULTILINE, + ) + + # Find any declarations of the free function + matches = list(operator_re.finditer(source)) + if matches == []: + warnings.warn( + f"Could not find {operator_name}; is it defined in another file? If so, you will need to fix it manually" + ) + return source, False + + # If we found more than one, remove the first one as it's probably + # a declaration and not a definition + if len(matches) > 1: + source = re.sub( + create_function_signature_re(operator_name, operator_type) + r"\s*;", + "", + source, + flags=re.VERBOSE | re.MULTILINE, + ) + + # Get the names of the function arguments. Every other group + # from the regex, as the other groups match the `UNUSED` macro + arg_names = operator_re.search(source).groups()[::2] + else: + arg_names = matches[0].groups()[::2] + + # Fix definition and any existing declarations + arguments = ", ".join(arg_names) + + # Modify the definition: it's out-of-line so we need the qualified name + modified = operator_re.sub(rf"int {model_name}::{new_name}({arguments})", source) + + # Create the declaration + return ( + modified, + PHYSICS_MODEL_RHS_SKELETON.format( + function=new_name, + arguments=arguments, + override=" override" if override else "", + ), + ) + + +def fix_bout_constrain(source, error_on_warning): + """Fix uses of bout_constrain. This is complicated because it might be + in a conditional, and Solver::constraint returns void + + """ + + if "bout_constrain" not in source: + return source + + # The bout_constrain free function returns False if the Solver + # doesn't have constraints, but the Solver::constraint method does + # the checking itself, so we don't need to repeat it + modified = re.sub( + r"""if\s*\(\s*(?:!|not)\s* # in a conditional, checking for false + bout_constrain\(([^;]+,[^;]+,[^;]+)\) # actual function call + \s*\) # end of conditional + (?P\s*{\s*)? # possible open brace + (?:\s*\n)? # possible newline + \s*throw\s+BoutException\(.*\);(?:\s*\n)? # throwing an exception + (?(brace)\s*})? # consume matching closing brace + """, + r"solver->constraint(\1);\n", + source, + flags=re.VERBOSE | re.MULTILINE, + ) + + # The above might not fix everything, so best check if there are any uses left + remaining_matches = list(re.finditer("bout_constrain", modified)) + if remaining_matches == []: + # We fixed everything! + return modified + + # Construct a useful error message + source_lines = source.splitlines() + lines_context = [] + for match in remaining_matches: + bad_line = source[: match.end()].count("\n") + line_range = range(max(0, bad_line - 1), min(len(source_lines), bad_line + 2)) + lines_context.append( + "\n ".join([f"{i}:{source_lines[i]}" for i in line_range]) + ) + + message = textwrap.dedent( + """\ + Some uses of `bout_constrain` remain, but we could not automatically + convert them to use `Solver::constraint`. Please fix them before + continuing: + """ + ) + message += " " + "\n ".join(lines_context) + + if error_on_warning: + raise RuntimeError(message) + print(message) + return modified + + +def convert_old_solver_api(source, name): + """Fix or remove old Solver API calls + + Parameters + ---------- + source: str + The source code to modify + name: str + The PhysicsModel class name + """ + + # Fixing `bout_solve` is a straight replacement, easy + source = BOUT_SOLVE_RE.sub(r"solver->add(\1)", source) + + # Completely remove calls to Solver::setRHS + source = RHS_RE.sub("", source) + + # List of old free functions that now need declarations inside the + # class definition + method_decls = [] + + # Fix uses of solver->setPrecon + # Get the name of any free functions passed as arguments to setPrecon + precons = PRECON_RE.findall(source) + for precon in precons: + source, decl = fix_model_operator( + source, + name, + precon, + ["BoutReal", "BoutReal", "BoutReal"], + precon, + override=False, + ) + if decl: + method_decls.append(decl) + # Almost a straight replacement, but it's now a member-function pointer + source = PRECON_RE.sub(rf"setPrecon(&{name}::\1)", source) + + # Fix uses of solver->setJacobian, basically the same as for setPrecon + jacobians = JACOBIAN_RE.findall(source) + for jacobian in jacobians: + source, decl = fix_model_operator( + source, name, jacobian, "BoutReal", jacobian, override=False + ) + if decl: + method_decls.append(decl) + source = JACOBIAN_RE.sub(rf"setJacobian(&{name}::\1)", source) + + # If we didn't find any free functions that need to be made into + # methods, we're done + if not method_decls: + return source + + # We need to find the class defintion + class_def = PHYSICS_MODEL_RE.search(source) + if class_def is None: + warnings.warn( + f"Could not find the '{name}' class to add" + "preconditioner and/or Jacobian declarations; is it defined" + "in another file? If so, you will need to fix it manually" + ) + return source, False + + # The easiest place to stick the method declaration is on the line + # immediately following the open brace of the class def, and the + # easiest way to insert it is to split the source into a list, + # insert in the list, then join the list back into a string. + # The regex from above finds the offset in the source which we + # need to turn into a line number + first_line_of_class = source[: class_def.end() + 1].count("\n") + methods = "\n ".join(method_decls) + source_lines = source.splitlines() + source_lines.insert(first_line_of_class, f" {methods}") + + return "\n".join(source_lines) + + +def convert_legacy_model(source, name, error_on_warning): + """Convert a legacy physics model to a PhysicsModel""" + + if not is_legacy_model(source): + return source + + source = fix_bout_constrain(source, error_on_warning) + + # Replace legacy header + source = LEGACY_MODEL_INCLUDE_RE.sub(r"#include \1bout/physicsmodel.hxx\2", source) + + method_decls = [] + + source, decl = fix_model_operator( + source, name, "physics_init", "bool", "init", override=True + ) + if decl: + method_decls.append(decl) + + split_operators = has_split_operator(source) + if split_operators: + source = SPLIT_OPERATOR_RE.sub(r"setSplitOperator(true)", source) + + convective, diffusive = split_operators + # Fix the free functions + source, decl = fix_model_operator( + source, name, convective, "BoutReal", "convective", override=True + ) + if decl: + method_decls.append(decl) + source, decl = fix_model_operator( + source, name, diffusive, "BoutReal", "diffusive", override=True + ) + if decl: + method_decls.append(decl) + else: + # Fix the rhs free function + source, decl = fix_model_operator( + source, name, "physics_run", "BoutReal", "rhs", override=True + ) + if decl: + method_decls.append(decl) + + source_lines = source.splitlines() + last_include = find_last_include(source_lines) + + methods = "\n ".join(method_decls) + physics_model = PHYSICS_MODEL_SKELETON.format(methods=methods, name=name) + + source_lines.insert(last_include, physics_model) + source_lines.append(BOUTMAIN.format(name)) + + return "\n".join(source_lines) + + +def add_parser(subcommand, default_args, files_args): + parser = subcommand.add_parser( + "model", + help="Upgrade legacy physics models", + formatter_class=argparse.RawDescriptionHelpFormatter, + description=textwrap.dedent( + """\ + Upgrade legacy physics models to use the PhysicsModel class + + This will do the bare minimum required to compile, and + won't make global objects (like Field3Ds) members of the + new class, or free functions (other than + `physics_init`/`physics_run`, preconditioners, and + Jacobians) methods of the new class. Comments may also be + left behind. + + By default, this will use the file name stripped of file + extensions as the name of the new class. Use '--name=' to give a different name. + """ + ), + parents=[default_args, files_args], + ) + parser.add_argument( + "--name", + "-n", + action="store", + nargs="?", + type=str, + help="Name for new PhysicsModel class, default is from filename", + ) + parser.set_defaults(func=run) + + +def run(args): + for filename in args.files: + with open(filename) as f: + contents = f.read() + + original = copy.deepcopy(contents) + + match = PHYSICS_MODEL_RE.search(original) + if match is not None: + new_name = match.group("name") + else: + new_name = args.name or pathlib.Path(filename).stem.capitalize().replace( + " ", "_" + ) + + try: + if re.match(r"^[0-9]+.*", new_name) and not args.force: + raise ValueError( + f"Invalid name: '{new_name}'. Use --name to specify a valid C++ identifier" + ) + + modified = convert_legacy_model( + original, new_name, not (args.force or args.patch_only) + ) + + modified = convert_old_solver_api(modified, new_name) + except (RuntimeError, ValueError) as e: + error_message = textwrap.indent(f"{e}", " ") + print( + f"There was a problem applying automatic fixes to {filename}:\n\n{error_message}" + ) + continue + + apply_or_display_patch( + filename, original, modified, args.patch_only, args.quiet, args.force + ) diff --git a/src/boutupgrader/bout_v5_xzinterpolation_upgrader.py b/src/boutupgrader/bout_v5_xzinterpolation_upgrader.py new file mode 100644 index 0000000..7cad47c --- /dev/null +++ b/src/boutupgrader/bout_v5_xzinterpolation_upgrader.py @@ -0,0 +1,205 @@ +import copy +import re + +from .common import apply_or_display_patch + +try: + import clang.cindex + + has_clang = True +except ImportError: + has_clang = False + + +headers = {"interpolation": {"old": "interpolation.hxx", "new": "interpolation_xz.hxx"}} + +interpolations = { + "Hermite": {"old": "HermiteSpline", "new": "XZHermiteSpline"}, + "Interpolation": {"old": "Interpolation", "new": "XZInterpolation"}, + "MonotonicHermite": { + "old": "MonotonicHermiteSpline", + "new": "XZMonotonicHermiteSpline", + }, + "Bilinear": {"old": "Bilinear", "new": "XZBilinear"}, + "Lagrange4pt": {"old": "Lagrange4pt", "new": "XZLagrange4pt"}, +} + +factories = { + "InterpolationFactory": { + "old": "InterpolationFactory", + "new": "XZInterpolationFactory", + } +} + + +def fix_header_includes(old_header, new_header, source): + """Replace old_header with new_header in source + + Parameters + ---------- + old_header: str + Name of header to be replaced + new_header: str + Name of replacement header + source: str + Text to search + + """ + return re.sub( + rf""" + (\s*\#\s*include\s*) # Preprocessor include + (<|") + ({old_header}) # Header name + (>|") + """, + rf"\1\2{new_header}\4", + source, + flags=re.VERBOSE, + ) + + +def fix_interpolations(old_interpolation, new_interpolation, source): + return re.sub( + rf""" + \b{old_interpolation}\b + """, + rf"{new_interpolation}", + source, + flags=re.VERBOSE, + ) + + +def clang_parse(filename, source): + index = clang.cindex.Index.create() + return index.parse(filename, unsaved_files=[(filename, source)]) + + +def clang_find_interpolations(node, typename, nodes=None): + if nodes is None: + nodes = [] + if node.kind == clang.cindex.CursorKind.TYPE_REF: + if node.type.spelling == typename: + nodes.append(node) + for child in node.get_children(): + clang_find_interpolations(child, typename, nodes) + return nodes + + +def clang_fix_single_interpolation( + old_interpolation, new_interpolation, source, location +): + modified = source + line = modified[location.line - 1] + new_line = ( + line[: location.column - 1] + + new_interpolation + + line[location.column + len(old_interpolation) - 1 :] + ) + modified[location.line - 1] = new_line + return modified + + +def clang_fix_interpolation(old_interpolation, new_interpolation, node, source): + nodes = clang_find_interpolations(node, old_interpolation) + modified = source + for node in nodes: + modified = clang_fix_single_interpolation( + old_interpolation, new_interpolation, modified, node.location + ) + return modified + + +def fix_factories(old_factory, new_factory, source): + return re.sub( + rf""" + \b{old_factory}\b + """, + new_factory, + source, + flags=re.VERBOSE, + ) + + +def apply_fixes(headers, interpolations, factories, source): + """Apply all Interpolation fixes to source + + Parameters + ---------- + headers + Dictionary of old/new headers + interpolations + Dictionary of old/new Interpolation types + source + Text to update + + """ + + modified = copy.deepcopy(source) + + for header in headers.values(): + modified = fix_header_includes(header["old"], header["new"], modified) + for interpolation in interpolations.values(): + modified = fix_interpolations( + interpolation["old"], interpolation["new"], modified + ) + for factory in factories.values(): + modified = fix_factories(factory["old"], factory["new"], modified) + + return modified + + +def clang_apply_fixes(headers, interpolations, factories, filename, source): + # translation unit + tu = clang_parse(filename, source) + + modified = source + + for header in headers.values(): + modified = fix_header_includes(header["old"], header["new"], modified) + + modified = modified.split("\n") + for interpolation in interpolations.values(): + modified = clang_fix_interpolation( + interpolation["old"], interpolation["new"], tu.cursor, modified + ) + modified = "\n".join(modified) + for factory in factories.values(): + modified = fix_factories(factory["old"], factory["new"], modified) + + return modified + + +def add_parser(subcommand, default_args, files_args): + parser = subcommand.add_parser( + "xzinterp", + help="Fix types of Interpolation objects", + description="Fix types of Interpolation objects", + parents=[default_args, files_args], + ) + parser.add_argument( + "--clang", action="store_true", help="Use libclang if available" + ) + parser.set_defaults(func=run) + + +def run(args): + if args.clang and not has_clang: + raise RuntimeError( + "libclang is not available. Please install libclang Python bindings" + ) + + for filename in args.files: + with open(filename) as f: + contents = f.read() + original = copy.deepcopy(contents) + + if args.clang and has_clang: + modified = clang_apply_fixes( + headers, interpolations, factories, filename, contents + ) + else: + modified = apply_fixes(headers, interpolations, factories, contents) + + apply_or_display_patch( + filename, original, modified, args.patch_only, args.quiet, args.force + ) diff --git a/src/boutupgrader/common.py b/src/boutupgrader/common.py new file mode 100644 index 0000000..60a00d1 --- /dev/null +++ b/src/boutupgrader/common.py @@ -0,0 +1,78 @@ +import difflib + + +def yes_or_no(question: str) -> bool: + """Convert user input from yes/no variations to True/False""" + while True: + reply = input(f"{question} [y/N] ").lower().strip() + if not reply or reply[0] == "n": + return False + if reply[0] == "y": + return True + + +def create_patch(filename: str, original: str, modified: str) -> str: + """Create a unified diff between original and modified""" + + patch = "\n".join( + difflib.unified_diff( + original.splitlines(), + modified.splitlines(), + fromfile=filename, + tofile=filename, + lineterm="", + ) + ) + + return patch + + +def apply_or_display_patch( + filename: str, + original: str, + modified: str, + patch_only: bool, + quiet: bool, + force: bool, +): + """Given the original and modified versions of a file, display and/or apply it + + Parameters + ---------- + filename : str + Name of file + original : str + Original text of file + modified : str + Modified text of file + patch_only : bool + If ``True``, only print the patch + quiet : bool + If ``True``, don't print to screen, unless ``patch_only`` is + ``True`` + force : bool + If ``True``, always apply modifications to file + + """ + + patch = create_patch(filename, original, modified) + + if patch_only: + print(patch) + return + + if not patch: + if not quiet: + print(f"No changes to make to {filename}") + return + + if not quiet: + print("\n******************************************") + print(f"Changes to {filename}\n{patch}") + print("\n******************************************") + + make_change = force or yes_or_no(f"Make changes to {filename}?") + + if make_change: + with open(filename, "w") as f: + f.write(modified)