From 2b280901ef039f48df21c0f01857246b6504e711 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 2 Aug 2023 10:15:54 -0400 Subject: [PATCH 01/29] api: add support for complex dtype --- devito/data/allocators.py | 8 +++++-- devito/finite_differences/differentiable.py | 2 +- devito/operator/operator.py | 6 +++++ devito/passes/clusters/factorization.py | 4 ++-- devito/passes/iet/misc.py | 26 ++++++++++++++++++++- devito/symbolics/inspection.py | 5 ++++ devito/tools/dtypes_lowering.py | 4 +++- devito/types/basic.py | 23 +++++++++++++++--- 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/devito/data/allocators.py b/devito/data/allocators.py index 72289c57bf..14f1b04fd1 100644 --- a/devito/data/allocators.py +++ b/devito/data/allocators.py @@ -92,8 +92,12 @@ def initialize(cls): return def alloc(self, shape, dtype, padding=0): - datasize = int(reduce(mul, shape)) - ctype = dtype_to_ctype(dtype) + # For complex number, allocate double the size of its real/imaginary part + alloc_dtype = dtype(0).real.__class__ + c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1 + + datasize = int(reduce(mul, shape) * c_scale) + ctype = dtype_to_ctype(alloc_dtype) # Add padding, if any try: diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 2e1fef6548..8b5a47207c 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -68,7 +68,7 @@ def grid(self): @cached_property def dtype(self): - dtypes = {f.dtype for f in self.find(Indexed)} - {None} + dtypes = {f.dtype for f in self._functions} - {None} return infer_dtype(dtypes) @cached_property diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 363c2507e3..e2fe8dd3a1 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -469,6 +469,12 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) + + # Complex header if needed. Needs to be done specialization + # as some specific cases requires complex to be loaded first + complex_include(graph) + + # Specialize graph = cls._specialize_iet(graph, **kwargs) # Instrument the IET for C-level profiling diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 33253e245e..794e437e97 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -1,6 +1,7 @@ from collections import defaultdict from sympy import Add, Mul, S, collect +from sympy.core import NumberKind from devito.ir import cluster_pass from devito.symbolics import BasicWrapperMixin, estimate_cost, retrieve_symbols @@ -173,8 +174,7 @@ def _collect_nested(expr): Recursion helper for `collect_nested`. """ # Return semantic (rebuilt expression, factorization candidates) - - if expr.is_Number: + if expr.kind is NumberKind: return expr, {'coeffs': expr} elif expr.is_Function: return expr, {'funcs': expr} diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index f0b2b7f4f5..f8a99ead5f 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -3,7 +3,9 @@ import cgen import numpy as np import sympy +import numpy as np +from devito import configuration from devito.finite_differences import Max, Min from devito.ir import (Any, Forward, Iteration, List, Prodder, FindApplications, FindNodes, FindSymbols, Transformer, Uxreplace, @@ -16,7 +18,7 @@ from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', - 'generate_macros', 'minimize_symbols'] + 'generate_macros', 'minimize_symbols', 'complex_include'] @iet_pass @@ -240,6 +242,28 @@ def minimize_symbols(iet): return iet, {} +@iet_pass +def complex_include(iet): + """ + Add headers for complex arithmetic + """ + if configuration['language'] == 'cuda': + lib = 'cuComplex.h' + elif configuration['language'] == 'hip': + lib = 'hip/hip_complex.h' + else: + lib = 'complex.h' + + functions = FindSymbols().visit(iet) + for f in functions: + try: + if np.issubdtype(f.dtype, np.complexfloating): + return iet, {'includes': (lib,)} + except TypeError: + pass + return iet, {} + + def remove_redundant_moddims(iet): key = lambda d: d.is_Modulo and d.origin is not None mds = [d for d in FindSymbols('dimensions').visit(iet) if key(d)] diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 437d48fff0..8339aabc2c 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -3,6 +3,8 @@ import numpy as np from sympy import (Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple) +from sympy.core.operations import AssocOp +from sympy.core.numbers import ImaginaryUnit from devito.finite_differences import Derivative from devito.finite_differences.differentiable import IndexDerivative @@ -167,6 +169,7 @@ def _(expr, estimate, seen): return 0, True +@_estimate_cost.register(ImaginaryUnit) @_estimate_cost.register(Number) @_estimate_cost.register(ReservedWord) def _(expr, estimate, seen): @@ -189,6 +192,8 @@ def _(expr, estimate, seen): flops, flags = _estimate_cost.registry[object](expr, estimate, seen) if {S.One, S.NegativeOne}.intersection(expr.args): flops -= 1 + if ImaginaryUnit in expr.args: + flops *= 2 return flops, flags diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 4e7908a552..f2d0c6ad31 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -133,6 +133,9 @@ def dtype_to_cstr(dtype): def dtype_to_ctype(dtype): """Translate numpy.dtype into a ctypes type.""" + if isinstance(dtype, CustomDtype): + return dtype + try: return ctypes_vector_mapper[dtype] except KeyError: @@ -230,7 +233,6 @@ def ctypes_to_cstr(ctype, toarray=None): retval = '%s[%d]' % (ctypes_to_cstr(ctype._type_, toarray), ctype._length_) elif ctype.__name__.startswith('c_'): name = ctype.__name__[2:] - # A primitive datatype # FIXME: Is there a better way of extracting the C typename ? # Here, we're following the ctypes convention that each basic type has diff --git a/devito/types/basic.py b/devito/types/basic.py index e21bae6453..065efe0590 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -13,7 +13,8 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex) + frozendict, memoized_meth, sympy_mutex, dtype_to_cstr, + CustomDtype) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -432,7 +433,16 @@ def _C_name(self): @property def _C_ctype(self): - return dtype_to_ctype(self.dtype) + if isinstance(self.dtype, CustomDtype): + return self.dtype + elif np.issubdtype(self.dtype, np.complexfloating): + rtype = self.dtype(0).real.__class__ + ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return r + else: + return dtype_to_ctype(self.dtype) def _subs(self, old, new, **hints): """ @@ -1470,7 +1480,14 @@ def _C_name(self): @cached_property def _C_ctype(self): try: - return POINTER(dtype_to_ctype(self.dtype)) + if np.issubdtype(self.dtype, np.complexfloating): + rtype = self.dtype(0).real.__class__ + ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return POINTER(r) + else: + return POINTER(dtype_to_ctype(self.dtype)) except TypeError: # `dtype` is a ctypes-derived type! return self.dtype From aa353b4614957e4a833fed485f57754821a9df37 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 May 2024 08:01:27 -0400 Subject: [PATCH 02/29] api: fix printer for complex dtype --- devito/finite_differences/differentiable.py | 3 +-- devito/passes/iet/misc.py | 1 - devito/symbolics/inspection.py | 1 - devito/symbolics/printer.py | 10 ++++++++ devito/types/basic.py | 2 +- tests/test_operator.py | 28 +++++++++++++++++---- 6 files changed, 35 insertions(+), 10 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 8b5a47207c..a95e8be88d 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -14,8 +14,7 @@ from devito.logger import warning from devito.tools import (as_tuple, filter_ordered, flatten, frozendict, infer_dtype, is_integer, split) -from devito.types import (Array, DimensionTuple, Evaluable, Indexed, - StencilDimension) +from devito.types import Array, DimensionTuple, Evaluable, StencilDimension __all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative', 'Weights'] diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index f8a99ead5f..1f7bbbe881 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -3,7 +3,6 @@ import cgen import numpy as np import sympy -import numpy as np from devito import configuration from devito.finite_differences import Max, Min diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 8339aabc2c..3332bc68d6 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -3,7 +3,6 @@ import numpy as np from sympy import (Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple) -from sympy.core.operations import AssocOp from sympy.core.numbers import ImaginaryUnit from devito.finite_differences import Derivative diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 2a25ef5c12..fddd133ac1 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -43,6 +43,10 @@ def single_prec(self, expr=None): dtype = sympy_dtype(expr) if expr is not None else self.dtype return dtype in [np.float32, np.float16] + def complex_prec(self, expr=None): + dtype = sympy_dtype(expr) if expr is not None else self.dtype + return np.issubdtype(dtype, np.complexfloating) + def parenthesize(self, item, level, strict=False): if isinstance(item, BooleanFunction): return "(%s)" % self._print(item) @@ -110,6 +114,8 @@ def _print_math_func(self, expr, nest=False, known=None): if self.single_prec(expr): cname = '%sf' % cname + if self.complex_prec(expr): + cname = 'c%s' % cname args = ', '.join((self._print(arg) for arg in expr.args)) @@ -255,8 +261,12 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) + if self.single_prec(): func_name = '%sf' % func_name + if self.complex_prec(): + func_name = 'c%s' % func_name + return '%s(%s)' % (func_name, self._print(*expr.args)) def _print_DefFunction(self, expr): diff --git a/devito/types/basic.py b/devito/types/basic.py index 065efe0590..8ee4fefc7f 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1482,7 +1482,7 @@ def _C_ctype(self): try: if np.issubdtype(self.dtype, np.complexfloating): rtype = self.dtype(0).real.__class__ - ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctname = '%s complex' % dtype_to_cstr(rtype) ctype = dtype_to_ctype(rtype) r = type(ctname, (ctype,), {}) return POINTER(r) diff --git a/tests/test_operator.py b/tests/test_operator.py index d5759c1c92..db962b7d6b 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -9,7 +9,7 @@ SparseFunction, SparseTimeFunction, Dimension, error, SpaceDimension, NODE, CELL, dimensions, configuration, TensorFunction, TensorTimeFunction, VectorFunction, VectorTimeFunction, - div, grad, switchconfig) + div, grad, switchconfig, exp) from devito import Inc, Le, Lt, Ge, Gt # noqa from devito.exceptions import InvalidOperator from devito.finite_differences.differentiable import diff2sympy @@ -640,6 +640,24 @@ def test_tensor(self, func1): op2 = Operator([Eq(f, f.dx) for f in f1.values()]) assert str(op1.ccode) == str(op2.ccode) + def test_complex(self): + grid = Grid((5, 5)) + x, y = grid.dimensions + # Float32 complex is called complex64 in numpy + u = Function(name="u", grid=grid, dtype=np.complex64) + + eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + # Currently wrong alias type + op = Operator(eq, opt='noop') + op() + + # Check against numpy + dx = grid.spacing_map[x.spacing] + xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) + npres = xx + 1j*yy + np.exp(1j + dx) + + assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + class TestAllocation: @@ -724,10 +742,10 @@ def verify_parameters(self, parameters, expected): """ boilerplate = ['timers'] parameters = [p.name for p in parameters] - for exp in expected: - if exp not in parameters + boilerplate: - error("Missing parameter: %s" % exp) - assert exp in parameters + boilerplate + for expi in expected: + if expi not in parameters + boilerplate: + error("Missing parameter: %s" % expi) + assert expi in parameters + boilerplate extra = [p for p in parameters if p not in expected and p not in boilerplate] if len(extra) > 0: error("Redundant parameters: %s" % str(extra)) From 92dfd9a1f308ec973a41cb91c275adefb9c4802d Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 May 2024 08:17:39 -0400 Subject: [PATCH 03/29] compiler: fix alias dtype with complex numbers --- devito/symbolics/inspection.py | 8 +++++++- devito/types/basic.py | 2 +- tests/test_gpu_common.py | 18 ++++++++++++++++++ tests/test_operator.py | 2 +- 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 3332bc68d6..6649fa86bf 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -304,4 +304,10 @@ def sympy_dtype(expr, base=None): dtypes.add(i.dtype) except AttributeError: pass - return infer_dtype(dtypes) + dtype = infer_dtype(dtypes) + + # Promote if complex + if expr.has(ImaginaryUnit): + dtype = np.promote_types(dtype, np.complex64).type + + return dtype diff --git a/devito/types/basic.py b/devito/types/basic.py index 8ee4fefc7f..8e5d0d0455 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -437,7 +437,7 @@ def _C_ctype(self): return self.dtype elif np.issubdtype(self.dtype, np.complexfloating): rtype = self.dtype(0).real.__class__ - ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctname = '%s complex' % dtype_to_cstr(rtype) ctype = dtype_to_ctype(rtype) r = type(ctname, (ctype,), {}) return r diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 8f100a1082..450814bf74 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -66,6 +66,24 @@ def test_maxpar_option(self): assert trees[0][0] is trees[1][0] assert trees[0][1] is not trees[1][1] + def test_complex(self): + grid = Grid((5, 5)) + x, y = grid.dimensions + # Float32 complex is called complex64 in numpy + u = Function(name="u", grid=grid, dtype=np.complex64) + + eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + # Currently wrong alias type + op = Operator(eq) + op() + + # Check against numpy + dx = grid.spacing_map[x.spacing] + xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) + npres = xx + 1j*yy + np.exp(1j + dx) + + assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + class TestPassesOptional: diff --git a/tests/test_operator.py b/tests/test_operator.py index db962b7d6b..5d975685ce 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -648,7 +648,7 @@ def test_complex(self): eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) # Currently wrong alias type - op = Operator(eq, opt='noop') + op = Operator(eq) op() # Check against numpy From 4364524b84e903f501375635c056f47db2734dfe Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 May 2024 08:25:51 -0400 Subject: [PATCH 04/29] api: move complex ctype to dtype lowering --- devito/operator/operator.py | 2 +- devito/passes/clusters/factorization.py | 3 +-- devito/passes/iet/misc.py | 24 +++++++++++++----------- devito/symbolics/printer.py | 3 +++ devito/tools/dtypes_lowering.py | 8 ++++++++ devito/types/basic.py | 23 +++-------------------- tests/test_gpu_common.py | 2 +- tests/test_operator.py | 1 + 8 files changed, 31 insertions(+), 35 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index e2fe8dd3a1..266aacd81c 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -472,7 +472,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Complex header if needed. Needs to be done specialization # as some specific cases requires complex to be loaded first - complex_include(graph) + complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler']) # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 794e437e97..47222a33be 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -1,7 +1,6 @@ from collections import defaultdict from sympy import Add, Mul, S, collect -from sympy.core import NumberKind from devito.ir import cluster_pass from devito.symbolics import BasicWrapperMixin, estimate_cost, retrieve_symbols @@ -174,7 +173,7 @@ def _collect_nested(expr): Recursion helper for `collect_nested`. """ # Return semantic (rebuilt expression, factorization candidates) - if expr.kind is NumberKind: + if expr.is_Number: return expr, {'coeffs': expr} elif expr.is_Function: return expr, {'funcs': expr} diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 1f7bbbe881..7211c5877a 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -4,7 +4,6 @@ import numpy as np import sympy -from devito import configuration from devito.finite_differences import Max, Min from devito.ir import (Any, Forward, Iteration, List, Prodder, FindApplications, FindNodes, FindSymbols, Transformer, Uxreplace, @@ -241,25 +240,28 @@ def minimize_symbols(iet): return iet, {} +_complex_lib = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'} + + @iet_pass -def complex_include(iet): +def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ - if configuration['language'] == 'cuda': - lib = 'cuComplex.h' - elif configuration['language'] == 'hip': - lib = 'hip/hip_complex.h' - else: - lib = 'complex.h' + lib = _complex_lib.get(language, 'complex.h') - functions = FindSymbols().visit(iet) - for f in functions: + headers = {} + # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise + if compiler._cpp: + headers = {('_Complex_I', ('1.0fi'))} + + for f in FindSymbols().visit(iet): try: if np.issubdtype(f.dtype, np.complexfloating): - return iet, {'includes': (lib,)} + return iet, {'includes': (lib,), 'headers': headers} except TypeError: pass + return iet, {} diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index fddd133ac1..7de815549e 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -210,6 +210,9 @@ def _print_Float(self, expr): return rv + def _print_ImaginaryUnit(self, expr): + return '_Complex_I' + def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index f2d0c6ad31..ff40f6c7d6 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -136,6 +136,14 @@ def dtype_to_ctype(dtype): if isinstance(dtype, CustomDtype): return dtype + # Complex data + if np.issubdtype(dtype, np.complexfloating): + rtype = dtype(0).real.__class__ + ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return r + try: return ctypes_vector_mapper[dtype] except KeyError: diff --git a/devito/types/basic.py b/devito/types/basic.py index 8e5d0d0455..e21bae6453 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -13,8 +13,7 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex, dtype_to_cstr, - CustomDtype) + frozendict, memoized_meth, sympy_mutex) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -433,16 +432,7 @@ def _C_name(self): @property def _C_ctype(self): - if isinstance(self.dtype, CustomDtype): - return self.dtype - elif np.issubdtype(self.dtype, np.complexfloating): - rtype = self.dtype(0).real.__class__ - ctname = '%s complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r - else: - return dtype_to_ctype(self.dtype) + return dtype_to_ctype(self.dtype) def _subs(self, old, new, **hints): """ @@ -1480,14 +1470,7 @@ def _C_name(self): @cached_property def _C_ctype(self): try: - if np.issubdtype(self.dtype, np.complexfloating): - rtype = self.dtype(0).real.__class__ - ctname = '%s complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return POINTER(r) - else: - return POINTER(dtype_to_ctype(self.dtype)) + return POINTER(dtype_to_ctype(self.dtype)) except TypeError: # `dtype` is a ctypes-derived type! return self.dtype diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 450814bf74..d1af179792 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -7,7 +7,7 @@ from conftest import assert_structure from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension, Dimension, MatrixSparseTimeFunction, SparseTimeFunction, - SubDimension, SubDomain, SubDomainSet, TimeFunction, + SubDimension, SubDomain, SubDomainSet, TimeFunction, exp, Operator, configuration, switchconfig, TensorTimeFunction) from devito.arch import get_gpu_info from devito.exceptions import InvalidArgument diff --git a/tests/test_operator.py b/tests/test_operator.py index 5d975685ce..9cdf34e313 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -655,6 +655,7 @@ def test_complex(self): dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) + print(op) assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) From 470f4f50cff87f4091c0360e7b414c00e9c2574c Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 28 May 2024 13:00:56 -0400 Subject: [PATCH 05/29] compiler: generate std:complex for cpp compilers --- devito/ir/iet/visitors.py | 43 +++++++++++++++++++++++---------- devito/passes/iet/misc.py | 4 +-- devito/symbolics/printer.py | 8 ++++++ devito/tools/dtypes_lowering.py | 7 ++---- tests/test_gpu_common.py | 3 ++- tests/test_operator.py | 2 +- 6 files changed, 45 insertions(+), 22 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 54e9188e1a..69c99a5161 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -10,6 +10,7 @@ import ctypes import cgen as c +import numpy as np from sympy import IndexedBase from sympy.core.function import Application @@ -188,6 +189,21 @@ def __init__(self, *args, compiler=None, **kwargs): } _restrict_keyword = 'restrict' + def _complex_type(self, ctypestr, dtype): + # Not complex + try: + if not np.issubdtype(dtype, np.complexfloating): + return ctypestr + except TypeError: + return ctypestr + # Complex only supported for float and double + if ctypestr not in ('float', 'double'): + return ctypestr + if self._compiler._cpp: + return 'std::complex<%s>' % ctypestr + else: + return '%s _Complex' % ctypestr + def _gen_struct_decl(self, obj, masked=()): """ Convert ctypes.Struct -> cgen.Structure. @@ -243,10 +259,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = obj._C_typedata + strtype = self._complex_type(obj._C_typedata, obj.dtype) strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) else: - strtype = ctypes_to_cstr(obj._C_ctype) + strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -376,10 +392,11 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed + cstr = self._complex_type(i._C_typedata, i.dtype) if f.is_PointerArray: # lvalue - lvalue = c.Value(i._C_typedata, '**%s' % f.name) + lvalue = c.Value(cstr, '**%s' % f.name) # rvalue if isinstance(o.obj, ArrayObject): @@ -388,7 +405,7 @@ def visit_PointerCast(self, o): v = f._C_name else: assert False - rvalue = '(%s**) %s' % (i._C_typedata, v) + rvalue = '(%s**) %s' % (cstr, v) else: # lvalue @@ -399,10 +416,10 @@ def visit_PointerCast(self, o): if o.flat is None: shape = ''.join("[%s]" % ccode(i) for i in o.castshape) rshape = '(*)%s' % shape - lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape)) + lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape)) else: rshape = '*' - lvalue = c.Value(i._C_typedata, '*%s' % v) + lvalue = c.Value(cstr, '*%s' % v) if o.alignment and f._data_alignment: lvalue = c.AlignedAttribute(f._data_alignment, lvalue) @@ -415,14 +432,14 @@ def visit_PointerCast(self, o): else: assert False - rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v) + rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v) else: if isinstance(o.obj, Pointer): v = o.obj.name else: v = f._C_name - rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v) + rvalue = '(%s %s) %s' % (cstr, rshape, v) return c.Initializer(lvalue, rvalue) @@ -430,15 +447,15 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed + cstr = self._complex_type(i._C_typedata, i.dtype) if o.flat is None: shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) - rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name, + rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, a1.dim.name) - lvalue = c.Value(i._C_typedata, - '(*restrict %s)%s' % (a0.name, shape)) + lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape)) else: - rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name) - lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name) + rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name) + lvalue = c.Value(cstr, '*restrict %s' % a0.name) if a0._data_alignment: lvalue = c.AlignedAttribute(a0._data_alignment, lvalue) else: diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 7211c5877a..1eac0664d4 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -248,12 +248,12 @@ def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ - lib = _complex_lib.get(language, 'complex.h') + lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h') headers = {} # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise if compiler._cpp: - headers = {('_Complex_I', ('1.0fi'))} + headers = {('_Complex_I', ('std::complex(0.0f, 1.0f)'))} for f in FindSymbols().visit(iet): try: diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 7de815549e..fd15796bf8 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -264,8 +264,16 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) +<<<<<<< HEAD if self.single_prec(): +======= + dtype = self.dtype + if np.issubdtype(dtype, np.complexfloating): + func_name = 'c%s' % func_name + dtype = self.dtype(0).real.dtype.type + if dtype == np.float32: +>>>>>>> 75d50a431 (compiler: generate std:complex for cpp compilers) func_name = '%sf' % func_name if self.complex_prec(): func_name = 'c%s' % func_name diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index ff40f6c7d6..6ca336e305 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -139,10 +139,7 @@ def dtype_to_ctype(dtype): # Complex data if np.issubdtype(dtype, np.complexfloating): rtype = dtype(0).real.__class__ - ctname = '%s _Complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r + return dtype_to_ctype(rtype) try: return ctypes_vector_mapper[dtype] @@ -217,7 +214,7 @@ class c_restrict_void_p(ctypes.c_void_p): # *** ctypes lowering -def ctypes_to_cstr(ctype, toarray=None): +def ctypes_to_cstr(ctype, toarray=None, cpp=False): """Translate ctypes types into C strings.""" if ctype in ctypes_vector_mapper.values(): retval = ctype.__name__ diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index d1af179792..c7bb0c0211 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -2,6 +2,7 @@ import pytest import numpy as np +import sympy import scipy.sparse from conftest import assert_structure @@ -72,7 +73,7 @@ def test_complex(self): # Float32 complex is called complex64 in numpy u = Function(name="u", grid=grid, dtype=np.complex64) - eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) op() diff --git a/tests/test_operator.py b/tests/test_operator.py index 9cdf34e313..61b117bcc6 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -646,7 +646,7 @@ def test_complex(self): # Float32 complex is called complex64 in numpy u = Function(name="u", grid=grid, dtype=np.complex64) - eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) op() From 7ffff0a2e8ad831d7d00aa244a8e037a47e3ffea Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 30 May 2024 12:33:30 -0400 Subject: [PATCH 06/29] compiler: add std::complex arithmetic defs for unsupported types --- devito/ir/iet/visitors.py | 3 ++- devito/passes/iet/misc.py | 33 +++++++++++++++++++++++++++++++-- devito/symbolics/printer.py | 10 +--------- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 69c99a5161..aed6eb1351 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -14,6 +14,7 @@ from sympy import IndexedBase from sympy.core.function import Application +from devito.parameters import configuration from devito.exceptions import VisitorException from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) @@ -177,7 +178,7 @@ class CGen(Visitor): def __init__(self, *args, compiler=None, **kwargs): super().__init__(*args, **kwargs) - self._compiler = compiler + self._compiler = compiler or configuration['compiler'] # The following mappers may be customized by subclasses (that is, # backend-specific CGen-erators) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 1eac0664d4..5b53c43796 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -248,17 +248,26 @@ def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ - lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h') + lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) headers = {} + # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise if compiler._cpp: + # Constant I headers = {('_Complex_I', ('std::complex(0.0f, 1.0f)'))} + # Mix arithmetic definitions + dest = compiler.get_jit_dir() + hfile = dest.joinpath('stdcomplex_arith.h') + if not hfile.is_file(): + with open(str(hfile), 'w') as ff: + ff.write(str(_stdcomplex_defs)) + lib += (str(hfile),) for f in FindSymbols().visit(iet): try: if np.issubdtype(f.dtype, np.complexfloating): - return iet, {'includes': (lib,), 'headers': headers} + return iet, {'includes': lib, 'headers': headers} except TypeError: pass @@ -343,3 +352,23 @@ def _rename_subdims(target, dimensions): return {d: d._rebuild(d.root.name) for d in dims if d.root not in dimensions and names.count(d.root.name) < 2} + + +_stdcomplex_defs = """ +#include + +template +std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() / a, b.imag() / a); +} + +template +std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} +""" diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index fd15796bf8..c7917b3ea1 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -41,7 +41,7 @@ def compiler(self): def single_prec(self, expr=None): dtype = sympy_dtype(expr) if expr is not None else self.dtype - return dtype in [np.float32, np.float16] + return dtype in [np.float32, np.float16, np.complex64] def complex_prec(self, expr=None): dtype = sympy_dtype(expr) if expr is not None else self.dtype @@ -264,16 +264,8 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) -<<<<<<< HEAD if self.single_prec(): -======= - dtype = self.dtype - if np.issubdtype(dtype, np.complexfloating): - func_name = 'c%s' % func_name - dtype = self.dtype(0).real.dtype.type - if dtype == np.float32: ->>>>>>> 75d50a431 (compiler: generate std:complex for cpp compilers) func_name = '%sf' % func_name if self.complex_prec(): func_name = 'c%s' % func_name From d1dd24e3b50787318a20463d5e7b8b6259bd4cbe Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 30 May 2024 14:08:12 -0400 Subject: [PATCH 07/29] compiler: fix alias dtype with complex numbers --- devito/__init__.py | 5 +++-- devito/arch/compiler.py | 24 +++++++++++++++++++-- devito/ir/iet/visitors.py | 34 +++++++++++++----------------- devito/operator/operator.py | 7 ++++--- devito/passes/iet/misc.py | 37 +++++++++++++++++++++++---------- devito/symbolics/inspection.py | 6 ++++-- devito/tools/dtypes_lowering.py | 12 ++++++++--- tests/test_gpu_common.py | 7 ++++--- tests/test_operator.py | 8 +++---- 9 files changed, 90 insertions(+), 50 deletions(-) diff --git a/devito/__init__.py b/devito/__init__.py index b0a981dcfa..b407988eb8 100644 --- a/devito/__init__.py +++ b/devito/__init__.py @@ -56,7 +56,8 @@ def reinit_compiler(val): """ Re-initialize the Compiler. """ - configuration['compiler'].__init__(suffix=configuration['compiler'].suffix, + configuration['compiler'].__init__(name=configuration['compiler'].name, + suffix=configuration['compiler'].suffix, mpi=configuration['mpi']) return val @@ -65,7 +66,7 @@ def reinit_compiler(val): configuration.add('platform', 'cpu64', list(platform_registry), callback=lambda i: platform_registry[i]()) configuration.add('compiler', 'custom', list(compiler_registry), - callback=lambda i: compiler_registry[i]()) + callback=lambda i: compiler_registry[i](name=i)) # Setup language for shared-memory parallelism preprocessor = lambda i: {0: 'C', 1: 'openmp'}.get(i, i) # Handles DEVITO_OPENMP deprec diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 9cd94ed597..de4711d257 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -180,6 +180,8 @@ def __init__(self): _cpp = False def __init__(self, **kwargs): + self._name = kwargs.pop('name', self.__class__.__name__) + super().__init__(**kwargs) self.__lookup_cmds__() @@ -223,13 +225,13 @@ def __new_with__(self, **kwargs): Create a new Compiler from an existing one, inherenting from it the flags that are not specified via ``kwargs``. """ - return self.__class__(suffix=kwargs.pop('suffix', self.suffix), + return self.__class__(name=self.name, suffix=kwargs.pop('suffix', self.suffix), mpi=kwargs.pop('mpi', configuration['mpi']), **kwargs) @property def name(self): - return self.__class__.__name__ + return self._name @property def version(self): @@ -245,6 +247,20 @@ def version(self): return version + @property + def _complex_ctype(self): + """ + Type definition for complex numbers. THese two cases cover 99% of the cases since + - Hip is now using std::complex +https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings + - Sycl supports std::complex + - C's _Complex is part of C99 + """ + if self._cpp: + return lambda dtype: 'std::complex<%s>' % str(dtype) + else: + return lambda dtype: '%s _Complex' % str(dtype) + def get_version(self): result, stdout, stderr = call_capture_output((self.cc, "--version")) if result != 0: @@ -699,6 +715,10 @@ def __lookup_cmds__(self): self.MPICC = 'nvcc' self.MPICXX = 'nvcc' + @property + def _complex_ctype(self): + return lambda dtype: 'thrust::complex<%s>' % str(dtype) + class HipCompiler(Compiler): diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index aed6eb1351..1e21f1d8ba 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -10,11 +10,10 @@ import ctypes import cgen as c -import numpy as np from sympy import IndexedBase from sympy.core.function import Application -from devito.parameters import configuration +from devito.parameters import configuration, switchconfig from devito.exceptions import VisitorException from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) @@ -190,20 +189,15 @@ def __init__(self, *args, compiler=None, **kwargs): } _restrict_keyword = 'restrict' - def _complex_type(self, ctypestr, dtype): - # Not complex - try: - if not np.issubdtype(dtype, np.complexfloating): - return ctypestr - except TypeError: - return ctypestr - # Complex only supported for float and double - if ctypestr not in ('float', 'double'): - return ctypestr - if self._compiler._cpp: - return 'std::complex<%s>' % ctypestr - else: - return '%s _Complex' % ctypestr + @property + def compiler(self): + return self._compiler + + def visit(self, o, *args, **kwargs): + # Make sure the visitor always is within the generating compiler + # in case the configuration is accessed + with switchconfig(compiler=self.compiler.name): + return super().visit(o, *args, **kwargs) def _gen_struct_decl(self, obj, masked=()): """ @@ -260,10 +254,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = self._complex_type(obj._C_typedata, obj.dtype) + strtype = obj._C_typedata strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) else: - strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype) + strtype = ctypes_to_cstr(obj._C_ctype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -393,7 +387,7 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed - cstr = self._complex_type(i._C_typedata, i.dtype) + cstr = i._C_typedata if f.is_PointerArray: # lvalue @@ -448,7 +442,7 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed - cstr = self._complex_type(i._C_typedata, i.dtype) + cstr = i._C_typedata if o.flat is None: shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 266aacd81c..57ef1134ae 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -470,8 +470,8 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) - # Complex header if needed. Needs to be done specialization - # as some specific cases requires complex to be loaded first + # Complex header if needed. Needs to be done before specialization + # as some specific cases require complex to be loaded first complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler']) # Specialize @@ -1353,7 +1353,8 @@ def parse_kwargs(**kwargs): raise InvalidOperator("Illegal `compiler=%s`" % str(compiler)) kwargs['compiler'] = compiler_registry[compiler](platform=kwargs['platform'], language=kwargs['language'], - mpi=configuration['mpi']) + mpi=configuration['mpi'], + name=compiler) elif any([platform, language]): kwargs['compiler'] =\ configuration['compiler'].__new_with__(platform=kwargs['platform'], diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 5b53c43796..53ebe7d3e8 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -12,7 +12,7 @@ from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, ccode) -from devito.tools import Bunch, as_mapper, filter_ordered, split +from devito.tools import Bunch, as_mapper, filter_ordered, split, dtype_to_cstr from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', @@ -240,7 +240,7 @@ def minimize_symbols(iet): return iet, {} -_complex_lib = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'} +_complex_lib = {'cuda': 'thrust/complex.h'} @iet_pass @@ -248,14 +248,20 @@ def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ + # Check if there is complex numbers that always take dtype precedence + max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)]) + if not np.issubdtype(max_dtype, np.complexfloating): + return iet, {} + lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) headers = {} # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise if compiler._cpp: + c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type) # Constant I - headers = {('_Complex_I', ('std::complex(0.0f, 1.0f)'))} + headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))} # Mix arithmetic definitions dest = compiler.get_jit_dir() hfile = dest.joinpath('stdcomplex_arith.h') @@ -264,14 +270,7 @@ def complex_include(iet, language, compiler): ff.write(str(_stdcomplex_defs)) lib += (str(hfile),) - for f in FindSymbols().visit(iet): - try: - if np.issubdtype(f.dtype, np.complexfloating): - return iet, {'includes': lib, 'headers': headers} - except TypeError: - pass - - return iet, {} + return iet, {'includes': lib, 'headers': headers} def remove_redundant_moddims(iet): @@ -362,8 +361,19 @@ def _rename_subdims(target, dimensions): return std::complex<_Tp>(b.real() * a, b.imag() * a); } +template +std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + template std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + _Tp denom = b.real() * b.real () + b.imag() * b.imag() + return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); +} + +template +std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ return std::complex<_Tp>(b.real() / a, b.imag() / a); } @@ -371,4 +381,9 @@ def _rename_subdims(target, dimensions): std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ return std::complex<_Tp>(b.real() + a, b.imag()); } + +template +std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} """ diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 6649fa86bf..53c7b07e39 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -304,10 +304,12 @@ def sympy_dtype(expr, base=None): dtypes.add(i.dtype) except AttributeError: pass + dtype = infer_dtype(dtypes) - # Promote if complex - if expr.has(ImaginaryUnit): + # Promote if we missed complex number, i.e f + I + is_im = np.issubdtype(dtype, np.complexfloating) + if expr.has(ImaginaryUnit) and not is_im: dtype = np.promote_types(dtype, np.complex64).type return dtype diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 6ca336e305..8a30b04cc4 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -139,7 +139,12 @@ def dtype_to_ctype(dtype): # Complex data if np.issubdtype(dtype, np.complexfloating): rtype = dtype(0).real.__class__ - return dtype_to_ctype(rtype) + from devito import configuration + make = configuration['compiler']._complex_ctype + ctname = make(dtype_to_cstr(rtype)) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return r try: return ctypes_vector_mapper[dtype] @@ -214,7 +219,7 @@ class c_restrict_void_p(ctypes.c_void_p): # *** ctypes lowering -def ctypes_to_cstr(ctype, toarray=None, cpp=False): +def ctypes_to_cstr(ctype, toarray=None): """Translate ctypes types into C strings.""" if ctype in ctypes_vector_mapper.values(): retval = ctype.__name__ @@ -308,7 +313,8 @@ def infer_dtype(dtypes): # Resolve the vector types, if any dtypes = {dtypes_vector_mapper.get_base_dtype(i, i) for i in dtypes} - fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating)} + fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating) or + np.issubdtype(i, np.complexfloating)} if len(fdtypes) > 1: return max(fdtypes, key=lambda i: np.dtype(i).itemsize) elif len(fdtypes) == 1: diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index c7bb0c0211..79b6dccb08 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -67,15 +67,16 @@ def test_maxpar_option(self): assert trees[0][0] is trees[1][0] assert trees[0][1] is not trees[1][1] - def test_complex(self): + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) + def test_complex(self, dtype): grid = Grid((5, 5)) x, y = grid.dimensions - # Float32 complex is called complex64 in numpy - u = Function(name="u", grid=grid, dtype=np.complex64) + u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) + print(op) op() # Check against numpy diff --git a/tests/test_operator.py b/tests/test_operator.py index 61b117bcc6..c1a8809379 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -640,22 +640,22 @@ def test_tensor(self, func1): op2 = Operator([Eq(f, f.dx) for f in f1.values()]) assert str(op1.ccode) == str(op2.ccode) - def test_complex(self): + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) + def test_complex(self, dtype): grid = Grid((5, 5)) x, y = grid.dimensions - # Float32 complex is called complex64 in numpy - u = Function(name="u", grid=grid, dtype=np.complex64) + u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) + # print(op) op() # Check against numpy dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) - print(op) assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) From 4f43f26d270d3bf2dc5815f796a47f1dc4f2475c Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 31 May 2024 09:58:54 -0400 Subject: [PATCH 08/29] compiler: fix internal language specific types and cast wip --- devito/arch/compiler.py | 3 +- devito/operator/operator.py | 2 +- devito/passes/iet/__init__.py | 1 + devito/passes/iet/misc.py | 71 +----------------------------- devito/symbolics/extended_sympy.py | 29 +++++++++++- tests/test_gpu_common.py | 2 - tests/test_operator.py | 2 - 7 files changed, 33 insertions(+), 77 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index de4711d257..a7d05259e5 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -250,7 +250,7 @@ def version(self): @property def _complex_ctype(self): """ - Type definition for complex numbers. THese two cases cover 99% of the cases since + Type definition for complex numbers. These two cases cover 99% of the cases since - Hip is now using std::complex https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings - Sycl supports std::complex @@ -998,6 +998,7 @@ def __new_with__(self, **kwargs): 'nvc++': NvidiaCompiler, 'nvidia': NvidiaCompiler, 'cuda': CudaCompiler, + 'nvcc': CudaCompiler, 'osx': ClangCompiler, 'intel': OneapiCompiler, 'icx': OneapiCompiler, diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 57ef1134ae..efb47640ce 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -472,7 +472,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Complex header if needed. Needs to be done before specialization # as some specific cases require complex to be loaded first - complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler']) + include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler']) # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/iet/__init__.py b/devito/passes/iet/__init__.py index c09db00c9b..6b4ada0b73 100644 --- a/devito/passes/iet/__init__.py +++ b/devito/passes/iet/__init__.py @@ -8,3 +8,4 @@ from .instrument import * # noqa from .languages import * # noqa from .errors import * # noqa +from .complex import * # noqa diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 53ebe7d3e8..50511b6005 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -16,7 +16,7 @@ from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', - 'generate_macros', 'minimize_symbols', 'complex_include'] + 'generate_macros', 'minimize_symbols'] @iet_pass @@ -240,39 +240,6 @@ def minimize_symbols(iet): return iet, {} -_complex_lib = {'cuda': 'thrust/complex.h'} - - -@iet_pass -def complex_include(iet, language, compiler): - """ - Add headers for complex arithmetic - """ - # Check if there is complex numbers that always take dtype precedence - max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)]) - if not np.issubdtype(max_dtype, np.complexfloating): - return iet, {} - - lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) - - headers = {} - - # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise - if compiler._cpp: - c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type) - # Constant I - headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))} - # Mix arithmetic definitions - dest = compiler.get_jit_dir() - hfile = dest.joinpath('stdcomplex_arith.h') - if not hfile.is_file(): - with open(str(hfile), 'w') as ff: - ff.write(str(_stdcomplex_defs)) - lib += (str(hfile),) - - return iet, {'includes': lib, 'headers': headers} - - def remove_redundant_moddims(iet): key = lambda d: d.is_Modulo and d.origin is not None mds = [d for d in FindSymbols('dimensions').visit(iet) if key(d)] @@ -351,39 +318,3 @@ def _rename_subdims(target, dimensions): return {d: d._rebuild(d.root.name) for d in dims if d.root not in dimensions and names.count(d.root.name) < 2} - - -_stdcomplex_defs = """ -#include - -template -std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ - return std::complex<_Tp>(b.real() * a, b.imag() * a); -} - -template -std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() * a, b.imag() * a); -} - -template -std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ - _Tp denom = b.real() * b.real () + b.imag() * b.imag() - return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); -} - -template -std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() / a, b.imag() / a); -} - -template -std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ - return std::complex<_Tp>(b.real() + a, b.imag()); -} - -template -std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() + a, b.imag()); -} -""" diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 7ed801d17a..03fec7438a 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,6 +7,7 @@ from sympy import Expr, Function, Number, Tuple, sympify from sympy.core.decorators import call_highest_priority +from devito import configuration from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, @@ -811,6 +812,20 @@ class VOID(Cast): _base_typ = 'void' +class CFLOAT(Cast): + + @property + def _base_typ(self): + return configuration['compiler']._complex_ctype('float') + + +class CDOUBLE(Cast): + + @property + def _base_typ(self): + return configuration['compiler']._complex_ctype('double') + + class CHARP(CastStar): base = CHAR @@ -827,6 +842,14 @@ class USHORTP(CastStar): base = USHORT +class CFLOATP(CastStar): + base = CFLOAT + + +class CDOUBLEP(CastStar): + base = CDOUBLE + + cast_mapper = { np.int8: CHAR, np.uint8: UCHAR, @@ -839,6 +862,8 @@ class USHORTP(CastStar): np.float32: FLOAT, # noqa float: DOUBLE, # noqa np.float64: DOUBLE, # noqa + np.complex64: CFLOAT, # noqa + np.complex128: CDOUBLE, # noqa (np.int8, '*'): CHARP, (np.uint8, '*'): UCHARP, @@ -849,7 +874,9 @@ class USHORTP(CastStar): (np.int64, '*'): INTP, # noqa (np.float32, '*'): FLOATP, # noqa (float, '*'): DOUBLEP, # noqa - (np.float64, '*'): DOUBLEP # noqa + (np.float64, '*'): DOUBLEP, # noqa + (np.complex64, '*'): CFLOATP, # noqa + (np.complex128, '*'): CDOUBLEP, # noqa } for base_name in ['int', 'float', 'double']: diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 79b6dccb08..e229cbb98d 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -74,9 +74,7 @@ def test_complex(self, dtype): u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) - # Currently wrong alias type op = Operator(eq) - print(op) op() # Check against numpy diff --git a/tests/test_operator.py b/tests/test_operator.py index c1a8809379..283249aac1 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -647,9 +647,7 @@ def test_complex(self, dtype): u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) - # Currently wrong alias type op = Operator(eq) - # print(op) op() # Check against numpy From 6b4f12d838f847462e630b369887137a16772a5b Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 20 Jun 2024 10:28:38 -0400 Subject: [PATCH 09/29] compiler: rework dtype lowering --- devito/arch/compiler.py | 20 +--- devito/ir/iet/visitors.py | 2 +- devito/operator/operator.py | 4 - devito/passes/iet/__init__.py | 2 +- devito/passes/iet/definitions.py | 12 ++- devito/passes/iet/dtypes.py | 58 ++++++++++++ devito/passes/iet/langbase.py | 11 +++ devito/passes/iet/languages/C.py | 12 ++- devito/passes/iet/languages/CXX.py | 69 ++++++++++++++ devito/passes/iet/languages/openacc.py | 5 +- devito/passes/iet/misc.py | 2 +- devito/symbolics/__init__.py | 1 + devito/symbolics/extended_dtypes.py | 123 ++++++++++++++++++++++++ devito/symbolics/extended_sympy.py | 126 +------------------------ devito/symbolics/inspection.py | 3 +- devito/symbolics/printer.py | 12 ++- devito/tools/dtypes_lowering.py | 24 ++--- devito/types/basic.py | 33 +++++-- devito/types/misc.py | 2 +- 19 files changed, 344 insertions(+), 177 deletions(-) create mode 100644 devito/passes/iet/dtypes.py create mode 100644 devito/passes/iet/languages/CXX.py create mode 100644 devito/symbolics/extended_dtypes.py diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index a7d05259e5..61cfa22b4c 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -247,20 +247,6 @@ def version(self): return version - @property - def _complex_ctype(self): - """ - Type definition for complex numbers. These two cases cover 99% of the cases since - - Hip is now using std::complex -https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings - - Sycl supports std::complex - - C's _Complex is part of C99 - """ - if self._cpp: - return lambda dtype: 'std::complex<%s>' % str(dtype) - else: - return lambda dtype: '%s _Complex' % str(dtype) - def get_version(self): result, stdout, stderr = call_capture_output((self.cc, "--version")) if result != 0: @@ -609,7 +595,7 @@ def __init_finalize__(self, **kwargs): self.cflags.remove('-O3') self.cflags.remove('-Wall') - self.cflags.append('-std=c++11') + self.cflags.append('-std=c++14') language = kwargs.pop('language', configuration['language']) platform = kwargs.pop('platform', configuration['platform']) @@ -715,10 +701,6 @@ def __lookup_cmds__(self): self.MPICC = 'nvcc' self.MPICXX = 'nvcc' - @property - def _complex_ctype(self): - return lambda dtype: 'thrust::complex<%s>' % str(dtype) - class HipCompiler(Compiler): diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 1e21f1d8ba..6e9879d873 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -602,7 +602,7 @@ def visit_MultiTraversable(self, o): return c.Collection(body) def visit_UsingNamespace(self, o): - return c.Statement('using namespace %s' % ccode(o.namespace)) + return c.Statement('using namespace %s' % str(o.namespace)) def visit_Lambda(self, o): body = [] diff --git a/devito/operator/operator.py b/devito/operator/operator.py index efb47640ce..ba411c0ea6 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -470,10 +470,6 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) - # Complex header if needed. Needs to be done before specialization - # as some specific cases require complex to be loaded first - include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler']) - # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/iet/__init__.py b/devito/passes/iet/__init__.py index 6b4ada0b73..1cdb97c794 100644 --- a/devito/passes/iet/__init__.py +++ b/devito/passes/iet/__init__.py @@ -8,4 +8,4 @@ from .instrument import * # noqa from .languages import * # noqa from .errors import * # noqa -from .complex import * # noqa +from .dtypes import * # noqa diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index ca4164d184..81a0168d58 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -12,6 +12,7 @@ from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create +from devito.passes.iet.dtypes import lower_complex from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -73,10 +74,12 @@ class DataManager: The language used to express data allocations, deletions, and host-device transfers. """ - def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs): + def __init__(self, rcompile=None, sregistry=None, platform=None, + compiler=None, **kwargs): self.rcompile = rcompile self.sregistry = sregistry self.platform = platform + self.compiler = compiler def _alloc_object_on_low_lat_mem(self, site, obj, storage): """ @@ -409,12 +412,18 @@ def place_casts(self, iet, **kwargs): return iet, {} + @iet_pass + def make_langtypes(self, iet): + iet, metadata = lower_complex(iet, self.lang, self.compiler) + return iet, metadata + def process(self, graph): """ Apply the `place_definitions` and `place_casts` passes. """ self.place_definitions(graph, globs=set()) self.place_casts(graph) + self.make_langtypes(graph) class DeviceAwareDataManager(DataManager): @@ -564,6 +573,7 @@ def process(self, graph): self.place_devptr(graph) self.place_bundling(graph, writes_input=graph.writes_input) self.place_casts(graph) + self.make_langtypes(graph) def make_zero_init(obj): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py new file mode 100644 index 0000000000..912f707afd --- /dev/null +++ b/devito/passes/iet/dtypes.py @@ -0,0 +1,58 @@ +import numpy as np +import ctypes + +from devito.ir import FindSymbols, Uxreplace + +__all__ = ['lower_complex'] + + +def lower_complex(iet, lang, compiler): + """ + Add headers for complex arithmetic + """ + # Check if there is complex numbers that always take dtype precedence + types = {f.dtype for f in FindSymbols().visit(iet) + if not issubclass(f.dtype, ctypes._Pointer)} + + if not any(np.issubdtype(d, np.complexfloating) for d in types): + return iet, {} + + lib = (lang['header-complex'],) + + metadata = {} + if lang.get('complex-namespace') is not None: + metadata['namespaces'] = lang['complex-namespace'] + + # Some languges such as c++11 need some extra arithmetic definitions + if lang.get('def-complex'): + dest = compiler.get_jit_dir() + hfile = dest.joinpath('complex_arith.h') + with open(str(hfile), 'w') as ff: + ff.write(str(lang['def-complex'])) + lib += (str(hfile),) + + iet = _complex_dtypes(iet, lang) + metadata['includes'] = lib + print(metadata) + return iet, metadata + + +def _complex_dtypes(iet, lang): + """ + Lower dtypes to language specific types + """ + mapper = {} + + for s in FindSymbols('indexeds').visit(iet): + if s.dtype in lang['types']: + mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) + + for s in FindSymbols().visit(iet): + if s.dtype in lang['types']: + mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) + + body = Uxreplace(mapper).visit(iet.body) + params = Uxreplace(mapper).visit(iet.parameters) + iet = iet._rebuild(body=body, parameters=params) + + return iet diff --git a/devito/passes/iet/langbase.py b/devito/passes/iet/langbase.py index d27674c419..e34aa2dac3 100644 --- a/devito/passes/iet/langbase.py +++ b/devito/passes/iet/langbase.py @@ -31,6 +31,9 @@ def __getitem__(self, k): raise NotImplementedError("Missing required mapping for `%s`" % k) return self.mapper[k] + def get(self, k): + return self.mapper.get(k) + class LangBB(metaclass=LangMeta): @@ -200,6 +203,14 @@ def initialize(self, iet, options=None): """ return iet, {} + @iet_pass + def make_langtypes(self, iet): + """ + An `iet_pass` which transforms an IET such that the target language + types are introduced. + """ + return iet, {} + @property def Region(self): return self.lang.Region diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 4b3358798d..bd5e0e6413 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,11 +1,18 @@ +import numpy as np + from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB +from devito.tools import CustomNpType __all__ = ['CBB', 'CDataManager', 'COrchestrator'] +CCFloat = CustomNpType('_Complex float', np.complex64) +CCDouble = CustomNpType('_Complex double', np.complex128) + + class CBB(LangBB): mapper = { @@ -19,7 +26,10 @@ class CBB(LangBB): 'host-free-pin': lambda i: Call('free', (i,)), 'alloc-global-symbol': lambda i, j, k: - Call('memcpy', (i, j, k)) + Call('memcpy', (i, j, k)), + # Complex + 'header-complex': 'complex.h', + 'types': {np.complex128: CCDouble, np.complex64: CCFloat}, } diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py new file mode 100644 index 0000000000..9f833d630b --- /dev/null +++ b/devito/passes/iet/languages/CXX.py @@ -0,0 +1,69 @@ +import numpy as np + +from devito.ir import Call, UsingNamespace +from devito.passes.iet.langbase import LangBB +from devito.tools import CustomNpType + +__all__ = ['CXXBB'] + + +std_arith = """ +#include + +template +std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + _Tp denom = b.real() * b.real () + b.imag() * b.imag() + return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); +} + +template +std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() / a, b.imag() / a); +} + +template +std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} + +template +std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} + +""" + +CXXCFloat = CustomNpType('std::complex', np.complex64, template='float') +CXXCDouble = CustomNpType('std::complex', np.complex128, template='double') + + +class CXXBB(LangBB): + + mapper = { + 'header-memcpy': 'string.h', + 'host-alloc': lambda i, j, k: + Call('posix_memalign', (i, j, k)), + 'host-alloc-pin': lambda i, j, k: + Call('posix_memalign', (i, j, k)), + 'host-free': lambda i: + Call('free', (i,)), + 'host-free-pin': lambda i: + Call('free', (i,)), + 'alloc-global-symbol': lambda i, j, k: + Call('memcpy', (i, j, k)), + # Complex + 'header-complex': 'complex', + 'complex-namespace': [UsingNamespace('std:complex_literals')], + 'def-complex': std_arith, + 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, + } diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index bcd2c8d006..bcf5660ac7 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -9,7 +9,7 @@ from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB, PragmaIteration, PragmaTransfer) -from devito.passes.iet.languages.C import CBB +from devito.passes.iet.languages.CXX import CXXBB from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration from devito.symbolics import FieldFromPointer, Macro, cast_mapper from devito.tools import filter_ordered, UnboundTuple @@ -122,7 +122,8 @@ class AccBB(PragmaLangBB): 'device-free': lambda i, *a: Call('acc_free', (i,)) } - mapper.update(CBB.mapper) + + mapper.update(CXXBB.mapper) Region = OmpRegion HostIteration = OmpIteration # Host parallelism still goes via OpenMP diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 50511b6005..f0b2b7f4f5 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -12,7 +12,7 @@ from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, ccode) -from devito.tools import Bunch, as_mapper, filter_ordered, split, dtype_to_cstr +from devito.tools import Bunch, as_mapper, filter_ordered, split from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', diff --git a/devito/symbolics/__init__.py b/devito/symbolics/__init__.py index 0f5c261471..9d7bee01b8 100644 --- a/devito/symbolics/__init__.py +++ b/devito/symbolics/__init__.py @@ -1,4 +1,5 @@ from devito.symbolics.extended_sympy import * # noqa +from devito.symbolics.extended_dtypes import * # noqa from devito.symbolics.queries import * # noqa from devito.symbolics.search import * # noqa from devito.symbolics.printer import * # noqa diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py new file mode 100644 index 0000000000..c558eb4e18 --- /dev/null +++ b/devito/symbolics/extended_dtypes.py @@ -0,0 +1,123 @@ +import numpy as np + +from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit +from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa + int2, int3, int4) + +__all__ = ['cast_mapper', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa + + +limits_mapper = { + np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')), + np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')), + np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')), + np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')), +} + + +class CustomType(ReservedWord): + pass + + +# Dynamically create INT, INT2, .... INTP, INT2P, ... FLOAT, ... +for base_name in ['int', 'float', 'double']: + for i in ['', '2', '3', '4']: + v = '%s%s' % (base_name, i) + cls = type(v.upper(), (Cast,), {'_base_typ': v}) + globals()[cls.__name__] = cls + + clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) + globals()[clsp.__name__] = clsp + + +class CHAR(Cast): + _base_typ = 'char' + + +class SHORT(Cast): + _base_typ = 'short' + + +class USHORT(Cast): + _base_typ = 'unsigned short' + + +class UCHAR(Cast): + _base_typ = 'unsigned char' + + +class LONG(Cast): + _base_typ = 'long' + + +class ULONG(Cast): + _base_typ = 'unsigned long' + + +class CFLOAT(Cast): + _base_typ = 'float' + + +class CDOUBLE(Cast): + _base_typ = 'double' + + +class VOID(Cast): + _base_typ = 'void' + + +class CHARP(CastStar): + base = CHAR + + +class UCHARP(CastStar): + base = UCHAR + + +class SHORTP(CastStar): + base = SHORT + + +class USHORTP(CastStar): + base = USHORT + + +class CFLOATP(CastStar): + base = CFLOAT + + +class CDOUBLEP(CastStar): + base = CDOUBLE + + +cast_mapper = { + np.int8: CHAR, + np.uint8: UCHAR, + np.int16: SHORT, # noqa + np.uint16: USHORT, # noqa + int: INT, # noqa + np.int32: INT, # noqa + np.int64: LONG, + np.uint64: ULONG, + np.float32: FLOAT, # noqa + float: DOUBLE, # noqa + np.float64: DOUBLE, # noqa + + (np.int8, '*'): CHARP, + (np.uint8, '*'): UCHARP, + (int, '*'): INTP, # noqa + (np.uint16, '*'): USHORTP, # noqa + (np.int16, '*'): SHORTP, # noqa + (np.int32, '*'): INTP, # noqa + (np.int64, '*'): INTP, # noqa + (np.float32, '*'): FLOATP, # noqa + (float, '*'): DOUBLEP, # noqa + (np.float64, '*'): DOUBLEP, # noqa +} + +for base_name in ['int', 'float', 'double']: + for i in [2, 3, 4]: + v = '%s%d' % (base_name, i) + cls = locals()[v] + cast_mapper[cls] = locals()[v.upper()] + cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()] diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 03fec7438a..b386a68a79 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,7 +7,6 @@ from sympy import Expr, Function, Number, Tuple, sympify from sympy.core.decorators import call_highest_priority -from devito import configuration from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, @@ -20,8 +19,7 @@ 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', 'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', 'Namespace', - 'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null', 'SizeOf', 'rfunc', - 'cast_mapper', 'BasicWrapperMixin', 'ValueLimit', 'limits_mapper'] + 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit'] class CondEq(sympy.Eq): @@ -548,14 +546,6 @@ class ValueLimit(ReservedWord, sympy.Expr): pass -limits_mapper = { - np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')), - np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')), - np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')), - np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')), -} - - class DefFunction(Function, Pickable): """ @@ -773,120 +763,6 @@ def __new__(cls, base=''): return cls.base(base, '*') -# Dynamically create INT, INT2, .... INTP, INT2P, ... FLOAT, ... -for base_name in ['int', 'float', 'double']: - for i in ['', '2', '3', '4']: - v = '%s%s' % (base_name, i) - cls = type(v.upper(), (Cast,), {'_base_typ': v}) - globals()[cls.__name__] = cls - - clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) - globals()[clsp.__name__] = clsp - - -class CHAR(Cast): - _base_typ = 'char' - - -class SHORT(Cast): - _base_typ = 'short' - - -class USHORT(Cast): - _base_typ = 'unsigned short' - - -class UCHAR(Cast): - _base_typ = 'unsigned char' - - -class LONG(Cast): - _base_typ = 'long' - - -class ULONG(Cast): - _base_typ = 'unsigned long' - - -class VOID(Cast): - _base_typ = 'void' - - -class CFLOAT(Cast): - - @property - def _base_typ(self): - return configuration['compiler']._complex_ctype('float') - - -class CDOUBLE(Cast): - - @property - def _base_typ(self): - return configuration['compiler']._complex_ctype('double') - - -class CHARP(CastStar): - base = CHAR - - -class UCHARP(CastStar): - base = UCHAR - - -class SHORTP(CastStar): - base = SHORT - - -class USHORTP(CastStar): - base = USHORT - - -class CFLOATP(CastStar): - base = CFLOAT - - -class CDOUBLEP(CastStar): - base = CDOUBLE - - -cast_mapper = { - np.int8: CHAR, - np.uint8: UCHAR, - np.int16: SHORT, # noqa - np.uint16: USHORT, # noqa - int: INT, # noqa - np.int32: INT, # noqa - np.int64: LONG, - np.uint64: ULONG, - np.float32: FLOAT, # noqa - float: DOUBLE, # noqa - np.float64: DOUBLE, # noqa - np.complex64: CFLOAT, # noqa - np.complex128: CDOUBLE, # noqa - - (np.int8, '*'): CHARP, - (np.uint8, '*'): UCHARP, - (int, '*'): INTP, # noqa - (np.uint16, '*'): USHORTP, # noqa - (np.int16, '*'): SHORTP, # noqa - (np.int32, '*'): INTP, # noqa - (np.int64, '*'): INTP, # noqa - (np.float32, '*'): FLOATP, # noqa - (float, '*'): DOUBLEP, # noqa - (np.float64, '*'): DOUBLEP, # noqa - (np.complex64, '*'): CFLOATP, # noqa - (np.complex128, '*'): CDOUBLEP, # noqa -} - -for base_name in ['int', 'float', 'double']: - for i in [2, 3, 4]: - v = '%s%d' % (base_name, i) - cls = locals()[v] - cast_mapper[cls] = locals()[v.upper()] - cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()] - - # Some other utility objects Null = Macro('NULL') diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 53c7b07e39..11b95a16d3 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -8,7 +8,8 @@ from devito.finite_differences import Derivative from devito.finite_differences.differentiable import IndexDerivative from devito.logger import warning -from devito.symbolics.extended_sympy import (INT, CallFromPointer, Cast, +from devito.symbolics.extended_dtypes import INT +from devito.symbolics.extended_sympy import (CallFromPointer, Cast, DefFunction, ReservedWord) from devito.symbolics.queries import q_routine from devito.tools import as_tuple, prod diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index c7917b3ea1..fc180300a3 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -11,6 +11,7 @@ from sympy.printing.precedence import PRECEDENCE_VALUES, precedence from sympy.printing.c import C99CodePrinter +from devito import configuration from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype from devito.types.basic import AbstractFunction @@ -37,13 +38,17 @@ def dtype(self): @property def compiler(self): - return self._settings['compiler'] + return self._settings['compiler'] or configuration['compiler'] def single_prec(self, expr=None): + if self.compiler._cpp and expr is not None: + return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return dtype in [np.float32, np.float16, np.complex64] def complex_prec(self, expr=None): + if self.compiler._cpp: + return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return np.issubdtype(dtype, np.complexfloating) @@ -211,7 +216,10 @@ def _print_Float(self, expr): return rv def _print_ImaginaryUnit(self, expr): - return '_Complex_I' + if self.compiler._cpp: + return '1i' + else: + return '_Complex_I' def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 8a30b04cc4..3d04f73e84 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -13,7 +13,7 @@ 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', - 'is_external_ctype', 'infer_dtype', 'CustomDtype'] + 'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomNpType'] # *** Custom np.dtypes @@ -123,6 +123,18 @@ def __repr__(self): __str__ = __repr__ +class CustomNpType(CustomDtype): + """ + Custom dtype for underlying numpy type. + """ + + def __init__(self, name, nptype, template=None, modifier=None): + self.nptype = nptype + super().__init__(name, template, modifier) + + def __call__(self, val): + return self.nptype(val) + # *** np.dtypes lowering @@ -136,16 +148,6 @@ def dtype_to_ctype(dtype): if isinstance(dtype, CustomDtype): return dtype - # Complex data - if np.issubdtype(dtype, np.complexfloating): - rtype = dtype(0).real.__class__ - from devito import configuration - make = configuration['compiler']._complex_ctype - ctname = make(dtype_to_cstr(rtype)) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r - try: return ctypes_vector_mapper[dtype] except KeyError: diff --git a/devito/types/basic.py b/devito/types/basic.py index e21bae6453..15cf7ab1b8 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -13,7 +13,8 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex) + frozendict, memoized_meth, sympy_mutex, CustomDtype, + Reconstructable) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -83,6 +84,9 @@ def _C_typedata(self): The type of the object in the generated code as a `str`. """ _type = self._C_ctype + if isinstance(_type, CustomDtype): + return _type + while issubclass(_type, _Pointer): _type = _type._type_ @@ -859,6 +863,7 @@ def __new__(cls, *args, **kwargs): name = kwargs.get('name') alias = kwargs.get('alias') function = kwargs.get('function') + dtype = kwargs.get('dtype') if alias or (function and function.name != name): function = kwargs['function'] = None @@ -866,7 +871,8 @@ def __new__(cls, *args, **kwargs): # definitely a reconstruction if function is not None and \ function.name == name and \ - function.indices == indices: + function.indices == indices and \ + function.dtype == dtype: # Special case: a syntactically identical alias of `function`, so # let's just return `function` itself return function @@ -1188,7 +1194,8 @@ def bound_symbols(self): @cached_property def indexed(self): """The wrapped IndexedData object.""" - return IndexedData(self.name, shape=self._shape, function=self.function) + return IndexedData(self.name, shape=self._shape, function=self.function, + dtype=self.dtype) @cached_property def dmap(self): @@ -1445,13 +1452,14 @@ class IndexedBase(sympy.IndexedBase, Basic, Pickable): __rargs__ = ('label', 'shape') __rkwargs__ = ('function',) - def __new__(cls, label, shape, function=None): + def __new__(cls, label, shape, function=None, dtype=None): # Make sure `label` is a devito.Symbol, not a sympy.Symbol if isinstance(label, str): label = Symbol(name=label, dtype=None) with sympy_mutex: obj = sympy.IndexedBase.__new__(cls, label, shape) obj.function = function + obj._dtype = dtype or function.dtype return obj func = Pickable._rebuild @@ -1485,7 +1493,7 @@ def indices(self): @property def dtype(self): - return self.function.dtype + return self._dtype @cached_property def free_symbols(self): @@ -1547,7 +1555,7 @@ def _C_ctype(self): return self.function._C_ctype -class Indexed(sympy.Indexed): +class Indexed(sympy.Indexed, Reconstructable): # The two type flags have changed in upstream sympy as of version 1.1, # but the below interpretation is used throughout the compiler to @@ -1559,6 +1567,17 @@ class Indexed(sympy.Indexed): is_Dimension = False + __rargs__ = ('base', 'indices') + __rkwargs__ = ('dtype',) + + def __new__(cls, base, *indices, dtype=None, **kwargs): + if len(indices) == 1: + indices = as_tuple(indices[0]) + newobj = sympy.Indexed.__new__(cls, base, *indices) + newobj._dtype = dtype or base.dtype + + return newobj + @memoized_meth def __str__(self): return super().__str__() @@ -1580,7 +1599,7 @@ def function(self): @property def dtype(self): - return self.function.dtype + return self._dtype @property def name(self): diff --git a/devito/types/misc.py b/devito/types/misc.py index 72f1ab895a..b8f68e39c1 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -79,7 +79,7 @@ class FIndexed(Indexed, Pickable): __rkwargs__ = ('strides_map', 'accessor') def __new__(cls, base, *args, strides_map=None, accessor=None): - obj = super().__new__(cls, base, *args) + obj = super().__new__(cls, base, args) obj.strides_map = frozendict(strides_map or {}) obj.accessor = accessor From 9abeea84145b24f3bd634b70e905ef1c3330a98c Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 27 Jun 2024 07:59:59 -0400 Subject: [PATCH 10/29] compiler: switch to c++14 for complex_literals --- devito/passes/iet/dtypes.py | 2 +- devito/passes/iet/languages/CXX.py | 2 +- devito/symbolics/extended_dtypes.py | 2 +- devito/symbolics/extended_sympy.py | 6 +----- devito/symbolics/printer.py | 15 +++++++++++---- tests/test_gpu_common.py | 2 +- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 912f707afd..1932b60f3a 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -33,7 +33,7 @@ def lower_complex(iet, lang, compiler): iet = _complex_dtypes(iet, lang) metadata['includes'] = lib - print(metadata) + return iet, metadata diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 9f833d630b..5f74070472 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -63,7 +63,7 @@ class CXXBB(LangBB): Call('memcpy', (i, j, k)), # Complex 'header-complex': 'complex', - 'complex-namespace': [UsingNamespace('std:complex_literals')], + 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, } diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index c558eb4e18..0e8ce0cc98 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -4,7 +4,7 @@ from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4) -__all__ = ['cast_mapper', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa +__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa limits_mapper = { diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index b386a68a79..19fcd83d4e 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -18,7 +18,7 @@ 'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', - 'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', 'Namespace', + 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit'] @@ -508,10 +508,6 @@ class Keyword(ReservedWord): pass -class CustomType(ReservedWord): - pass - - class String(ReservedWord): pass diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index fc180300a3..c9c73ed0b4 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -34,14 +34,18 @@ class CodePrinter(C99CodePrinter): @property def dtype(self): - return self._settings['dtype'] + try: + return self._settings['dtype'].nptype + except AttributeError: + return self._settings['dtype'] @property def compiler(self): return self._settings['compiler'] or configuration['compiler'] - def single_prec(self, expr=None): - if self.compiler._cpp and expr is not None: + def single_prec(self, expr=None, with_f=False): + no_f = self.compiler._cpp and not with_f + if no_f and expr is not None: return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return dtype in [np.float32, np.float16, np.complex64] @@ -217,7 +221,10 @@ def _print_Float(self, expr): def _print_ImaginaryUnit(self, expr): if self.compiler._cpp: - return '1i' + if self.single_prec(with_f=True): + return '1if' + else: + return '1i' else: return '_Complex_I' diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index e229cbb98d..2e26b78c22 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -82,7 +82,7 @@ def test_complex(self, dtype): xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) - assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + assert np.allclose(u.data, npres.T, rtol=1e-6, atol=0) class TestPassesOptional: From 94d5571c3c100316c6444de91552e1b688834b7a Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 8 Jul 2024 12:47:53 -0400 Subject: [PATCH 11/29] compiler: subdtype numpy for dtype lowering --- devito/passes/iet/dtypes.py | 6 +----- devito/passes/iet/languages/C.py | 19 ++++++++++++++++--- devito/passes/iet/languages/CXX.py | 20 +++++++++++++++++--- devito/symbolics/printer.py | 2 +- devito/tools/dtypes_lowering.py | 14 +------------- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 1932b60f3a..57eb10c4d8 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -43,11 +43,7 @@ def _complex_dtypes(iet, lang): """ mapper = {} - for s in FindSymbols('indexeds').visit(iet): - if s.dtype in lang['types']: - mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) - - for s in FindSymbols().visit(iet): + for s in FindSymbols('indexeds|basics|symbolics').visit(iet): if s.dtype in lang['types']: mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index bd5e0e6413..2cee279428 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,16 +1,29 @@ +import ctypes as ct import numpy as np from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.tools import CustomNpType +from devito.tools.dtypes_lowering import ctypes_vector_mapper + __all__ = ['CBB', 'CDataManager', 'COrchestrator'] -CCFloat = CustomNpType('_Complex float', np.complex64) -CCDouble = CustomNpType('_Complex double', np.complex128) +class CCFloat(np.complex64): + pass + + +class CCDouble(np.complex128): + pass + + +c_complex = type('_Complex float', (ct.c_double,), {}) +c_double_complex = type('_Complex double', (ct.c_longdouble,), {}) + +ctypes_vector_mapper[CCFloat] = c_complex +ctypes_vector_mapper[CCDouble] = c_double_complex class CBB(LangBB): diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 5f74070472..fb802acb8b 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,8 +1,9 @@ +import ctypes as ct import numpy as np from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.tools import CustomNpType +from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] @@ -43,8 +44,21 @@ """ -CXXCFloat = CustomNpType('std::complex', np.complex64, template='float') -CXXCDouble = CustomNpType('std::complex', np.complex128, template='double') + +class CXXCFloat(np.complex64): + pass + + +class CXXCDouble(np.complex128): + pass + + +cxx_complex = type('std::complex', (ct.c_double,), {}) +cxx_double_complex = type('std::complex', (ct.c_longdouble,), {}) + + +ctypes_vector_mapper[CXXCFloat] = cxx_complex +ctypes_vector_mapper[CXXCDouble] = cxx_double_complex class CXXBB(LangBB): diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index c9c73ed0b4..77bc407dd6 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -48,7 +48,7 @@ def single_prec(self, expr=None, with_f=False): if no_f and expr is not None: return False dtype = sympy_dtype(expr) if expr is not None else self.dtype - return dtype in [np.float32, np.float16, np.complex64] + return any(issubclass(dtype, d) for d in [np.float32, np.float16, np.complex64]) def complex_prec(self, expr=None): if self.compiler._cpp: diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 3d04f73e84..43def2d8cd 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -13,7 +13,7 @@ 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', - 'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomNpType'] + 'is_external_ctype', 'infer_dtype', 'CustomDtype'] # *** Custom np.dtypes @@ -123,18 +123,6 @@ def __repr__(self): __str__ = __repr__ -class CustomNpType(CustomDtype): - """ - Custom dtype for underlying numpy type. - """ - - def __init__(self, name, nptype, template=None, modifier=None): - self.nptype = nptype - super().__init__(name, template, modifier) - - def __call__(self, val): - return self.nptype(val) - # *** np.dtypes lowering From ecadec3e85ea93fc27942b6a199bcece91a4bfaf Mon Sep 17 00:00:00 2001 From: enwask Date: Tue, 9 Jul 2024 19:23:37 +0100 Subject: [PATCH 12/29] compiler: use structs to pass complex arguments --- devito/ir/iet/visitors.py | 3 ++- devito/passes/iet/languages/C.py | 10 +++++----- devito/passes/iet/languages/CXX.py | 5 +++-- devito/symbolics/extended_dtypes.py | 26 +++++++++++++++++++++++++- 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 6e9879d873..23aab0f29a 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -20,6 +20,7 @@ from devito.ir.support.space import Backward from devito.symbolics import (FieldFromComposite, FieldFromPointer, ListInitializer, ccode, uxreplace) +from devito.symbolics.extended_dtypes import NoDeclStruct from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered, filter_sorted, flatten, is_external_ctype, c_restrict_void_p, sorted_priority) @@ -208,7 +209,7 @@ def _gen_struct_decl(self, obj, masked=()): while issubclass(ctype, ctypes._Pointer): ctype = ctype._type_ - if not issubclass(ctype, ctypes.Structure): + if not issubclass(ctype, ctypes.Structure) or issubclass(ctype, NoDeclStruct): return None except TypeError: # E.g., `ctype` is of type `dtypes_lowering.CustomDtype` diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 2cee279428..1c233b0ff8 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,10 +1,10 @@ -import ctypes as ct import numpy as np from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB +from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.tools.dtypes_lowering import ctypes_vector_mapper @@ -19,11 +19,11 @@ class CCDouble(np.complex128): pass -c_complex = type('_Complex float', (ct.c_double,), {}) -c_double_complex = type('_Complex double', (ct.c_longdouble,), {}) +c99_complex = type('_Complex float', (c_complex,), {}) +c99_double_complex = type('_Complex double', (c_double_complex,), {}) -ctypes_vector_mapper[CCFloat] = c_complex -ctypes_vector_mapper[CCDouble] = c_double_complex +ctypes_vector_mapper[CCFloat] = c99_complex +ctypes_vector_mapper[CCDouble] = c99_double_complex class CBB(LangBB): diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index fb802acb8b..88ed923640 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -3,6 +3,7 @@ from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB +from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] @@ -53,9 +54,9 @@ class CXXCDouble(np.complex128): pass -cxx_complex = type('std::complex', (ct.c_double,), {}) -cxx_double_complex = type('std::complex', (ct.c_longdouble,), {}) +cxx_complex = type('std::complex', (c_complex,), {}) +cxx_double_complex = type('std::complex', (c_double_complex,), {}) ctypes_vector_mapper[CXXCFloat] = cxx_complex ctypes_vector_mapper[CXXCDouble] = cxx_double_complex diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 0e8ce0cc98..d63ca92bf5 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -1,10 +1,11 @@ +import ctypes as ct import numpy as np from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4) -__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa +__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'c_complex', 'c_double_complex'] # noqa limits_mapper = { @@ -15,6 +16,29 @@ } +class NoDeclStruct(ct.Structure): + # ctypes.Structure that does not generate a struct definition + pass + + +class c_complex(NoDeclStruct): + # Structure for passing complex float to C/C++ + _fields_ = [('real', ct.c_float), ('imag', ct.c_float)] + + @classmethod + def from_param(cls, val): + return cls(val.real, val.imag) + + +class c_double_complex(NoDeclStruct): + # Structure for passing complex double to C/C++ + _fields_ = [('real', ct.c_double), ('imag', ct.c_double)] + + @classmethod + def from_param(cls, val): + return cls(val.real, val.imag) + + class CustomType(ReservedWord): pass From 27ff82aeba35578c3f69800cafe26529b76894f4 Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 13:05:11 +0100 Subject: [PATCH 13/29] compiler: add Dereference scalar case --- devito/ir/iet/nodes.py | 23 +++++++++++++++++------ devito/ir/iet/visitors.py | 3 +++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 9bcc3460f6..ee577e939f 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1,6 +1,7 @@ """The Iteration/Expression Tree (IET) hierarchy.""" import abc +import ctypes import inspect from functools import cached_property from collections import OrderedDict, namedtuple @@ -1030,6 +1031,9 @@ class Dereference(ExprStmt, Node): * `pointer` is a PointerArray or TempFunction, and `pointee` is an Array. * `pointer` is an ArrayObject representing a pointer to a C struct, and `pointee` is a field in `pointer`. + * `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and + `pointee` is a Symbol representing the dereferenced value. + """ is_Dereference = True @@ -1048,13 +1052,20 @@ def functions(self): @property def expr_symbols(self): - ret = [self.pointer.indexed] - if self.pointer.is_PointerArray or self.pointer.is_TempFunction: - ret.append(self.pointee.indexed) - ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:])) - ret.extend(self.pointer.free_symbols) - else: + ret = [] + if self.pointer.is_Symbol: + assert (issubclass(self.pointer._C_ctype, ctypes._Pointer), + "Scalar dereference must have a pointer ctype") + ret.append(self.pointer._C_symbol) ret.append(self.pointee._C_symbol) + else: + ret.append(self.pointer.indexed) + if self.pointer.is_PointerArray or self.pointer.is_TempFunction: + ret.append(self.pointee.indexed) + ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:])) + ret.extend(self.pointer.free_symbols) + else: + ret.append(self.pointee._C_symbol) return tuple(filter_ordered(ret)) @property diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 23aab0f29a..fef876b817 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -454,6 +454,9 @@ def visit_Dereference(self, o): lvalue = c.Value(cstr, '*restrict %s' % a0.name) if a0._data_alignment: lvalue = c.AlignedAttribute(a0._data_alignment, lvalue) + elif a1.is_Symbol: + rvalue = '*%s' % a1.name + lvalue = self._gen_value(a0, 0) else: rvalue = '%s->%s' % (a1.name, a0._C_name) lvalue = self._gen_value(a0, 0) From 066279d4cf68595024a8b61c383bf31d6076927c Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 13:51:55 +0100 Subject: [PATCH 14/29] compiler: implement float16 support --- devito/ir/iet/visitors.py | 4 +- devito/passes/iet/definitions.py | 3 +- devito/passes/iet/dtypes.py | 70 +++++++++++++++++++++-------- devito/passes/iet/languages/C.py | 18 ++++++-- devito/passes/iet/languages/CXX.py | 17 +++++-- devito/symbolics/extended_dtypes.py | 31 +++++++++++-- devito/tools/dtypes_lowering.py | 3 ++ 7 files changed, 116 insertions(+), 30 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index fef876b817..cfd0ee3892 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -24,7 +24,7 @@ from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered, filter_sorted, flatten, is_external_ctype, c_restrict_void_p, sorted_priority) -from devito.types.basic import AbstractFunction, Basic +from devito.types.basic import AbstractFunction, AbstractSymbol, Basic from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer, IndexedData, DeviceMap) @@ -961,6 +961,7 @@ def default_retval(cls): Drive the search. Accepted: - `symbolics`: Collect all AbstractFunction objects, default - `basics`: Collect all Basic objects + - `scalars`: Collect all AbstractSymbol objects - `dimensions`: Collect all Dimensions - `indexeds`: Collect all Indexed objects - `indexedbases`: Collect all IndexedBase objects @@ -981,6 +982,7 @@ def _defines_aliases(n): rules = { 'symbolics': lambda n: n.functions, 'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)], + 'scalars': lambda n: [i for i in n.expr_symbols if isinstance(i, AbstractSymbol)], 'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)], 'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed], 'indexedbases': lambda n: [i for i in n.expr_symbols diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 81a0168d58..120c1ca6e5 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -12,7 +12,7 @@ from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create -from devito.passes.iet.dtypes import lower_complex +from devito.passes.iet.dtypes import lower_complex, lower_scalar_half from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -414,6 +414,7 @@ def place_casts(self, iet, **kwargs): @iet_pass def make_langtypes(self, iet): + iet, _ = lower_scalar_half(iet, self.lang, self.sregistry) iet, metadata = lower_complex(iet, self.lang, self.compiler) return iet, metadata diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 57eb10c4d8..6f35883423 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -2,8 +2,44 @@ import ctypes from devito.ir import FindSymbols, Uxreplace +from devito.ir.iet.nodes import Dereference +from devito.tools.utils import as_tuple +from devito.types.basic import Symbol -__all__ = ['lower_complex'] +__all__ = ['lower_scalar_half', 'lower_complex'] + + +def lower_scalar_half(iet, lang, sregistry): + """ + Lower half float scalars to pointers (special case, since we can't + pass them directly for lack of a ctypes equivalent) + """ + if lang.get('half_types') is None: + return iet, {} + + # dtype mappings for float16 + half, half_p = lang['half_types'] + + body = [] # derefs to prepend to the body + body_mapper = {} + params_mapper = {} + + for s in FindSymbols('scalars').visit(iet): + if s.dtype != np.float16 or s not in iet.parameters: + continue + + ptr = s._rebuild(dtype=half_p) + val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half, is_const=True) + + params_mapper[s] = ptr + body_mapper[s] = val + body.append(Dereference(val, ptr)) # val = *ptr + + body.extend(as_tuple(Uxreplace(body_mapper).visit(iet.body))) + params = Uxreplace(params_mapper).visit(iet.parameters) + + iet = iet._rebuild(body=body, parameters=params) + return iet, {} def lower_complex(iet, lang, compiler): @@ -14,30 +50,28 @@ def lower_complex(iet, lang, compiler): types = {f.dtype for f in FindSymbols().visit(iet) if not issubclass(f.dtype, ctypes._Pointer)} - if not any(np.issubdtype(d, np.complexfloating) for d in types): - return iet, {} - - lib = (lang['header-complex'],) - metadata = {} - if lang.get('complex-namespace') is not None: - metadata['namespaces'] = lang['complex-namespace'] + if any(np.issubdtype(d, np.complexfloating) for d in types): + lib = (lang['header-complex'],) + + if lang.get('complex-namespace') is not None: + metadata['namespaces'] = lang['complex-namespace'] - # Some languges such as c++11 need some extra arithmetic definitions - if lang.get('def-complex'): - dest = compiler.get_jit_dir() - hfile = dest.joinpath('complex_arith.h') - with open(str(hfile), 'w') as ff: - ff.write(str(lang['def-complex'])) - lib += (str(hfile),) + # Some languges such as c++11 need some extra arithmetic definitions + if lang.get('def-complex'): + dest = compiler.get_jit_dir() + hfile = dest.joinpath('complex_arith.h') + with open(str(hfile), 'w') as ff: + ff.write(str(lang['def-complex'])) + lib += (str(hfile),) - iet = _complex_dtypes(iet, lang) - metadata['includes'] = lib + metadata['includes'] = lib + iet = _lower_dtypes(iet, lang) return iet, metadata -def _complex_dtypes(iet, lang): +def _lower_dtypes(iet, lang): """ Lower dtypes to language specific types """ diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 1c233b0ff8..57e7864c11 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,10 +1,11 @@ +from ctypes import c_float import numpy as np from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import c_complex, c_double_complex +from devito.symbolics.extended_dtypes import c_complex, c_double_complex, c_float16, c_float16_p from devito.tools.dtypes_lowering import ctypes_vector_mapper @@ -19,11 +20,21 @@ class CCDouble(np.complex128): pass +class CHalf(np.float16): + pass + + +class CHalfP(np.float16): + pass + + c99_complex = type('_Complex float', (c_complex,), {}) c99_double_complex = type('_Complex double', (c_double_complex,), {}) ctypes_vector_mapper[CCFloat] = c99_complex ctypes_vector_mapper[CCDouble] = c99_double_complex +ctypes_vector_mapper[CHalf] = c_float16 +ctypes_vector_mapper[CHalfP] = c_float16_p class CBB(LangBB): @@ -40,9 +51,10 @@ class CBB(LangBB): Call('free', (i,)), 'alloc-global-symbol': lambda i, j, k: Call('memcpy', (i, j, k)), - # Complex + # Complex and float16 'header-complex': 'complex.h', - 'types': {np.complex128: CCDouble, np.complex64: CCFloat}, + 'types': {np.complex128: CCDouble, np.complex64: CCFloat, np.float16: CHalf}, + 'half_types': (CHalf, CHalfP), } diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 88ed923640..c207b793c2 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,9 +1,8 @@ -import ctypes as ct import numpy as np from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import c_complex, c_double_complex +from devito.symbolics.extended_dtypes import c_complex, c_double_complex, c_float16, c_float16_p from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] @@ -54,12 +53,21 @@ class CXXCDouble(np.complex128): pass +class CXXHalf(np.float16): + pass + + +class CXXHalfP(np.float16): + pass + cxx_complex = type('std::complex', (c_complex,), {}) cxx_double_complex = type('std::complex', (c_double_complex,), {}) ctypes_vector_mapper[CXXCFloat] = cxx_complex ctypes_vector_mapper[CXXCDouble] = cxx_double_complex +ctypes_vector_mapper[CXXHalf] = c_float16 +ctypes_vector_mapper[CXXHalfP] = c_float16_p class CXXBB(LangBB): @@ -76,9 +84,10 @@ class CXXBB(LangBB): Call('free', (i,)), 'alloc-global-symbol': lambda i, j, k: Call('memcpy', (i, j, k)), - # Complex + # Complex and float16 'header-complex': 'complex', 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, - 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, + 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat, np.float16: CXXHalf}, + 'half_types': (CXXHalf, CXXHalfP), } diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index d63ca92bf5..e7bb595ed9 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -5,7 +5,9 @@ from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4) -__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'c_complex', 'c_double_complex'] # noqa +__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', + 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', + 'c_float16', 'c_float16_p'] limits_mapper = { @@ -28,7 +30,7 @@ class c_complex(NoDeclStruct): @classmethod def from_param(cls, val): return cls(val.real, val.imag) - + class c_double_complex(NoDeclStruct): # Structure for passing complex double to C/C++ @@ -37,7 +39,30 @@ class c_double_complex(NoDeclStruct): @classmethod def from_param(cls, val): return cls(val.real, val.imag) - + + +class _c_half(ct.c_uint16): + # Ctype for non-scalar half floats + @classmethod + def from_param(cls, val): + return cls(np.float16(val).view(np.uint16)) + + +c_float16 = type('_Float16', (_c_half,), {}) + + +class _c_half_p(ct.POINTER(c_float16)): + # Ctype for half scalars; we can't directly pass _Float16 values so + # we use a pointer and dereference (see `passes.iet.dtypes`) + @classmethod + def from_param(cls, val): + arr = np.array(val, dtype=np.float16) + return arr.ctypes.data_as(cls) + + +# ctypes directly parses class dict; can't inherit the _type_ attribute +c_float16_p = type('_Float16 *', (_c_half_p,), {'_type_': c_float16}) + class CustomType(ReservedWord): pass diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 43def2d8cd..98f6a0e23d 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -145,6 +145,9 @@ def dtype_to_ctype(dtype): # Bypass np.ctypeslib's normalization rules such as # `np.ctypeslib.as_ctypes_type(ctypes.c_void_p) -> ctypes.c_ulong` return dtype + elif dtype == np.float16: + # Allocator wants a ctype before float16 gets mapped + return ctypes.c_uint16 else: return np.ctypeslib.as_ctypes_type(dtype) From 7af930b1446b6aaa4a0fce0ac55c9ff847509efb Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 14:26:08 +0100 Subject: [PATCH 15/29] symbolics: fix printer for half precision --- devito/symbolics/printer.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 77bc407dd6..f92f0b24a5 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -48,7 +48,14 @@ def single_prec(self, expr=None, with_f=False): if no_f and expr is not None: return False dtype = sympy_dtype(expr) if expr is not None else self.dtype - return any(issubclass(dtype, d) for d in [np.float32, np.float16, np.complex64]) + return any(issubclass(dtype, d) for d in [np.float32, np.complex64]) + + def half_prec(self, expr=None, with_f=False): + no_f = self.compiler._cpp and not with_f + if no_f and expr is not None: + return False + dtype = sympy_dtype(expr) if expr is not None else self.dtype + return issubclass(dtype, np.float16) def complex_prec(self, expr=None): if self.compiler._cpp: @@ -121,7 +128,7 @@ def _print_math_func(self, expr, nest=False, known=None): except KeyError: return super()._print_math_func(expr, nest=nest, known=known) - if self.single_prec(expr): + if self.single_prec(expr) or self.half_prec(expr): cname = '%sf' % cname if self.complex_prec(expr): cname = 'c%s' % cname @@ -137,6 +144,9 @@ def _print_Pow(self, expr): if expr.exp == -1 and self.single_prec(): PREC = precedence(expr) return '1.0F/%s' % self.parenthesize(expr.base, PREC) + if expr.exp == -1 and self.half_prec(): + PREC = precedence(expr) + return '1.0F16/%s' % self.parenthesize(expr.base, PREC) except AttributeError: pass return super()._print_Pow(expr) @@ -216,6 +226,8 @@ def _print_Float(self, expr): if self.single_prec(): rv = '%sF' % rv + elif self.half_prec(): + rv = '%sF16' % rv return rv @@ -223,6 +235,8 @@ def _print_ImaginaryUnit(self, expr): if self.compiler._cpp: if self.single_prec(with_f=True): return '1if' + elif self.half_prec(with_f=True): + return '1if16' else: return '1i' else: @@ -280,7 +294,7 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) - if self.single_prec(): + if self.single_prec() or self.half_prec(): func_name = '%sf' % func_name if self.complex_prec(): func_name = 'c%s' % func_name From 5b7efcc1913219d963159c06267248d06cc49a2a Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 14:36:42 +0100 Subject: [PATCH 16/29] misc: fix formatting --- devito/ir/iet/nodes.py | 8 ++++---- devito/passes/iet/dtypes.py | 2 +- devito/passes/iet/languages/C.py | 4 ++-- devito/passes/iet/languages/CXX.py | 6 ++++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index ee577e939f..61eb4ba6b0 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1033,7 +1033,6 @@ class Dereference(ExprStmt, Node): `pointee` is a field in `pointer`. * `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and `pointee` is a Symbol representing the dereferenced value. - """ is_Dereference = True @@ -1054,15 +1053,16 @@ def functions(self): def expr_symbols(self): ret = [] if self.pointer.is_Symbol: - assert (issubclass(self.pointer._C_ctype, ctypes._Pointer), - "Scalar dereference must have a pointer ctype") + assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \ + "Scalar dereference must have a pointer ctype" ret.append(self.pointer._C_symbol) ret.append(self.pointee._C_symbol) else: ret.append(self.pointer.indexed) if self.pointer.is_PointerArray or self.pointer.is_TempFunction: ret.append(self.pointee.indexed) - ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:])) + ret.extend(flatten(i.free_symbols + for i in self.pointee.symbolic_shape[1:])) ret.extend(self.pointer.free_symbols) else: ret.append(self.pointee._C_symbol) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 6f35883423..1fe98edfd8 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -16,7 +16,7 @@ def lower_scalar_half(iet, lang, sregistry): """ if lang.get('half_types') is None: return iet, {} - + # dtype mappings for float16 half, half_p = lang['half_types'] diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 57e7864c11..572d4a86cd 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,11 +1,11 @@ -from ctypes import c_float import numpy as np from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import c_complex, c_double_complex, c_float16, c_float16_p +from devito.symbolics.extended_dtypes import (c_complex, c_double_complex, + c_float16, c_float16_p) from devito.tools.dtypes_lowering import ctypes_vector_mapper diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index c207b793c2..30e5ab689a 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -2,7 +2,8 @@ from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import c_complex, c_double_complex, c_float16, c_float16_p +from devito.symbolics.extended_dtypes import (c_complex, c_double_complex, + c_float16, c_float16_p) from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] @@ -88,6 +89,7 @@ class CXXBB(LangBB): 'header-complex': 'complex', 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, - 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat, np.float16: CXXHalf}, + 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat, + np.float16: CXXHalf}, 'half_types': (CXXHalf, CXXHalfP), } From a09350fbfe0b360f41b9d8c7cf085c6d5b74bbf8 Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 17:57:50 +0100 Subject: [PATCH 17/29] compiler: refactor float16 and lower_dtypes --- devito/passes/iet/definitions.py | 5 +- devito/passes/iet/dtypes.py | 96 +++++++++++++---------------- devito/passes/iet/languages/C.py | 7 ++- devito/passes/iet/languages/CXX.py | 4 +- devito/symbolics/extended_dtypes.py | 13 +--- devito/symbolics/printer.py | 4 +- 6 files changed, 56 insertions(+), 73 deletions(-) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 120c1ca6e5..a1596f20be 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -12,7 +12,7 @@ from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create -from devito.passes.iet.dtypes import lower_complex, lower_scalar_half +from devito.passes.iet.dtypes import lower_dtypes from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -414,8 +414,7 @@ def place_casts(self, iet, **kwargs): @iet_pass def make_langtypes(self, iet): - iet, _ = lower_scalar_half(iet, self.lang, self.sregistry) - iet, metadata = lower_complex(iet, self.lang, self.compiler) + iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry) return iet, metadata def process(self, graph): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 1fe98edfd8..a2d0899224 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -3,46 +3,52 @@ from devito.ir import FindSymbols, Uxreplace from devito.ir.iet.nodes import Dereference -from devito.tools.utils import as_tuple +from devito.tools.utils import as_list from devito.types.basic import Symbol -__all__ = ['lower_scalar_half', 'lower_complex'] +__all__ = ['lower_dtypes'] -def lower_scalar_half(iet, lang, sregistry): +def lower_dtypes(iet, lang, compiler, sregistry): """ - Lower half float scalars to pointers (special case, since we can't - pass them directly for lack of a ctypes equivalent) + Lower language-specific dtypes and add headers for complex arithmetic """ - if lang.get('half_types') is None: - return iet, {} + # Include complex headers if needed (before we replace complex dtypes) + metadata = _complex_includes(iet, lang, compiler) - # dtype mappings for float16 - half, half_p = lang['half_types'] - - body = [] # derefs to prepend to the body + body_prefix = [] # Derefs to prepend to the body body_mapper = {} params_mapper = {} - for s in FindSymbols('scalars').visit(iet): - if s.dtype != np.float16 or s not in iet.parameters: - continue + # Lower scalar float16s to pointers and dereference them + if lang.get('half_types') is not None: + half, half_p = lang['half_types'] # dtype mappings for half float + + for s in FindSymbols('scalars').visit(iet): + if s.dtype != np.float16 or s not in iet.parameters: + continue - ptr = s._rebuild(dtype=half_p) - val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half, is_const=True) + ptr = s._rebuild(dtype=half_p, is_const=True) + val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half, + is_const=s.is_const) - params_mapper[s] = ptr - body_mapper[s] = val - body.append(Dereference(val, ptr)) # val = *ptr + params_mapper[s], body_mapper[s] = ptr, val + body_prefix.append(Dereference(val, ptr)) # val = *ptr + + # Lower remaining language-specific dtypes + for s in FindSymbols('indexeds|basics|symbolics').visit(iet): + if s.dtype in lang['types'] and s not in params_mapper: + body_mapper[s] = params_mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) - body.extend(as_tuple(Uxreplace(body_mapper).visit(iet.body))) + # Apply the dtype replacements + body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body)) params = Uxreplace(params_mapper).visit(iet.parameters) iet = iet._rebuild(body=body, parameters=params) - return iet, {} + return iet, metadata -def lower_complex(iet, lang, compiler): +def _complex_includes(iet, lang, compiler): """ Add headers for complex arithmetic """ @@ -50,39 +56,23 @@ def lower_complex(iet, lang, compiler): types = {f.dtype for f in FindSymbols().visit(iet) if not issubclass(f.dtype, ctypes._Pointer)} - metadata = {} - if any(np.issubdtype(d, np.complexfloating) for d in types): - lib = (lang['header-complex'],) - - if lang.get('complex-namespace') is not None: - metadata['namespaces'] = lang['complex-namespace'] - - # Some languges such as c++11 need some extra arithmetic definitions - if lang.get('def-complex'): - dest = compiler.get_jit_dir() - hfile = dest.joinpath('complex_arith.h') - with open(str(hfile), 'w') as ff: - ff.write(str(lang['def-complex'])) - lib += (str(hfile),) - - metadata['includes'] = lib - - iet = _lower_dtypes(iet, lang) - return iet, metadata + if not any(np.issubdtype(d, np.complexfloating) for d in types): + return {} + metadata = {} + lib = (lang['header-complex'],) -def _lower_dtypes(iet, lang): - """ - Lower dtypes to language specific types - """ - mapper = {} + if lang.get('complex-namespace') is not None: + metadata['namespaces'] = lang['complex-namespace'] - for s in FindSymbols('indexeds|basics|symbolics').visit(iet): - if s.dtype in lang['types']: - mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) + # Some languges such as c++11 need some extra arithmetic definitions + if lang.get('def-complex'): + dest = compiler.get_jit_dir() + hfile = dest.joinpath('complex_arith.h') + with open(str(hfile), 'w') as ff: + ff.write(str(lang['def-complex'])) + lib += (str(hfile),) - body = Uxreplace(mapper).visit(iet.body) - params = Uxreplace(mapper).visit(iet.parameters) - iet = iet._rebuild(body=body, parameters=params) + metadata['includes'] = lib - return iet + return metadata diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 572d4a86cd..6112c3e895 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -5,11 +5,11 @@ from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB from devito.symbolics.extended_dtypes import (c_complex, c_double_complex, - c_float16, c_float16_p) + c_half, c_half_p) from devito.tools.dtypes_lowering import ctypes_vector_mapper -__all__ = ['CBB', 'CDataManager', 'COrchestrator'] +__all__ = ['CBB', 'CDataManager', 'COrchestrator', 'c_float16', 'c_float16_p'] class CCFloat(np.complex64): @@ -31,6 +31,9 @@ class CHalfP(np.float16): c99_complex = type('_Complex float', (c_complex,), {}) c99_double_complex = type('_Complex double', (c_double_complex,), {}) +c_float16 = type('_Float16', (c_half,), {}) +c_float16_p = type('_Float16 *', (c_half_p,), {'_type_': c_float16}) + ctypes_vector_mapper[CCFloat] = c99_complex ctypes_vector_mapper[CCDouble] = c99_double_complex ctypes_vector_mapper[CHalf] = c_float16 diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 30e5ab689a..48fdb4471b 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -2,8 +2,8 @@ from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import (c_complex, c_double_complex, - c_float16, c_float16_p) +from devito.passes.iet.languages.C import c_float16, c_float16_p +from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index e7bb595ed9..f6265e4938 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -7,7 +7,7 @@ __all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', - 'c_float16', 'c_float16_p'] + 'c_half', 'c_half_p'] limits_mapper = { @@ -41,17 +41,14 @@ def from_param(cls, val): return cls(val.real, val.imag) -class _c_half(ct.c_uint16): +class c_half(ct.c_uint16): # Ctype for non-scalar half floats @classmethod def from_param(cls, val): return cls(np.float16(val).view(np.uint16)) -c_float16 = type('_Float16', (_c_half,), {}) - - -class _c_half_p(ct.POINTER(c_float16)): +class c_half_p(ct.POINTER(c_half)): # Ctype for half scalars; we can't directly pass _Float16 values so # we use a pointer and dereference (see `passes.iet.dtypes`) @classmethod @@ -60,10 +57,6 @@ def from_param(cls, val): return arr.ctypes.data_as(cls) -# ctypes directly parses class dict; can't inherit the _type_ attribute -c_float16_p = type('_Float16 *', (_c_half_p,), {'_type_': c_float16}) - - class CustomType(ReservedWord): pass diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index f92f0b24a5..81ddf637f1 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -233,10 +233,8 @@ def _print_Float(self, expr): def _print_ImaginaryUnit(self, expr): if self.compiler._cpp: - if self.single_prec(with_f=True): + if self.single_prec(with_f=True) or self.half_prec(with_f=True): return '1if' - elif self.half_prec(with_f=True): - return '1if16' else: return '1i' else: From d4c94542a26d0f9a0543e9650381d9810b37f98d Mon Sep 17 00:00:00 2001 From: enwask Date: Thu, 11 Jul 2024 18:40:05 +0100 Subject: [PATCH 18/29] compiler: add dtype_alloc_ctype helper for allocation size --- devito/data/allocators.py | 9 +++------ devito/tools/dtypes_lowering.py | 33 ++++++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/devito/data/allocators.py b/devito/data/allocators.py index 14f1b04fd1..ff5301b387 100644 --- a/devito/data/allocators.py +++ b/devito/data/allocators.py @@ -11,7 +11,8 @@ from devito.logger import logger from devito.parameters import configuration -from devito.tools import dtype_to_ctype, is_integer +from devito.tools import is_integer +from devito.tools.dtypes_lowering import dtype_alloc_ctype __all__ = ['ALLOC_ALIGNED', 'ALLOC_NUMA_LOCAL', 'ALLOC_NUMA_ANY', 'ALLOC_KNL_MCDRAM', 'ALLOC_KNL_DRAM', 'ALLOC_GUARD', @@ -92,12 +93,8 @@ def initialize(cls): return def alloc(self, shape, dtype, padding=0): - # For complex number, allocate double the size of its real/imaginary part - alloc_dtype = dtype(0).real.__class__ - c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1 - + ctype, c_scale = dtype_alloc_ctype(dtype) datasize = int(reduce(mul, shape) * c_scale) - ctype = dtype_to_ctype(alloc_dtype) # Add padding, if any try: diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 98f6a0e23d..17ed550ebb 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -11,8 +11,8 @@ __all__ = ['int2', 'int3', 'int4', 'float2', 'float3', 'float4', 'double2', # noqa 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', - 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', - 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', + 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_alloc_ctype', 'dtype_to_mpitype', + 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', 'is_external_ctype', 'infer_dtype', 'CustomDtype'] @@ -145,13 +145,36 @@ def dtype_to_ctype(dtype): # Bypass np.ctypeslib's normalization rules such as # `np.ctypeslib.as_ctypes_type(ctypes.c_void_p) -> ctypes.c_ulong` return dtype - elif dtype == np.float16: - # Allocator wants a ctype before float16 gets mapped - return ctypes.c_uint16 else: return np.ctypeslib.as_ctypes_type(dtype) +def dtype_alloc_ctype(dtype): + """ + Translate numpy.dtype to (ctype, int): type and scale for correct C allocation size. + """ + if isinstance(dtype, CustomDtype): + return dtype, 1 + + try: + return ctypes_vector_mapper[dtype], 1 + except KeyError: + pass + + if issubclass(dtype, ctypes._SimpleCData): + return dtype, 1 + + if dtype == np.float16: + # Allocate half float as unsigned short + return ctypes.c_uint16, 1 + + if np.issubdtype(dtype, np.complexfloating): + # For complex float, allocate twice the size of real/imaginary part + return np.ctypeslib.as_ctypes_type(dtype(0).real.__class__), 2 + + return np.ctypeslib.as_ctypes_type(dtype), 1 + + def dtype_to_mpitype(dtype): """Map numpy types to MPI datatypes.""" From d3169d089bc8539024d1d1a13788cee8094e204d Mon Sep 17 00:00:00 2001 From: enwask Date: Mon, 15 Jul 2024 13:53:44 +0100 Subject: [PATCH 19/29] misc: more float16 refactoring/formatting fixes --- devito/data/allocators.py | 3 +-- devito/ir/iet/nodes.py | 3 +-- devito/symbolics/extended_dtypes.py | 31 ++++++++++++++++++----------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/devito/data/allocators.py b/devito/data/allocators.py index ff5301b387..afac7b7b4f 100644 --- a/devito/data/allocators.py +++ b/devito/data/allocators.py @@ -11,8 +11,7 @@ from devito.logger import logger from devito.parameters import configuration -from devito.tools import is_integer -from devito.tools.dtypes_lowering import dtype_alloc_ctype +from devito.tools import is_integer, dtype_alloc_ctype __all__ = ['ALLOC_ALIGNED', 'ALLOC_NUMA_LOCAL', 'ALLOC_NUMA_ANY', 'ALLOC_KNL_MCDRAM', 'ALLOC_KNL_DRAM', 'ALLOC_GUARD', diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 61eb4ba6b0..a4341f3f35 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1055,8 +1055,7 @@ def expr_symbols(self): if self.pointer.is_Symbol: assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \ "Scalar dereference must have a pointer ctype" - ret.append(self.pointer._C_symbol) - ret.append(self.pointee._C_symbol) + ret.extend([self.pointer._C_symbol, self.pointee._C_symbol]) else: ret.append(self.pointer.indexed) if self.pointer.is_PointerArray or self.pointer.is_TempFunction: diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index f6265e4938..85256a3f94 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -1,4 +1,4 @@ -import ctypes as ct +import ctypes import numpy as np from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit @@ -18,14 +18,16 @@ } -class NoDeclStruct(ct.Structure): - # ctypes.Structure that does not generate a struct definition +class NoDeclStruct(ctypes.Structure): + """A ctypes.Structure that does not generate a struct definition""" + pass class c_complex(NoDeclStruct): - # Structure for passing complex float to C/C++ - _fields_ = [('real', ct.c_float), ('imag', ct.c_float)] + """Structure for passing complex float to C/C++""" + + _fields_ = [('real', ctypes.c_float), ('imag', ctypes.c_float)] @classmethod def from_param(cls, val): @@ -33,24 +35,29 @@ def from_param(cls, val): class c_double_complex(NoDeclStruct): - # Structure for passing complex double to C/C++ - _fields_ = [('real', ct.c_double), ('imag', ct.c_double)] + """Structure for passing complex double to C/C++""" + + _fields_ = [('real', ctypes.c_double), ('imag', ctypes.c_double)] @classmethod def from_param(cls, val): return cls(val.real, val.imag) -class c_half(ct.c_uint16): - # Ctype for non-scalar half floats +class c_half(ctypes.c_uint16): + """Ctype for non-scalar half floats""" + @classmethod def from_param(cls, val): return cls(np.float16(val).view(np.uint16)) -class c_half_p(ct.POINTER(c_half)): - # Ctype for half scalars; we can't directly pass _Float16 values so - # we use a pointer and dereference (see `passes.iet.dtypes`) +class c_half_p(ctypes.POINTER(c_half)): + """ + Ctype for half scalars; we can't directly pass _Float16 values so + we use a pointer and dereference (see `passes.iet.dtypes`) + """ + @classmethod def from_param(cls, val): arr = np.array(val, dtype=np.float16) From 493c1e878c8783cc5c416f1ea16df35323fe6e4b Mon Sep 17 00:00:00 2001 From: enwask Date: Tue, 16 Jul 2024 13:55:28 +0100 Subject: [PATCH 20/29] Remove dtypes lowering from IET layer --- devito/passes/iet/definitions.py | 10 +++--- devito/passes/iet/dtypes.py | 54 ++++---------------------------- 2 files changed, 11 insertions(+), 53 deletions(-) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index a1596f20be..bf03a8a71d 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -12,7 +12,7 @@ from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create -from devito.passes.iet.dtypes import lower_dtypes +from devito.passes.iet.dtypes import include_complex from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -413,8 +413,8 @@ def place_casts(self, iet, **kwargs): return iet, {} @iet_pass - def make_langtypes(self, iet): - iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry) + def include_complex(self, iet): + iet, metadata = include_complex(iet, self.lang, self.compiler) return iet, metadata def process(self, graph): @@ -423,7 +423,7 @@ def process(self, graph): """ self.place_definitions(graph, globs=set()) self.place_casts(graph) - self.make_langtypes(graph) + self.include_complex(graph) class DeviceAwareDataManager(DataManager): @@ -573,7 +573,7 @@ def process(self, graph): self.place_devptr(graph) self.place_bundling(graph, writes_input=graph.writes_input) self.place_casts(graph) - self.make_langtypes(graph) + self.include_complex(graph) def make_zero_init(obj): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index a2d0899224..789f49d5b4 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -1,63 +1,21 @@ import numpy as np import ctypes -from devito.ir import FindSymbols, Uxreplace -from devito.ir.iet.nodes import Dereference -from devito.tools.utils import as_list -from devito.types.basic import Symbol +from devito.ir import FindSymbols -__all__ = ['lower_dtypes'] +__all__ = ['include_complex'] -def lower_dtypes(iet, lang, compiler, sregistry): +def include_complex(iet, lang, compiler): """ - Lower language-specific dtypes and add headers for complex arithmetic - """ - # Include complex headers if needed (before we replace complex dtypes) - metadata = _complex_includes(iet, lang, compiler) - - body_prefix = [] # Derefs to prepend to the body - body_mapper = {} - params_mapper = {} - - # Lower scalar float16s to pointers and dereference them - if lang.get('half_types') is not None: - half, half_p = lang['half_types'] # dtype mappings for half float - - for s in FindSymbols('scalars').visit(iet): - if s.dtype != np.float16 or s not in iet.parameters: - continue - - ptr = s._rebuild(dtype=half_p, is_const=True) - val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half, - is_const=s.is_const) - - params_mapper[s], body_mapper[s] = ptr, val - body_prefix.append(Dereference(val, ptr)) # val = *ptr - - # Lower remaining language-specific dtypes - for s in FindSymbols('indexeds|basics|symbolics').visit(iet): - if s.dtype in lang['types'] and s not in params_mapper: - body_mapper[s] = params_mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) - - # Apply the dtype replacements - body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body)) - params = Uxreplace(params_mapper).visit(iet.parameters) - - iet = iet._rebuild(body=body, parameters=params) - return iet, metadata - - -def _complex_includes(iet, lang, compiler): - """ - Add headers for complex arithmetic + Include complex arithmetic headers for the given language, if needed. """ # Check if there is complex numbers that always take dtype precedence types = {f.dtype for f in FindSymbols().visit(iet) if not issubclass(f.dtype, ctypes._Pointer)} if not any(np.issubdtype(d, np.complexfloating) for d in types): - return {} + return iet, {} metadata = {} lib = (lang['header-complex'],) @@ -75,4 +33,4 @@ def _complex_includes(iet, lang, compiler): metadata['includes'] = lib - return metadata + return iet, metadata From 516b4ad0adc0f9eca18f5bd072a80051557cccf4 Mon Sep 17 00:00:00 2001 From: enwask Date: Fri, 26 Jul 2024 16:53:11 +0100 Subject: [PATCH 21/29] compiler: reimplement float16/complex lowering --- devito/operator/operator.py | 10 +++++++ devito/passes/iet/definitions.py | 10 +++---- devito/passes/iet/dtypes.py | 44 +++++++++++++++++++++++++++-- devito/passes/iet/languages/C.py | 30 ++++---------------- devito/passes/iet/languages/CXX.py | 31 ++++---------------- devito/symbolics/extended_dtypes.py | 12 +++++++- 6 files changed, 78 insertions(+), 59 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index ba411c0ea6..97507f4693 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -25,10 +25,12 @@ from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, error_mapper, is_on_device) +from devito.passes.iet.langbase import LangBB from devito.symbolics import estimate_cost, subs_op_args from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple, flatten, filter_sorted, frozendict, is_integer, split, timed_pass, timed_region, contains_val) +from devito.tools.dtypes_lowering import ctypes_vector_mapper from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer, disk_layer) @@ -264,6 +266,9 @@ def _lower(cls, expressions, **kwargs): # expression for which a partial or complete lowering is desired kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs) + # Load language-specific types into the global dtype->ctype mapper + cls._load_dtype_mappings(**kwargs) + # [Eq] -> [LoweredEq] expressions = cls._lower_exprs(expressions, **kwargs) @@ -285,6 +290,11 @@ def _lower(cls, expressions, **kwargs): def _rcompile_wrapper(cls, **kwargs0): raise NotImplementedError + @classmethod + def _load_dtype_mappings(cls, **kwargs): + lang: type[LangBB] = cls._Target.DataManager.lang + ctypes_vector_mapper.update(lang.mapper.get('types', {})) + @classmethod def _initialize_state(cls, **kwargs): return {} diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index bf03a8a71d..807d485699 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -12,7 +12,7 @@ from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create -from devito.passes.iet.dtypes import include_complex +from devito.passes.iet.dtypes import lower_dtypes from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -413,8 +413,8 @@ def place_casts(self, iet, **kwargs): return iet, {} @iet_pass - def include_complex(self, iet): - iet, metadata = include_complex(iet, self.lang, self.compiler) + def lower_dtypes(self, iet): + iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry) return iet, metadata def process(self, graph): @@ -423,7 +423,7 @@ def process(self, graph): """ self.place_definitions(graph, globs=set()) self.place_casts(graph) - self.include_complex(graph) + self.lower_dtypes(graph) class DeviceAwareDataManager(DataManager): @@ -573,7 +573,7 @@ def process(self, graph): self.place_devptr(graph) self.place_bundling(graph, writes_input=graph.writes_input) self.place_casts(graph) - self.include_complex(graph) + self.lower_dtypes(graph) def make_zero_init(obj): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 789f49d5b4..f4f73e7663 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -2,11 +2,51 @@ import ctypes from devito.ir import FindSymbols +from devito.ir.iet.nodes import Dereference +from devito.ir.iet.visitors import Uxreplace +from devito.symbolics.extended_dtypes import Float16P +from devito.tools.utils import as_list +from devito.types.basic import Symbol -__all__ = ['include_complex'] +__all__ = ['lower_dtypes'] -def include_complex(iet, lang, compiler): +def lower_dtypes(iet, lang, compiler, sregistry): + """ + Lowers float16 scalar types to pointers since we can't directly pass their + value. Also includes headers for complex arithmetic if needed. + """ + + iet, metadata = _complex_includes(iet, lang, compiler) + + # Lower float16 parameters to pointers and dereference + body_prefix = [] + body_mapper = {} + params_mapper = {} + + # Lower scalar float16s to pointers and dereference them + for s in FindSymbols('scalars').visit(iet): + if not np.issubdtype(s.dtype, np.float16) or s not in iet.parameters: + continue + + # Replace the parameter with a pointer; replace occurences in the IET + # body with a dereference (using the original symbol's dtype) + ptr = s._rebuild(dtype=Float16P, is_const=True) + val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=s.dtype, + is_const=s.is_const) + + params_mapper[s], body_mapper[s] = ptr, val + body_prefix.append(Dereference(val, ptr)) # val = *ptr + + # Apply the replacements + body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body)) + params = Uxreplace(params_mapper).visit(iet.parameters) + + iet = iet._rebuild(body=body, parameters=params) + return iet, metadata + + +def _complex_includes(iet, lang, compiler): """ Include complex arithmetic headers for the given language, if needed. """ diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 6112c3e895..069aa10320 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -4,41 +4,19 @@ from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import (c_complex, c_double_complex, +from devito.symbolics.extended_dtypes import (Float16P, c_complex, c_double_complex, c_half, c_half_p) -from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CBB', 'CDataManager', 'COrchestrator', 'c_float16', 'c_float16_p'] -class CCFloat(np.complex64): - pass - - -class CCDouble(np.complex128): - pass - - -class CHalf(np.float16): - pass - - -class CHalfP(np.float16): - pass - - c99_complex = type('_Complex float', (c_complex,), {}) c99_double_complex = type('_Complex double', (c_double_complex,), {}) c_float16 = type('_Float16', (c_half,), {}) c_float16_p = type('_Float16 *', (c_half_p,), {'_type_': c_float16}) -ctypes_vector_mapper[CCFloat] = c99_complex -ctypes_vector_mapper[CCDouble] = c99_double_complex -ctypes_vector_mapper[CHalf] = c_float16 -ctypes_vector_mapper[CHalfP] = c_float16_p - class CBB(LangBB): @@ -56,8 +34,10 @@ class CBB(LangBB): Call('memcpy', (i, j, k)), # Complex and float16 'header-complex': 'complex.h', - 'types': {np.complex128: CCDouble, np.complex64: CCFloat, np.float16: CHalf}, - 'half_types': (CHalf, CHalfP), + 'types': {np.complex128: c99_double_complex, + np.complex64: c99_complex, + np.float16: c_float16, + Float16P: c_float16_p} } diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 48fdb4471b..1174a27f8d 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -3,8 +3,7 @@ from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB from devito.passes.iet.languages.C import c_float16, c_float16_p -from devito.symbolics.extended_dtypes import c_complex, c_double_complex -from devito.tools.dtypes_lowering import ctypes_vector_mapper +from devito.symbolics.extended_dtypes import Float16P, c_complex, c_double_complex __all__ = ['CXXBB'] @@ -46,30 +45,9 @@ """ -class CXXCFloat(np.complex64): - pass - - -class CXXCDouble(np.complex128): - pass - - -class CXXHalf(np.float16): - pass - - -class CXXHalfP(np.float16): - pass - - cxx_complex = type('std::complex', (c_complex,), {}) cxx_double_complex = type('std::complex', (c_double_complex,), {}) -ctypes_vector_mapper[CXXCFloat] = cxx_complex -ctypes_vector_mapper[CXXCDouble] = cxx_double_complex -ctypes_vector_mapper[CXXHalf] = c_float16 -ctypes_vector_mapper[CXXHalfP] = c_float16_p - class CXXBB(LangBB): @@ -89,7 +67,8 @@ class CXXBB(LangBB): 'header-complex': 'complex', 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, - 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat, - np.float16: CXXHalf}, - 'half_types': (CXXHalf, CXXHalfP), + "types": {np.complex128: cxx_double_complex, + np.complex64: cxx_complex, + np.float16: c_float16, + Float16P: c_float16_p} } diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 85256a3f94..0b8b1bcad1 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -7,7 +7,7 @@ __all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', - 'c_half', 'c_half_p'] + 'c_half', 'c_half_p', 'Float16P'] limits_mapper = { @@ -64,6 +64,16 @@ def from_param(cls, val): return arr.ctypes.data_as(cls) +class Float16P(np.float16): + """ + Dummy dtype for a scalar float16 value that's been mapped to a pointer. + This is needed because we can't directly pass in the values; we map to + pointers and dereference in the kernel; see `passes.iet.dtypes`. + """ + + pass + + class CustomType(ReservedWord): pass From 079facbd7193b1a1879c79ded613e616b7384331 Mon Sep 17 00:00:00 2001 From: enwask Date: Mon, 29 Jul 2024 12:43:36 +0100 Subject: [PATCH 22/29] misc: cleanup, docs and typing for half support --- devito/ir/iet/nodes.py | 14 ++++++-------- devito/ir/iet/visitors.py | 5 +---- devito/passes/iet/dtypes.py | 29 ++++++++++++++++------------- devito/symbolics/extended_dtypes.py | 8 +++++++- devito/symbolics/printer.py | 3 +-- 5 files changed, 31 insertions(+), 28 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index a4341f3f35..85c40b5c40 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1056,15 +1056,13 @@ def expr_symbols(self): assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \ "Scalar dereference must have a pointer ctype" ret.extend([self.pointer._C_symbol, self.pointee._C_symbol]) + elif self.pointer.is_PointerArray or self.pointer.is_TempFunction: + ret.extend([self.pointer.indexed, self.pointee.indexed]) + ret.extend(flatten(i.free_symbols + for i in self.pointee.symbolic_shape[1:])) + ret.extend(self.pointer.free_symbols) else: - ret.append(self.pointer.indexed) - if self.pointer.is_PointerArray or self.pointer.is_TempFunction: - ret.append(self.pointee.indexed) - ret.extend(flatten(i.free_symbols - for i in self.pointee.symbolic_shape[1:])) - ret.extend(self.pointer.free_symbols) - else: - ret.append(self.pointee._C_symbol) + ret.extend([self.pointer.indexed, self.pointee._C_symbol]) return tuple(filter_ordered(ret)) @property diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index cfd0ee3892..31e5fb4e90 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -454,11 +454,8 @@ def visit_Dereference(self, o): lvalue = c.Value(cstr, '*restrict %s' % a0.name) if a0._data_alignment: lvalue = c.AlignedAttribute(a0._data_alignment, lvalue) - elif a1.is_Symbol: - rvalue = '*%s' % a1.name - lvalue = self._gen_value(a0, 0) else: - rvalue = '%s->%s' % (a1.name, a0._C_name) + rvalue = '*%s' % a1.name if a1.is_Symbol else '%s->%s' % (a1.name, a0._C_name) lvalue = self._gen_value(a0, 0) return c.Initializer(lvalue, rvalue) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index f4f73e7663..03093c18a1 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -1,17 +1,18 @@ -import numpy as np import ctypes +import numpy as np -from devito.ir import FindSymbols -from devito.ir.iet.nodes import Dereference -from devito.ir.iet.visitors import Uxreplace +from devito.arch.compiler import Compiler +from devito.ir import Callable, Dereference, FindSymbols, SymbolRegistry, Uxreplace +from devito.passes.iet.langbase import LangBB from devito.symbolics.extended_dtypes import Float16P -from devito.tools.utils import as_list -from devito.types.basic import Symbol +from devito.tools import as_list +from devito.types import Symbol __all__ = ['lower_dtypes'] -def lower_dtypes(iet, lang, compiler, sregistry): +def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, + sregistry: SymbolRegistry) -> tuple[Callable, dict]: """ Lowers float16 scalar types to pointers since we can't directly pass their value. Also includes headers for complex arithmetic if needed. @@ -20,13 +21,14 @@ def lower_dtypes(iet, lang, compiler, sregistry): iet, metadata = _complex_includes(iet, lang, compiler) # Lower float16 parameters to pointers and dereference - body_prefix = [] + prefix = [] body_mapper = {} params_mapper = {} # Lower scalar float16s to pointers and dereference them + params = set(iet.parameters) for s in FindSymbols('scalars').visit(iet): - if not np.issubdtype(s.dtype, np.float16) or s not in iet.parameters: + if s.dtype != np.float16 or s not in params: continue # Replace the parameter with a pointer; replace occurences in the IET @@ -36,17 +38,18 @@ def lower_dtypes(iet, lang, compiler, sregistry): is_const=s.is_const) params_mapper[s], body_mapper[s] = ptr, val - body_prefix.append(Dereference(val, ptr)) # val = *ptr + prefix.append(Dereference(val, ptr)) # val = *ptr # Apply the replacements - body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body)) + prefix.extend(as_list(Uxreplace(body_mapper).visit(iet.body))) params = Uxreplace(params_mapper).visit(iet.parameters) - iet = iet._rebuild(body=body, parameters=params) + iet = iet._rebuild(body=prefix, parameters=params) return iet, metadata -def _complex_includes(iet, lang, compiler): +def _complex_includes(iet: Callable, lang: type[LangBB], + compiler: Compiler) -> tuple[Callable, dict]: """ Include complex arithmetic headers for the given language, if needed. """ diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 0b8b1bcad1..af2da5d353 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -19,7 +19,13 @@ class NoDeclStruct(ctypes.Structure): - """A ctypes.Structure that does not generate a struct definition""" + """ + A ctypes.Structure that does not generate a struct definition. + + Some foreign types (e.g. complex) need to be passed to C/C++ as a struct + that mimics an existing type, but the struct types themselves don't show + up in the kernel, so we don't need to generate their definitions. + """ pass diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 81ddf637f1..54ce2c0a78 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -141,11 +141,10 @@ def _print_Pow(self, expr): # Need to override because of issue #1627 # E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x try: + PREC = precedence(expr) if expr.exp == -1 and self.single_prec(): - PREC = precedence(expr) return '1.0F/%s' % self.parenthesize(expr.base, PREC) if expr.exp == -1 and self.half_prec(): - PREC = precedence(expr) return '1.0F16/%s' % self.parenthesize(expr.base, PREC) except AttributeError: pass From cc08bcd03edc1f8fd91bd66346d31f8ca3475ba2 Mon Sep 17 00:00:00 2001 From: enwask Date: Mon, 29 Jul 2024 12:46:33 +0100 Subject: [PATCH 23/29] compiler: FindSymbols 'scalars' -> 'abstractsymbols' --- devito/ir/iet/visitors.py | 5 +++-- devito/passes/iet/dtypes.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 31e5fb4e90..f90c1f6724 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -958,7 +958,7 @@ def default_retval(cls): Drive the search. Accepted: - `symbolics`: Collect all AbstractFunction objects, default - `basics`: Collect all Basic objects - - `scalars`: Collect all AbstractSymbol objects + - `abstractsymbols`: Collect all AbstractSymbol objects - `dimensions`: Collect all Dimensions - `indexeds`: Collect all Indexed objects - `indexedbases`: Collect all IndexedBase objects @@ -979,7 +979,8 @@ def _defines_aliases(n): rules = { 'symbolics': lambda n: n.functions, 'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)], - 'scalars': lambda n: [i for i in n.expr_symbols if isinstance(i, AbstractSymbol)], + 'abstractsymbols': lambda n: [i for i in n.expr_symbols + if isinstance(i, AbstractSymbol)], 'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)], 'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed], 'indexedbases': lambda n: [i for i in n.expr_symbols diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 03093c18a1..216f989ad6 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -27,7 +27,7 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, # Lower scalar float16s to pointers and dereference them params = set(iet.parameters) - for s in FindSymbols('scalars').visit(iet): + for s in FindSymbols('abstractsymbols').visit(iet): if s.dtype != np.float16 or s not in params: continue From f9c04ac2e2dc5e1033dd6cc2e0c91d72fa3cab24 Mon Sep 17 00:00:00 2001 From: enwask Date: Tue, 30 Jul 2024 11:22:57 +0100 Subject: [PATCH 24/29] test: include scalar parameters in complex tests --- tests/test_gpu_common.py | 10 ++++++---- tests/test_operator.py | 8 +++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 2e26b78c22..e9d3b1d9aa 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -71,18 +71,20 @@ def test_maxpar_option(self): def test_complex(self, dtype): grid = Grid((5, 5)) x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype) u = Function(name="u", grid=grid, dtype=dtype) - eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) + eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing) * c) op = Operator(eq) - op() + op(c=1.0 + 2.0j) # Check against numpy dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) - npres = xx + 1j*yy + np.exp(1j + dx) + npres = xx + 1j*yy + np.exp(1j + dx) * (1.0 + 2.0j) - assert np.allclose(u.data, npres.T, rtol=1e-6, atol=0) + assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) class TestPassesOptional: diff --git a/tests/test_operator.py b/tests/test_operator.py index 283249aac1..4282165314 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -644,16 +644,18 @@ def test_tensor(self, func1): def test_complex(self, dtype): grid = Grid((5, 5)) x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype) u = Function(name="u", grid=grid, dtype=dtype) - eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) + eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing) * c) op = Operator(eq) - op() + op(c=1.0 + 2.0j) # Check against numpy dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) - npres = xx + 1j*yy + np.exp(1j + dx) + npres = xx + 1j*yy + np.exp(1j + dx) * (1.0 + 2.0j) assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) From e519a9ae2be8fb0524b48fb4cc2d034395f2da50 Mon Sep 17 00:00:00 2001 From: enwask Date: Tue, 30 Jul 2024 11:23:49 +0100 Subject: [PATCH 25/29] test: add test_dtypes with initial tests for float16 + complex --- tests/test_dtypes.py | 202 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 tests/test_dtypes.py diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py new file mode 100644 index 0000000000..32f87449c7 --- /dev/null +++ b/tests/test_dtypes.py @@ -0,0 +1,202 @@ +import numpy as np +import pytest +import sympy + +from devito import Constant, Eq, Function, Grid, Operator +from devito.ir.iet.nodes import Dereference +from devito.passes.iet.langbase import LangBB +from devito.passes.iet.languages.C import CBB +from devito.passes.iet.languages.openacc import AccBB +from devito.passes.iet.languages.openmp import OmpBB +from devito.symbolics.extended_dtypes import Float16P +from devito.types.basic import Basic, Scalar, Symbol +from devito.types.dimension import Dimension, Spacing + +# Mappers for language-specific types and headers +_languages: dict[str, type[LangBB]] = { + 'C': CBB, + 'openmp': OmpBB, + 'openacc': AccBB +} + + +def _get_language(language: str, **_) -> type[LangBB]: + """ + Gets the language building block type from parametrized kwargs. + """ + + return _languages[language] + + +def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str]: + """ + Generates kwargs for Operator to test language-specific behavior. + """ + + return { + 'platform': platform, + 'language': language, + 'compiler': compiler + } + + +# List of pararmetrized operator kwargs for testing language-specific behavior +_configs: list[dict[str, str]] = [ + _config_kwargs(*cfg) for cfg in [ + ('cpu64', 'C', 'gcc'), + ('cpu64', 'openmp', 'gcc'), + ('nvidiaX', 'openmp', 'nvc'), + ('nvidiaX', 'openacc', 'nvc') + ] +] + + +@pytest.mark.parametrize('dtype', [np.float16, np.complex64, np.complex128]) +@pytest.mark.parametrize('kwargs', _configs) +def test_dtype_mapping(dtype: np.dtype, kwargs: dict[str, str]) -> None: + """ + Tests that half and complex floats' dtypes result in the correct type + strings in generated code. + """ + + # Retrieve the language-specific type mapping + lang_types: dict[np.dtype, type] = _get_language(**kwargs).get('types') + + # Set up an operator + grid = Grid(shape=(3, 3)) + x: Dimension + y: Dimension + x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype) + u = Function(name='u', grid=grid, dtype=dtype) + eq = Eq(u, c * x * y) + op = Operator(eq, **kwargs) + + # Check ctypes of the mapped parameters + params: dict[str, Basic] = {p.name: p for p in op.parameters} + _u: Function = params['u'] + _c: Constant = params['c'] + assert _u.indexed._C_ctype._type_ == lang_types[_u.dtype] + assert _c._C_ctype == lang_types[_c.dtype] + + +def test_half_params() -> None: + """ + Tests float16 input parameters: scalars should be lowered to pointers + and dereferenced; other parameters should keep the original dtype. + """ + + grid = Grid(shape=(5, 5), dtype=np.float16) + x: Dimension + y: Dimension + x, y = grid.dimensions + + c = Constant(name='c', dtype=np.float16) + u = Function(name='u', grid=grid) + eq = Eq(u, x * x.spacing + c * y * y.spacing) + op = Operator(eq) + + # Check that lowered parameters have the correct dtypes + params: dict[str, Basic] = {p.name: p for p in op.parameters} + _u: Function = params['u'] + _c: Constant = params['c'] + _dx: Spacing = params['h_x'] + _dy: Spacing = params['h_y'] + + assert _u.dtype == np.float16 + assert _c.dtype == Float16P + assert _dx.dtype == Float16P + assert _dy.dtype == Float16P + + # Ensure the mapped pointer-to-half symbols are dereferenced + derefs: set[Symbol] = {n.pointer for n in op.body.body + if isinstance(n, Dereference)} + assert _c in derefs + assert _dx in derefs + assert _dy in derefs + + +@pytest.mark.parametrize('dtype', [np.float16, np.float32, + np.complex64, np.complex128]) +@pytest.mark.parametrize('kwargs', _configs) +def test_complex_headers(dtype: np.dtype, kwargs: dict[str, str]) -> None: + """ + Tests that the correct complex headers are included when complex dtypes + are present in the operator, and omitted otherwise. + """ + + # Set up an operator + grid = Grid(shape=(3, 3)) + x: Dimension + y: Dimension + x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype) + u = Function(name='u', grid=grid, dtype=dtype) + eq = Eq(u, c * x * y) + op = Operator(eq, **kwargs) + + # Check that the complex header is included <=> complex dtypes are present + header: str = _get_language(**kwargs).get('header-complex') + if np.issubdtype(dtype, np.complexfloating): + assert header in op._includes + else: + assert header not in op._includes + + +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +@pytest.mark.parametrize('kwargs', _configs) +def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: + """ + Tests that the correct literal is used for the imaginary unit. + """ + + # Determine the expected imaginary unit string + unit_str: str + if kwargs['compiler'] == 'gcc': + # In C we multiply by the _Complex_I macro constant + unit_str = '_Complex_I' + else: + # C++ provides an imaginary literal + unit_str = '1if' if dtype == np.complex64 else '1i' + + # Set up an operator + s = Symbol(name='s', dtype=dtype) + eq = Eq(s, 2.0 + 3.0j) + op = Operator(eq, **kwargs) + + # Check that the correct imaginary unit is used + assert unit_str in str(op) + + +@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64, + np.complex64, np.complex128]) +@pytest.mark.parametrize(['sym', 'fun'], [(sympy.exp, np.exp), + (sympy.log, np.log), + (sympy.sin, np.sin)]) +def test_math_functions(dtype: np.dtype, sym: sympy.Function, fun: np.ufunc) -> None: + """ + Tests that the correct math functions are used, and their results cast + and assigned appropriately for different float precisions and for + complex floats/doubles. + """ + + # Get the expected function call string + call_str = str(sym) + if np.issubdtype(dtype, np.complexfloating): + # Complex functions have a 'c' prefix + call_str = 'c%s' % call_str + if dtype(0).real.itemsize <= 4: + # Single precision have an 'f' suffix (half is promoted to single) + call_str = '%sf' % call_str + + # Operator setup + a = Symbol(name='a', dtype=dtype) + b = Scalar(name='b', dtype=dtype) + + eq = Eq(a, sym(b)) + op = Operator(eq) + + # Ensure the generated function call has the correct form + assert call_str in str(op) From 6725c9acdd09130bc37ecd9d218b7e56886ac4bd Mon Sep 17 00:00:00 2001 From: enwask Date: Tue, 30 Jul 2024 11:55:41 +0100 Subject: [PATCH 26/29] misc: more lower_dtypes cleanup + type hints --- devito/passes/iet/dtypes.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 216f989ad6..6f698a659d 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -2,11 +2,11 @@ import numpy as np from devito.arch.compiler import Compiler -from devito.ir import Callable, Dereference, FindSymbols, SymbolRegistry, Uxreplace +from devito.ir import Callable, Dereference, FindSymbols, Node, SymbolRegistry, Uxreplace from devito.passes.iet.langbase import LangBB from devito.symbolics.extended_dtypes import Float16P from devito.tools import as_list -from devito.types import Symbol +from devito.types.basic import AbstractSymbol, Basic, Symbol __all__ = ['lower_dtypes'] @@ -21,19 +21,19 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, iet, metadata = _complex_includes(iet, lang, compiler) # Lower float16 parameters to pointers and dereference - prefix = [] - body_mapper = {} - params_mapper = {} + prefix: list[Node] = [] + params_mapper: dict[AbstractSymbol, AbstractSymbol] = {} + body_mapper: dict[AbstractSymbol, Symbol] = {} - # Lower scalar float16s to pointers and dereference them - params = set(iet.parameters) + params_set = set(iet.parameters) + s: AbstractSymbol for s in FindSymbols('abstractsymbols').visit(iet): - if s.dtype != np.float16 or s not in params: + if s.dtype != np.float16 or s not in params_set: continue # Replace the parameter with a pointer; replace occurences in the IET - # body with a dereference (using the original symbol's dtype) - ptr = s._rebuild(dtype=Float16P, is_const=True) + # body with dereferenced symbol (using the original symbol's dtype) + ptr: AbstractSymbol = s._rebuild(dtype=Float16P, is_const=True) val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=s.dtype, is_const=s.is_const) @@ -42,7 +42,7 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, # Apply the replacements prefix.extend(as_list(Uxreplace(body_mapper).visit(iet.body))) - params = Uxreplace(params_mapper).visit(iet.parameters) + params: tuple[Basic] = Uxreplace(params_mapper).visit(iet.parameters) iet = iet._rebuild(body=prefix, parameters=params) return iet, metadata @@ -51,9 +51,10 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, def _complex_includes(iet: Callable, lang: type[LangBB], compiler: Compiler) -> tuple[Callable, dict]: """ - Include complex arithmetic headers for the given language, if needed. + Includes complex arithmetic headers for the given language, if needed. """ - # Check if there is complex numbers that always take dtype precedence + + # Check if there are complex numbers that always take dtype precedence types = {f.dtype for f in FindSymbols().visit(iet) if not issubclass(f.dtype, ctypes._Pointer)} From b57949a50bd3820c99b4e3bcc1ad72268629ea74 Mon Sep 17 00:00:00 2001 From: enwask Date: Wed, 31 Jul 2024 15:33:59 +0100 Subject: [PATCH 27/29] api: use grid dtype for extent and origin, add test_grid --- devito/types/grid.py | 11 +++++++---- tests/test_grid.py | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 tests/test_grid.py diff --git a/devito/types/grid.py b/devito/types/grid.py index 523d044e2b..9282f54889 100644 --- a/devito/types/grid.py +++ b/devito/types/grid.py @@ -69,9 +69,10 @@ class Grid(CartesianDiscretization, ArgProvider): ---------- shape : tuple of ints Shape of the computational domain in grid points. - extent : tuple of floats, default=unit box of extent 1m in all dimensions + extent : tuple of values interpretable as dtype, default=unit box of extent 1m + in all dimensions Physical extent of the domain in m. - origin : tuple of floats, default=0.0 in all dimensions + origin : tuple of values interpretable as dtype, default=0.0 in all dimensions Physical coordinate of the origin of the domain. dimensions : tuple of SpaceDimension, optional The dimensions of the computational domain encapsulated by this Grid. @@ -168,7 +169,8 @@ def __init__(self, shape, extent=None, origin=None, dimensions=None, self._distributor = Distributor(shape, dimensions, comm, topology) # The physical extent - self._extent = as_tuple(extent or tuple(1. for _ in self.shape)) + extent = as_tuple(extent or tuple(1. for _ in self.shape)) + self._extent = tuple(dtype(e) for e in extent) # Initialize SubDomains subdomains = tuple(i for i in (Domain(), Interior(), *as_tuple(subdomains))) @@ -176,7 +178,8 @@ def __init__(self, shape, extent=None, origin=None, dimensions=None, i.__subdomain_finalize__(self, counter=counter) self._subdomains = subdomains - self._origin = as_tuple(origin or tuple(0. for _ in self.shape)) + origin = as_tuple(origin or tuple(0. for _ in self.shape)) + self._origin = tuple(dtype(o) for o in origin) self._origin_symbols = tuple(Scalar(name='o_%s' % d.name, dtype=dtype, is_const=True) for d in self.dimensions) diff --git a/tests/test_grid.py b/tests/test_grid.py new file mode 100644 index 0000000000..5753e17d7b --- /dev/null +++ b/tests/test_grid.py @@ -0,0 +1,27 @@ +import numpy as np +import pytest + +from devito import Grid + + +# Unsigned ints are unreasonable but not necessarily invalid +@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64, np.longdouble, + np.complex64, np.complex128, np.int8, np.int16, + np.int32, np.int64, np.uint8, np.uint16, np.uint32, + np.uint64]) +def test_extent_dtypes(dtype: np.dtype[np.number]) -> None: + """ + Test that grid spacings are correctly computed for different dtypes. + """ + + # Construct a grid with the dtype and retrieve the spacing values + extent = (1, 1j) if np.issubdtype(dtype, np.complexfloating) else (2, 4) + grid = Grid(shape=(5, 5), extent=extent, dtype=dtype) + dx, dy = grid.spacing_map.values() + + # Check that the spacings have the correct dtype + assert dx.dtype == dy.dtype == dtype + + # Check that the spacings have the correct values + assert dx == dtype(extent[0] / 4) + assert dy == dtype(extent[1] / 4) From ec98a02fa7d97c3df4481500e14c0695d74ab3d1 Mon Sep 17 00:00:00 2001 From: enwask Date: Wed, 31 Jul 2024 15:57:07 +0100 Subject: [PATCH 28/29] test: clean up and add more half/complex tests --- devito/symbolics/extended_dtypes.py | 4 +- tests/test_dtypes.py | 164 ++++++++++++++++++++++++---- 2 files changed, 147 insertions(+), 21 deletions(-) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index af2da5d353..8c90e48986 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -72,9 +72,9 @@ def from_param(cls, val): class Float16P(np.float16): """ - Dummy dtype for a scalar float16 value that's been mapped to a pointer. + Dummy dtype for a scalar half value that has been mapped to a pointer. This is needed because we can't directly pass in the values; we map to - pointers and dereference in the kernel; see `passes.iet.dtypes`. + pointers and dereference in the kernel. See `passes.iet.dtypes`. """ pass diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 32f87449c7..a330096b36 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import re import sympy from devito import Constant, Eq, Function, Grid, Operator @@ -9,8 +10,9 @@ from devito.passes.iet.languages.openacc import AccBB from devito.passes.iet.languages.openmp import OmpBB from devito.symbolics.extended_dtypes import Float16P +from devito.tools import ctypes_to_cstr from devito.types.basic import Basic, Scalar, Symbol -from devito.types.dimension import Dimension, Spacing +from devito.types.dense import TimeFunction # Mappers for language-specific types and headers _languages: dict[str, type[LangBB]] = { @@ -45,7 +47,6 @@ def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str _config_kwargs(*cfg) for cfg in [ ('cpu64', 'C', 'gcc'), ('cpu64', 'openmp', 'gcc'), - ('nvidiaX', 'openmp', 'nvc'), ('nvidiaX', 'openacc', 'nvc') ] ] @@ -53,7 +54,7 @@ def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str @pytest.mark.parametrize('dtype', [np.float16, np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) -def test_dtype_mapping(dtype: np.dtype, kwargs: dict[str, str]) -> None: +def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: """ Tests that half and complex floats' dtypes result in the correct type strings in generated code. @@ -64,8 +65,6 @@ def test_dtype_mapping(dtype: np.dtype, kwargs: dict[str, str]) -> None: # Set up an operator grid = Grid(shape=(3, 3)) - x: Dimension - y: Dimension x, y = grid.dimensions c = Constant(name='c', dtype=dtype) @@ -75,12 +74,38 @@ def test_dtype_mapping(dtype: np.dtype, kwargs: dict[str, str]) -> None: # Check ctypes of the mapped parameters params: dict[str, Basic] = {p.name: p for p in op.parameters} - _u: Function = params['u'] - _c: Constant = params['c'] + _u, _c = params['u'], params['c'] assert _u.indexed._C_ctype._type_ == lang_types[_u.dtype] assert _c._C_ctype == lang_types[_c.dtype] +@pytest.mark.parametrize('dtype', [np.float16, np.complex64, np.complex128]) +@pytest.mark.parametrize('kwargs', _configs) +def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: + """ + Tests that variables introduced by CSE have the correct type strings in + the generated code. + """ + + # Retrieve the language-specific type mapping + lang_types: dict[np.dtype, type] = _get_language(**kwargs).get('types') + + # Set up an operator + grid = Grid(shape=(3, 3)) + x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype) + u = Function(name='u', grid=grid, dtype=dtype) + # sin(c) should be CSE'd + eq = Eq(u, x * x.spacing + y * y.spacing * sympy.sin(c)) + op = Operator(eq, **kwargs) + + # Ensure the CSE'd variable has the correct type + match = re.search(r'[^\S\n\r]*(.*\S)\sr0 = ', str(op)) + assert match is not None + assert match.group(1) == ctypes_to_cstr(lang_types[dtype]) + + def test_half_params() -> None: """ Tests float16 input parameters: scalars should be lowered to pointers @@ -88,8 +113,6 @@ def test_half_params() -> None: """ grid = Grid(shape=(5, 5), dtype=np.float16) - x: Dimension - y: Dimension x, y = grid.dimensions c = Constant(name='c', dtype=np.float16) @@ -99,10 +122,7 @@ def test_half_params() -> None: # Check that lowered parameters have the correct dtypes params: dict[str, Basic] = {p.name: p for p in op.parameters} - _u: Function = params['u'] - _c: Constant = params['c'] - _dx: Spacing = params['h_x'] - _dy: Spacing = params['h_y'] + _u, _c, _dx, _dy = params['u'], params['c'], params['h_x'], params['h_y'] assert _u.dtype == np.float16 assert _c.dtype == Float16P @@ -120,7 +140,8 @@ def test_half_params() -> None: @pytest.mark.parametrize('dtype', [np.float16, np.float32, np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) -def test_complex_headers(dtype: np.dtype, kwargs: dict[str, str]) -> None: +def test_complex_headers(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: + np.dtype """ Tests that the correct complex headers are included when complex dtypes are present in the operator, and omitted otherwise. @@ -128,8 +149,6 @@ def test_complex_headers(dtype: np.dtype, kwargs: dict[str, str]) -> None: # Set up an operator grid = Grid(shape=(3, 3)) - x: Dimension - y: Dimension x, y = grid.dimensions c = Constant(name='c', dtype=dtype) @@ -158,8 +177,11 @@ def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: # In C we multiply by the _Complex_I macro constant unit_str = '_Complex_I' else: - # C++ provides an imaginary literal - unit_str = '1if' if dtype == np.complex64 else '1i' + # C++ provides imaginary literals + if dtype == np.complex64: + unit_str = '1if' + else: + unit_str = '1i' # Set up an operator s = Symbol(name='s', dtype=dtype) @@ -175,7 +197,8 @@ def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: @pytest.mark.parametrize(['sym', 'fun'], [(sympy.exp, np.exp), (sympy.log, np.log), (sympy.sin, np.sin)]) -def test_math_functions(dtype: np.dtype, sym: sympy.Function, fun: np.ufunc) -> None: +def test_math_functions(dtype: np.dtype[np.inexact], + sym: sympy.Function, fun: np.ufunc) -> None: """ Tests that the correct math functions are used, and their results cast and assigned appropriately for different float precisions and for @@ -200,3 +223,106 @@ def test_math_functions(dtype: np.dtype, sym: sympy.Function, fun: np.ufunc) -> # Ensure the generated function call has the correct form assert call_str in str(op) + + +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_complex_override(dtype: np.dtype[np.complexfloating]) -> None: + """ + Tests overriding complex values in op.apply(). + """ + + grid = Grid(shape=(5, 5)) + x, y = grid.dimensions + + c = Constant(name='c', dtype=dtype, value=1.0 + 0.0j) + u = Function(name='u', grid=grid, dtype=dtype) + eq = Eq(u, x * x.spacing + c * y * y.spacing) + op = Operator(eq) + op.apply(c=2.0 + 1.0j) + + # Check against numpy result + dx, dy = grid.spacing_map.values() + xx, yy = np.meshgrid(np.linspace(0, 4, 5, dtype=dtype), + np.linspace(0, 4, 5, dtype=dtype)) + expected = xx * dx + yy * dy * dtype(2.0 + 1.0j) + assert np.allclose(u.data.T, expected) + + +def test_half_time_deriv() -> None: + """ + Tests taking the time derivative of a float16 function. + """ + + grid = Grid(shape=(5, 5)) + x, y = grid.dimensions + t = grid.time_dim + + f = TimeFunction(name='f', grid=grid, space_order=2, dtype=np.float16) + g = Function(name='g', grid=grid, dtype=np.float16) + eqns = [Eq(f.forward, t * x * x.spacing + + y * y.spacing), + Eq(g, f.dt)] + op = Operator(eqns) + op.apply(time=10, dt=1.0) + + # Check against expected result + dx = grid.spacing_map[x.spacing] + xx = np.repeat(np.linspace(0, 4, 5, dtype=np.float16)[np.newaxis, :], 5, axis=0) + expected = xx * np.float16(dx) + assert np.allclose(g.data.T, expected) + + +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_complex_time_deriv(dtype: np.dtype[np.complexfloating]) -> None: + """ + Tests taking the time derivative of a complex-valued function. + """ + + grid = Grid(shape=(5, 5)) + x, y = grid.dimensions + t = grid.time_dim + + f = TimeFunction(name='f', grid=grid, space_order=2, dtype=dtype) + g = Function(name='g', grid=grid, dtype=dtype) + eqns = [Eq(f.forward, t * x * x.spacing * (1.0 + 0.0j) + + t * y * y.spacing * (0.0 + 1.0j)), + Eq(g, f.dt)] + op = Operator(eqns) + op.apply(time=10, dt=1.0) + + # Check against expected result + dx, dy = grid.spacing_map.values() + xx, yy = np.meshgrid(np.linspace(0, 4, 5, dtype=dtype), + np.linspace(0, 4, 5, dtype=dtype)) + expected = xx * dx + yy * dy * dtype(0.0 + 1.0j) + assert np.allclose(g.data.T, expected) + + +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_complex_space_deriv(dtype: np.dtype[np.complexfloating]) -> None: + """ + Tests taking the space derivative of a complex-valued function, with + respect to the real and imaginary axes. + """ + + grid = Grid(shape=(7, 7), dtype=dtype) + x, y = grid.dimensions + + # Operator setup + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid) + h = Function(name='h', grid=grid) + eqns = [Eq(f, x * x.spacing + y * y.spacing), + Eq(g, f.dx, subdomain=grid.interior), + Eq(h, f.dy, subdomain=grid.interior)] + op = Operator(eqns) + + dx = 1.0 + 0.0j + dy = 0.0 + 1.0j + op.apply(h_x=dx, h_y=dy) + + # Check against expected result (1 within the interior) + dfdx = g.data.T[1:-1, 1:-1] + dfdy = h.data.T[1:-1, 1:-1] + assert np.allclose(dfdx, np.ones((5, 5), dtype=dtype)) + assert np.allclose(dfdy, np.ones((5, 5), dtype=dtype)) From 6a2f8e6e98cc714db0ce5d79c629dbe8bb7d712f Mon Sep 17 00:00:00 2001 From: enwask Date: Wed, 31 Jul 2024 18:47:55 +0100 Subject: [PATCH 29/29] test: fix test_grid_objs, add test_grid_dtypes --- tests/test_caching.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_caching.py b/tests/test_caching.py index f4346706ea..0767d6323a 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -442,8 +442,16 @@ def test_grid_objs(self): assert y0 is y1 assert x0.spacing is x1.spacing assert y0.spacing is y1.spacing - assert ox0 is ox1 - assert oy0 is oy1 + + def test_grid_dtypes(self): + """ + Test that two grids with different dtypes have different hash values. + """ + + grid0 = Grid(shape=(4, 4), dtype=np.float32) + grid1 = Grid(shape=(4, 4), dtype=np.float64) + + assert hash(grid0) != hash(grid1) def test_special_symbols(self): """