diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 8feb600c09..fd633029e5 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -215,7 +215,7 @@ ttype | Array(ttype type, dimension* dims, array_physical_type physical_type) | FunctionType(ttype* arg_types, ttype? return_var_type, abi abi, deftype deftype, string? bindc_name, bool elemental, bool pure, bool module, bool inline, bool static, symbol* restrictions, bool is_restriction) -cast_kind = RealToInteger | IntegerToReal | LogicalToReal | RealToReal | IntegerToInteger | RealToComplex | IntegerToComplex | IntegerToLogical | RealToLogical | CharacterToLogical | CharacterToInteger | CharacterToList | ComplexToLogical | ComplexToComplex | ComplexToReal | ComplexToInteger | LogicalToInteger | RealToCharacter | IntegerToCharacter | LogicalToCharacter | UnsignedIntegerToInteger | UnsignedIntegerToUnsignedInteger | UnsignedIntegerToReal | UnsignedIntegerToLogical | IntegerToUnsignedInteger | RealToUnsignedInteger | CPtrToUnsignedInteger | UnsignedIntegerToCPtr | IntegerToSymbolicExpression | ListToArray +cast_kind = RealToInteger | IntegerToReal | LogicalToReal | RealToReal | IntegerToInteger | RealToComplex | IntegerToComplex | IntegerToLogical | RealToLogical | CharacterToLogical | CharacterToInteger | CharacterToList | ComplexToLogical | ComplexToComplex | ComplexToReal | ComplexToInteger | LogicalToInteger | RealToCharacter | IntegerToCharacter | LogicalToCharacter | UnsignedIntegerToInteger | UnsignedIntegerToUnsignedInteger | UnsignedIntegerToReal | UnsignedIntegerToLogical | IntegerToUnsignedInteger | RealToUnsignedInteger | CPtrToUnsignedInteger | UnsignedIntegerToCPtr | IntegerToSymbolicExpression | ListToArray | DerivedToBase storage_type = Default | Save | Parameter access = Public | Private intent = Local | In | Out | InOut | ReturnVar | Unspecified diff --git a/src/libasr/casting_utils.cpp b/src/libasr/casting_utils.cpp index 45ab744304..68e3971839 100644 --- a/src/libasr/casting_utils.cpp +++ b/src/libasr/casting_utils.cpp @@ -41,7 +41,8 @@ namespace LCompilers::CastingUtil { {ASR::ttypeType::Complex, ASR::cast_kindType::ComplexToComplex}, {ASR::ttypeType::Real, ASR::cast_kindType::RealToReal}, {ASR::ttypeType::Integer, ASR::cast_kindType::IntegerToInteger}, - {ASR::ttypeType::UnsignedInteger, ASR::cast_kindType::UnsignedIntegerToUnsignedInteger} + {ASR::ttypeType::UnsignedInteger, ASR::cast_kindType::UnsignedIntegerToUnsignedInteger}, + {ASR::ttypeType::StructType, ASR::cast_kindType::DerivedToBase} }; int get_type_priority(ASR::ttypeType type) { diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 4aad34d197..8a5e494b28 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -7725,6 +7725,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = LLVM::CreateLoad(*builder, list_api->get_pointer_to_list_data(tmp)); break; } + case (ASR::cast_kindType::DerivedToBase) : { + this->visit_expr(*x.m_arg); + tmp = llvm_utils->create_gep(tmp, 0); + break; + } default : throw CodeGenError("Cast kind not implemented"); } } diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index fad8029a00..70c1aab7e0 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -784,9 +784,26 @@ class CommonVisitor : public AST::BaseVisitor { ASR::call_arg_t c_arg; c_arg.loc = args[i].loc; c_arg.m_value = args[i].m_value; - cast_helper(m_args[i], c_arg.m_value, true); ASR::ttype_t* left_type = ASRUtils::expr_type(m_args[i]); ASR::ttype_t* right_type = ASRUtils::expr_type(c_arg.m_value); + if ( ASR::is_a(*left_type) && ASR::is_a(*right_type) ) { + ASR::StructType_t *l_type = ASR::down_cast(left_type); + ASR::StructType_t *r_type = ASR::down_cast(right_type); + ASR::Struct_t *l2_type = ASR::down_cast( + ASRUtils::symbol_get_past_external( + l_type->m_derived_type)); + ASR::Struct_t *r2_type = ASR::down_cast( + ASRUtils::symbol_get_past_external( + r_type->m_derived_type)); + if ( ASRUtils::is_derived_type_similar(l2_type, r2_type) ) { + cast_helper(m_args[i], c_arg.m_value, true, true); + check_type_equality = false; + } else { + cast_helper(m_args[i], c_arg.m_value, true); + } + } else { + cast_helper(m_args[i], c_arg.m_value, true); + } if( check_type_equality && !ASRUtils::check_equal_type(left_type, right_type) ) { std::string ltype = ASRUtils::type_to_str_python(left_type); std::string rtype = ASRUtils::type_to_str_python(right_type); @@ -2962,9 +2979,8 @@ class CommonVisitor : public AST::BaseVisitor { std::string obj_name = x.m_args.m_args->m_arg; for(size_t i = 0; i < x.n_body; i++) { std::string var_name; - if (! AST::is_a(*x.m_body[i]) ){ - throw SemanticError("Only AnnAssign implemented in __init__ ", - x.m_body[i]->base.loc); + if ( !AST::is_a(*x.m_body[i]) ){ + continue; } AST::AnnAssign_t ann_assign = *AST::down_cast(x.m_body[i]); if(AST::is_a(*ann_assign.m_target)){ @@ -3301,10 +3317,21 @@ class CommonVisitor : public AST::BaseVisitor { current_scope->add_symbol(x_m_name, class_type); } } else { - if( x.n_bases > 0 ) { - throw SemanticError("Inheritance in classes isn't supported yet.", + ASR::symbol_t* parent = nullptr; + if( x.n_bases > 1 ) { + throw SemanticError("Multiple inheritance in classes isn't supported yet.", x.base.base.loc); } + else if (x.n_bases == 1) { + std::string b_name = ""; + if ( AST::is_a(*x.m_bases[0]) ) { + b_name = AST::down_cast(x.m_bases[0])->m_id; + } else { + throw SemanticError("Expected a Name here", x.base.base.loc); + } + parent = current_scope->resolve_symbol(b_name); + LCOMPILERS_ASSERT(ASR::is_a(*parent)); + } SymbolTable *parent_scope = current_scope; if( ASR::symbol_t* sym = current_scope->resolve_symbol(x_m_name) ) { LCOMPILERS_ASSERT(ASR::is_a(*sym)); @@ -3316,7 +3343,7 @@ class CommonVisitor : public AST::BaseVisitor { f = AST::down_cast(x.m_body[i]); init_self_type(*f, sym, x.base.base.loc); if ( std::string(f->m_name) == std::string("__init__") ) { - this->visit_init_body(*f); + this->visit_init_body(*f, st->m_parent, x.m_body[i]->base.loc); } else { this->visit_stmt(*x.m_body[i]); } @@ -3344,7 +3371,7 @@ class CommonVisitor : public AST::BaseVisitor { member_names.p, member_names.size(), member_fn_names.p, member_fn_names.size(), class_abi, ASR::accessType::Public, false, false, member_init.p, member_init.size(), - nullptr, nullptr)); + nullptr, parent)); parent_scope->add_symbol(x.m_name, class_sym); visit_ClassMembers(x, member_names, member_fn_names, struct_dependencies, member_init, false, class_abi, true); @@ -3387,7 +3414,7 @@ class CommonVisitor : public AST::BaseVisitor { current_scope = parent_scope; } - virtual void visit_init_body (const AST::FunctionDef_t &/*x*/) = 0; + virtual void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/) = 0; void add_name(const Location &loc) { std::string var_name = "__name__"; @@ -4421,7 +4448,7 @@ class SymbolTableVisitor : public CommonVisitor { // Implement visit_Global for Symbol Table visitor. void visit_Global(const AST::Global_t &/*x*/) {} - void visit_init_body (const AST::FunctionDef_t &/*x*/) { + void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/) { //Implemented in BodyVisitor } @@ -5153,7 +5180,7 @@ class BodyVisitor : public CommonVisitor { tmp = asr; } - void visit_init_body (const AST::FunctionDef_t &x) { + void visit_init_body (const AST::FunctionDef_t &x, ASR::symbol_t* parent_sym, const Location loc) { SymbolTable *old_scope = current_scope; ASR::symbol_t *t = current_scope->get_symbol("__init__"); if ( t==nullptr ) { @@ -5163,31 +5190,82 @@ class BodyVisitor : public CommonVisitor { throw SemanticError("__init__ is not a function", x.base.base.loc); } ASR::Function_t *f = ASR::down_cast(t); + current_scope = f->m_symtab; //Transform statements into correct format - Vec new_body; - new_body.reserve(al, 1); + Vec body; + body.reserve(al, 1); + ASR::stmt_t* super_call_stmt = nullptr; for (size_t i=0; i(x.m_body[i]); - if ( ann_assign.m_value != nullptr ) { - Vectarget; - target.reserve(al, 1); - target.push_back(al, ann_assign.m_target); - AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc, - target.p, 1, ann_assign.m_value, nullptr); - AST::stmt_t* assgn = AST::down_cast(assgn_ast); - new_body.push_back(al, assgn); + if (AST::is_a(*x.m_body[i])) { + AST::AnnAssign_t ann_assign = *AST::down_cast(x.m_body[i]); + if ( ann_assign.m_value != nullptr ) { + Vectarget; + target.reserve(al, 1); + target.push_back(al, ann_assign.m_target); + AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc, + target.p, 1, ann_assign.m_value, nullptr); + AST::stmt_t* assgn = AST::down_cast(assgn_ast); + body.push_back(al, assgn); + } + } else if (AST::is_a(*x.m_body[i]) && + AST::is_a(*(AST::down_cast(x.m_body[i])->m_value))) { + AST::Call_t* c = AST::down_cast(AST::down_cast(x.m_body[i])->m_value); + + if ( !AST::is_a(*(c->m_func)) + || !AST::is_a(*(AST::down_cast(c->m_func)->m_value)) ) { + body.push_back(al, x.m_body[i]); + continue; + } + AST::Call_t* super_call = AST::down_cast(AST::down_cast(c->m_func)->m_value); + std::string attr = AST::down_cast(c->m_func)->m_attr; + if ( AST::is_a(*(super_call->m_func)) && + std::string(AST::down_cast(super_call->m_func)->m_id)=="super" && + attr == "__init__") { + if (parent_sym == nullptr) { + throw SemanticError("The class doesn't have a base class",loc); + } + Vec args; + args.reserve(al, 1); + parse_args(*super_call,args); + ASR::call_arg_t first_arg; + first_arg.loc = loc; + ASR::symbol_t* self_sym = current_scope->get_symbol("self"); + first_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al,loc,self_sym)); + ASR::ttype_t* target_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al,loc,parent_sym)); + cast_helper(target_type, first_arg.m_value, x.base.base.loc, true); + Vec args_w_first; args_w_first.reserve(al,1); + args_w_first.push_back(al, first_arg); + for( size_t i = 0; i < args.size(); i++ ) { + args_w_first.push_back(al,args[i]); + } + std::string call_name = "__init__"; + ASR::symbol_t* call_sym = get_struct_member(parent_sym,call_name,loc); + super_call_stmt = ASRUtils::STMT( + ASR::make_SubroutineCall_t(al, loc, call_sym, call_sym, args_w_first.p, + args_w_first.size(), nullptr)); + } + } else { + body.push_back(al, x.m_body[i]); } } current_scope = f->m_symtab; - Vec body; - body.reserve(al, x.n_body); + Vec body_asr; + body_asr.reserve(al, x.n_body); + Vec new_body_asr; + new_body_asr.reserve(al,1); + if ( super_call_stmt ) { + new_body_asr.push_back(al, super_call_stmt); + } Vec rts; rts.reserve(al, 4); dependencies.clear(al); - transform_stmts(body, new_body.n, new_body.p); + transform_stmts(body_asr, body.n, body.p); + for (size_t i=0; im_body = body.p; - f->n_body = body.size(); + f->m_body = new_body_asr.p; + f->n_body = new_body_asr.size(); ASR::FunctionType_t* func_type = ASR::down_cast( f->m_function_signature); func_type->m_restrictions = rts.p; @@ -6239,10 +6317,14 @@ class BodyVisitor : public CommonVisitor { for( size_t i = 0; i < der_type->n_members && !member_found; i++ ) { member_found = std::string(der_type->m_members[i]) == member_name; } - if( !member_found ) { + if( !member_found && !der_type->m_parent ) { throw SemanticError("No member " + member_name + " found in " + std::string(der_type->m_name), loc); + } else if ( !member_found && der_type->m_parent ) { + ASR::ttype_t* parent_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc,der_type->m_parent)); + visit_AttributeUtil(parent_type,attr_char,t,loc); + return; } ASR::expr_t *val = ASR::down_cast(ASR::make_Var_t(al, loc, t)); ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name); @@ -8064,7 +8146,8 @@ we will have to use something else. //TODO: Correct Class and ClassType // call to struct member function // modifying args to pass the object as self - ASR::symbol_t* der = ASR::down_cast(var->m_type)->m_derived_type; + ASR::symbol_t* der_sym = ASR::down_cast(var->m_type)->m_derived_type; + ASR::Struct_t* der = ASR::down_cast(der_sym); Vec new_args; new_args.reserve(al, args.n + 1); ASR::call_arg_t self_arg; self_arg.loc = args[0].loc; @@ -8073,7 +8156,20 @@ we will have to use something else. for (size_t i=0; im_symtab->get_symbol(call_name) ) { + st = get_struct_member(der_sym, call_name, loc); + } else if ( der->m_parent ) { + ASR::Struct_t* parent = ASR::down_cast(der->m_parent); + if ( !parent->m_symtab->get_symbol(call_name) ) { + throw SemanticError("Method not found in the class "+ std::string(der->m_name) + + " or it's parents",loc); + } else { + st = get_struct_member(der->m_parent, call_name, loc); + } + } else { + throw SemanticError("Method not found in the class "+std::string(der->m_name)+ + " or it's parents",loc); + } tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc); return; } else {