Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Fix complex arguments and implement float16 lowering #2403

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2b28090
api: add support for complex dtype
mloubout Aug 2, 2023
aa353b4
api: fix printer for complex dtype
mloubout May 22, 2024
92dfd9a
compiler: fix alias dtype with complex numbers
mloubout May 22, 2024
4364524
api: move complex ctype to dtype lowering
mloubout May 22, 2024
470f4f5
compiler: generate std:complex for cpp compilers
mloubout May 28, 2024
7ffff0a
compiler: add std::complex arithmetic defs for unsupported types
mloubout May 30, 2024
d1dd24e
compiler: fix alias dtype with complex numbers
mloubout May 30, 2024
4f43f26
compiler: fix internal language specific types and cast
mloubout May 31, 2024
6b4f12d
compiler: rework dtype lowering
mloubout Jun 20, 2024
9abeea8
compiler: switch to c++14 for complex_literals
mloubout Jun 27, 2024
94d5571
compiler: subdtype numpy for dtype lowering
mloubout Jul 8, 2024
ecadec3
compiler: use structs to pass complex arguments
enwask Jul 9, 2024
27ff82a
compiler: add Dereference scalar case
enwask Jul 11, 2024
066279d
compiler: implement float16 support
enwask Jul 11, 2024
7af930b
symbolics: fix printer for half precision
enwask Jul 11, 2024
5b7efcc
misc: fix formatting
enwask Jul 11, 2024
a09350f
compiler: refactor float16 and lower_dtypes
enwask Jul 11, 2024
d4c9454
compiler: add dtype_alloc_ctype helper for allocation size
enwask Jul 11, 2024
d3169d0
misc: more float16 refactoring/formatting fixes
enwask Jul 15, 2024
493c1e8
Remove dtypes lowering from IET layer
enwask Jul 16, 2024
516b4ad
compiler: reimplement float16/complex lowering
enwask Jul 26, 2024
079facb
misc: cleanup, docs and typing for half support
enwask Jul 29, 2024
cc08bcd
compiler: FindSymbols 'scalars' -> 'abstractsymbols'
enwask Jul 29, 2024
f9c04ac
test: include scalar parameters in complex tests
enwask Jul 30, 2024
e519a9a
test: add test_dtypes with initial tests for float16 + complex
enwask Jul 30, 2024
6725c9a
misc: more lower_dtypes cleanup + type hints
enwask Jul 30, 2024
b57949a
api: use grid dtype for extent and origin, add test_grid
enwask Jul 31, 2024
ec98a02
test: clean up and add more half/complex tests
enwask Jul 31, 2024
6a2f8e6
test: fix test_grid_objs, add test_grid_dtypes
enwask Jul 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions devito/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def reinit_compiler(val):
"""
Re-initialize the Compiler.
"""
configuration['compiler'].__init__(suffix=configuration['compiler'].suffix,
configuration['compiler'].__init__(name=configuration['compiler'].name,
suffix=configuration['compiler'].suffix,
mpi=configuration['mpi'])
return val

Expand All @@ -65,7 +66,7 @@ def reinit_compiler(val):
configuration.add('platform', 'cpu64', list(platform_registry),
callback=lambda i: platform_registry[i]())
configuration.add('compiler', 'custom', list(compiler_registry),
callback=lambda i: compiler_registry[i]())
callback=lambda i: compiler_registry[i](name=i))

# Setup language for shared-memory parallelism
preprocessor = lambda i: {0: 'C', 1: 'openmp'}.get(i, i) # Handles DEVITO_OPENMP deprec
Expand Down
9 changes: 6 additions & 3 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def __init__(self):
_cpp = False

def __init__(self, **kwargs):
self._name = kwargs.pop('name', self.__class__.__name__)

super().__init__(**kwargs)

self.__lookup_cmds__()
Expand Down Expand Up @@ -223,13 +225,13 @@ def __new_with__(self, **kwargs):
Create a new Compiler from an existing one, inherenting from it
the flags that are not specified via ``kwargs``.
"""
return self.__class__(suffix=kwargs.pop('suffix', self.suffix),
return self.__class__(name=self.name, suffix=kwargs.pop('suffix', self.suffix),
mpi=kwargs.pop('mpi', configuration['mpi']),
**kwargs)

@property
def name(self):
return self.__class__.__name__
return self._name

@property
def version(self):
Expand Down Expand Up @@ -593,7 +595,7 @@ def __init_finalize__(self, **kwargs):
self.cflags.remove('-O3')
self.cflags.remove('-Wall')

self.cflags.append('-std=c++11')
self.cflags.append('-std=c++14')

language = kwargs.pop('language', configuration['language'])
platform = kwargs.pop('platform', configuration['platform'])
Expand Down Expand Up @@ -978,6 +980,7 @@ def __new_with__(self, **kwargs):
'nvc++': NvidiaCompiler,
'nvidia': NvidiaCompiler,
'cuda': CudaCompiler,
'nvcc': CudaCompiler,
'osx': ClangCompiler,
'intel': OneapiCompiler,
'icx': OneapiCompiler,
Expand Down
6 changes: 3 additions & 3 deletions devito/data/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from devito.logger import logger
from devito.parameters import configuration
from devito.tools import dtype_to_ctype, is_integer
from devito.tools import is_integer, dtype_alloc_ctype

__all__ = ['ALLOC_ALIGNED', 'ALLOC_NUMA_LOCAL', 'ALLOC_NUMA_ANY',
'ALLOC_KNL_MCDRAM', 'ALLOC_KNL_DRAM', 'ALLOC_GUARD',
Expand Down Expand Up @@ -92,8 +92,8 @@ def initialize(cls):
return

def alloc(self, shape, dtype, padding=0):
datasize = int(reduce(mul, shape))
ctype = dtype_to_ctype(dtype)
ctype, c_scale = dtype_alloc_ctype(dtype)
datasize = int(reduce(mul, shape) * c_scale)

# Add padding, if any
try:
Expand Down
5 changes: 2 additions & 3 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from devito.logger import warning
from devito.tools import (as_tuple, filter_ordered, flatten, frozendict,
infer_dtype, is_integer, split)
from devito.types import (Array, DimensionTuple, Evaluable, Indexed,
StencilDimension)
from devito.types import Array, DimensionTuple, Evaluable, StencilDimension

__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
'Weights']
Expand Down Expand Up @@ -68,7 +67,7 @@ def grid(self):

@cached_property
def dtype(self):
dtypes = {f.dtype for f in self.find(Indexed)} - {None}
dtypes = {f.dtype for f in self._functions} - {None}
return infer_dtype(dtypes)

@cached_property
Expand Down
18 changes: 13 additions & 5 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The Iteration/Expression Tree (IET) hierarchy."""

import abc
import ctypes
import inspect
from functools import cached_property
from collections import OrderedDict, namedtuple
Expand Down Expand Up @@ -1030,6 +1031,8 @@ class Dereference(ExprStmt, Node):
* `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
* `pointer` is an ArrayObject representing a pointer to a C struct, and
`pointee` is a field in `pointer`.
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
`pointee` is a Symbol representing the dereferenced value.
georgebisbas marked this conversation as resolved.
Show resolved Hide resolved
"""

is_Dereference = True
Expand All @@ -1048,13 +1051,18 @@ def functions(self):

@property
def expr_symbols(self):
ret = [self.pointer.indexed]
if self.pointer.is_PointerArray or self.pointer.is_TempFunction:
ret.append(self.pointee.indexed)
ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:]))
ret = []
if self.pointer.is_Symbol:
assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \
"Scalar dereference must have a pointer ctype"
ret.extend([self.pointer._C_symbol, self.pointee._C_symbol])
elif self.pointer.is_PointerArray or self.pointer.is_TempFunction:
ret.extend([self.pointer.indexed, self.pointee.indexed])
ret.extend(flatten(i.free_symbols
for i in self.pointee.symbolic_shape[1:]))
ret.extend(self.pointer.free_symbols)
else:
ret.append(self.pointee._C_symbol)
ret.extend([self.pointer.indexed, self.pointee._C_symbol])
return tuple(filter_ordered(ret))

@property
Expand Down
48 changes: 32 additions & 16 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
from sympy import IndexedBase
from sympy.core.function import Application

from devito.parameters import configuration, switchconfig
from devito.exceptions import VisitorException
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
Call, Lambda, BlankLine, Section, ListMajor)
from devito.ir.support.space import Backward
from devito.symbolics import (FieldFromComposite, FieldFromPointer,
ListInitializer, ccode, uxreplace)
from devito.symbolics.extended_dtypes import NoDeclStruct
from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered,
filter_sorted, flatten, is_external_ctype,
c_restrict_void_p, sorted_priority)
from devito.types.basic import AbstractFunction, Basic
from devito.types.basic import AbstractFunction, AbstractSymbol, Basic
from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer,
IndexedData, DeviceMap)

Expand Down Expand Up @@ -176,7 +178,7 @@ class CGen(Visitor):

def __init__(self, *args, compiler=None, **kwargs):
super().__init__(*args, **kwargs)
self._compiler = compiler
self._compiler = compiler or configuration['compiler']

# The following mappers may be customized by subclasses (that is,
# backend-specific CGen-erators)
Expand All @@ -188,6 +190,16 @@ def __init__(self, *args, compiler=None, **kwargs):
}
_restrict_keyword = 'restrict'

@property
def compiler(self):
return self._compiler

def visit(self, o, *args, **kwargs):
# Make sure the visitor always is within the generating compiler
# in case the configuration is accessed
with switchconfig(compiler=self.compiler.name):
return super().visit(o, *args, **kwargs)

def _gen_struct_decl(self, obj, masked=()):
"""
Convert ctypes.Struct -> cgen.Structure.
Expand All @@ -197,7 +209,7 @@ def _gen_struct_decl(self, obj, masked=()):
while issubclass(ctype, ctypes._Pointer):
ctype = ctype._type_

if not issubclass(ctype, ctypes.Structure):
if not issubclass(ctype, ctypes.Structure) or issubclass(ctype, NoDeclStruct):
enwask marked this conversation as resolved.
Show resolved Hide resolved
return None
except TypeError:
# E.g., `ctype` is of type `dtypes_lowering.CustomDtype`
Expand Down Expand Up @@ -376,10 +388,11 @@ def visit_tuple(self, o):
def visit_PointerCast(self, o):
f = o.function
i = f.indexed
cstr = i._C_typedata

if f.is_PointerArray:
# lvalue
lvalue = c.Value(i._C_typedata, '**%s' % f.name)
lvalue = c.Value(cstr, '**%s' % f.name)

# rvalue
if isinstance(o.obj, ArrayObject):
Expand All @@ -388,7 +401,7 @@ def visit_PointerCast(self, o):
v = f._C_name
else:
assert False
rvalue = '(%s**) %s' % (i._C_typedata, v)
rvalue = '(%s**) %s' % (cstr, v)

else:
# lvalue
Expand All @@ -399,10 +412,10 @@ def visit_PointerCast(self, o):
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape))
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
else:
rshape = '*'
lvalue = c.Value(i._C_typedata, '*%s' % v)
lvalue = c.Value(cstr, '*%s' % v)
if o.alignment and f._data_alignment:
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)

Expand All @@ -415,34 +428,34 @@ def visit_PointerCast(self, o):
else:
assert False

rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v)
rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v)
else:
if isinstance(o.obj, Pointer):
v = o.obj.name
else:
v = f._C_name

rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v)
rvalue = '(%s %s) %s' % (cstr, rshape, v)

return c.Initializer(lvalue, rvalue)

def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
cstr = i._C_typedata
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
a1.dim.name)
lvalue = c.Value(i._C_typedata,
'(*restrict %s)%s' % (a0.name, shape))
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
else:
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name)
rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name)
lvalue = c.Value(cstr, '*restrict %s' % a0.name)
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
else:
rvalue = '%s->%s' % (a1.name, a0._C_name)
rvalue = '*%s' % a1.name if a1.is_Symbol else '%s->%s' % (a1.name, a0._C_name)
enwask marked this conversation as resolved.
Show resolved Hide resolved
lvalue = self._gen_value(a0, 0)
return c.Initializer(lvalue, rvalue)

Expand Down Expand Up @@ -590,7 +603,7 @@ def visit_MultiTraversable(self, o):
return c.Collection(body)

def visit_UsingNamespace(self, o):
return c.Statement('using namespace %s' % ccode(o.namespace))
return c.Statement('using namespace %s' % str(o.namespace))

def visit_Lambda(self, o):
body = []
Expand Down Expand Up @@ -945,6 +958,7 @@ def default_retval(cls):
Drive the search. Accepted:
- `symbolics`: Collect all AbstractFunction objects, default
- `basics`: Collect all Basic objects
- `abstractsymbols`: Collect all AbstractSymbol objects
- `dimensions`: Collect all Dimensions
- `indexeds`: Collect all Indexed objects
- `indexedbases`: Collect all IndexedBase objects
Expand All @@ -965,6 +979,8 @@ def _defines_aliases(n):
rules = {
'symbolics': lambda n: n.functions,
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)],
'abstractsymbols': lambda n: [i for i in n.expr_symbols
if isinstance(i, AbstractSymbol)],
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)],
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed],
'indexedbases': lambda n: [i for i in n.expr_symbols
Expand Down
15 changes: 14 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate,
error_mapper, is_on_device)
from devito.passes.iet.langbase import LangBB
from devito.symbolics import estimate_cost, subs_op_args
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
flatten, filter_sorted, frozendict, is_integer,
split, timed_pass, timed_region, contains_val)
from devito.tools.dtypes_lowering import ctypes_vector_mapper
from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer,
disk_layer)

Expand Down Expand Up @@ -264,6 +266,9 @@ def _lower(cls, expressions, **kwargs):
# expression for which a partial or complete lowering is desired
kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs)

# Load language-specific types into the global dtype->ctype mapper
cls._load_dtype_mappings(**kwargs)

# [Eq] -> [LoweredEq]
expressions = cls._lower_exprs(expressions, **kwargs)

Expand All @@ -285,6 +290,11 @@ def _lower(cls, expressions, **kwargs):
def _rcompile_wrapper(cls, **kwargs0):
raise NotImplementedError

@classmethod
def _load_dtype_mappings(cls, **kwargs):
lang: type[LangBB] = cls._Target.DataManager.lang
ctypes_vector_mapper.update(lang.mapper.get('types', {}))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that's a bit tricky because this updates a global ctypes_vector_mapper which might lead to odd behavior building multiple operators with different languages.
Do you know where it's called and needs those types ? I.e can the mapper be "local" to the operator and passed there?


@classmethod
def _initialize_state(cls, **kwargs):
return {}
Expand Down Expand Up @@ -469,6 +479,8 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):

# Lower IET to a target-specific IET
graph = Graph(iet, **kwargs)

# Specialize
graph = cls._specialize_iet(graph, **kwargs)

# Instrument the IET for C-level profiling
Expand Down Expand Up @@ -1347,7 +1359,8 @@ def parse_kwargs(**kwargs):
raise InvalidOperator("Illegal `compiler=%s`" % str(compiler))
kwargs['compiler'] = compiler_registry[compiler](platform=kwargs['platform'],
language=kwargs['language'],
mpi=configuration['mpi'])
mpi=configuration['mpi'],
name=compiler)
elif any([platform, language]):
kwargs['compiler'] =\
configuration['compiler'].__new_with__(platform=kwargs['platform'],
Expand Down
1 change: 0 additions & 1 deletion devito/passes/clusters/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def _collect_nested(expr):
Recursion helper for `collect_nested`.
"""
# Return semantic (rebuilt expression, factorization candidates)

if expr.is_Number:
return expr, {'coeffs': expr}
elif expr.is_Function:
Expand Down
1 change: 1 addition & 0 deletions devito/passes/iet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .instrument import * # noqa
from .languages import * # noqa
from .errors import * # noqa
from .dtypes import * # noqa
Loading