Skip to content

Commit

Permalink
Dyno: Improve initialization of fields from non-matching types (#26219)
Browse files Browse the repository at this point in the history
Prior to this PR InitResolver would always use the RHS type of field
initialization statements to set the type of the field. For example, in
"this.foo = RHS;", the field 'foo' would always use the type of RHS when
resolving any relevant initialization calls (e.g. assignment). A
real-world example is initializing a nilable class from ``nil``.

This PR uses ``Resolver::getTypeForDecl`` to check that the types are
compatible, and then uses the computed type for 'foo'.

This PR also includes various other improvements needed to get tests
working:
- correctly identifying field initialization during call-init-deinit
(thanks to @brandon-neth)
- fix faulty "invalid class type construction" error when not dealing
with managed classes
- Enable comparisons between types and ``?``
- Fix for faulty ``isCoercible`` error (mistakenly attempted to resolve
as a method)
- Update ``isRecordLike`` in call-init-deinit to return false for
any-managed
- Improve syntactic detection of genericity for managed and nilable
classes (thanks to @DanilaFe)

Thusly, new tests are added to demonstrate correct initialization of
generic fields when initialized from a type that does not match.

[reviewed-by @DanilaFe]
  • Loading branch information
benharsh authored Nov 6, 2024
2 parents 4fd9b8d + cdafcb8 commit 513a56b
Show file tree
Hide file tree
Showing 9 changed files with 335 additions and 52 deletions.
5 changes: 5 additions & 0 deletions frontend/include/chpl/parsing/parsing-queries.h
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,11 @@ uast::Module::Kind idToModuleKind(Context* context, ID id);
*/
bool isSpecialMethodName(UniqueString name);

/*
Given a function call, determine if it is a call to a class manager.
*/
bool isCallToClassManager(const uast::FnCall* call);

} // end namespace parsing
} // end namespace chpl
#endif
12 changes: 12 additions & 0 deletions frontend/lib/parsing/parsing-queries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2065,5 +2065,17 @@ bool isSpecialMethodName(UniqueString name) {
}
}

bool isCallToClassManager(const uast::FnCall* call) {
if (auto ident = call->calledExpression()->toIdentifier()) {
auto name = ident->name();
return name == USTR("owned") || name == USTR("_owned") ||
name == USTR("shared") || name == USTR("_shared") ||
name == USTR("unmanaged") ||
name == USTR("borrowed");
}

return false;
}

} // end namespace parsing
} // end namespace chpl
30 changes: 25 additions & 5 deletions frontend/lib/resolution/InitResolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,14 +821,30 @@ bool InitResolver::handleAssignmentToField(const OpCall* node) {

// TODO: Anything to do if the opposite is true?
if (!isAlreadyInitialized) {
// Recompute field type in case it depends on a recently-instantiated
// field. For example, ``var curField : typeField;``.
auto rf = resolveFieldDecl(ctx_, currentRecvType_->getCompositeType(), fieldId, DefaultsPolicy::IGNORE_DEFAULTS);
QualifiedType initialFieldType;
for (int i = 0; i < rf.numFields(); i++) {
auto id = rf.fieldDeclId(i);
if (id == fieldId) {
initialFieldType = rf.fieldType(i);
}
}

auto rhsType = initResolver_.byPostorder.byAst(rhs).type();
auto adjusted = QualifiedType(QualifiedType::TYPE, initialFieldType.type());
// TODO: prevent 'getTypeForDecl' from issuing the error message, and
// instead do something field-specific.
auto computed = initResolver_.getTypeForDecl(node,
lhs, rhs, state->qt.kind(),
adjusted, rhsType);

auto param = state->qt.kind() == QualifiedType::PARAM ? rhsType.param() : nullptr;
auto qt = QualifiedType(state->qt.kind(), rhsType.type(), param);
state->qt = qt;
state->qt = computed;

state->initPointId = node->id();
state->isInitialized = true;
initPoints.insert(node);

// We could probably get away with running this less, but it's easier
// to just attempt updating the receiver type for each field even if the
Expand All @@ -838,9 +854,9 @@ bool InitResolver::handleAssignmentToField(const OpCall* node) {
auto lhsKind = state->qt.kind();
if (lhsKind != QualifiedType::TYPE && lhsKind != QualifiedType::PARAM) {
// Regardless of the field's intent, it is mutable in this expression.
lhsKind = QualifiedType::REF;
lhsKind = QualifiedType::VAR;
}
auto lhsType = QualifiedType(lhsKind, state->qt.type(), state->qt.param());
auto lhsType = QualifiedType(lhsKind, computed.type(), computed.param());
initResolver_.byPostorder.byAst(lhs).setType(lhsType);

} else {
Expand Down Expand Up @@ -956,5 +972,9 @@ void InitResolver::checkEarlyReturn(const Return* ret) {
}
}

bool InitResolver::isInitPoint(const uast::AstNode* node) {
return initPoints.find(node) != initPoints.end();
}

} // end namespace resolution
} // end namespace chpl
6 changes: 6 additions & 0 deletions frontend/lib/resolution/InitResolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class InitResolver {
// Stores field ID and ID of the uAST referencing the field.
std::vector<std::pair<ID, ID>> useOfSuperFields_;

//initialization points to guide handling `=` operators
std::set<const uast::AstNode*> initPoints;

InitResolver(Context* ctx, Resolver& visitor,
const uast::Function* fn,
const types::Type* recvType)
Expand Down Expand Up @@ -155,6 +158,9 @@ class InitResolver {
const TypedFnSignature* finalize(void);

void checkEarlyReturn(const uast::Return* ret);

// Returns true if the AST node is an initialization point
bool isInitPoint(const uast::AstNode* node);
};

} // end namespace resolution
Expand Down
11 changes: 1 addition & 10 deletions frontend/lib/resolution/Resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,15 +951,6 @@ static void varArgTypeQueryError(Context* context,
result.setType(errType);
}

static bool isCallToClassManager(const FnCall* call) {
auto ident = call->calledExpression()->toIdentifier();
if (!ident) return false;
auto name = ident->name();
return name == USTR("owned") || name == USTR("shared") ||
name == USTR("_owned") || name == USTR("_shared") ||
name == USTR("unmanaged") || name == USTR("borrowed");
}

static std::vector<const TypeQuery*>
collectTypeQueriesIn(const AstNode* ast, bool recurse=true) {
std::vector<const TypeQuery*> ret;
Expand Down Expand Up @@ -1048,7 +1039,7 @@ void Resolver::resolveTypeQueries(const AstNode* formalTypeExpr,
}
}
}
} else if (isCallToClassManager(call) &&
} else if (parsing::isCallToClassManager(call) &&
call->numActuals() == 1 &&
actualTypePtr->isClassType()) {
// Strip the owned/shared/etc. for both the formal and the type
Expand Down
8 changes: 6 additions & 2 deletions frontend/lib/resolution/call-init-deinit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ static bool isRecordLike(const Type* t) {
// no action needed for 'borrowed' or 'unmanaged'
// (these should just default initialized to 'nil',
// so nothing else needs to be resolved)
if (! (decorator.isBorrowed() || decorator.isUnmanaged())) {
if (! (decorator.isBorrowed() || decorator.isUnmanaged() ||
decorator.isUnknownManagement())) {
return true;
}
} else if (t->isRecordType() || t->isUnionType()) {
Expand Down Expand Up @@ -900,10 +901,13 @@ void CallInitDeinit::handleAssign(const OpCall* ast, RV& rv) {
// check for use of deinited variables
processMentions(ast, rv);

bool isInit = splitInited;
isInit |= resolver.initResolver && resolver.initResolver->isInitPoint(ast);

if (lhsType.isType() || lhsType.isParam()) {
// these are basically 'move' initialization
resolveMoveInit(ast, rhsAst, lhsType, rhsType, rv);
} else if (splitInited) {
} else if (isInit) {
processInit(frame, ast, lhsType, rhsType, rv);
} else {
// it is assignment, so resolve the '=' call
Expand Down
110 changes: 77 additions & 33 deletions frontend/lib/resolution/resolution-queries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,38 @@ Type::Genericity getTypeGenericity(Context* context, QualifiedType qt) {
return getTypeGenericityViaQualifiedTypeQuery(context, qt);
}

static bool callHasQuestionMark(const FnCall* call) {
for (auto actual : call->actuals()) {
if (auto ident = actual->toIdentifier()) {
if (ident->name() == "?") {
return true;
}
}
}

return false;
}

static const FnCall* unwrapClassCall(const FnCall* call) {
const Call* unwrapped = call;

if (parsing::isCallToClassManager(call)) {
if (call->numActuals() == 1) {
unwrapped = call->actual(0)->toCall();
}
}

if (unwrapped) {
if (auto opCall = unwrapped->toOpCall()) {
if (opCall->numActuals() == 1 && opCall->op() == "?") {
unwrapped = opCall->actual(0)->toFnCall();
}
}
}

return unwrapped ? unwrapped->toFnCall() : nullptr;
}

/**
Written primarily to support multi-decls, though the logic is the same
as for single declarations. Sets 'outIsGeneric' with the genericity of the
Expand Down Expand Up @@ -1396,13 +1428,10 @@ static bool isVariableDeclWithClearGenericity(Context* context,
outIsGeneric = isNameBuiltinGenericType(context, ident->name());
return true;
} else if (auto call = var->typeExpression()->toFnCall()) {
for (auto actual : call->actuals()) {
if (auto ident = actual->toIdentifier()) {
if (ident->name() == "?") {
outIsGeneric = true;
return true;
}
}
auto unwrapped = unwrapClassCall(call);
if (unwrapped && callHasQuestionMark(unwrapped)) {
outIsGeneric = true;
return true;
}
}

Expand Down Expand Up @@ -3225,21 +3254,27 @@ static const Type* getManagedClassType(Context* context,
UniqueString name = ci.name();

if (ci.hasQuestionArg()) {
if (ci.numActuals() != 0) {
context->error(astForErr, "invalid class type construction");
return ErroneousType::get(context);
} else if (name == USTR("owned")) {
return AnyOwnedType::get(context);

const Type* ret = nullptr;
if (name == USTR("owned")) {
ret = AnyOwnedType::get(context);
} else if (name == USTR("shared")) {
return AnySharedType::get(context);
ret = AnySharedType::get(context);
} else if (name == USTR("unmanaged")) {
return ClassType::get(context, AnyClassType::get(context), nullptr, ClassTypeDecorator(ClassTypeDecorator::UNMANAGED));
ret = ClassType::get(context, AnyClassType::get(context), nullptr, ClassTypeDecorator(ClassTypeDecorator::UNMANAGED));
} else if (name == USTR("borrowed")) {
return ClassType::get(context, AnyClassType::get(context), nullptr, ClassTypeDecorator(ClassTypeDecorator::BORROWED));
ret = ClassType::get(context, AnyClassType::get(context), nullptr, ClassTypeDecorator(ClassTypeDecorator::BORROWED));
} else {
// case not handled in here
return nullptr;
}

if (ret != nullptr && ci.numActuals() != 0) {
context->error(astForErr, "invalid class type construction");
return ErroneousType::get(context);
} else {
return ret;
}
}

ClassTypeDecorator::ClassTypeDecoratorEnum de;
Expand Down Expand Up @@ -3717,21 +3752,26 @@ static bool resolveFnCallSpecial(Context* context,
}
}

if ((ci.name() == USTR("==") || ci.name() == USTR("!=")) &&
ci.numActuals() == 2) {
auto lhs = ci.actual(0).type();
auto rhs = ci.actual(1).type();
if ((ci.name() == USTR("==") || ci.name() == USTR("!="))) {
if (ci.numActuals() == 2 || ci.hasQuestionArg()) {
auto lhs = ci.actual(0).type();

bool bothType = lhs.kind() == QualifiedType::TYPE &&
rhs.kind() == QualifiedType::TYPE;
bool bothParam = lhs.kind() == QualifiedType::PARAM &&
rhs.kind() == QualifiedType::PARAM;
if (bothType || bothParam) {
bool result = lhs == rhs;
result = ci.name() == USTR("==") ? result : !result;
exprTypeOut = QualifiedType(QualifiedType::PARAM, BoolType::get(context),
BoolParam::get(context, result));
return true;
// support comparisions with '?'
auto rhs = ci.hasQuestionArg() ?
QualifiedType(QualifiedType::TYPE, AnyType::get(context)) :
ci.actual(1).type();

bool bothType = lhs.kind() == QualifiedType::TYPE &&
rhs.kind() == QualifiedType::TYPE;
bool bothParam = lhs.kind() == QualifiedType::PARAM &&
rhs.kind() == QualifiedType::PARAM;
if (bothType || bothParam) {
bool result = lhs == rhs;
result = ci.name() == USTR("==") ? result : !result;
exprTypeOut = QualifiedType(QualifiedType::PARAM, BoolType::get(context),
BoolParam::get(context, result));
return true;
}
}
}

Expand All @@ -3748,10 +3788,14 @@ static bool resolveFnCallSpecial(Context* context,

if (ci.name() == USTR("isCoercible")) {
if (ci.numActuals() != 2) {
context->error(astForErr, "bad call to %s", ci.name().c_str());
exprTypeOut = QualifiedType(QualifiedType::UNKNOWN,
ErroneousType::get(context));
return true;
if (!ci.isMethodCall()) {
context->error(astForErr, "bad call to %s", ci.name().c_str());
exprTypeOut = QualifiedType(QualifiedType::UNKNOWN,
ErroneousType::get(context));
return true;
} else {
return false;
}
}
auto got = canPass(context, ci.actual(0).type(), ci.actual(1).type());
bool result = got.passes();
Expand Down
Loading

0 comments on commit 513a56b

Please sign in to comment.