-
-
Notifications
You must be signed in to change notification settings - Fork 71
Remove component tensors #339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
eae33f2
a5fb11c
768f403
83609f8
36e0e5d
438c594
4713c06
0fa0eec
3e6cc92
ec00649
de91178
35472a4
13e68c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
|
||
from ufl.action import Action | ||
from ufl.algorithms.analysis import extract_arguments | ||
from ufl.algorithms.estimate_degrees import SumDegreeEstimator | ||
from ufl.algorithms.map_integrands import map_integrand_dags | ||
from ufl.algorithms.replace_derivative_nodes import replace_derivative_nodes | ||
from ufl.argument import BaseArgument | ||
|
@@ -562,6 +563,14 @@ def __init__(self, geometric_dimension): | |
"""Initialise.""" | ||
GenericDerivativeRuleset.__init__(self, var_shape=(geometric_dimension,)) | ||
self._Id = Identity(geometric_dimension) | ||
self.degree_estimator = SumDegreeEstimator(1, {}) | ||
|
||
def is_cellwise_constant(self, o): | ||
"""More precise checks for cellwise constants.""" | ||
if is_cellwise_constant(o): | ||
return True | ||
degree = map_expr_dag(self.degree_estimator, o) | ||
return degree == 0 | ||
|
||
# --- Specialized rules for geometric quantities | ||
|
||
|
@@ -572,7 +581,7 @@ def geometric_quantity(self, o): | |
otherwise transform derivatives to reference derivatives. | ||
Override for specific types if other behaviour is needed. | ||
""" | ||
if is_cellwise_constant(o): | ||
if self.is_cellwise_constant(o): | ||
return self.independent_terminal(o) | ||
else: | ||
domain = extract_unique_domain(o) | ||
|
@@ -583,7 +592,7 @@ def geometric_quantity(self, o): | |
def jacobian_inverse(self, o): | ||
"""Differentiate a jacobian_inverse.""" | ||
# grad(K) == K_ji rgrad(K)_rj | ||
if is_cellwise_constant(o): | ||
if self.is_cellwise_constant(o): | ||
return self.independent_terminal(o) | ||
if not o._ufl_is_terminal_: | ||
raise ValueError("ReferenceValue can only wrap a terminal") | ||
|
@@ -653,9 +662,10 @@ def reference_value(self, o): | |
|
||
def reference_grad(self, o): | ||
"""Differentiate a reference_grad.""" | ||
if self.is_cellwise_constant(o): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this correct? Looks like one derivative is lost. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we are simplifying There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, thanks. I tried following simple UFL code import ufl
import ufl.algorithms
import ufl.algorithms.apply_algebra_lowering
import ufl.algorithms.apply_function_pullbacks
import ufl.utils.formatting
mesh_el = ufl.finiteelement.FiniteElement("P", ufl.Cell("triangle"), 1, (2, ), ufl.pullback.IdentityPullback(), ufl.sobolevspace.H1)
mesh = ufl.Mesh(mesh_el)
V_el = ufl.finiteelement.FiniteElement("JM", ufl.Cell("triangle"), 1, (2, 2), ufl.pullback.DoubleCovariantPiola(), ufl.sobolevspace.L2)
V = ufl.FunctionSpace(mesh, V_el)
u = ufl.Coefficient(V)
x = ufl.SpatialCoordinate(mesh)
a = ufl.div(u)
a = ufl.algorithms.apply_algebra_lowering.apply_algebra_lowering(a)
a = ufl.algorithms.apply_derivatives.apply_derivatives(a)
a = ufl.algorithms.apply_function_pullbacks.apply_function_pullbacks(a)
a = ufl.algorithms.apply_derivatives.apply_derivatives(a)
print(ufl.utils.formatting.tree_format(a)) which mimics the order in But lets assume that such node will exist and remain unsimplified even if it shouldn't. Would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is always confusing is the order of preprocessing operations in
In the code above I am seeing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Firedrake we are getting lots of |
||
return self.independent_terminal(o) | ||
# grad(o) == grad(rgrad(rv(f))) -> K_ji*rgrad(rgrad(rv(f)))_rj | ||
f = o.ufl_operands[0] | ||
|
||
valid_operand = f._ufl_is_in_reference_frame_ or isinstance( | ||
f, (JacobianInverse, SpatialCoordinate, Jacobian, JacobianDeterminant) | ||
) | ||
|
@@ -676,7 +686,6 @@ def grad(self, o): | |
# Check that o is a "differential terminal" | ||
if not isinstance(o.ufl_operands[0], (Grad, Terminal)): | ||
raise ValueError("Expecting only grads applied to a terminal.") | ||
|
||
return Grad(o) | ||
|
||
def _grad(self, o): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
"""Remove component tensors. | ||
|
||
This module contains classes and functions to remove component tensors. | ||
""" | ||
# Copyright (C) 2008-2016 Martin Sandve Alnæs | ||
pbrubeck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# This file is part of UFL (https://www.fenicsproject.org) | ||
# | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
from ufl.classes import ComponentTensor, Form, Index, MultiIndex, Zero | ||
from ufl.corealg.map_dag import map_expr_dag | ||
from ufl.corealg.multifunction import MultiFunction, memoized_handler | ||
|
||
|
||
class IndexReplacer(MultiFunction): | ||
"""Replace Indices.""" | ||
|
||
def __init__(self, fimap: dict): | ||
"""Initialise. | ||
|
||
Args: | ||
fimap: map for index replacements. | ||
|
||
""" | ||
MultiFunction.__init__(self) | ||
self.fimap = fimap | ||
self._object_cache = {} | ||
|
||
expr = MultiFunction.reuse_if_untouched | ||
|
||
@memoized_handler | ||
def zero(self, o): | ||
"""Handle Zero.""" | ||
free_indices = [] | ||
index_dimensions = [] | ||
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions): | ||
k = Index(i) | ||
j = self.fimap.get(k, k) | ||
if isinstance(j, Index): | ||
free_indices.append(j.count()) | ||
index_dimensions.append(d) | ||
return Zero( | ||
shape=o.ufl_shape, | ||
free_indices=tuple(free_indices), | ||
index_dimensions=tuple(index_dimensions), | ||
) | ||
|
||
@memoized_handler | ||
def multi_index(self, o): | ||
"""Handle MultiIndex.""" | ||
return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices())) | ||
|
||
|
||
class IndexRemover(MultiFunction): | ||
"""Remove Indexed.""" | ||
|
||
def __init__(self): | ||
"""Initialise.""" | ||
MultiFunction.__init__(self) | ||
self._object_cache = {} | ||
|
||
expr = MultiFunction.reuse_if_untouched | ||
|
||
@memoized_handler | ||
def _unary_operator(self, o): | ||
"""Simplify UnaryOperator(Zero).""" | ||
(operand,) = o.ufl_operands | ||
f = map_expr_dag(self, operand) | ||
if isinstance(f, Zero): | ||
return Zero( | ||
shape=o.ufl_shape, | ||
free_indices=o.ufl_free_indices, | ||
index_dimensions=o.ufl_index_dimensions, | ||
) | ||
if f is operand: | ||
# Reuse if untouched | ||
return o | ||
return o._ufl_expr_reconstruct_(f) | ||
|
||
@memoized_handler | ||
def indexed(self, o): | ||
"""Simplify Indexed.""" | ||
o1, i1 = o.ufl_operands | ||
if isinstance(o1, ComponentTensor): | ||
# Simplify Indexed ComponentTensor | ||
o2, i2 = o1.ufl_operands | ||
# Replace inner indices first | ||
v = map_expr_dag(self, o2) | ||
# Replace outer indices | ||
assert len(i2) == len(i1) | ||
fimap = dict(zip(i2, i1)) | ||
pbrubeck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rule = IndexReplacer(fimap) | ||
return map_expr_dag(rule, v) | ||
pbrubeck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
expr = map_expr_dag(self, o1) | ||
if expr is o1: | ||
# Reuse if untouched | ||
return o | ||
return o._ufl_expr_reconstruct_(expr, i1) | ||
|
||
reference_grad = _unary_operator | ||
reference_value = _unary_operator | ||
|
||
|
||
def remove_component_tensors(o): | ||
"""Remove component tensors.""" | ||
pbrubeck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(o, Form): | ||
integrals = [] | ||
for integral in o.integrals(): | ||
integrand = remove_component_tensors(integral.integrand()) | ||
if not isinstance(integrand, Zero): | ||
integrals.append(integral.reconstruct(integrand=integrand)) | ||
return o._ufl_expr_reconstruct_(integrals) | ||
else: | ||
rule = IndexRemover() | ||
return map_expr_dag(rule, o) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You've introduced here that
is_cellwise_constant
logic is based on degree estimation, which is the only place in the code. How does it work, since the degree estimation for geometric quantities callsis_cellwise_constant
itself?ufl/ufl/algorithms/estimate_degrees.py
Lines 46 to 56 in 1ab30c2
If there is no infinite loop, could this be used in the generic cell-wise constant check? Or is it that only
GradRuleset
does not cause infinite loop?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I first tried to add this logic to the generic
is_cellwise_constant
in #334, but simplifyinggrad(cell_wise_constant)
on instantiation turns out to be problematic for nestedgrad
expression simply becausegrad(f)
expects f to have a domain andZero
does not have a domain, sograd(grad(grad(linear)))
would simplify tograd(Zero)
.