13
13
from sympy import IndexedBase
14
14
from sympy .core .function import Application
15
15
16
- from devito .parameters import configuration , switchconfig
17
16
from devito .exceptions import CompilationError
18
17
from devito .ir .iet .nodes import (Node , Iteration , Expression , ExpressionBundle ,
19
18
Call , Lambda , BlankLine , Section , ListMajor )
20
19
from devito .ir .support .space import Backward
21
20
from devito .symbolics import (FieldFromComposite , FieldFromPointer ,
22
21
ListInitializer , uxreplace )
23
- from devito .symbolics .printer import _DevitoPrinterBase
22
+ from devito .symbolics .printer import ccode
24
23
from devito .symbolics .extended_dtypes import NoDeclStruct
25
24
from devito .tools import (GenericVisitor , as_tuple , filter_ordered ,
26
25
filter_sorted , flatten , is_external_ctype ,
@@ -177,10 +176,8 @@ class CGen(Visitor):
177
176
Return a representation of the Iteration/Expression tree as a :module:`cgen` tree.
178
177
"""
179
178
180
- def __init__ (self , * args , compiler = None , printer = None , ** kwargs ):
179
+ def __init__ (self , * args , ** kwargs ):
181
180
super ().__init__ (* args , ** kwargs )
182
- self ._compiler = compiler or configuration ['compiler' ]
183
- self ._printer = printer or _DevitoPrinterBase
184
181
185
182
# The following mappers may be customized by subclasses (that is,
186
183
# backend-specific CGen-erators)
@@ -192,19 +189,6 @@ def __init__(self, *args, compiler=None, printer=None, **kwargs):
192
189
}
193
190
_restrict_keyword = 'restrict'
194
191
195
- @property
196
- def compiler (self ):
197
- return self ._compiler
198
-
199
- def ccode (self , expr , ** settings ):
200
- return self ._printer (settings = settings ).doprint (expr , None )
201
-
202
- def visit (self , o , * args , ** kwargs ):
203
- # Make sure the visitor always is within the generating compiler
204
- # in case the configuration is accessed
205
- with switchconfig (compiler = self .compiler .name ):
206
- return super ().visit (o , * args , ** kwargs )
207
-
208
192
def _gen_struct_decl (self , obj , masked = ()):
209
193
"""
210
194
Convert ctypes.Struct -> cgen.Structure.
@@ -238,7 +222,7 @@ def _gen_struct_decl(self, obj, masked=()):
238
222
try :
239
223
entries .append (self ._gen_value (i , 0 , masked = ('const' ,)))
240
224
except AttributeError :
241
- cstr = self . ccode (ct )
225
+ cstr = ccode (ct )
242
226
if ct is c_restrict_void_p :
243
227
cstr = '%srestrict' % cstr
244
228
entries .append (c .Value (cstr , n ))
@@ -260,10 +244,10 @@ def _gen_value(self, obj, mode=1, masked=()):
260
244
if getattr (obj .function , k , False ) and v not in masked ]
261
245
262
246
if (obj ._mem_stack or obj ._mem_constant ) and mode == 1 :
263
- strtype = self . ccode (obj ._C_typedata )
264
- strshape = '' .join ('[%s]' % self . ccode (i ) for i in obj .symbolic_shape )
247
+ strtype = ccode (obj ._C_typedata )
248
+ strshape = '' .join ('[%s]' % ccode (i ) for i in obj .symbolic_shape )
265
249
else :
266
- strtype = self . ccode (obj ._C_ctype )
250
+ strtype = ccode (obj ._C_ctype )
267
251
strshape = ''
268
252
if isinstance (obj , (AbstractFunction , IndexedData )) and mode >= 1 :
269
253
if not obj ._mem_stack :
@@ -277,7 +261,7 @@ def _gen_value(self, obj, mode=1, masked=()):
277
261
strobj = '%s%s' % (strname , strshape )
278
262
279
263
if obj .is_LocalObject and obj .cargs and mode == 1 :
280
- arguments = [self . ccode (i ) for i in obj .cargs ]
264
+ arguments = [ccode (i ) for i in obj .cargs ]
281
265
strobj = MultilineCall (strobj , arguments , True )
282
266
283
267
value = c .Value (strtype , strobj )
@@ -291,9 +275,9 @@ def _gen_value(self, obj, mode=1, masked=()):
291
275
if obj .is_Array and obj .initvalue is not None and mode == 1 :
292
276
init = ListInitializer (obj .initvalue )
293
277
if not obj ._mem_constant or init .is_numeric :
294
- value = c .Initializer (value , self . ccode (init ))
278
+ value = c .Initializer (value , ccode (init ))
295
279
elif obj .is_LocalObject and obj .initvalue is not None and mode == 1 :
296
- value = c .Initializer (value , self . ccode (obj .initvalue ))
280
+ value = c .Initializer (value , ccode (obj .initvalue ))
297
281
298
282
return value
299
283
@@ -327,7 +311,7 @@ def _args_call(self, args):
327
311
else :
328
312
ret .append (i ._C_name )
329
313
except AttributeError :
330
- ret .append (self . ccode (i ))
314
+ ret .append (ccode (i ))
331
315
return ret
332
316
333
317
def _gen_signature (self , o , is_declaration = False ):
@@ -393,7 +377,7 @@ def visit_tuple(self, o):
393
377
def visit_PointerCast (self , o ):
394
378
f = o .function
395
379
i = f .indexed
396
- cstr = self . ccode (i ._C_typedata )
380
+ cstr = ccode (i ._C_typedata )
397
381
398
382
if f .is_PointerArray :
399
383
# lvalue
@@ -415,7 +399,7 @@ def visit_PointerCast(self, o):
415
399
else :
416
400
v = f .name
417
401
if o .flat is None :
418
- shape = '' .join ("[%s]" % self . ccode (i ) for i in o .castshape )
402
+ shape = '' .join ("[%s]" % ccode (i ) for i in o .castshape )
419
403
rshape = '(*)%s' % shape
420
404
lvalue = c .Value (cstr , '(*restrict %s)%s' % (v , shape ))
421
405
else :
@@ -448,9 +432,9 @@ def visit_Dereference(self, o):
448
432
a0 , a1 = o .functions
449
433
if a1 .is_PointerArray or a1 .is_TempFunction :
450
434
i = a1 .indexed
451
- cstr = self . ccode (i ._C_typedata )
435
+ cstr = ccode (i ._C_typedata )
452
436
if o .flat is None :
453
- shape = '' .join ("[%s]" % self . ccode (i ) for i in a0 .symbolic_shape [1 :])
437
+ shape = '' .join ("[%s]" % ccode (i ) for i in a0 .symbolic_shape [1 :])
454
438
rvalue = '(%s (*)%s) %s[%s]' % (cstr , shape , a1 .name ,
455
439
a1 .dim .name )
456
440
lvalue = c .Value (cstr , '(*restrict %s)%s' % (a0 .name , shape ))
@@ -489,8 +473,8 @@ def visit_Definition(self, o):
489
473
return self ._gen_value (o .function )
490
474
491
475
def visit_Expression (self , o ):
492
- lhs = self . ccode (o .expr .lhs , dtype = o .dtype , compiler = self . _compiler )
493
- rhs = self . ccode (o .expr .rhs , dtype = o .dtype , compiler = self . _compiler )
476
+ lhs = ccode (o .expr .lhs , dtype = o .dtype )
477
+ rhs = ccode (o .expr .rhs , dtype = o .dtype )
494
478
495
479
if o .init :
496
480
code = c .Initializer (self ._gen_value (o .expr .lhs , 0 ), rhs )
@@ -503,8 +487,8 @@ def visit_Expression(self, o):
503
487
return code
504
488
505
489
def visit_AugmentedExpression (self , o ):
506
- c_lhs = self . ccode (o .expr .lhs , dtype = o .dtype , compiler = self . _compiler )
507
- c_rhs = self . ccode (o .expr .rhs , dtype = o .dtype , compiler = self . _compiler )
490
+ c_lhs = ccode (o .expr .lhs , dtype = o .dtype )
491
+ c_rhs = ccode (o .expr .rhs , dtype = o .dtype )
508
492
code = c .Statement ("%s %s= %s" % (c_lhs , o .op , c_rhs ))
509
493
if o .pragmas :
510
494
code = c .Module (self ._visit (o .pragmas ) + (code ,))
@@ -523,7 +507,7 @@ def visit_Call(self, o, nested_call=False):
523
507
o .templates )
524
508
if retobj .is_Indexed or \
525
509
isinstance (retobj , (FieldFromComposite , FieldFromPointer )):
526
- return c .Assign (self . ccode (retobj ), call )
510
+ return c .Assign (ccode (retobj ), call )
527
511
else :
528
512
return c .Initializer (c .Value (rettype , retobj ._C_name ), call )
529
513
@@ -537,9 +521,9 @@ def visit_Conditional(self, o):
537
521
then_body = c .Block (self ._visit (then_body ))
538
522
if else_body :
539
523
else_body = c .Block (self ._visit (else_body ))
540
- return c .If (self . ccode (o .condition ), then_body , else_body )
524
+ return c .If (ccode (o .condition ), then_body , else_body )
541
525
else :
542
- return c .If (self . ccode (o .condition ), then_body )
526
+ return c .If (ccode (o .condition ), then_body )
543
527
544
528
def visit_Iteration (self , o ):
545
529
body = flatten (self ._visit (i ) for i in self ._blankline_logic (o .children ))
@@ -549,23 +533,23 @@ def visit_Iteration(self, o):
549
533
550
534
# For backward direction flip loop bounds
551
535
if o .direction == Backward :
552
- loop_init = 'int %s = %s' % (o .index , self . ccode (_max ))
553
- loop_cond = '%s >= %s' % (o .index , self . ccode (_min ))
536
+ loop_init = 'int %s = %s' % (o .index , ccode (_max ))
537
+ loop_cond = '%s >= %s' % (o .index , ccode (_min ))
554
538
loop_inc = '%s -= %s' % (o .index , o .limits [2 ])
555
539
else :
556
- loop_init = 'int %s = %s' % (o .index , self . ccode (_min ))
557
- loop_cond = '%s <= %s' % (o .index , self . ccode (_max ))
540
+ loop_init = 'int %s = %s' % (o .index , ccode (_min ))
541
+ loop_cond = '%s <= %s' % (o .index , ccode (_max ))
558
542
loop_inc = '%s += %s' % (o .index , o .limits [2 ])
559
543
560
544
# Append unbounded indices, if any
561
545
if o .uindices :
562
- uinit = ['%s = %s' % (i .name , self . ccode (i .symbolic_min )) for i in o .uindices ]
546
+ uinit = ['%s = %s' % (i .name , ccode (i .symbolic_min )) for i in o .uindices ]
563
547
loop_init = c .Line (', ' .join ([loop_init ] + uinit ))
564
548
565
549
ustep = []
566
550
for i in o .uindices :
567
551
op = '=' if i .is_Modulo else '+='
568
- ustep .append ('%s %s %s' % (i .name , op , self . ccode (i .symbolic_incr )))
552
+ ustep .append ('%s %s %s' % (i .name , op , ccode (i .symbolic_incr )))
569
553
loop_inc = c .Line (', ' .join ([loop_inc ] + ustep ))
570
554
571
555
# Create For header+body
@@ -582,7 +566,7 @@ def visit_Pragma(self, o):
582
566
return c .Pragma (o ._generate )
583
567
584
568
def visit_While (self , o ):
585
- condition = self . ccode (o .condition )
569
+ condition = ccode (o .condition )
586
570
if o .body :
587
571
body = flatten (self ._visit (i ) for i in o .children )
588
572
return c .While (condition , c .Block (body ))
0 commit comments