Skip to content

Commit

Permalink
refactor: combine subs phase with replace placeholder phase
Browse files Browse the repository at this point in the history
  • Loading branch information
mgreminger committed Jan 19, 2025
1 parent 2739b21 commit 99ef5b7
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,28 +1593,35 @@ def replace_sympy_funcs_with_placeholder_funcs(expression: Expr) -> Expr:
return expression


def replace_placeholder_funcs(expr: Expr,
def replace_placeholder_funcs(expr: Expr,
parameter_subs: dict[Symbol, Expr],
func_key: Literal["dim_func"] | Literal["sympy_func"],
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function],
dim_values_dict: dict[tuple[Basic,...], DimValues],
function_parents: list[Basic],
data_table_subs: DataTableSubs | None) -> Expr:

while (not is_matrix(expr)) and expr.func == function_id_wrapper:
function_parents.append(expr.args[0])
expr = cast(Expr, expr.args[1])

if (not is_matrix(expr)) and isinstance(expr, Symbol) and expr.name == "_zero_delayed_substitution":
return S.Zero
if (not is_matrix(expr)) and isinstance(expr, Symbol):
if expr.name == "_zero_delayed_substitution":
return S.Zero
elif expr in parameter_subs:
sub = parameter_subs[expr]
if isinstance(sub, Symbol) and sub.name == "_zero_delayed_substitution":
sub = S.Zero
return sub

if is_matrix(expr):
rows = []
for i in range(expr.rows):
row = []
rows.append(row)
for j in range(expr.cols):
row.append(replace_placeholder_funcs(cast(Expr, expr[i,j]), func_key,
row.append(replace_placeholder_funcs(cast(Expr, expr[i,j]), parameter_subs, func_key,
placeholder_map, placeholder_set,
dim_values_dict, function_parents.copy(),
data_table_subs) )
Expand All @@ -1629,7 +1636,7 @@ def replace_placeholder_funcs(expr: Expr,
if expr.func == dim_needs_values_wrapper:
if func_key == "sympy_func":
child_expr = expr.args[1]
dim_args = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs) for arg in child_expr.args]
dim_args = [replace_placeholder_funcs(cast(Expr, arg), parameter_subs, func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs) for arg in child_expr.args]
result = cast(Expr, cast(Callable, placeholder_map[cast(Function, child_expr.func)][func_key])(*dim_args))
if data_table_subs is not None and len(data_table_subs.subs_stack) > 0:
dim_args_snapshot = list(dim_args)
Expand All @@ -1645,21 +1652,21 @@ def replace_placeholder_funcs(expr: Expr,
dim_values = dim_values_dict.get((expr.args[0],*function_parents), None)
if dim_values is None:
raise KeyError('Dim values lookup error, this is likely a bug, please report to support@engineeringpaper.xyz')
child_processed_args = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs) for arg in child_expr.args]
child_processed_args = [replace_placeholder_funcs(cast(Expr, arg), parameter_subs, func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs) for arg in child_expr.args]
return cast(Expr, cast(Callable, placeholder_map[cast(Function, child_expr.func)][func_key])(dim_values, *child_processed_args))
elif expr.func in dummy_var_placeholder_set and func_key == "dim_func":
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs) if index > 0 else arg for index, arg in enumerate(expr.args))))
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), parameter_subs, func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs) if index > 0 else arg for index, arg in enumerate(expr.args))))
elif expr.func in placeholder_set:
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs) for arg in expr.args)))
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), parameter_subs, func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs) for arg in expr.args)))

elif data_table_subs is not None and expr.func == data_table_calc_wrapper:
if len(expr.args[0].atoms(data_table_id_wrapper)) == 0:
return replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs)
return replace_placeholder_funcs(cast(Expr, expr.args[0]), parameter_subs, func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs)

data_table_subs.subs_stack.append({})
data_table_subs.shortest_col_stack.append(None)

sub_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs)
sub_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), parameter_subs, func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs)

subs = data_table_subs.subs_stack.pop()
shortest_col = data_table_subs.shortest_col_stack.pop()
Expand All @@ -1680,7 +1687,7 @@ def replace_placeholder_funcs(expr: Expr,
return cast(Expr, Matrix([sub_expr,]*shortest_col))

elif data_table_subs is not None and expr.func == data_table_id_wrapper:
current_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs)
current_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), parameter_subs, func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs)
new_var = Symbol(f"_data_table_var_{data_table_subs.get_next_id()}")

if not is_matrix(current_expr):
Expand All @@ -1698,21 +1705,19 @@ def replace_placeholder_funcs(expr: Expr,
return cast(Expr, current_expr[0,0])

else:
return cast(Expr, expr.func(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs) for arg in expr.args)))
return cast(Expr, expr.func(*(replace_placeholder_funcs(cast(Expr, arg), parameter_subs, func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents.copy(), data_table_subs) for arg in expr.args)))

def get_dimensional_analysis_expression(parameter_subs: dict[Symbol, Expr],
expression: Expr,
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function],
dim_values_dict: dict[tuple[Basic,...], DimValues]) -> tuple[Expr | None, Exception | None]:

expression_with_parameter_subs = cast(Expr, expression.xreplace(parameter_subs))

error = None
final_expression = None

try:
final_expression = replace_placeholder_funcs(expression_with_parameter_subs,
final_expression = replace_placeholder_funcs(expression, parameter_subs,
"dim_func", placeholder_map, placeholder_set,
dim_values_dict, [], DataTableSubs())
except Exception as e:
Expand Down Expand Up @@ -1901,7 +1906,7 @@ def remove_implicit(input_set: set[str]) -> set[str]:


def solve_system(statements: list[EqualityStatement], variables: list[str],
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function], convert_floats_to_fractions: bool):
parameters = get_all_implicit_parameters(statements)
parameter_subs = get_parameter_subs(parameters, convert_floats_to_fractions)
Expand All @@ -1921,7 +1926,7 @@ def solve_system(statements: list[EqualityStatement], variables: list[str],
system_variables.update(statement["params"])
system_implicit_params.extend(statement["implicitParams"])

equality = replace_placeholder_funcs(cast(Expr, statement["expression"]),
equality = replace_placeholder_funcs(cast(Expr, statement["expression"]), {},
"sympy_func",
placeholder_map, placeholder_set, {}, [], None)

Expand Down Expand Up @@ -2004,8 +2009,7 @@ def solve_system_numerical(statements: list[EqualityStatement], variables: list[
for statement in statements:
system_variables.update(statement["params"])

equality = cast(Expr, statement["expression"]).subs(parameter_subs)
equality = replace_placeholder_funcs(cast(Expr, equality),
equality = replace_placeholder_funcs(cast(Expr, statement["expression"]), parameter_subs,
"sympy_func",
placeholder_map, placeholder_set, {}, [], None)
system.append(cast(Expr, equality.doit()))
Expand Down Expand Up @@ -2439,9 +2443,8 @@ def get_evaluated_expression(expression: Expr,
simplify_symbolic_expressions: bool,
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function]) -> tuple[ExprWithAssumptions, str | list[list[str]], dict[tuple[Basic,...],DimValues]]:
expression = cast(Expr, expression.xreplace(parameter_subs))
dim_values_dict: dict[tuple[Basic,...], DimValues] = {}
expression = replace_placeholder_funcs(expression,
expression = replace_placeholder_funcs(expression, parameter_subs,
"sympy_func",
placeholder_map,
placeholder_set, dim_values_dict, [],
Expand Down

0 comments on commit 99ef5b7

Please sign in to comment.