Skip to content

Commit 493c1e8

Browse files
committed
Remove dtypes lowering from IET layer
1 parent d3169d0 commit 493c1e8

File tree

2 files changed

+11
-53
lines changed

2 files changed

+11
-53
lines changed

devito/passes/iet/definitions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
1313
FindSymbols, MapExprStmts, Transformer, make_callable)
1414
from devito.passes import is_gpu_create
15-
from devito.passes.iet.dtypes import lower_dtypes
15+
from devito.passes.iet.dtypes import include_complex
1616
from devito.passes.iet.engine import iet_pass
1717
from devito.passes.iet.langbase import LangBB
1818
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
@@ -413,8 +413,8 @@ def place_casts(self, iet, **kwargs):
413413
return iet, {}
414414

415415
@iet_pass
416-
def make_langtypes(self, iet):
417-
iet, metadata = lower_dtypes(iet, self.lang, self.compiler, self.sregistry)
416+
def include_complex(self, iet):
417+
iet, metadata = include_complex(iet, self.lang, self.compiler)
418418
return iet, metadata
419419

420420
def process(self, graph):
@@ -423,7 +423,7 @@ def process(self, graph):
423423
"""
424424
self.place_definitions(graph, globs=set())
425425
self.place_casts(graph)
426-
self.make_langtypes(graph)
426+
self.include_complex(graph)
427427

428428

429429
class DeviceAwareDataManager(DataManager):
@@ -573,7 +573,7 @@ def process(self, graph):
573573
self.place_devptr(graph)
574574
self.place_bundling(graph, writes_input=graph.writes_input)
575575
self.place_casts(graph)
576-
self.make_langtypes(graph)
576+
self.include_complex(graph)
577577

578578

579579
def make_zero_init(obj):

devito/passes/iet/dtypes.py

Lines changed: 6 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,21 @@
11
import numpy as np
22
import ctypes
33

4-
from devito.ir import FindSymbols, Uxreplace
5-
from devito.ir.iet.nodes import Dereference
6-
from devito.tools.utils import as_list
7-
from devito.types.basic import Symbol
4+
from devito.ir import FindSymbols
85

9-
__all__ = ['lower_dtypes']
6+
__all__ = ['include_complex']
107

118

12-
def lower_dtypes(iet, lang, compiler, sregistry):
9+
def include_complex(iet, lang, compiler):
1310
"""
14-
Lower language-specific dtypes and add headers for complex arithmetic
15-
"""
16-
# Include complex headers if needed (before we replace complex dtypes)
17-
metadata = _complex_includes(iet, lang, compiler)
18-
19-
body_prefix = [] # Derefs to prepend to the body
20-
body_mapper = {}
21-
params_mapper = {}
22-
23-
# Lower scalar float16s to pointers and dereference them
24-
if lang.get('half_types') is not None:
25-
half, half_p = lang['half_types'] # dtype mappings for half float
26-
27-
for s in FindSymbols('scalars').visit(iet):
28-
if s.dtype != np.float16 or s not in iet.parameters:
29-
continue
30-
31-
ptr = s._rebuild(dtype=half_p, is_const=True)
32-
val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=half,
33-
is_const=s.is_const)
34-
35-
params_mapper[s], body_mapper[s] = ptr, val
36-
body_prefix.append(Dereference(val, ptr)) # val = *ptr
37-
38-
# Lower remaining language-specific dtypes
39-
for s in FindSymbols('indexeds|basics|symbolics').visit(iet):
40-
if s.dtype in lang['types'] and s not in params_mapper:
41-
body_mapper[s] = params_mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])
42-
43-
# Apply the dtype replacements
44-
body = body_prefix + as_list(Uxreplace(body_mapper).visit(iet.body))
45-
params = Uxreplace(params_mapper).visit(iet.parameters)
46-
47-
iet = iet._rebuild(body=body, parameters=params)
48-
return iet, metadata
49-
50-
51-
def _complex_includes(iet, lang, compiler):
52-
"""
53-
Add headers for complex arithmetic
11+
Include complex arithmetic headers for the given language, if needed.
5412
"""
5513
# Check if there is complex numbers that always take dtype precedence
5614
types = {f.dtype for f in FindSymbols().visit(iet)
5715
if not issubclass(f.dtype, ctypes._Pointer)}
5816

5917
if not any(np.issubdtype(d, np.complexfloating) for d in types):
60-
return {}
18+
return iet, {}
6119

6220
metadata = {}
6321
lib = (lang['header-complex'],)
@@ -75,4 +33,4 @@ def _complex_includes(iet, lang, compiler):
7533

7634
metadata['includes'] = lib
7735

78-
return metadata
36+
return iet, metadata

0 commit comments

Comments
 (0)