diff --git a/src/pycropml/cyml.py b/src/pycropml/cyml.py index 478a5c4..fb22348 100644 --- a/src/pycropml/cyml.py +++ b/src/pycropml/cyml.py @@ -164,6 +164,9 @@ def transpile_package(package, language): name = os.path.split(file)[1].split(".")[0] for model in models: # in the case we haven't the same order if name.lower() == model.name.lower() and prefix(model) != "function": + #getattr(getattr(pycropml.transpiler.generators, f'{NAMES[language]}Generator'), + # f'header_mu_cpp2')(model, tg_rep, mc_name) + test = Main(file, language, model, T.model.name) test.parse() test.to_ast(source) diff --git a/src/pycropml/transpiler/generators/cpp2Generator.py b/src/pycropml/transpiler/generators/cpp2Generator.py index 66c78fb..2481823 100644 --- a/src/pycropml/transpiler/generators/cpp2Generator.py +++ b/src/pycropml/transpiler/generators/cpp2Generator.py @@ -56,6 +56,11 @@ def visit_comparison(self, node): # self.write(')') def visit_binary_op(self, node): + # ignore checks against none (nullptr) because we use value types for lists and arrays + if node.right.type == "none" and node.left.pseudo_type[0] in ["list", "array"]: + self.write("true") + return + op = node.op prec = self.binop_precedence.get(op, 0) self.operator_enter(prec) @@ -86,13 +91,14 @@ def visit_unary_op(self, node): def visit_breakstatnode(self, node): self.newline(node) - self.write('break;') + self.write("break;") def visit_import(self, node): pass def visit_none(self, node): - pass + self.write("null") + #pass def visit_cond_expr_node(self, node): self.visit(node.test) @@ -103,20 +109,19 @@ def visit_cond_expr_node(self, node): def visit_if_statement(self, node): self.newline(node) - self.write('if (') + self.write("if (") self.visit(node.test) - self.write(')') - self.newline(node) - self.write('{') + self.write(") {") self.newline(node) self.body(node.block) self.newline(node) - self.write('}') + self.write("}") + self.newline(node) while True: else_ = node.otherwise if len(else_) == 0: break - elif len(else_) == 1 and else_[0].type == 'elseif_statement': + elif len(else_) == 1 and else_[0].type == "elseif_statement": self.visit(else_[0]) else: self.visit(else_) @@ -125,23 +130,19 @@ def visit_if_statement(self, node): def visit_elseif_statement(self, node): self.newline() - self.write('else if ( ') + self.write("else if (") self.visit(node.test) - self.write(')') - self.newline(node) - self.write('{') + self.write(") {") self.body(node.block) self.newline(node) - self.write('}') + self.write("}") def visit_else_statement(self, node): self.newline() - self.write('else') - self.newline(node) - self.write('{') + self.write("else {") self.body(node.block) self.newline(node) - self.write('}') + self.write("}") def visit_print(self, node): pass @@ -177,7 +178,10 @@ def visit_dict(self, node): self.write(u'}') def visit_bool(self, node): - self.write("true") if node.value == True else self.write("false") + if isinstance(node.value, str): + self.write(node.value) + elif isinstance(node.value, bool): + self.write("true") if node.value == True else self.write("false") def visit_standard_method_call(self, node): l = node.receiver.pseudo_type @@ -247,6 +251,8 @@ def visit_sliceindex(self, node): self.write(u"]") def visit_assignment(self, node): + dv = dir(node.value) + dt = dir(node.target) if node.value.type == "binary_op" and node.value.left.type == "list": self.visit(node.target) self.write(".assign(") @@ -255,7 +261,6 @@ def visit_assignment(self, node): self.visit(node.value.left.elements[0]) self.write(");") - elif "function" in dir(node.value) and node.value.function.split('_')[0] == "model": name = node.value.function.split('model_')[1] for m in self.model.model: @@ -323,6 +328,457 @@ def visit_assignment(self, node): elif node.value.type == "none": pass + # all combinations of copies between slices + # assumes right now that the slices have the same size + # so no contraction or expansion of the target slice/array + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex" + and node.value.type == "sliceindex" and node.value.message == "sliceindex"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=") + self.visit(node.value.args[0]) + self.write("; _i < (") + self.visit(node.target.args[1]) + self.write(") && _k < (") + self.visit(node.value.args[1]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_from"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=") + self.visit(node.value.args[0]) + self.write("; _i < (") + self.visit(node.target.args[1]) + self.write(") && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_to"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < (") + self.visit(node.target.args[1]) + self.write(") && _k < (") + self.visit(node.value.args[0]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex" + and node.value.type == "sliceindex" and node.value.message == "slice_"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < (") + self.visit(node.target.args[1]) + self.write(") && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex" and + "pseudo_type" in dv and node.value.pseudo_type[0] == "array"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < (") + self.visit(node.target.args[1]) + self.write(") && _k < ") + self.visit(node.value) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_from" and + node.value.type == "sliceindex" and node.value.message == "sliceindex"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < (") + self.visit(node.value.args[1]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_from" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_from"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_from" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_to"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < (") + self.visit(node.target.receiver) + self.write(".size() && _k < (") + self.visit(node.value.args[0]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_from" and + node.value.type == "sliceindex" and node.value.message == "slice_"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_from" and + "pseudo_type" in dv and node.value.pseudo_type[0] == "array"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_to" and + node.value.type == "sliceindex" and node.value.message == "sliceindex"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < (") + self.visit(node.target.args[0]) + self.write(") && _k < (") + self.visit(node.value.args[1]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_to" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_from"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < (") + self.visit(node.target.args[0]) + self.write(") && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_to" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_to"): + self.write("for (int _i=0, _k=0; _i < (") + self.visit(node.target.args[0]) + self.write(") && _k < (") + self.visit(node.value.args[0]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_to" and + node.value.type == "sliceindex" and node.value.message == "slice_"): + self.write("for (int _i=0, _k=0; _i < (") + self.visit(node.target.args[0]) + self.write(") && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_to" and + "pseudo_type" in dv and node.value.pseudo_type[0] == "array"): + self.write("for (int _i=0, _k=0; _i < (") + self.visit(node.target.args[0]) + self.write(") && _k < ") + self.visit(node.value) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "slice_" + and node.value.type == "sliceindex" and node.value.message == "sliceindex"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < (") + self.visit(node.value.args[1]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "slice_" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_from"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "slice_" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_to"): + self.write("for (int _i=0, _k=0; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < (") + self.visit(node.value.args[0]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "slice_" + and node.value.type == "sliceindex" and node.value.message == "slice_"): + self.write("for (int _i=0, _k=0; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "slice_" and + "pseudo_type" in dv and node.value.pseudo_type[0] == "array"): + self.write("for (int _i=0, _k=0; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif ("pseudo_type" in dt and node.target.pseudo_type[0] == "array" and + node.value.type == "sliceindex" and node.value.message == "sliceindex"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target) + self.write(".size() && _k < (") + self.visit(node.value.args[1]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif ("pseudo_type" in dt and node.target.pseudo_type[0] == "array" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_from"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif ("pseudo_type" in dt and node.target.pseudo_type[0] == "array" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_to"): + self.write("for (int _i=0, _k=0; _i < ") + self.visit(node.target) + self.write(".size() && _k < (") + self.visit(node.value.args[0]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif ("pseudo_type" in dt and node.target.pseudo_type[0] == "array" and + node.value.type == "sliceindex" and node.value.message == "slice_"): + self.write("for (int _i=0, _k=0; _i < ") + self.visit(node.target) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + else: self.newline(node) self.visit(node.target) @@ -465,6 +921,7 @@ def visit_module(self, node): '#include \n' '#include \n' '#include \n' + '#include \n' '#include \n') if self.model: self.write(f'#include "{self.model.name}.h"\n') @@ -697,8 +1154,15 @@ def visit_declaration(self, node): self.write(f"std::vector<{self.types[n.pseudo_type[1]]}> {n.name};") else: self.write(f"std::vector<{self.types[n.pseudo_type[1]]}> {n.name}") - self.write(f"({n.elts[0].name if 'name' in dir(n.elts[0]) else n.elts[0].value});") - elif 'value' in dn and n.type in ("int", "float", "str", "bool"): + if "name" in dir(n.elts[0]): + self.write(f"({n.elts[0].name});") + elif "value" in dir(n.elts[0]): + self.write(f"({n.elts[0].value});") + else: + self.write(f"(") + self.visit(n.elts[0]) + self.write(");") + elif "value" in dn and n.type in ("int", "float", "str", "bool"): if "feat" in dn and n.feat in ("OUT", "INOUT") and is_init_or_model_func: self.write(f"{self.struct_name_for(n.name)}.{n.name} = ") else: @@ -708,7 +1172,7 @@ def visit_declaration(self, node): else: self.visit(n) self.write(";") - elif n.type == 'datetime': + elif n.type == "datetime": if "feat" in dn and n.feat in ("OUT", "INOUT") and is_init_or_model_func: self.write(f"{self.struct_name_for(n.name)}.{n.name}") else: @@ -717,40 +1181,40 @@ def visit_declaration(self, node): self.write(" = ") self.visit(n.elts) self.write(";") - elif 'elements' in dn and n.type in ("list", "tuple"): + elif "elements" in dn and n.type in ("list", "tuple"): if "feat" in dn and n.feat in ("OUT", "INOUT") and is_init_or_model_func: if n.type == "list": self.write(f"{self.struct_name_for(n.name)}.{n.name} = ") - self.write(u'{') + self.write("{") self.comma_separated_list(n.elements) - self.write(u'};') + self.write("};") else: if n.type == "list": self.visit_decl(n.pseudo_type) self.write(n.name) self.write(" = ") - self.write(u'{') + self.write("{") self.comma_separated_list(n.elements) - self.write(u'};') - if n.type == 'tuple': + self.write("};") + if n.type == "tuple": pass - elif 'pairs' in dn and n.type == "dict": + elif "pairs" in dn and n.type == "dict": self.visit_decl(n.pseudo_type) self.write(n.name) - self.write(u' = {') + self.write(" = {") self.comma_separated_list(n.pairs) - self.write(u'};') + self.write("};") self.newline(node) def visit_list_decl(self, node, pa=None): if not isinstance(node[1], list): self.write(self.types[node[1]]) - self.write('>') + self.write(">") else: node = node[1] self.visit_decl(node, pa) - self.write('>') + self.write(">") if pa and "name" in dir(pa): self.write(f" {pa.name}") @@ -759,11 +1223,11 @@ def visit_dict_decl(self, node): self.write(",") if not isinstance(node[2], list): self.write(self.types[node[2]]) - self.write('>') + self.write(">") else: node = node[2] self.visit_decl(node) - self.write('>') + self.write(">") def visit_tuple_decl(self, node): self.visit_decl(node[0]) @@ -771,7 +1235,7 @@ def visit_tuple_decl(self, node): self.visit_decl(n) self.write(",") self.visit_decl(node[-1]) - self.write('>') + self.write(">") def visit_float_decl(self, node, pa=None): self.write(self.types[node]) @@ -830,23 +1294,23 @@ def visit_array_decl(self, node, pa=None): else:""" node = node[1] self.visit_decl(node) - self.write('> ') + self.write("> ") if pa: self.write(pa.name) def visit_decl(self, node, pa=None): if isinstance(node, list): if node[0] == "list": - self.write('std::vector<') + self.write("std::vector<") self.visit_list_decl(node, pa) elif node[0] == "dict": self.write("std::map<") self.visit_dict_decl(node) elif node[0] == "tuple": - self.write('std::tuple<') + self.write("std::tuple<") self.visit_tuple_decl(node) elif node[0] == "array": - self.write('std::vector<') + self.write("std::vector<") self.visit_array_decl(node, pa) else: if node == "float": @@ -861,11 +1325,11 @@ def visit_decl(self, node, pa=None): self.visit_datetime_decl(node) def visit_pair(self, node): - self.write(u'{') + self.write("{") self.visit(node.key) - self.write(u", ") + self.write(", ") self.visit(node.value) - self.write(u'}') + self.write("}") def visit_call(self, node): want_comma = [] @@ -883,17 +1347,29 @@ def write_comma(): self.visit(node.function(node)) else: self.write(node.function) - self.write('(') + self.write("(") if isinstance(node.args, list): for arg in node.args: write_comma() self.visit(arg) else: self.visit(node.args) - self.write(')') + self.write(")") def visit_standard_call(self, node): - node.function = self.functions[node.namespace][node.function] + ns = self.functions[node.namespace] + fn_name = node.function + node.function = ns[node.function] if fn_name in ns else None + if not node.function: # search other namespaces + for ns_name, ns2 in self.functions.items(): + if ns_name == node.namespace: + continue + if fn_name in ns2: + node.function = ns2[fn_name] + break + if not node.function: + print(f"Couldn't find function {fn_name} in namespace {node.namespace}") + return self.visit_call(node) def visit_importfrom(self, node): @@ -901,14 +1377,12 @@ def visit_importfrom(self, node): def visit_for_statement(self, node): self.newline(node) - self.write("for(") + self.write("for (") if "iterators" in dir(node): self.visit(node.iterators) if "sequences" in dir(node): self.visit(node.sequences) - self.write(')') - self.newline(node) - self.write('{') + self.write(") {") if "iterators" in dir(node): self.newline(node) self.indentation += 1 @@ -917,10 +1391,11 @@ def visit_for_statement(self, node): self.body(node.block) self.newline(node) self.write('}') + self.newline(node) def visit_for_iterator_with_index(self, node): self.visit(node.index) - self.write(' , ') + self.write(', ') self.visit(node.iterator) def visit_for_sequence_with_index(self, node): @@ -940,34 +1415,31 @@ def visit_for_range_statement(self, node): self.visit(node.index) self.write("=") self.visit(node.start) - self.write(' ; ') + self.write("; ") self.visit(node.index) self.write("!=") self.visit(node.end) - self.write(' ; ') + self.write("; ") self.visit(node.index) self.write("+=") if "value" in dir(node.step) and node.step.value == 1: self.write("1") else: self.visit(node.step) - self.write(')') - self.newline(node) - self.write('{') + self.write(") {") self.body(node.block) self.newline(node) - self.write('}') + self.write("}") + self.newline(node) def visit_while_statement(self, node): self.newline(node) - self.write('while ( ') + self.write("while ( ") self.visit(node.test) - self.write(')') - self.newline(node) - self.write('{') + self.write(") {") self.body_or_else(node) self.newline(node) - self.write('}') + self.write("}") class Cpp2Trans(Cpp2Generator): @@ -1565,6 +2037,14 @@ def visit_assignment(self, node): break self.write(f"_{name}.Calculate_Model(s, s1, r, a, ex);") self.newline(node) + elif "function" in dir(node.value) and node.value.function.split('_')[0]=="init": + name = node.value.function.split('init_')[1] + for m in self.modelt.model: + if name.lower() == signature2(m).lower(): + name = signature2(m) + break + self.write(f"_{name}.Init(s, s1, r, a, ex);") + self.newline(node) else: self.newline(node) if node.value.name not in self.modeltparams: diff --git a/src/pycropml/transpiler/generators/cppGenerator.py b/src/pycropml/transpiler/generators/cppGenerator.py index d16da31..e9d98f0 100644 --- a/src/pycropml/transpiler/generators/cppGenerator.py +++ b/src/pycropml/transpiler/generators/cppGenerator.py @@ -56,6 +56,11 @@ def visit_comparison(self, node): # self.write(')') def visit_binary_op(self, node): + # ignore checks against none (nullptr) because we use value types for lists and arrays + if node.right.type == "none" and node.left.pseudo_type[0] in ["list", "array"]: + self.write("true") + return + op = node.op prec = self.binop_precedence.get(op, 0) self.operator_enter(prec) @@ -92,7 +97,8 @@ def visit_import(self, node): pass def visit_none(self, node): - pass + self.write("null") + #pass def visit_cond_expr_node(self, node): self.visit(node.test) @@ -177,7 +183,10 @@ def visit_dict(self, node): self.write(u'}') def visit_bool(self, node): - self.write(node.value) #if node.value == True else self.write("false") + if isinstance(node.value, str): + self.write(node.value) + elif isinstance(node.value, bool): + self.write("true") if node.value == True else self.write("false") def visit_standard_method_call(self, node): l = node.receiver.pseudo_type @@ -237,6 +246,8 @@ def visit_sliceindex(self, node): self.write(u"]") def visit_assignment(self, node): + dv = dir(node.value) + dt = dir(node.target) if node.value.type == "binary_op" and node.value.left.type == "list": self.visit(node.target) self.write(".assign(") @@ -245,7 +256,6 @@ def visit_assignment(self, node): self.visit(node.value.left.elements[0]) self.write(");") - elif "function" in dir(node.value) and node.value.function.split('_')[0] == "model": name = node.value.function.split('model_')[1] for m in self.model.model: @@ -313,6 +323,458 @@ def visit_assignment(self, node): elif node.value.type == "none": pass + # all combinations of copies between slices + # assumes right now that the slices have the same size + # so no contraction or expansion of the target slice/array + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex" + and node.value.type == "sliceindex" and node.value.message == "sliceindex"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=") + self.visit(node.value.args[0]) + self.write("; _i < (") + self.visit(node.target.args[1]) + self.write(") && _k < (") + self.visit(node.value.args[1]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_from"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=") + self.visit(node.value.args[0]) + self.write("; _i < (") + self.visit(node.target.args[1]) + self.write(") && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_to"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < (") + self.visit(node.target.args[1]) + self.write(") && _k < (") + self.visit(node.value.args[0]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex" + and node.value.type == "sliceindex" and node.value.message == "slice_"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < (") + self.visit(node.target.args[1]) + self.write(") && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex" and + "pseudo_type" in dv and node.value.pseudo_type[0] == "array"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < (") + self.visit(node.target.args[1]) + self.write(") && _k < ") + self.visit(node.value) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_from" and + node.value.type == "sliceindex" and node.value.message == "sliceindex"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < (") + self.visit(node.value.args[1]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_from" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_from"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_from" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_to"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < (") + self.visit(node.target.receiver) + self.write(".size() && _k < (") + self.visit(node.value.args[0]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_from" and + node.value.type == "sliceindex" and node.value.message == "slice_"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_from" and + "pseudo_type" in dv and node.value.pseudo_type[0] == "array"): + self.write("for (int _i=") + self.visit(node.target.args[0]) + self.write(", _k=0; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_to" and + node.value.type == "sliceindex" and node.value.message == "sliceindex"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < (") + self.visit(node.target.args[0]) + self.write(") && _k < (") + self.visit(node.value.args[1]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_to" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_from"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < (") + self.visit(node.target.args[0]) + self.write(") && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_to" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_to"): + self.write("for (int _i=0, _k=0; _i < (") + self.visit(node.target.args[0]) + self.write(") && _k < (") + self.visit(node.value.args[0]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_to" and + node.value.type == "sliceindex" and node.value.message == "slice_"): + self.write("for (int _i=0, _k=0; _i < (") + self.visit(node.target.args[0]) + self.write(") && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "sliceindex_to" and + "pseudo_type" in dv and node.value.pseudo_type[0] == "array"): + self.write("for (int _i=0, _k=0; _i < (") + self.visit(node.target.args[0]) + self.write(") && _k < ") + self.visit(node.value) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "slice_" + and node.value.type == "sliceindex" and node.value.message == "sliceindex"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < (") + self.visit(node.value.args[1]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "slice_" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_from"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "slice_" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_to"): + self.write("for (int _i=0, _k=0; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < (") + self.visit(node.value.args[0]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "slice_" + and node.value.type == "sliceindex" and node.value.message == "slice_"): + self.write("for (int _i=0, _k=0; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif (node.target.type == "sliceindex" and node.target.message == "slice_" and + "pseudo_type" in dv and node.value.pseudo_type[0] == "array"): + self.write("for (int _i=0, _k=0; _i < ") + self.visit(node.target.receiver) + self.write(".size() && _k < ") + self.visit(node.value) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target.receiver) + self.write("[_i] = ") + self.visit(node.value) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif ("pseudo_type" in dt and node.target.pseudo_type[0] == "array" and + node.value.type == "sliceindex" and node.value.message == "sliceindex"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target) + self.write(".size() && _k < (") + self.visit(node.value.args[1]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif ("pseudo_type" in dt and node.target.pseudo_type[0] == "array" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_from"): + self.write("for (int _i=0, _k=") + self.visit(node.value.args[0]) + self.write("; _i < ") + self.visit(node.target) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif ("pseudo_type" in dt and node.target.pseudo_type[0] == "array" and + node.value.type == "sliceindex" and node.value.message == "sliceindex_to"): + self.write("for (int _i=0, _k=0; _i < ") + self.visit(node.target) + self.write(".size() && _k < (") + self.visit(node.value.args[0]) + self.write("); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + + elif ("pseudo_type" in dt and node.target.pseudo_type[0] == "array" and + node.value.type == "sliceindex" and node.value.message == "slice_"): + self.write("for (int _i=0, _k=0; _i < ") + self.visit(node.target) + self.write(".size() && _k < ") + self.visit(node.value.receiver) + self.write(".size(); _i++, _k++) {") + self.newline(node) + self.write(" ") + self.visit(node.target) + self.write("[_i] = ") + self.visit(node.value.receiver) + self.write("[_k];") + self.newline(node) + self.write("}") + self.newline(node) + else: self.newline(node) self.visit(node.target) @@ -439,6 +901,7 @@ def visit_module(self, node): '#include \n' '#include \n' '#include \n' + '#include \n' '#include \n') if self.model: self.write(f'#include "{self.model.name}.h"\n') @@ -657,7 +1120,14 @@ def visit_declaration(self, node): self.write(f"std::vector<{self.types[n.pseudo_type[1]]}> {n.name};") else: self.write(f"std::vector<{self.types[n.pseudo_type[1]]}> {n.name}") - self.write(f"({n.elts[0].name if 'name' in dir(n.elts[0]) else n.elts[0].value});") + if "name" in dir(n.elts[0]): + self.write(f"({n.elts[0].name});") + elif "value" in dir(n.elts[0]): + self.write(f"({n.elts[0].value});") + else: + self.write(f"(") + self.visit(n.elts[0]) + self.write(");") elif 'value' in dir(n) and n.type in ("int", "float", "str", "bool"): self.write(f"{self.types[n.type]} {n.name}") self.write(" = ") @@ -821,7 +1291,19 @@ def write_comma(): self.write(')') def visit_standard_call(self, node): - node.function = self.functions[node.namespace][node.function] + ns = self.functions[node.namespace] + fn_name = node.function + node.function = ns[node.function] if fn_name in ns else None + if not node.function: # search other namespaces + for ns_name, ns2 in self.functions.items(): + if ns_name == node.namespace: + continue + if fn_name in ns2: + node.function = ns2[fn_name] + break + if not node.function: + print(f"Couldn't find function {fn_name} in namespace {node.namespace}") + return self.visit_call(node) def visit_importfrom(self, node): @@ -1313,7 +1795,7 @@ def header_mu_cpp(models, rep, name): file_func = mf.filename path_func = Path(os.path.join(m.path, "crop2ml", file_func)) func_tree = parser(Path(path_func)) - newtree = AstTransformer(func_tree, path_func) + newtree = AstTransformer(func_tree, path_func, m) # print(newtree) dict_ast = newtree.transformer() node_ast = transform_to_syntax_tree(dict_ast) @@ -1346,6 +1828,50 @@ def header_mu_cpp(models, rep, name): h = [] +def header_mu_cpp2(model, rep, name): + #mc = models[0].name + mc = model.name + m = model + h = [] + init = False + #for m in models[0].model: + if m.function: + for mf in m.function: + file_func = mf.filename + path_func = Path(os.path.join(m.path, "crop2ml", file_func)) + func_tree = parser(Path(path_func)) + newtree = AstTransformer(func_tree, path_func, m) + # print(newtree) + dict_ast = newtree.transformer() + node_ast = transform_to_syntax_tree(dict_ast) + z = {} + for f in filter(lambda x: x.type == "function_definition", node_ast.body): + z[f.name] = [f.return_type, f.params] + h.append(z) + if m.initialization: + init = True + generator = CppTrans([m]) + generator.result = [f''' +#pragma once +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include "{name}State.h" +#include "{name}Rate.h" +#include "{name}Auxiliary.h" +#include "{name}Exogenous.h" +'''] + generator.model2Node() + param = generator.node_param + generator.generate_hpp(param, f"{m.name}", mc=mc, h=h, init=init, ns=rep.name) + z = ''.join(generator.result) + filename = Path(os.path.join(rep, f"{m.name}.h")) + with open(filename, "wb") as tg_file: + tg_file.write(z.encode('utf-8')) + h = [] + def headerCompo(models, rep, name): """ Header file of model composite""" mc = models[0].name @@ -1456,6 +1982,14 @@ def visit_assignment(self, node): break self.write(f"_{name}.Calculate_Model(s, s1, r, a, ex);") self.newline(node) + elif "function" in dir(node.value) and node.value.function.split('_')[0]=="init": + name = node.value.function.split('init_')[1] + for m in self.modelt.model: + if name.lower() == signature2(m).lower(): + name = signature2(m) + break + self.write(f"_{name}.Init(s, s1, r, a, ex);") + self.newline(node) else: self.newline(node) if node.value.name not in self.modeltparams: diff --git a/src/pycropml/transpiler/rules/cpp2Rules.py b/src/pycropml/transpiler/rules/cpp2Rules.py index 7fcc658..5322ab3 100644 --- a/src/pycropml/transpiler/rules/cpp2Rules.py +++ b/src/pycropml/transpiler/rules/cpp2Rules.py @@ -1,20 +1,20 @@ from pycropml.transpiler.rules.generalRule import GeneralRule from pycropml.transpiler.pseudo_tree import Node +def translate_len_list(node): + return Node("call", function="int", args=[ + Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type)]) -def translateLenList(node): - return Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type) - -def translateLenStr(node): +def translate_len_str(node): return Node("method_call", receiver=node.receiver, message=".length()", args=[], pseudo_type=node.pseudo_type) -def translateLog(node): +def translate_log(node): return Node("call", function="std::log10", args=[node.args[0]], pseudo_type=node.pseudo_type) -def translateSum(node): +def translate_sum(node): if "name" in dir(node.receiver): #print(node.receiver.y) if "cpp_struct_name" in dir(node.receiver) and node.receiver.cpp_struct_name is not None: @@ -28,64 +28,74 @@ def translateSum(node): pseudo_type=node.pseudo_type) -def translateNotContains(node): +def translate_not_contains(node): return Node("call", function="!", args=[Node("standard_method_call", receiver=node.receiver, message="contains?", args=node.args, pseudo_type=node.pseudo_type)]) -def translateLenDict(node): - return Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type) +def translate_len_dict(node): + return Node("call", function="int", args=[ + Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type)]) -def translateLenArray(node): - return Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type) +def translate_len_array(node): + return Node("call", function="int", args=[ + Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type)]) -def translatekeyDict(node): +def translate_key_dict(node): return Node("method_call", receiver=node.receiver, message=".Keys", args=[], pseudo_type=node.pseudo_type) -def translateget(node): +def translate_get(node): if "value" in dir(node.args[0]): - return Node('index', sequence=Node('local', name=node.receiver.name, + return Node("index", sequence=Node("local", name=node.receiver.name, pseudo_type=node.receiver.pseudo_type), index=Node(node.args[0].type, value=node.args[0].value, pseudo_type=node.args[0].pseudo_type), pseudo_type="Void") elif "name" in dir(node.args[0]): - return Node('index', sequence=Node('local', name=node.receiver.name, + return Node("index", sequence=Node("local", name=node.receiver.name, pseudo_type=node.receiver.pseudo_type), index=Node(node.args[0].type, name=node.args[0].name, pseudo_type=node.args[0].pseudo_type), pseudo_type="Void") -def translatePop(node): - return Node("custom_call", receiver=node.receiver, function="%s.erase" % node.receiver.name, args=[ - Node(type="binary_op", left=Node(type="local", name=f"{node.receiver.name}.begin()"), op="+", - right=node.args)], pseudo_type=node.pseudo_type) +def translate_pop(node): + return Node("custom_call", receiver=node.receiver, function=f"{node.receiver.name}.erase", + args=[Node(type="binary_op", left=Node(type="local", name=f"{node.receiver.name}.begin()"), op="+", + right=node.args)], pseudo_type=node.pseudo_type) -def translateInsert(node): +def translate_insert(node): return Node("custom_call", receiver=node.receiver, function=f"{node.receiver.name}.insert", args=[Node(type="binary_op", left=Node(type="local", name=f"{node.receiver.name}.begin()"), op="+", right=node.args[0]), node.args[1]], pseudo_type=node.pseudo_type) -def translateContains(node): - return Node(type="binary_op", op="!=", - left=Node("custom_call", receiver=node.receiver, function="find", - args=[Node(type="local", name=f"{node.receiver.name}.begin()"), - Node(type="local", name=f"{node.receiver.name}.end()"), - node.args[0]]), - right=Node(type="local", name="%s.end()" % node.receiver.name)) - - -def translateIndex(node): +def translate_contains(node): + if "name" in dir(node.receiver): + return Node(type="binary_op", op="!=", + left=Node("custom_call", receiver=node.receiver, function="find", + args=[Node(type="local", name=f"{node.receiver.name}.begin()"), + Node(type="local", name=f"{node.receiver.name}.end()"), + node.args[0]]), + right=Node(type="local", name="%s.end()" % node.receiver.name)) + else: + if "elements" in dir(node.receiver): + args_str = ", ".join([str(arg.value) for arg in node.receiver.elements]) + return Node(type="binary_op", op=">", + left=Node("custom_call", receiver=node.receiver, + function=f"std::set<{CppRules.types[node.receiver.pseudo_type[1]]}>({{\"{args_str}\"}}).count", + args=[node.args[0]]), + right=Node(type="int", value="0")) + +def translate_index(node): return Node(type="binary_op", op="-", left=Node("custom_call", receiver=node.receiver, function="find", args=[Node(type="local", name=f"{node.receiver.name}.begin()"), Node(type="local", name=f"{node.receiver.name}.end()"), node.args[0]]), - right=Node(type="local", name="%s.begin()" % node.receiver.name)) + right=Node(type="local", name=f"{node.receiver.name}.begin()")) def translate_min_max(node, f): @@ -102,14 +112,13 @@ def translate_min_max(node, f): return node -def translateMIN(node): +def translate_min(node): return translate_min_max(node, "std::min") -def translateMAX(node): +def translate_max(node): return translate_min_max(node, "std::max") - class CppRules(GeneralRule): def __init__(self): GeneralRule.__init__(self) @@ -117,6 +126,7 @@ def __init__(self): binary_op = {"and": "&&", "or": "||", "not": "!", + "is_not": "!=", "<": "<", ">": ">", "==": "==", @@ -129,10 +139,10 @@ def __init__(self): "!=": "!="} unary_op = { - 'not': '!', - '+': '+', - '-': '-', - '~': '~' + "not": "!", + "+": "+", + "-": "-", + "~": "~" } types = { @@ -162,42 +172,43 @@ def __init__(self): } functions = { - 'math': { - 'ln': 'std::log', - 'log': translateLog, - 'tan': 'std::tan', - 'sin': 'std::sin', - 'cos': 'std::cos', - 'asin': 'std::asin', - 'acos': 'std::acos', - 'atan': 'std::atan', - 'sqrt': 'std::sqrt', - 'ceil': '(int) std::ceil', - 'round': 'std::round', - 'exp': 'std::exp', - 'pow': 'std::pow', - 'floor': 'std::floor' - + "math": { + "ln": "std::log", + "log": translate_log, + "tan": "std::tan", + "sin": "std::sin", + "cos": "std::cos", + "asin": "std::asin", + "acos": "std::acos", + "atan": "std::atan", + "sqrt": "std::sqrt", + "ceil": "(int) std::ceil", + "round": "std::round", + "exp": "std::exp", + "pow": "std::pow", + "floor": "std::floor", + "isnan": "std::isnan", + }, + "io": { + "print": "std::cout << ", + "read": "Console.ReadLine", + "read_file": "File.ReadAllText", + "write_file": "File.WriteAllText" }, - 'io': { - 'print': "std::cout << ", - 'read': 'Console.ReadLine', - 'read_file': 'File.ReadAllText', - 'write_file': 'File.WriteAllText' + "system": { + "min": translate_min, + "max": translate_max, + "abs": "std::abs", + "pow": "std::pow" }, - 'system': { - 'min': translateMIN, - 'max': translateMAX, - 'abs': 'std::abs', - 'pow': 'std::pow'}, - 'datetime': { - 'datetime': lambda node: Node(type="str", value=argsToStr(node.args)) + "datetime": { + "datetime": lambda node: Node(type="str", value=args_to_str(node.args)) } } constant = { - 'math': { - 'pi': 'M_PI' + "math": { + "pi": "M_PI" } } @@ -211,29 +222,29 @@ def __init__(self): 'str': { 'int': 'int', 'find': '.find', - 'len': translateLenStr + 'len': translate_len_str }, 'list': { - 'len': translateLenList, + 'len': translate_len_list, 'append': '.push_back', - 'sum': translateSum, - 'pop': translatePop, - 'insert_at': translateInsert, - 'contains?': translateContains, - 'not contains?': translateNotContains, - 'index': translateIndex, + 'sum': translate_sum, + 'pop': translate_pop, + 'insert_at': translate_insert, + 'contains?': translate_contains, + 'not contains?': translate_not_contains, + 'index': translate_index, "allocate": lambda node: Node("assignment", target=node.receiver, value=Node("list", elements=node.args, pseudo_type=node.receiver.pseudo_type)) }, 'dict': { - 'len': translateLenDict, - 'keys': translatekeyDict, - "get": translateget + 'len': translate_len_dict, + 'keys': translate_key_dict, + "get": translate_get }, 'array': { - 'len': translateLenArray, - 'sum': translateSum, + 'len': translate_len_array, + 'sum': translate_sum, 'append': '.Add', "allocate": lambda node: Node("assignment", target=node.receiver, value=Node("list", elements=node.args, @@ -301,7 +312,7 @@ def __init__(self): ''' -def argsToStr(args): +def args_to_str(args): t = [] for arg in args: t.append(arg.value) diff --git a/src/pycropml/transpiler/rules/cppRules.py b/src/pycropml/transpiler/rules/cppRules.py index f353101..0a5cc11 100644 --- a/src/pycropml/transpiler/rules/cppRules.py +++ b/src/pycropml/transpiler/rules/cppRules.py @@ -2,19 +2,20 @@ from pycropml.transpiler.pseudo_tree import Node -def translateLenList(node): - return Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type) +def translate_len_list(node): + return Node("call", function="int", args=[ + Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type)]) -def translateLenStr(node): +def translate_len_str(node): return Node("method_call", receiver=node.receiver, message=".length()", args=[], pseudo_type=node.pseudo_type) -def translateLog(node): +def translate_log(node): return Node("call", function="std::log10", args=[node.args[0]], pseudo_type=node.pseudo_type) -def translateSum(node): +def translate_sum(node): if "name" in dir(node.receiver): print(node.receiver.y) return Node("call", function="accumulate", @@ -24,24 +25,26 @@ def translateSum(node): pseudo_type=node.pseudo_type) -def translateNotContains(node): +def translate_not_contains(node): return Node("call", function="!", args=[Node("standard_method_call", receiver=node.receiver, message="contains?", args=node.args, pseudo_type=node.pseudo_type)]) -def translateLenDict(node): - return Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type) +def translate_len_dict(node): + return Node("call", function="int", args=[ + Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type)]) -def translateLenArray(node): - return Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type) +def translate_len_array(node): + return Node("call", function="int", args=[ + Node("method_call", receiver=node.receiver, message=".size()", args=[], pseudo_type=node.pseudo_type)]) -def translatekeyDict(node): +def translate_key_dict(node): return Node("method_call", receiver=node.receiver, message=".Keys", args=[], pseudo_type=node.pseudo_type) -def translateget(node): +def translate_get(node): if "value" in dir(node.args[0]): return Node('index', sequence=Node('local', name=node.receiver.name, pseudo_type=node.receiver.pseudo_type), @@ -54,34 +57,42 @@ def translateget(node): pseudo_type="Void") -def translatePop(node): +def translate_pop(node): return Node("custom_call", receiver=node.receiver, function="%s.erase" % node.receiver.name, args=[ Node(type="binary_op", left=Node(type="local", name=f"{node.receiver.name}.begin()"), op="+", right=node.args)], pseudo_type=node.pseudo_type) -def translateInsert(node): +def translate_insert(node): return Node("custom_call", receiver=node.receiver, function=f"{node.receiver.name}.insert", args=[Node(type="binary_op", left=Node(type="local", name=f"{node.receiver.name}.begin()"), op="+", right=node.args[0]), node.args[1]], pseudo_type=node.pseudo_type) -def translateContains(node): - return Node(type="binary_op", op="!=", - left=Node("custom_call", receiver=node.receiver, function="find", - args=[Node(type="local", name=f"{node.receiver.name}.begin()"), - Node(type="local", name=f"{node.receiver.name}.end()"), - node.args[0]]), - right=Node(type="local", name="%s.end()" % node.receiver.name)) - - -def translateIndex(node): +def translate_contains(node): + if "name" in dir(node.receiver): + return Node(type="binary_op", op="!=", + left=Node("custom_call", receiver=node.receiver, function="find", + args=[Node(type="local", name=f"{node.receiver.name}.begin()"), + Node(type="local", name=f"{node.receiver.name}.end()"), + node.args[0]]), + right=Node(type="local", name="%s.end()" % node.receiver.name)) + else: + if "elements" in dir(node.receiver): + args_str = ", ".join([str(arg.value) for arg in node.receiver.elements]) + return Node(type="binary_op", op=">", + left=Node("custom_call", receiver=node.receiver, + function=f"std::set<{CppRules.types[node.receiver.pseudo_type[1]]}>({{\"{args_str}\"}}).count", + args=[node.args[0]]), + right=Node(type="int", value="0")) + +def translate_index(node): return Node(type="binary_op", op="-", left=Node("custom_call", receiver=node.receiver, function="find", args=[Node(type="local", name=f"{node.receiver.name}.begin()"), Node(type="local", name=f"{node.receiver.name}.end()"), node.args[0]]), - right=Node(type="local", name="%s.begin()" % node.receiver.name)) + right=Node(type="local", name=f"{node.receiver.name}.begin()")) def translate_min_max(node, f): @@ -98,14 +109,13 @@ def translate_min_max(node, f): return node -def translateMIN(node): +def translate_min(node): return translate_min_max(node, "std::min") -def translateMAX(node): +def translate_max(node): return translate_min_max(node, "std::max") - class CppRules(GeneralRule): def __init__(self): GeneralRule.__init__(self) @@ -113,6 +123,7 @@ def __init__(self): binary_op = {"and": "&&", "or": "||", "not": "!", + "is_not": "!=", "<": "<", ">": ">", "==": "==", @@ -125,10 +136,10 @@ def __init__(self): "!=": "!="} unary_op = { - 'not': '!', - '+': '+', - '-': '-', - '~': '~' + "not": "!", + "+": "+", + "-": "-", + "~": "~" } types = { @@ -158,42 +169,43 @@ def __init__(self): } functions = { - 'math': { - 'ln': 'std::log', - 'log': translateLog, - 'tan': 'std::tan', - 'sin': 'std::sin', - 'cos': 'std::cos', - 'asin': 'std::asin', - 'acos': 'std::acos', - 'atan': 'std::atan', - 'sqrt': 'std::sqrt', - 'ceil': '(int) std::ceil', - 'round': 'std::round', - 'exp': 'std::exp', - 'pow': 'std::pow', - 'floor': 'std::floor' - + "math": { + "ln": "std::log", + "log": translate_log, + "tan": "std::tan", + "sin": "std::sin", + "cos": "std::cos", + "asin": "std::asin", + "acos": "std::acos", + "atan": "std::atan", + "sqrt": "std::sqrt", + "ceil": "(int) std::ceil", + "round": "std::round", + "exp": "std::exp", + "pow": "std::pow", + "floor": "std::floor", + "isnan": "std::isnan", + }, + "io": { + "print": "std::cout << ", + "read": "Console.ReadLine", + "read_file": "File.ReadAllText", + "write_file": "File.WriteAllText" }, - 'io': { - 'print': "std::cout << ", - 'read': 'Console.ReadLine', - 'read_file': 'File.ReadAllText', - 'write_file': 'File.WriteAllText' + "system": { + "min": translate_min, + "max": translate_max, + "abs": "std::abs", + "pow": "std::pow" }, - 'system': { - 'min': translateMIN, - 'max': translateMAX, - 'abs': 'std::abs', - 'pow': 'std::pow'}, - 'datetime': { - 'datetime': lambda node: Node(type="str", value=argsToStr(node.args)) + "datetime": { + "datetime": lambda node: Node(type="str", value=args_to_str(node.args)) } } constant = { - 'math': { - 'pi': 'M_PI' + "math": { + "pi": "M_PI" } } @@ -207,29 +219,29 @@ def __init__(self): 'str': { 'int': 'int', 'find': '.find', - 'len': translateLenStr + 'len': translate_len_str }, 'list': { - 'len': translateLenList, + 'len': translate_len_list, 'append': '.push_back', - 'sum': translateSum, - 'pop': translatePop, - 'insert_at': translateInsert, - 'contains?': translateContains, - 'not contains?': translateNotContains, - 'index': translateIndex, + 'sum': translate_sum, + 'pop': translate_pop, + 'insert_at': translate_insert, + 'contains?': translate_contains, + 'not contains?': translate_not_contains, + 'index': translate_index, "allocate": lambda node: Node("assignment", target=node.receiver, value=Node("list", elements=node.args, pseudo_type=node.receiver.pseudo_type)) }, 'dict': { - 'len': translateLenDict, - 'keys': translatekeyDict, - "get": translateget + 'len': translate_len_dict, + 'keys': translate_key_dict, + "get": translate_get }, 'array': { - 'len': translateLenArray, - 'sum': translateSum, + 'len': translate_len_array, + 'sum': translate_sum, 'append': '.Add', "allocate": lambda node: Node("assignment", target=node.receiver, value=Node("list", elements=node.args, @@ -297,7 +309,7 @@ def __init__(self): ''' -def argsToStr(args): +def args_to_str(args): t = [] for arg in args: t.append(arg.value)