1
1
import numpy as np
2
2
import ctypes
3
3
4
- from devito .ir import FindSymbols , Uxreplace
5
- from devito .ir .iet .nodes import Dereference
6
- from devito .tools .utils import as_list
7
- from devito .types .basic import Symbol
4
+ from devito .ir import FindSymbols
8
5
9
- __all__ = ['lower_dtypes ' ]
6
+ __all__ = ['include_complex ' ]
10
7
11
8
12
- def lower_dtypes (iet , lang , compiler , sregistry ):
9
+ def include_complex (iet , lang , compiler ):
13
10
"""
14
- Lower language-specific dtypes and add headers for complex arithmetic
15
- """
16
- # Include complex headers if needed (before we replace complex dtypes)
17
- metadata = _complex_includes (iet , lang , compiler )
18
-
19
- body_prefix = [] # Derefs to prepend to the body
20
- body_mapper = {}
21
- params_mapper = {}
22
-
23
- # Lower scalar float16s to pointers and dereference them
24
- if lang .get ('half_types' ) is not None :
25
- half , half_p = lang ['half_types' ] # dtype mappings for half float
26
-
27
- for s in FindSymbols ('scalars' ).visit (iet ):
28
- if s .dtype != np .float16 or s not in iet .parameters :
29
- continue
30
-
31
- ptr = s ._rebuild (dtype = half_p , is_const = True )
32
- val = Symbol (name = sregistry .make_name (prefix = 'hf' ), dtype = half ,
33
- is_const = s .is_const )
34
-
35
- params_mapper [s ], body_mapper [s ] = ptr , val
36
- body_prefix .append (Dereference (val , ptr )) # val = *ptr
37
-
38
- # Lower remaining language-specific dtypes
39
- for s in FindSymbols ('indexeds|basics|symbolics' ).visit (iet ):
40
- if s .dtype in lang ['types' ] and s not in params_mapper :
41
- body_mapper [s ] = params_mapper [s ] = s ._rebuild (dtype = lang ['types' ][s .dtype ])
42
-
43
- # Apply the dtype replacements
44
- body = body_prefix + as_list (Uxreplace (body_mapper ).visit (iet .body ))
45
- params = Uxreplace (params_mapper ).visit (iet .parameters )
46
-
47
- iet = iet ._rebuild (body = body , parameters = params )
48
- return iet , metadata
49
-
50
-
51
- def _complex_includes (iet , lang , compiler ):
52
- """
53
- Add headers for complex arithmetic
11
+ Include complex arithmetic headers for the given language, if needed.
54
12
"""
55
13
# Check if there is complex numbers that always take dtype precedence
56
14
types = {f .dtype for f in FindSymbols ().visit (iet )
57
15
if not issubclass (f .dtype , ctypes ._Pointer )}
58
16
59
17
if not any (np .issubdtype (d , np .complexfloating ) for d in types ):
60
- return {}
18
+ return iet , {}
61
19
62
20
metadata = {}
63
21
lib = (lang ['header-complex' ],)
@@ -75,4 +33,4 @@ def _complex_includes(iet, lang, compiler):
75
33
76
34
metadata ['includes' ] = lib
77
35
78
- return metadata
36
+ return iet , metadata
0 commit comments