Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Fix complex arguments and implement float16 lowering #2403

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2b28090
api: add support for complex dtype
mloubout Aug 2, 2023
aa353b4
api: fix printer for complex dtype
mloubout May 22, 2024
92dfd9a
compiler: fix alias dtype with complex numbers
mloubout May 22, 2024
4364524
api: move complex ctype to dtype lowering
mloubout May 22, 2024
470f4f5
compiler: generate std:complex for cpp compilers
mloubout May 28, 2024
7ffff0a
compiler: add std::complex arithmetic defs for unsupported types
mloubout May 30, 2024
d1dd24e
compiler: fix alias dtype with complex numbers
mloubout May 30, 2024
4f43f26
compiler: fix internal language specific types and cast
mloubout May 31, 2024
6b4f12d
compiler: rework dtype lowering
mloubout Jun 20, 2024
9abeea8
compiler: switch to c++14 for complex_literals
mloubout Jun 27, 2024
94d5571
compiler: subdtype numpy for dtype lowering
mloubout Jul 8, 2024
ecadec3
compiler: use structs to pass complex arguments
enwask Jul 9, 2024
27ff82a
compiler: add Dereference scalar case
enwask Jul 11, 2024
066279d
compiler: implement float16 support
enwask Jul 11, 2024
7af930b
symbolics: fix printer for half precision
enwask Jul 11, 2024
5b7efcc
misc: fix formatting
enwask Jul 11, 2024
a09350f
compiler: refactor float16 and lower_dtypes
enwask Jul 11, 2024
d4c9454
compiler: add dtype_alloc_ctype helper for allocation size
enwask Jul 11, 2024
d3169d0
misc: more float16 refactoring/formatting fixes
enwask Jul 15, 2024
493c1e8
Remove dtypes lowering from IET layer
enwask Jul 16, 2024
516b4ad
compiler: reimplement float16/complex lowering
enwask Jul 26, 2024
079facb
misc: cleanup, docs and typing for half support
enwask Jul 29, 2024
cc08bcd
compiler: FindSymbols 'scalars' -> 'abstractsymbols'
enwask Jul 29, 2024
f9c04ac
test: include scalar parameters in complex tests
enwask Jul 30, 2024
e519a9a
test: add test_dtypes with initial tests for float16 + complex
enwask Jul 30, 2024
6725c9a
misc: more lower_dtypes cleanup + type hints
enwask Jul 30, 2024
b57949a
api: use grid dtype for extent and origin, add test_grid
enwask Jul 31, 2024
ec98a02
test: clean up and add more half/complex tests
enwask Jul 31, 2024
6a2f8e6
test: fix test_grid_objs, add test_grid_dtypes
enwask Jul 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions devito/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def reinit_compiler(val):
"""
Re-initialize the Compiler.
"""
configuration['compiler'].__init__(suffix=configuration['compiler'].suffix,
configuration['compiler'].__init__(name=configuration['compiler'].name,
suffix=configuration['compiler'].suffix,
mpi=configuration['mpi'])
return val

Expand All @@ -65,7 +66,7 @@ def reinit_compiler(val):
configuration.add('platform', 'cpu64', list(platform_registry),
callback=lambda i: platform_registry[i]())
configuration.add('compiler', 'custom', list(compiler_registry),
callback=lambda i: compiler_registry[i]())
callback=lambda i: compiler_registry[i](name=i))

# Setup language for shared-memory parallelism
preprocessor = lambda i: {0: 'C', 1: 'openmp'}.get(i, i) # Handles DEVITO_OPENMP deprec
Expand Down
9 changes: 6 additions & 3 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def __init__(self):
_cpp = False

def __init__(self, **kwargs):
self._name = kwargs.pop('name', self.__class__.__name__)

super().__init__(**kwargs)

self.__lookup_cmds__()
Expand Down Expand Up @@ -223,13 +225,13 @@ def __new_with__(self, **kwargs):
Create a new Compiler from an existing one, inherenting from it
the flags that are not specified via ``kwargs``.
"""
return self.__class__(suffix=kwargs.pop('suffix', self.suffix),
return self.__class__(name=self.name, suffix=kwargs.pop('suffix', self.suffix),
mpi=kwargs.pop('mpi', configuration['mpi']),
**kwargs)

@property
def name(self):
return self.__class__.__name__
return self._name

@property
def version(self):
Expand Down Expand Up @@ -593,7 +595,7 @@ def __init_finalize__(self, **kwargs):
self.cflags.remove('-O3')
self.cflags.remove('-Wall')

self.cflags.append('-std=c++11')
self.cflags.append('-std=c++14')

language = kwargs.pop('language', configuration['language'])
platform = kwargs.pop('platform', configuration['platform'])
Expand Down Expand Up @@ -978,6 +980,7 @@ def __new_with__(self, **kwargs):
'nvc++': NvidiaCompiler,
'nvidia': NvidiaCompiler,
'cuda': CudaCompiler,
'nvcc': CudaCompiler,
'osx': ClangCompiler,
'intel': OneapiCompiler,
'icx': OneapiCompiler,
Expand Down
8 changes: 6 additions & 2 deletions devito/data/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def initialize(cls):
return

def alloc(self, shape, dtype, padding=0):
datasize = int(reduce(mul, shape))
ctype = dtype_to_ctype(dtype)
# For complex number, allocate double the size of its real/imaginary part
alloc_dtype = dtype(0).real.__class__
c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1

datasize = int(reduce(mul, shape) * c_scale)
ctype = dtype_to_ctype(alloc_dtype)

# Add padding, if any
try:
Expand Down
5 changes: 2 additions & 3 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from devito.logger import warning
from devito.tools import (as_tuple, filter_ordered, flatten, frozendict,
infer_dtype, is_integer, split)
from devito.types import (Array, DimensionTuple, Evaluable, Indexed,
StencilDimension)
from devito.types import Array, DimensionTuple, Evaluable, StencilDimension

__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
'Weights']
Expand Down Expand Up @@ -68,7 +67,7 @@ def grid(self):

@cached_property
def dtype(self):
dtypes = {f.dtype for f in self.find(Indexed)} - {None}
dtypes = {f.dtype for f in self._functions} - {None}
return infer_dtype(dtypes)

@cached_property
Expand Down
23 changes: 17 additions & 6 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The Iteration/Expression Tree (IET) hierarchy."""

import abc
import ctypes
import inspect
from functools import cached_property
from collections import OrderedDict, namedtuple
Expand Down Expand Up @@ -1030,6 +1031,8 @@ class Dereference(ExprStmt, Node):
* `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
* `pointer` is an ArrayObject representing a pointer to a C struct, and
`pointee` is a field in `pointer`.
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
`pointee` is a Symbol representing the dereferenced value.
georgebisbas marked this conversation as resolved.
Show resolved Hide resolved
"""

is_Dereference = True
Expand All @@ -1048,13 +1051,21 @@ def functions(self):

@property
def expr_symbols(self):
ret = [self.pointer.indexed]
if self.pointer.is_PointerArray or self.pointer.is_TempFunction:
ret.append(self.pointee.indexed)
ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:]))
ret.extend(self.pointer.free_symbols)
else:
ret = []
if self.pointer.is_Symbol:
assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \
"Scalar dereference must have a pointer ctype"
ret.append(self.pointer._C_symbol)
enwask marked this conversation as resolved.
Show resolved Hide resolved
ret.append(self.pointee._C_symbol)
else:
ret.append(self.pointer.indexed)
if self.pointer.is_PointerArray or self.pointer.is_TempFunction:
enwask marked this conversation as resolved.
Show resolved Hide resolved
ret.append(self.pointee.indexed)
ret.extend(flatten(i.free_symbols
for i in self.pointee.symbolic_shape[1:]))
ret.extend(self.pointer.free_symbols)
else:
ret.append(self.pointee._C_symbol)
return tuple(filter_ordered(ret))

@property
Expand Down
48 changes: 33 additions & 15 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
from sympy import IndexedBase
from sympy.core.function import Application

from devito.parameters import configuration, switchconfig
from devito.exceptions import VisitorException
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
Call, Lambda, BlankLine, Section, ListMajor)
from devito.ir.support.space import Backward
from devito.symbolics import (FieldFromComposite, FieldFromPointer,
ListInitializer, ccode, uxreplace)
from devito.symbolics.extended_dtypes import NoDeclStruct
from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered,
filter_sorted, flatten, is_external_ctype,
c_restrict_void_p, sorted_priority)
from devito.types.basic import AbstractFunction, Basic
from devito.types.basic import AbstractFunction, AbstractSymbol, Basic
from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer,
IndexedData, DeviceMap)

Expand Down Expand Up @@ -176,7 +178,7 @@ class CGen(Visitor):

def __init__(self, *args, compiler=None, **kwargs):
super().__init__(*args, **kwargs)
self._compiler = compiler
self._compiler = compiler or configuration['compiler']

# The following mappers may be customized by subclasses (that is,
# backend-specific CGen-erators)
Expand All @@ -188,6 +190,16 @@ def __init__(self, *args, compiler=None, **kwargs):
}
_restrict_keyword = 'restrict'

@property
def compiler(self):
return self._compiler

def visit(self, o, *args, **kwargs):
# Make sure the visitor always is within the generating compiler
# in case the configuration is accessed
with switchconfig(compiler=self.compiler.name):
return super().visit(o, *args, **kwargs)

def _gen_struct_decl(self, obj, masked=()):
"""
Convert ctypes.Struct -> cgen.Structure.
Expand All @@ -197,7 +209,7 @@ def _gen_struct_decl(self, obj, masked=()):
while issubclass(ctype, ctypes._Pointer):
ctype = ctype._type_

if not issubclass(ctype, ctypes.Structure):
if not issubclass(ctype, ctypes.Structure) or issubclass(ctype, NoDeclStruct):
enwask marked this conversation as resolved.
Show resolved Hide resolved
return None
except TypeError:
# E.g., `ctype` is of type `dtypes_lowering.CustomDtype`
Expand Down Expand Up @@ -376,10 +388,11 @@ def visit_tuple(self, o):
def visit_PointerCast(self, o):
f = o.function
i = f.indexed
cstr = i._C_typedata

if f.is_PointerArray:
# lvalue
lvalue = c.Value(i._C_typedata, '**%s' % f.name)
lvalue = c.Value(cstr, '**%s' % f.name)

# rvalue
if isinstance(o.obj, ArrayObject):
Expand All @@ -388,7 +401,7 @@ def visit_PointerCast(self, o):
v = f._C_name
else:
assert False
rvalue = '(%s**) %s' % (i._C_typedata, v)
rvalue = '(%s**) %s' % (cstr, v)

else:
# lvalue
Expand All @@ -399,10 +412,10 @@ def visit_PointerCast(self, o):
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape))
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
else:
rshape = '*'
lvalue = c.Value(i._C_typedata, '*%s' % v)
lvalue = c.Value(cstr, '*%s' % v)
if o.alignment and f._data_alignment:
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)

Expand All @@ -415,32 +428,35 @@ def visit_PointerCast(self, o):
else:
assert False

rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v)
rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v)
else:
if isinstance(o.obj, Pointer):
v = o.obj.name
else:
v = f._C_name

rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v)
rvalue = '(%s %s) %s' % (cstr, rshape, v)

return c.Initializer(lvalue, rvalue)

def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
cstr = i._C_typedata
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
a1.dim.name)
lvalue = c.Value(i._C_typedata,
'(*restrict %s)%s' % (a0.name, shape))
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
else:
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name)
rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name)
lvalue = c.Value(cstr, '*restrict %s' % a0.name)
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
elif a1.is_Symbol:
enwask marked this conversation as resolved.
Show resolved Hide resolved
rvalue = '*%s' % a1.name
lvalue = self._gen_value(a0, 0)
else:
rvalue = '%s->%s' % (a1.name, a0._C_name)
lvalue = self._gen_value(a0, 0)
Expand Down Expand Up @@ -590,7 +606,7 @@ def visit_MultiTraversable(self, o):
return c.Collection(body)

def visit_UsingNamespace(self, o):
return c.Statement('using namespace %s' % ccode(o.namespace))
return c.Statement('using namespace %s' % str(o.namespace))

def visit_Lambda(self, o):
body = []
Expand Down Expand Up @@ -945,6 +961,7 @@ def default_retval(cls):
Drive the search. Accepted:
- `symbolics`: Collect all AbstractFunction objects, default
- `basics`: Collect all Basic objects
- `scalars`: Collect all AbstractSymbol objects
enwask marked this conversation as resolved.
Show resolved Hide resolved
- `dimensions`: Collect all Dimensions
- `indexeds`: Collect all Indexed objects
- `indexedbases`: Collect all IndexedBase objects
Expand All @@ -965,6 +982,7 @@ def _defines_aliases(n):
rules = {
'symbolics': lambda n: n.functions,
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)],
'scalars': lambda n: [i for i in n.expr_symbols if isinstance(i, AbstractSymbol)],
enwask marked this conversation as resolved.
Show resolved Hide resolved
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)],
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed],
'indexedbases': lambda n: [i for i in n.expr_symbols
Expand Down
5 changes: 4 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):

# Lower IET to a target-specific IET
graph = Graph(iet, **kwargs)

# Specialize
graph = cls._specialize_iet(graph, **kwargs)

# Instrument the IET for C-level profiling
Expand Down Expand Up @@ -1347,7 +1349,8 @@ def parse_kwargs(**kwargs):
raise InvalidOperator("Illegal `compiler=%s`" % str(compiler))
kwargs['compiler'] = compiler_registry[compiler](platform=kwargs['platform'],
language=kwargs['language'],
mpi=configuration['mpi'])
mpi=configuration['mpi'],
name=compiler)
elif any([platform, language]):
kwargs['compiler'] =\
configuration['compiler'].__new_with__(platform=kwargs['platform'],
Expand Down
1 change: 0 additions & 1 deletion devito/passes/clusters/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def _collect_nested(expr):
Recursion helper for `collect_nested`.
"""
# Return semantic (rebuilt expression, factorization candidates)

if expr.is_Number:
return expr, {'coeffs': expr}
elif expr.is_Function:
Expand Down
1 change: 1 addition & 0 deletions devito/passes/iet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .instrument import * # noqa
from .languages import * # noqa
from .errors import * # noqa
from .dtypes import * # noqa
13 changes: 12 additions & 1 deletion devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
FindSymbols, MapExprStmts, Transformer, make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.dtypes import lower_complex, lower_scalar_half
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
Expand Down Expand Up @@ -73,10 +74,12 @@ class DataManager:
The language used to express data allocations, deletions, and host-device transfers.
"""

def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs):
def __init__(self, rcompile=None, sregistry=None, platform=None,
compiler=None, **kwargs):
self.rcompile = rcompile
self.sregistry = sregistry
self.platform = platform
self.compiler = compiler

def _alloc_object_on_low_lat_mem(self, site, obj, storage):
"""
Expand Down Expand Up @@ -409,12 +412,19 @@ def place_casts(self, iet, **kwargs):

return iet, {}

@iet_pass
def make_langtypes(self, iet):
iet, _ = lower_scalar_half(iet, self.lang, self.sregistry)
enwask marked this conversation as resolved.
Show resolved Hide resolved
iet, metadata = lower_complex(iet, self.lang, self.compiler)
return iet, metadata

def process(self, graph):
"""
Apply the `place_definitions` and `place_casts` passes.
"""
self.place_definitions(graph, globs=set())
self.place_casts(graph)
self.make_langtypes(graph)


class DeviceAwareDataManager(DataManager):
Expand Down Expand Up @@ -564,6 +574,7 @@ def process(self, graph):
self.place_devptr(graph)
self.place_bundling(graph, writes_input=graph.writes_input)
self.place_casts(graph)
self.make_langtypes(graph)


def make_zero_init(obj):
Expand Down
Loading