Skip to content

Commit

Permalink
update expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
mscroggs committed Dec 20, 2024
1 parent 54ef0c4 commit d76950c
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 25 deletions.
2 changes: 1 addition & 1 deletion ffcx/codegeneration/C/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def generator(ir: ExpressionIR, options):
"""Generate UFC code for an expression."""
logger.info("Generating code for expression:")
assert len(ir.expression.integrand) == 1, "Expressions only support single quadrature rule"
points = next(iter(ir.expression.integrand)).points
points = next(iter(ir.expression.integrand))[1].points
logger.info(f"--- points: {points}")
factory_name = ir.expression.name
logger.info(f"--- name: {factory_name}")
Expand Down
2 changes: 1 addition & 1 deletion ffcx/codegeneration/C/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def generator(ir: FormIR, options):
f"static ufcx_integral* form_integrals_{ir.name}[{sizes}] = {{{values}}};"
)
d["form_integrals"] = f"form_integrals_{ir.name}"
values = ", ".join(f"{i}" for i in integral_ids for _ in integral_domains[i])
values = ", ".join(f"{i}" for i, domains in zip(integral_ids, integral_domains) for _ in domains)
d["form_integral_ids_init"] = f"int form_integral_ids_{ir.name}[{sizes}] = {{{values}}};"
d["form_integral_ids"] = f"form_integral_ids_{ir.name}"
else:
Expand Down
12 changes: 6 additions & 6 deletions ffcx/codegeneration/expression_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def generate_element_tables(self):
"""Generate tables of FE basis evaluated at specified points."""
parts = []

tables = self.ir.expression.unique_tables
tables = self.ir.expression.unique_tables[self.quadrature_rule[0]]
table_names = sorted(tables)

for name in table_names:
Expand Down Expand Up @@ -125,7 +125,7 @@ def generate_quadrature_loop(self):
# Generate varying partition
body = self.generate_varying_partition()
body = L.commented_code_list(
body, f"Points loop body setup quadrature loop {self.quadrature_rule.id()}"
body, f"Points loop body setup quadrature loop {self.quadrature_rule[1].id()}"
)

# Generate dofblock parts, some of this
Expand All @@ -139,7 +139,7 @@ def generate_quadrature_loop(self):
quadparts = []
else:
iq = self.backend.symbols.quadrature_loop_index
num_points = self.quadrature_rule.points.shape[0]
num_points = self.quadrature_rule[1].points.shape[0]
quadparts = [L.ForRange(iq, 0, num_points, body=body)]
return preparts, quadparts

Expand All @@ -148,11 +148,11 @@ def generate_varying_partition(self):
# Get annotated graph of factorisation
F = self.ir.expression.integrand[self.quadrature_rule]["factorization"]

arraysymbol = L.Symbol(f"sv_{self.quadrature_rule.id()}", dtype=L.DataType.SCALAR)
arraysymbol = L.Symbol(f"sv_{self.quadrature_rule[1].id()}", dtype=L.DataType.SCALAR)
parts = self.generate_partition(arraysymbol, F, "varying")
parts = L.commented_code_list(
parts,
f"Unstructured varying computations for quadrature rule {self.quadrature_rule.id()}",
f"Unstructured varying computations for quadrature rule {self.quadrature_rule[1].id()}",
)
return parts

Expand Down Expand Up @@ -216,7 +216,7 @@ def generate_block_parts(self, blockmap, blockdata):
assert not blockdata.transposed, "Not handled yet"
components = ufl.product(self.ir.expression.shape)

num_points = self.quadrature_rule.points.shape[0]
num_points = self.quadrature_rule[1].points.shape[0]
A_shape = [num_points, components] + self.ir.expression.tensor_shape
A = self.backend.symbols.element_tensor
iq = self.backend.symbols.quadrature_loop_index
Expand Down
4 changes: 2 additions & 2 deletions ffcx/ir/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _compute_integral_ir(
i: (
points[i],
weights[i],
None if tensor_factors is None else tensor_factors[i],
tensor_factors[i] if i in tensor_factors else None,
)
for i in points
}
Expand Down Expand Up @@ -580,7 +580,7 @@ def _compute_expression_ir(

weights = np.array([1.0] * points.shape[0])
rule = QuadratureRule(points, weights)
integrands = {rule: expression}
integrands = {"": {rule: expression}}

if cell is None:
assert (
Expand Down
28 changes: 13 additions & 15 deletions ffcx/ir/representationutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,31 +53,29 @@ def create_quadrature_points_and_weights(
integral_type, cell, degree, rule, elements, use_tensor_product=False
):
"""Create quadrature rule and return points and weights."""
pts = None
wts = None
tensor_factors = None

pts = {}
wts = {}
tensor_factors = {}
if integral_type == "cell":
if cell.cellname() in ["quadrilateral", "hexahedron"] and use_tensor_product:
if cell.cellname() == "quadrilateral":
tensor_factors = [
cell_name = cell.cellname()
if cell_name in ["quadrilateral", "hexahedron"] and use_tensor_product:
if cell_name == "quadrilateral":
tensor_factors[cell_name] = [
create_quadrature("interval", degree, rule, elements) for _ in range(2)
]
elif cell.cellname() == "hexahedron":
tensor_factors = [
elif cell_name == "hexahedron":
tensor_factors[cell_name] = [
create_quadrature("interval", degree, rule, elements) for _ in range(3)
]
pts["interval"] = np.array(
[tuple(i[0] for i in p) for p in itertools.product(*[f[0] for f in tensor_factors])]
pts[cell_name] = np.array(
[tuple(i[0] for i in p) for p in itertools.product(*[f[0] for f in tensor_factors[cell_name]])]
)
wts["interval"] = np.array(
[np.prod(p) for p in itertools.product(*[f[1] for f in tensor_factors])]
wts[cell_name] = np.array(
[np.prod(p) for p in itertools.product(*[f[1] for f in tensor_factors[cell_name]])]
)
else:
pts[cell.cellname()], wts[cell.cellname()] = create_quadrature(
cell.cellname(), degree, rule, elements
pts[cell_name], wts[cell_name] = create_quadrature(
cell_name, degree, rule, elements
)
elif integral_type in ufl.measure.facet_integral_types:
for ft in cell.facet_types():
Expand Down

0 comments on commit d76950c

Please sign in to comment.