Skip to content

Commit

Permalink
Merge pull request #111 from boutproject/upgrader-script
Browse files Browse the repository at this point in the history
Add `bout-upgrader` tool
  • Loading branch information
ZedThree authored Sep 30, 2024
2 parents 025b48f + cbe7da9 commit bfd3ca0
Show file tree
Hide file tree
Showing 11 changed files with 2,257 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ docs = [

[project.scripts]
bout-squashoutput = "boutdata.scripts.bout_squashoutput:main"
bout-upgrader = "boutupgrader:main"

[tool.setuptools.dynamic]
version = { attr = "setuptools_scm.get_version" }
Expand Down
70 changes: 70 additions & 0 deletions src/boutupgrader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import argparse
from importlib.metadata import PackageNotFoundError, version

from .bout_3to4 import add_parser as add_3to4_parser
from .bout_v5_factory_upgrader import add_parser as add_factory_parser
from .bout_v5_format_upgrader import add_parser as add_format_parser
from .bout_v5_header_upgrader import add_parser as add_header_parser
from .bout_v5_input_file_upgrader import add_parser as add_input_parser
from .bout_v5_macro_upgrader import add_parser as add_macro_parser
from .bout_v5_physics_model_upgrader import add_parser as add_model_parser
from .bout_v5_xzinterpolation_upgrader import add_parser as add_xzinterp_parser

try:
# This gives the version if the boututils package was installed
__version__ = version("boutdata")
except PackageNotFoundError:
# This branch handles the case when boututils is used from the git repo
try:
from setuptools_scm import get_version

__version__ = get_version(root="..", relative_to=__file__)
except (ModuleNotFoundError, LookupError):
__version__ = "dev"


def main():
# Parent parser that has arguments common to all subcommands
common_args = argparse.ArgumentParser(add_help=False)
common_args.add_argument(
"--quiet", "-q", action="store_true", help="Don't print patches"
)
force_or_patch_group = common_args.add_mutually_exclusive_group()
force_or_patch_group.add_argument(
"--force", "-f", action="store_true", help="Make changes without asking"
)
force_or_patch_group.add_argument(
"--patch-only", "-p", action="store_true", help="Print the patches and exit"
)

# Parent parser for commands that always take a list of files
files_args = argparse.ArgumentParser(add_help=False)
files_args.add_argument("files", action="store", nargs="+", help="Input files")

parser = argparse.ArgumentParser(
description="Upgrade BOUT++ source and input files to newer versions"
)
parser.add_argument(
"--version", action="version", version=f"%(prog)s {__version__}"
)
subcommand = parser.add_subparsers(title="subcommands", required=True)

v4_subcommand = subcommand.add_parser(
"v4", help="BOUT++ v4 upgrades"
).add_subparsers(title="v4 subcommands", required=True)
add_3to4_parser(v4_subcommand, common_args, files_args)

v5_subcommand = subcommand.add_parser(
"v5", help="BOUT++ v5 upgrades"
).add_subparsers(title="v5 subcommands", required=True)

add_factory_parser(v5_subcommand, common_args, files_args)
add_format_parser(v5_subcommand, common_args, files_args)
add_header_parser(v5_subcommand, common_args)
add_input_parser(v5_subcommand, common_args, files_args)
add_macro_parser(v5_subcommand, common_args, files_args)
add_model_parser(v5_subcommand, common_args, files_args)
add_xzinterp_parser(v5_subcommand, common_args, files_args)

args = parser.parse_args()
args.func(args)
230 changes: 230 additions & 0 deletions src/boutupgrader/bout_3to4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
#!/usr/bin/env python3

import fileinput
import re
import sys

nonmembers = {
"DC": ["DC", 1],
"slice": ["sliceXZ", 2],
}

coordinates = [
"outputVars",
"dx",
"dy",
"dz",
"non_uniform",
"d1_dx",
"d1_dy",
"J",
"Bxy",
"g11",
"g22",
"g33",
"g12",
"g13",
"g23",
"g_11",
"g_22",
"g_33",
"g_12",
"g_13",
"g_23",
"G1_11",
"G1_22",
"G1_33",
"G1_12",
"G1_13",
"G2_11",
"G2_22",
"G2_33",
"G2_12",
"G2_23",
"G3_11",
"G3_22",
"G3_33",
"G3_13",
"G3_23",
"G1",
"G2",
"G3",
"ShiftTorsion",
"IntShiftTorsion",
"geometry",
"calcCovariant",
"calcContravariant",
"jacobian",
]

local_mesh = [
("ngx", "LocalNx"),
("ngy", "LocalNy"),
]

warnings = [
(r"\^", "Use pow(a,b) instead of a^b"),
(r"\.max\(", "Use max(a) instead of a.max()"),
(
"ngz",
(
"ngz is changed to LocalNz in v4."
" The extra point in z has been removed."
" Change ngz -> LocalNz, and ensure that"
" the number of points are correct"
),
),
]


def fix_nonmembers(line_text, filename, line_num, replace=False):
"""Replace member functions with nonmembers"""

old_line_text = line_text

for old, (new, num_args) in nonmembers.items():
pattern = re.compile(rf"(\w*)\.{old}\(")
matches = re.findall(pattern, line_text)
for match in matches:
replacement = f"{new}({match}"
if num_args > 1:
replacement += ", "
line_text = re.sub(pattern, replacement, line_text)
if not replace:
name_num = f"{filename}:{line_num}:"
print(f"{name_num}{old_line_text}", end="")
print(" " * len(name_num) + line_text)
if replace:
return line_text


def fix_subscripts(line_text, filename, line_num, replace=False):
"""Replace triple square brackets with round brackets
Should also check that the variable is a Field3D/Field2D - but doesn't
"""

old_line_text = line_text
# Catch both 2D and 3D arrays
pattern = re.compile(r"\[([^[]*)\]\[([^[]*)\](?:\[([^[]*)\])?")
matches = re.findall(pattern, line_text)
for match in matches:
# If the last group is non-empty, then it was a 3D array
if len(match[2]):
replacement = r"(\1, \2, \3)"
else:
replacement = r"(\1, \2)"
line_text = re.sub(pattern, replacement, line_text)
if not replace:
name_num = f"{filename}:{line_num}:"
print(f"{name_num}{old_line_text}", end="")
print(" " * len(name_num) + line_text)
if replace:
return line_text


def fix_coordinates(line_text, filename, line_num, replace=False):
"""Fix variables that have moved from mesh to coordinates"""

old_line_text = line_text

for var in coordinates:
pattern = re.compile(f"mesh->{var}")
matches = re.findall(pattern, line_text)
for match in matches:
line_text = re.sub(pattern, f"mesh->coordinates()->{var}", line_text)
if not replace:
name_num = f"{filename}:{line_num}:"
print(
f"{name_num}{old_line_text}",
end="",
)
print(" " * len(name_num) + line_text)
if replace:
return line_text


def fix_local_mesh_size(line_text, filename, line_num, replace=False):
"""Replaces ng@ with LocalNg@, where @ is in {x,y,z}"""

old_line_text = line_text

for lm in local_mesh:
pattern = re.compile(lm[0])
matches = re.findall(pattern, line_text)
for match in matches:
line_text = re.sub(pattern, lm[1], line_text)
if not replace:
name_num = f"{filename}:{line_num}:"
print(
f"{name_num}{old_line_text}",
end="",
)
print(" " * len(name_num) + line_text)
if replace:
return line_text


def throw_warnings(line_text, filename, line_num):
"""Throws a warning for ^, .max() and ngz"""

for warn in warnings:
pattern = re.compile(warn[0])
matches = re.findall(pattern, line_text)
for match in matches:
name_num = f"{filename}:{line_num}:"
# stdout is redirected to the file if --replace is given,
# therefore use stderr
sys.stderr.write(f"{name_num}{line_text}")
# Coloring with \033[91m, end coloring with \033[0m\n
sys.stderr.write(
" " * len(name_num) + f"\033[91m!!!WARNING: {warn[1]}\033[0m\n\n"
)


def add_parser(subcommand, default_args, files_args):
epilog = """
Currently bout_3to4 can detect the following transformations are needed:
- Triple square brackets instead of round brackets for subscripts
- Field member functions that are now non-members
- Variables/functions that have moved from Mesh to Coordinates
Note that in the latter case of transformations, you will still need to manually add
Coordinates *coords = mesh->coordinates();
to the correct scopes
"""

parser = subcommand.add_parser(
"3to4",
help="A little helper for upgrading from BOUT++ version 3 to version 4",
description="A little helper for upgrading from BOUT++ version 3 to version 4",
parents=[default_args, files_args],
epilog=epilog,
)
parser.set_defaults(func=run)


def run(args):
# Loops over all lines across all files
for line in fileinput.input(files=args.files, inplace=args.replace):
filename = fileinput.filename()
line_num = fileinput.filelineno()

# Apply the transformations and then update the line if we're doing a replacement
new_line = fix_nonmembers(line, filename, line_num, args.replace)
line = new_line if args.replace else line

new_line = fix_subscripts(line, filename, line_num, args.replace)
line = new_line if args.replace else line

new_line = fix_coordinates(line, filename, line_num, args.replace)
line = new_line if args.replace else line

new_line = fix_local_mesh_size(line, filename, line_num, args.replace)
line = new_line if args.replace else line

new_line = throw_warnings(line, filename, line_num)

# If we're doing a replacement, then we need to print all lines, without a newline
if args.replace:
print(line, end="")
Loading

0 comments on commit bfd3ca0

Please sign in to comment.