Skip to content

Commit 2ecb172

Browse files
committed
symbolics: move printers rogether through registry
1 parent 62f2deb commit 2ecb172

File tree

11 files changed

+117
-144
lines changed

11 files changed

+117
-144
lines changed

devito/ir/iet/visitors.py

Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
from sympy import IndexedBase
1414
from sympy.core.function import Application
1515

16-
from devito.parameters import configuration, switchconfig
1716
from devito.exceptions import CompilationError
1817
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
1918
Call, Lambda, BlankLine, Section, ListMajor)
2019
from devito.ir.support.space import Backward
2120
from devito.symbolics import (FieldFromComposite, FieldFromPointer,
2221
ListInitializer, uxreplace)
23-
from devito.symbolics.printer import _DevitoPrinterBase
22+
from devito.symbolics.printer import ccode
2423
from devito.symbolics.extended_dtypes import NoDeclStruct
2524
from devito.tools import (GenericVisitor, as_tuple, filter_ordered,
2625
filter_sorted, flatten, is_external_ctype,
@@ -177,10 +176,8 @@ class CGen(Visitor):
177176
Return a representation of the Iteration/Expression tree as a :module:`cgen` tree.
178177
"""
179178

180-
def __init__(self, *args, compiler=None, printer=None, **kwargs):
179+
def __init__(self, *args, **kwargs):
181180
super().__init__(*args, **kwargs)
182-
self._compiler = compiler or configuration['compiler']
183-
self._printer = printer or _DevitoPrinterBase
184181

185182
# The following mappers may be customized by subclasses (that is,
186183
# backend-specific CGen-erators)
@@ -192,19 +189,6 @@ def __init__(self, *args, compiler=None, printer=None, **kwargs):
192189
}
193190
_restrict_keyword = 'restrict'
194191

195-
@property
196-
def compiler(self):
197-
return self._compiler
198-
199-
def ccode(self, expr, **settings):
200-
return self._printer(settings=settings).doprint(expr, None)
201-
202-
def visit(self, o, *args, **kwargs):
203-
# Make sure the visitor always is within the generating compiler
204-
# in case the configuration is accessed
205-
with switchconfig(compiler=self.compiler.name):
206-
return super().visit(o, *args, **kwargs)
207-
208192
def _gen_struct_decl(self, obj, masked=()):
209193
"""
210194
Convert ctypes.Struct -> cgen.Structure.
@@ -238,7 +222,7 @@ def _gen_struct_decl(self, obj, masked=()):
238222
try:
239223
entries.append(self._gen_value(i, 0, masked=('const',)))
240224
except AttributeError:
241-
cstr = self.ccode(ct)
225+
cstr = ccode(ct)
242226
if ct is c_restrict_void_p:
243227
cstr = '%srestrict' % cstr
244228
entries.append(c.Value(cstr, n))
@@ -260,10 +244,10 @@ def _gen_value(self, obj, mode=1, masked=()):
260244
if getattr(obj.function, k, False) and v not in masked]
261245

262246
if (obj._mem_stack or obj._mem_constant) and mode == 1:
263-
strtype = self.ccode(obj._C_typedata)
264-
strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape)
247+
strtype = ccode(obj._C_typedata)
248+
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
265249
else:
266-
strtype = self.ccode(obj._C_ctype)
250+
strtype = ccode(obj._C_ctype)
267251
strshape = ''
268252
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
269253
if not obj._mem_stack:
@@ -277,7 +261,7 @@ def _gen_value(self, obj, mode=1, masked=()):
277261
strobj = '%s%s' % (strname, strshape)
278262

279263
if obj.is_LocalObject and obj.cargs and mode == 1:
280-
arguments = [self.ccode(i) for i in obj.cargs]
264+
arguments = [ccode(i) for i in obj.cargs]
281265
strobj = MultilineCall(strobj, arguments, True)
282266

283267
value = c.Value(strtype, strobj)
@@ -291,9 +275,9 @@ def _gen_value(self, obj, mode=1, masked=()):
291275
if obj.is_Array and obj.initvalue is not None and mode == 1:
292276
init = ListInitializer(obj.initvalue)
293277
if not obj._mem_constant or init.is_numeric:
294-
value = c.Initializer(value, self.ccode(init))
278+
value = c.Initializer(value, ccode(init))
295279
elif obj.is_LocalObject and obj.initvalue is not None and mode == 1:
296-
value = c.Initializer(value, self.ccode(obj.initvalue))
280+
value = c.Initializer(value, ccode(obj.initvalue))
297281

298282
return value
299283

@@ -327,7 +311,7 @@ def _args_call(self, args):
327311
else:
328312
ret.append(i._C_name)
329313
except AttributeError:
330-
ret.append(self.ccode(i))
314+
ret.append(ccode(i))
331315
return ret
332316

333317
def _gen_signature(self, o, is_declaration=False):
@@ -393,7 +377,7 @@ def visit_tuple(self, o):
393377
def visit_PointerCast(self, o):
394378
f = o.function
395379
i = f.indexed
396-
cstr = self.ccode(i._C_typedata)
380+
cstr = ccode(i._C_typedata)
397381

398382
if f.is_PointerArray:
399383
# lvalue
@@ -415,7 +399,7 @@ def visit_PointerCast(self, o):
415399
else:
416400
v = f.name
417401
if o.flat is None:
418-
shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape)
402+
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
419403
rshape = '(*)%s' % shape
420404
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
421405
else:
@@ -448,9 +432,9 @@ def visit_Dereference(self, o):
448432
a0, a1 = o.functions
449433
if a1.is_PointerArray or a1.is_TempFunction:
450434
i = a1.indexed
451-
cstr = self.ccode(i._C_typedata)
435+
cstr = ccode(i._C_typedata)
452436
if o.flat is None:
453-
shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:])
437+
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
454438
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
455439
a1.dim.name)
456440
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
@@ -489,8 +473,8 @@ def visit_Definition(self, o):
489473
return self._gen_value(o.function)
490474

491475
def visit_Expression(self, o):
492-
lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler)
493-
rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler)
476+
lhs = ccode(o.expr.lhs, dtype=o.dtype)
477+
rhs = ccode(o.expr.rhs, dtype=o.dtype)
494478

495479
if o.init:
496480
code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs)
@@ -503,8 +487,8 @@ def visit_Expression(self, o):
503487
return code
504488

505489
def visit_AugmentedExpression(self, o):
506-
c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler)
507-
c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler)
490+
c_lhs = ccode(o.expr.lhs, dtype=o.dtype)
491+
c_rhs = ccode(o.expr.rhs, dtype=o.dtype)
508492
code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs))
509493
if o.pragmas:
510494
code = c.Module(self._visit(o.pragmas) + (code,))
@@ -523,7 +507,7 @@ def visit_Call(self, o, nested_call=False):
523507
o.templates)
524508
if retobj.is_Indexed or \
525509
isinstance(retobj, (FieldFromComposite, FieldFromPointer)):
526-
return c.Assign(self.ccode(retobj), call)
510+
return c.Assign(ccode(retobj), call)
527511
else:
528512
return c.Initializer(c.Value(rettype, retobj._C_name), call)
529513

@@ -537,9 +521,9 @@ def visit_Conditional(self, o):
537521
then_body = c.Block(self._visit(then_body))
538522
if else_body:
539523
else_body = c.Block(self._visit(else_body))
540-
return c.If(self.ccode(o.condition), then_body, else_body)
524+
return c.If(ccode(o.condition), then_body, else_body)
541525
else:
542-
return c.If(self.ccode(o.condition), then_body)
526+
return c.If(ccode(o.condition), then_body)
543527

544528
def visit_Iteration(self, o):
545529
body = flatten(self._visit(i) for i in self._blankline_logic(o.children))
@@ -549,23 +533,23 @@ def visit_Iteration(self, o):
549533

550534
# For backward direction flip loop bounds
551535
if o.direction == Backward:
552-
loop_init = 'int %s = %s' % (o.index, self.ccode(_max))
553-
loop_cond = '%s >= %s' % (o.index, self.ccode(_min))
536+
loop_init = 'int %s = %s' % (o.index, ccode(_max))
537+
loop_cond = '%s >= %s' % (o.index, ccode(_min))
554538
loop_inc = '%s -= %s' % (o.index, o.limits[2])
555539
else:
556-
loop_init = 'int %s = %s' % (o.index, self.ccode(_min))
557-
loop_cond = '%s <= %s' % (o.index, self.ccode(_max))
540+
loop_init = 'int %s = %s' % (o.index, ccode(_min))
541+
loop_cond = '%s <= %s' % (o.index, ccode(_max))
558542
loop_inc = '%s += %s' % (o.index, o.limits[2])
559543

560544
# Append unbounded indices, if any
561545
if o.uindices:
562-
uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices]
546+
uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices]
563547
loop_init = c.Line(', '.join([loop_init] + uinit))
564548

565549
ustep = []
566550
for i in o.uindices:
567551
op = '=' if i.is_Modulo else '+='
568-
ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr)))
552+
ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr)))
569553
loop_inc = c.Line(', '.join([loop_inc] + ustep))
570554

571555
# Create For header+body
@@ -582,7 +566,7 @@ def visit_Pragma(self, o):
582566
return c.Pragma(o._generate)
583567

584568
def visit_While(self, o):
585-
condition = self.ccode(o.condition)
569+
condition = ccode(o.condition)
586570
if o.body:
587571
body = flatten(self._visit(i) for i in o.children)
588572
return c.While(condition, c.Block(body))

devito/operator/operator.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from devito.operator.profiling import create_profile
2424
from devito.operator.registry import operator_selector
2525
from devito.mpi import MPI
26-
from devito.parameters import configuration
26+
from devito.parameters import configuration, switchconfig
2727
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
2828
generate_macros, minimize_symbols, unevaluate,
2929
error_mapper, is_on_device)
@@ -758,19 +758,14 @@ def _soname(self):
758758
"""A unique name for the shared object resulting from JIT compilation."""
759759
return Signer._digest(self, configuration)
760760

761-
@property
762-
def printer(self):
763-
return self._Target.Printer
764-
765761
@cached_property
766762
def ccode(self):
767-
try:
768-
return self._ccode_handler(compiler=self._compiler,
769-
printer=self.printer).visit(self)
770-
except (AttributeError, TypeError):
771-
from devito.ir.iet.visitors import CGen
772-
return CGen(compiler=self._compiler,
773-
printer=self.printer).visit(self)
763+
with switchconfig(compiler=self._compiler, language=self._language):
764+
try:
765+
return self._ccode_handler().visit(self)
766+
except (AttributeError, TypeError):
767+
from devito.ir.iet.visitors import CGen
768+
return CGen().visit(self)
774769

775770
def _jit_compile(self):
776771
"""

devito/passes/iet/definitions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,10 @@ class DataManager:
7575
The language used to express data allocations, deletions, and host-device transfers.
7676
"""
7777

78-
def __init__(self, rcompile=None, sregistry=None, platform=None,
79-
compiler=None, **kwargs):
78+
def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs):
8079
self.rcompile = rcompile
8180
self.sregistry = sregistry
8281
self.platform = platform
83-
self.compiler = compiler
8482

8583
def _alloc_object_on_low_lat_mem(self, site, obj, storage):
8684
"""

devito/passes/iet/languages/C.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import numpy as np
2-
31
from devito.ir import Call
42
from devito.passes.iet.definitions import DataManager
53
from devito.passes.iet.orchestration import Orchestrator
64
from devito.passes.iet.langbase import LangBB
7-
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
8-
from devito.symbolics.printer import _DevitoPrinterBase
95

106
__all__ = ['CBB', 'CDataManager', 'COrchestrator']
117

@@ -35,18 +31,3 @@ class CDataManager(DataManager):
3531

3632
class COrchestrator(Orchestrator):
3733
lang = CBB
38-
39-
40-
class CDevitoPrinter(_DevitoPrinterBase):
41-
42-
# These cannot go through _print_xxx because they are classes not
43-
# instances
44-
type_mappings = {**_DevitoPrinterBase.type_mappings,
45-
c_complex: 'float _Complex',
46-
c_double_complex: 'double _Complex'}
47-
48-
_func_prefix = {**_DevitoPrinterBase._func_prefix, np.complex64: 'c',
49-
np.complex128: 'c'}
50-
51-
def _print_ImaginaryUnit(self, expr):
52-
return '_Complex_I'

devito/passes/iet/languages/CXX.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
from sympy.printing.cxx import CXX11CodePrinter
2-
31
from devito.ir import Call, UsingNamespace
42
from devito.passes.iet.langbase import LangBB
5-
from devito.symbolics.printer import _DevitoPrinterBase
6-
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
73

84
__all__ = ['CXXBB']
95

@@ -64,21 +60,3 @@ class CXXBB(LangBB):
6460
'complex-namespace': [UsingNamespace('std::complex_literals')],
6561
'def-complex': std_arith,
6662
}
67-
68-
69-
class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter):
70-
71-
_default_settings = {**_DevitoPrinterBase._default_settings,
72-
**CXX11CodePrinter._default_settings}
73-
_ns = "std::"
74-
_func_litterals = {}
75-
76-
# These cannot go through _print_xxx because they are classes not
77-
# instances
78-
type_mappings = {**_DevitoPrinterBase.type_mappings,
79-
c_complex: 'std::complex<float>',
80-
c_double_complex: 'std::complex<double>',
81-
**CXX11CodePrinter.type_mappings}
82-
83-
def _print_ImaginaryUnit(self, expr):
84-
return f'1i{self.prec_literal(expr).lower()}'

devito/passes/iet/languages/openacc.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from devito.passes.iet.orchestration import Orchestrator
1010
from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB,
1111
PragmaIteration, PragmaTransfer)
12-
from devito.passes.iet.languages.CXX import CXXBB, CXXDevitoPrinter
12+
from devito.passes.iet.languages.CXX import CXXBB
1313
from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration
1414
from devito.symbolics import FieldFromPointer, Macro, cast_mapper
1515
from devito.tools import filter_ordered, UnboundTuple
@@ -263,8 +263,3 @@ def place_devptr(self, iet, **kwargs):
263263

264264
class AccOrchestrator(Orchestrator):
265265
lang = AccBB
266-
267-
268-
class AccDevitoPrinter(CXXDevitoPrinter):
269-
270-
pass

devito/passes/iet/languages/targets.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from devito.passes.iet.languages.C import CDataManager, COrchestrator, CDevitoPrinter
1+
from devito.passes.iet.languages.C import CDataManager, COrchestrator
22
from devito.passes.iet.languages.openmp import (SimdOmpizer, Ompizer, DeviceOmpizer,
33
OmpDataManager, DeviceOmpDataManager,
44
OmpOrchestrator, DeviceOmpOrchestrator)
55
from devito.passes.iet.languages.openacc import (DeviceAccizer, DeviceAccDataManager,
6-
AccOrchestrator, AccDevitoPrinter)
6+
AccOrchestrator)
77
from devito.passes.iet.instrument import instrument
88

99
__all__ = ['CTarget', 'OmpTarget', 'DeviceOmpTarget', 'DeviceAccTarget']
@@ -13,7 +13,6 @@ class Target:
1313
Parizer = None
1414
DataManager = None
1515
Orchestrator = None
16-
Printer = None
1716

1817
@classmethod
1918
def lang(cls):
@@ -28,25 +27,21 @@ class CTarget(Target):
2827
Parizer = SimdOmpizer
2928
DataManager = CDataManager
3029
Orchestrator = COrchestrator
31-
Printer = CDevitoPrinter
3230

3331

3432
class OmpTarget(Target):
3533
Parizer = Ompizer
3634
DataManager = OmpDataManager
3735
Orchestrator = OmpOrchestrator
38-
Printer = CDevitoPrinter
3936

4037

4138
class DeviceOmpTarget(Target):
4239
Parizer = DeviceOmpizer
4340
DataManager = DeviceOmpDataManager
4441
Orchestrator = DeviceOmpOrchestrator
45-
Printer = CDevitoPrinter
4642

4743

4844
class DeviceAccTarget(Target):
4945
Parizer = DeviceAccizer
5046
DataManager = DeviceAccDataManager
5147
Orchestrator = AccOrchestrator
52-
Printer = AccDevitoPrinter

0 commit comments

Comments
 (0)