Skip to content

Commit

Permalink
Implemented accessing attrs and fn of base class from der type objs a…
Browse files Browse the repository at this point in the history
…nd polymorphic fn calls
  • Loading branch information
tanay-man authored and Thirumalai-Shaktivel committed Aug 18, 2024
1 parent a104c30 commit 009bbe6
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ ttype
| Array(ttype type, dimension* dims, array_physical_type physical_type)
| FunctionType(ttype* arg_types, ttype? return_var_type, abi abi, deftype deftype, string? bindc_name, bool elemental, bool pure, bool module, bool inline, bool static, symbol* restrictions, bool is_restriction)

cast_kind = RealToInteger | IntegerToReal | LogicalToReal | RealToReal | IntegerToInteger | RealToComplex | IntegerToComplex | IntegerToLogical | RealToLogical | CharacterToLogical | CharacterToInteger | CharacterToList | ComplexToLogical | ComplexToComplex | ComplexToReal | ComplexToInteger | LogicalToInteger | RealToCharacter | IntegerToCharacter | LogicalToCharacter | UnsignedIntegerToInteger | UnsignedIntegerToUnsignedInteger | UnsignedIntegerToReal | UnsignedIntegerToLogical | IntegerToUnsignedInteger | RealToUnsignedInteger | CPtrToUnsignedInteger | UnsignedIntegerToCPtr | IntegerToSymbolicExpression | ListToArray
cast_kind = RealToInteger | IntegerToReal | LogicalToReal | RealToReal | IntegerToInteger | RealToComplex | IntegerToComplex | IntegerToLogical | RealToLogical | CharacterToLogical | CharacterToInteger | CharacterToList | ComplexToLogical | ComplexToComplex | ComplexToReal | ComplexToInteger | LogicalToInteger | RealToCharacter | IntegerToCharacter | LogicalToCharacter | UnsignedIntegerToInteger | UnsignedIntegerToUnsignedInteger | UnsignedIntegerToReal | UnsignedIntegerToLogical | IntegerToUnsignedInteger | RealToUnsignedInteger | CPtrToUnsignedInteger | UnsignedIntegerToCPtr | IntegerToSymbolicExpression | ListToArray | DerivedToBase
storage_type = Default | Save | Parameter
access = Public | Private
intent = Local | In | Out | InOut | ReturnVar | Unspecified
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/casting_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ namespace LCompilers::CastingUtil {
{ASR::ttypeType::Complex, ASR::cast_kindType::ComplexToComplex},
{ASR::ttypeType::Real, ASR::cast_kindType::RealToReal},
{ASR::ttypeType::Integer, ASR::cast_kindType::IntegerToInteger},
{ASR::ttypeType::UnsignedInteger, ASR::cast_kindType::UnsignedIntegerToUnsignedInteger}
{ASR::ttypeType::UnsignedInteger, ASR::cast_kindType::UnsignedIntegerToUnsignedInteger},
{ASR::ttypeType::StructType, ASR::cast_kindType::DerivedToBase}
};

int get_type_priority(ASR::ttypeType type) {
Expand Down
5 changes: 5 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7725,6 +7725,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = LLVM::CreateLoad(*builder, list_api->get_pointer_to_list_data(tmp));
break;
}
case (ASR::cast_kindType::DerivedToBase) : {
this->visit_expr(*x.m_arg);
tmp = llvm_utils->create_gep(tmp, 0);
break;
}
default : throw CodeGenError("Cast kind not implemented");
}
}
Expand Down
151 changes: 121 additions & 30 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,9 +784,26 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
ASR::call_arg_t c_arg;
c_arg.loc = args[i].loc;
c_arg.m_value = args[i].m_value;
cast_helper(m_args[i], c_arg.m_value, true);
ASR::ttype_t* left_type = ASRUtils::expr_type(m_args[i]);
ASR::ttype_t* right_type = ASRUtils::expr_type(c_arg.m_value);
if ( ASR::is_a<ASR::StructType_t>(*left_type) && ASR::is_a<ASR::StructType_t>(*right_type) ) {
ASR::StructType_t *l_type = ASR::down_cast<ASR::StructType_t>(left_type);
ASR::StructType_t *r_type = ASR::down_cast<ASR::StructType_t>(right_type);
ASR::Struct_t *l2_type = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(
l_type->m_derived_type));
ASR::Struct_t *r2_type = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(
r_type->m_derived_type));
if ( ASRUtils::is_derived_type_similar(l2_type, r2_type) ) {
cast_helper(m_args[i], c_arg.m_value, true, true);
check_type_equality = false;
} else {
cast_helper(m_args[i], c_arg.m_value, true);
}
} else {
cast_helper(m_args[i], c_arg.m_value, true);
}
if( check_type_equality && !ASRUtils::check_equal_type(left_type, right_type) ) {
std::string ltype = ASRUtils::type_to_str_python(left_type);
std::string rtype = ASRUtils::type_to_str_python(right_type);
Expand Down Expand Up @@ -2962,9 +2979,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
std::string obj_name = x.m_args.m_args->m_arg;
for(size_t i = 0; i < x.n_body; i++) {
std::string var_name;
if (! AST::is_a<AST::AnnAssign_t>(*x.m_body[i]) ){
throw SemanticError("Only AnnAssign implemented in __init__ ",
x.m_body[i]->base.loc);
if ( !AST::is_a<AST::AnnAssign_t>(*x.m_body[i]) ){
continue;
}
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if(AST::is_a<AST::Attribute_t>(*ann_assign.m_target)){
Expand Down Expand Up @@ -3301,10 +3317,21 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
current_scope->add_symbol(x_m_name, class_type);
}
} else {
if( x.n_bases > 0 ) {
throw SemanticError("Inheritance in classes isn't supported yet.",
ASR::symbol_t* parent = nullptr;
if( x.n_bases > 1 ) {
throw SemanticError("Multiple inheritance in classes isn't supported yet.",
x.base.base.loc);
}
else if (x.n_bases == 1) {
std::string b_name = "";
if ( AST::is_a<AST::Name_t>(*x.m_bases[0]) ) {
b_name = AST::down_cast<AST::Name_t>(x.m_bases[0])->m_id;
} else {
throw SemanticError("Expected a Name here", x.base.base.loc);
}
parent = current_scope->resolve_symbol(b_name);
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*parent));
}
SymbolTable *parent_scope = current_scope;
if( ASR::symbol_t* sym = current_scope->resolve_symbol(x_m_name) ) {
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*sym));
Expand All @@ -3316,7 +3343,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
f = AST::down_cast<AST::FunctionDef_t>(x.m_body[i]);
init_self_type(*f, sym, x.base.base.loc);
if ( std::string(f->m_name) == std::string("__init__") ) {
this->visit_init_body(*f);
this->visit_init_body(*f, st->m_parent, x.m_body[i]->base.loc);
} else {
this->visit_stmt(*x.m_body[i]);
}
Expand Down Expand Up @@ -3344,7 +3371,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
member_names.p, member_names.size(), member_fn_names.p,
member_fn_names.size(), class_abi, ASR::accessType::Public,
false, false, member_init.p, member_init.size(),
nullptr, nullptr));
nullptr, parent));
parent_scope->add_symbol(x.m_name, class_sym);
visit_ClassMembers(x, member_names, member_fn_names,
struct_dependencies, member_init, false, class_abi, true);
Expand Down Expand Up @@ -3387,7 +3414,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
current_scope = parent_scope;
}

virtual void visit_init_body (const AST::FunctionDef_t &/*x*/) = 0;
virtual void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/) = 0;

void add_name(const Location &loc) {
std::string var_name = "__name__";
Expand Down Expand Up @@ -4421,7 +4448,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
// Implement visit_Global for Symbol Table visitor.
void visit_Global(const AST::Global_t &/*x*/) {}

void visit_init_body (const AST::FunctionDef_t &/*x*/) {
void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/) {
//Implemented in BodyVisitor
}

Expand Down Expand Up @@ -5153,7 +5180,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = asr;
}

void visit_init_body (const AST::FunctionDef_t &x) {
void visit_init_body (const AST::FunctionDef_t &x, ASR::symbol_t* parent_sym, const Location loc) {
SymbolTable *old_scope = current_scope;
ASR::symbol_t *t = current_scope->get_symbol("__init__");
if ( t==nullptr ) {
Expand All @@ -5163,31 +5190,77 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
throw SemanticError("__init__ is not a function", x.base.base.loc);
}
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
current_scope = f->m_symtab;
//Transform statements into correct format
Vec<AST::stmt_t*> new_body;
new_body.reserve(al, 1);
Vec<AST::stmt_t*> body;
body.reserve(al, 1);
ASR::stmt_t* super_call_stmt = nullptr;
for (size_t i=0; i<x.n_body; i++) {
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if ( ann_assign.m_value != nullptr ) {
Vec<AST::expr_t*>target;
target.reserve(al, 1);
target.push_back(al, ann_assign.m_target);
AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc,
target.p, 1, ann_assign.m_value, nullptr);
AST::stmt_t* assgn = AST::down_cast<AST::stmt_t>(assgn_ast);
new_body.push_back(al, assgn);
if (AST::is_a<AST::AnnAssign_t>(*x.m_body[i])) {
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if ( ann_assign.m_value != nullptr ) {
Vec<AST::expr_t*>target;
target.reserve(al, 1);
target.push_back(al, ann_assign.m_target);
AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc,
target.p, 1, ann_assign.m_value, nullptr);
AST::stmt_t* assgn = AST::down_cast<AST::stmt_t>(assgn_ast);
body.push_back(al, assgn);
}
} else if (AST::is_a<AST::Expr_t>(*x.m_body[i]) &&
AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Expr_t>(x.m_body[i])->m_value))) {
AST::Call_t* c = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Expr_t>(x.m_body[i])->m_value);

if ( !AST::is_a<AST::Attribute_t>(*(c->m_func))
|| !AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Attribute_t>(c->m_func)->m_value)) ) {
body.push_back(al, x.m_body[i]);
continue;
}
AST::Call_t* super_call = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Attribute_t>(c->m_func)->m_value);
std::string attr = AST::down_cast<AST::Attribute_t>(c->m_func)->m_attr;
if ( AST::is_a<AST::Name_t>(*(super_call->m_func)) &&
std::string(AST::down_cast<AST::Name_t>(super_call->m_func)->m_id)=="super" &&
attr == "__init__") {
if (parent_sym == nullptr) {
throw SemanticError("The class doesn't have a base class",loc);
}
Vec<ASR::call_arg_t> args;
args.reserve(al, 1);
parse_args(*super_call,args);
ASR::call_arg_t first_arg;
first_arg.loc = loc;
ASR::symbol_t* self_sym = current_scope->get_symbol("self");
first_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al,loc,self_sym));
ASR::ttype_t* target_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al,loc,parent_sym));
cast_helper(target_type, first_arg.m_value, x.base.base.loc, true);
Vec<ASR::call_arg_t> args_w_first; args_w_first.reserve(al,1);
args_w_first.push_back(al, first_arg);
for( size_t i = 0; i < args.size(); i++ ) {
args_w_first.push_back(al,args[i]);
}
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(parent_sym,call_name,loc);
super_call_stmt = ASRUtils::STMT(
ASR::make_SubroutineCall_t(al, loc, call_sym, call_sym, args_w_first.p,
args_w_first.size(), nullptr));
}
} else {
body.push_back(al, x.m_body[i]);
}
}
current_scope = f->m_symtab;
Vec<ASR::stmt_t*> body;
body.reserve(al, x.n_body);
Vec<ASR::stmt_t*> body_asr;
body_asr.reserve(al, x.n_body);
if ( super_call_stmt ) {
body_asr.push_back(al, super_call_stmt);
}
Vec<ASR::symbol_t*> rts;
rts.reserve(al, 4);
dependencies.clear(al);
transform_stmts(body, new_body.n, new_body.p);
transform_stmts(body_asr, body.n, body.p);
for (const auto &rt: rt_vec) { rts.push_back(al, rt); }
f->m_body = body.p;
f->n_body = body.size();
f->m_body = body_asr.p;
f->n_body = body_asr.size();
ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(
f->m_function_signature);
func_type->m_restrictions = rts.p;
Expand Down Expand Up @@ -6239,10 +6312,14 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
for( size_t i = 0; i < der_type->n_members && !member_found; i++ ) {
member_found = std::string(der_type->m_members[i]) == member_name;
}
if( !member_found ) {
if( !member_found && !der_type->m_parent ) {
throw SemanticError("No member " + member_name +
" found in " + std::string(der_type->m_name),
loc);
} else if ( !member_found && der_type->m_parent ) {
ASR::ttype_t* parent_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc,der_type->m_parent));
visit_AttributeUtil(parent_type,attr_char,t,loc);
return;
}
ASR::expr_t *val = ASR::down_cast<ASR::expr_t>(ASR::make_Var_t(al, loc, t));
ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name);
Expand Down Expand Up @@ -8064,7 +8141,8 @@ we will have to use something else.
//TODO: Correct Class and ClassType
// call to struct member function
// modifying args to pass the object as self
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
ASR::symbol_t* der_sym = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
ASR::Struct_t* der = ASR::down_cast<ASR::Struct_t>(der_sym);
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, args.n + 1);
ASR::call_arg_t self_arg;
self_arg.loc = args[0].loc;
Expand All @@ -8073,7 +8151,20 @@ we will have to use something else.
for (size_t i=0; i<args.n; i++) {
new_args.push_back(al, args[i]);
}
st = get_struct_member(der, call_name, loc);
if ( der->m_symtab->get_symbol(call_name) ) {
st = get_struct_member(der_sym, call_name, loc);
} else if ( der->m_parent ) {
ASR::Struct_t* parent = ASR::down_cast<ASR::Struct_t>(der->m_parent);
if ( !parent->m_symtab->get_symbol(call_name) ) {
throw SemanticError("Method not found in the class "+ std::string(der->m_name) +
" or it's parents",loc);
} else {
st = get_struct_member(der->m_parent, call_name, loc);
}
} else {
throw SemanticError("Method not found in the class "+std::string(der->m_name)+
" or it's parents",loc);
}
tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc);
return;
} else {
Expand Down

0 comments on commit 009bbe6

Please sign in to comment.