Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type inference for enums and match expr #183

Merged
merged 12 commits into from
Aug 21, 2024
124 changes: 79 additions & 45 deletions src/main/grammars/MoveParser.bnf
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
tokenTypeClass="org.move.lang.core.MvTokenType"

extends(".*Expr")=Expr
extends("(Struct|Binding|Tuple|Wild|EnumVariant)Pat")=Pat
// extends(".*Pat")=Pat
extends("Pat(Struct|Binding|Tuple|Wild|Const)")=Pat
extends("(Lambda|Ref|Path|Tuple|Unit|Parens)Type")=Type

elementType(".+BinExpr")=BinaryExpr
elementType(".+BinOp")=BinaryOp

name(".+BinOp")="operator"
name(".*Expr")="expression"
name(".*Pat")="pattern"
name("Pat*")="pattern"

extraRoot(".*CodeFragmentElement")=true

Expand Down Expand Up @@ -172,6 +171,7 @@ AttrItem ::= IDENTIFIER (AttrItemList | AttrItemInitializer)?
{
implements = [
"org.move.lang.core.resolve.ref.MvReferenceElement"
// "org.move.lang.core.types.infer.MvInferenceContextOwner"
]
mixin = "org.move.lang.core.psi.ext.MvAttrItemMixin"
}
Expand Down Expand Up @@ -369,9 +369,9 @@ fake Struct ::= Attr* native? STRUCT_KW IDENTIFIER? TypeParameterList? Abilities
(';' | BlockFields)?
{
implements = [
"org.move.lang.core.psi.MvQualNamedElement"
"org.move.lang.core.psi.MvTypeParametersOwner"
"org.move.lang.core.psi.ext.MvItemElement"
// "org.move.lang.core.psi.MvQualNamedElement"
// "org.move.lang.core.psi.MvTypeParametersOwner"
"org.move.lang.core.psi.ext.MvStructOrEnumItemElement"
"org.move.lang.core.psi.ext.MvFieldsOwner"
]
mixin = "org.move.lang.core.psi.ext.MvStructMixin"
Expand Down Expand Up @@ -400,8 +400,8 @@ Enum ::= Attr* enum IDENTIFIER TypeParameterList? AbilitiesList? EnumBody
{
pin = "enum"
implements = [
"org.move.lang.core.psi.MvTypeParametersOwner"
"org.move.lang.core.psi.ext.MvItemElement"
// "org.move.lang.core.psi.MvTypeParametersOwner"
"org.move.lang.core.psi.ext.MvStructOrEnumItemElement"
]
mixin = "org.move.lang.core.psi.ext.MvEnumMixin"
stubClass = "org.move.lang.core.stubs.MvEnumStub"
Expand Down Expand Up @@ -494,7 +494,7 @@ private FunctionParameter_with_recover ::= !(')' | '{' | ';') FunctionParameter
}
private FunctionParameter_recover ::= !(')' | '{' | ';' | IDENTIFIER)

FunctionParameter ::= BindingPat TypeAnnotation {
FunctionParameter ::= PatBinding TypeAnnotation {
pin = 1
implements = [
"org.move.lang.core.psi.MvTypeAnnotationOwner"
Expand Down Expand Up @@ -672,16 +672,29 @@ TypeArgument ::= Type
///////////////////////////////////////////////////////////////////////////////////////////////////
// Patterns (destructuring)
///////////////////////////////////////////////////////////////////////////////////////////////////
Pat ::= TuplePat
| StructPat
| WildPat
| BindingPat
//Pat ::= TuplePat
// | StructPat
// | WildPat
// | BindingPat
// | ConstPat

Pat ::= PatWild
| PatTuple
| PatBinding
| PatStruct
| PatConst

MatchPat ::= Pat | PathPat
PathPat ::= PathImpl
EnumVariantPat ::= PathImpl

PatWild ::= '_'
WildPat ::= '_'
EnumVariantPat ::= PathImpl

ConstPat ::= PathImpl
PatConst ::= PathExpr

PatIdent ::= PatBinding


// XXX(matklad): it is impossible to distinguish between nullary enum variants
Expand All @@ -691,45 +704,56 @@ EnumVariantPat ::= PathImpl
// None => { } // match enum variant
// Name => { } // bind Name to x
// }
BindingPat ::= IDENTIFIER !ForbiddenBindingPatLast {
//PatIdent ::= PatBinding
PatBinding ::= IDENTIFIER !ForbiddenPatBindingLast {
implements = [
"org.move.lang.core.psi.MvMandatoryNameIdentifierOwner"
"org.move.lang.core.resolve.ref.MvMandatoryReferenceElement"
]
mixin = "org.move.lang.core.psi.ext.MvBindingPatMixin"
mixin = "org.move.lang.core.psi.ext.MvPatBindingMixin"
}
BindingPat ::= IDENTIFIER !ForbiddenPatBindingLast

//private ForbiddenBindingPatLast ::= '...' | '::' | '..=' | '..' | '<' | '(' | '{' | '!' {
private ForbiddenBindingPatLast ::= '::' | '<' | '(' | '{' | '!' {
private ForbiddenPatBindingLast ::= '::' | '<' | '(' | '{' | '!' {
consumeTokenMethod = "consumeTokenFast"
}

TuplePat ::= '(' ParenListElemPat_with_recover* ')'
PatTuple ::= '(' ParenListElemPat_with_recover* ')'
private ParenListElemPat_with_recover ::= !')' Pat (',' | &')') {
pin = 1
recoverWhile = ParenListElemPat_recover
}
private ParenListElemPat_recover ::= !(')' | Pat_first)

StructPat ::= PathImpl '{' FieldPat_with_recover* '}'
StructPat ::= PathImpl '{' PatField_with_recover* '}'
PatStruct ::= PathImpl '{' PatField_with_recover* '}'
{
implements = [ "org.move.lang.core.resolve2.ref.InferenceCachedPathElement" ]
}
//StructPatFieldsBlock ::= '{' FieldPat_with_recover* '}'

private FieldPat_with_recover ::= !'}' FieldPat (',' | &'}')
private PatField_with_recover ::= !'}' PatField (',' | &'}')
{
pin = 1
recoverWhile = FieldPat_recover
recoverWhile = PatField_recover

}
private FieldPat_recover ::= !('}' | IDENTIFIER)
private PatField_recover ::= !('}' | IDENTIFIER)

FieldPatFull ::= IDENTIFIER ':' Pat
PatFieldFull ::= IDENTIFIER ':' Pat
{
implements = [
// "org.move.lang.core.resolve.ref.MvStructPatFieldReferenceElement"
"org.move.lang.core.resolve.ref.MvMandatoryReferenceElement"
]
mixin = "org.move.lang.core.psi.ext.MvFieldPatFullMixin"
mixin = "org.move.lang.core.psi.ext.MvPatFieldFullMixin"
}
FieldPat ::= FieldPatFull | BindingPat
FieldPatFull ::= IDENTIFIER ':' Pat

PatField ::= PatFieldFull | PatBinding
FieldPat ::= PatFieldFull | PatBinding

//FieldPat ::= (BindingPat !':') | (IDENTIFIER FieldPatBinding)
//{
Expand Down Expand Up @@ -878,8 +902,11 @@ private AtomExpr ::=
| VectorLitExpr
| DotExpr
| IndexExpr
| (CallExpr | AssertBangExpr)
| RefExpr
| PathExpr
| CallExpr
| AssertMacroExpr
// | (CallExpr | AssertExpr)
// | RefExpr
| LambdaExpr
| LitExpr
| CodeBlockExpr
Expand Down Expand Up @@ -962,15 +989,15 @@ AbortExpr ::= abort Expr
BreakExpr ::= break
ContinueExpr ::= continue

VectorLitExpr ::= <<VECTOR_IDENTIFIER>> ('<' TypeArgument '>')? VectorLitItems
VectorLitExpr ::= <<vectorIdent>> ('<' TypeArgument '>')? VectorLitItems
VectorLitItems ::= '[' <<non_empty_comma_sep_items Expr>>? ']'
{
pin = 1
}

StructLitExpr ::= <<includeStmtModeFalse>> PathImpl StructLitFieldsBlock
{
implements = ["org.move.lang.core.psi.PathExpr"]
implements = ["org.move.lang.core.resolve2.ref.InferenceCachedPathElement"]
}
StructLitFieldsBlock ::= '{' StructLitField_with_recover* '}' { pin = 1 }

Expand Down Expand Up @@ -1008,7 +1035,7 @@ upper TupleLitExprUpper ::= ',' [ Expr (',' Expr)* ','? ] ')' {
elementType = TupleLitExpr
}

LambdaExpr ::= '|' <<non_empty_comma_sep_items BindingPat>> '|' Expr { pin = 1 }
LambdaExpr ::= '|' <<non_empty_comma_sep_items PatBinding>> '|' Expr { pin = 1 }

RangeExpr ::= Expr dotdot Expr { pin = 2 }

Expand All @@ -1030,12 +1057,12 @@ private AnyLitToken_first ::= HEX_INTEGER_LITERAL

AddressLit ::= '@' AddressRef { pin = 1 }

CallExpr ::= (PathImpl &'(') ValueArgumentList {
pin = 1
CallExpr ::= PathImpl !'!' ValueArgumentList {
// pin = 2
implements = [
"org.move.lang.core.psi.PathExpr"
"org.move.lang.core.psi.ext.MvCallable"
"org.move.lang.core.psi.MvAcquireTypesOwner"
"org.move.lang.core.resolve2.ref.InferenceCachedPathElement"
]
}
ValueArgumentList ::= '(' ValueArgumentList_items? ')' {
Expand All @@ -1056,7 +1083,7 @@ MatchExpr ::= Attr* <<remapContextualKwOnRollback (match MatchArgument MatchBody
MatchArgument ::= '(' Expr ')'

MatchBody ::= '{' MatchArm_with_recover* '}' { pin = 1 }
MatchArm ::= Attr* (Pat | EnumVariantPat) MatchArmGuard? '=>' Expr ','?
MatchArm ::= Attr* Pat MatchArmGuard? '=>' Expr ','?
{
pin = 2
implements = [ "org.move.lang.core.psi.ext.MvDocAndAttributeOwner" ]
Expand Down Expand Up @@ -1090,7 +1117,7 @@ ForExpr ::= for ForIterCondition AnyBlock {
pin = 1
implements = [ "org.move.lang.core.psi.MvLoopLike" ]
}
ForIterCondition ::= '(' BindingPat in Expr ')' {
ForIterCondition ::= '(' PatBinding in Expr ')' {
pin = 1
}

Expand All @@ -1106,7 +1133,7 @@ private DotExpr_inner ::= '.' !('.' | VectorStart) (MethodCall | StructDotField)
consumeTokenMethod = "consumeTokenFast"
}

private VectorStart ::= (<<VECTOR_IDENTIFIER>> ('[' | '<'))
private VectorStart ::= (<<vectorIdent>> ('[' | '<'))

StructDotField ::= IDENTIFIER !('(' | '::' | '!' | '{')
{
Expand All @@ -1131,10 +1158,10 @@ IndexExpr ::= Expr IndexArg
// Do not inline this rule, it breaks expression parsing
private IndexArg ::= '[' Expr ']'

RefExpr ::= PathImpl !'{' {
//RefExpr ::= PathImpl !'{' {
//RefExpr ::= Path !'{' {
implements = ["org.move.lang.core.psi.PathExpr"]
}
// implements = ["org.move.lang.core.psi.PathExpr"]
//}

//Path3Impl ::= (ModulePathIdent | FQModulePathIdent | LocalPathIdent) TypeArgumentList?
//{
Expand All @@ -1145,6 +1172,12 @@ RefExpr ::= PathImpl !'{' {
// mixin = "org.move.lang.core.psi.ext.MvPathMixin"
//}

PathExpr ::= PathImpl !('(' | '!')
{
implements = [ "org.move.lang.core.resolve2.ref.InferenceCachedPathElement" ]
}
RefExpr ::= PathImpl

fake Path ::= (Path '::')? (PathIdent | PathAddress) TypeArgumentList?
{
implements = [
Expand Down Expand Up @@ -1200,10 +1233,11 @@ NamedAddress ::= IDENTIFIER
}

/// Macros
AssertBangExpr ::= <<ASSERT_IDENTIFIER>> '!' ValueArgumentList {
pin = 2
AssertMacroExpr ::= <<assertIdent>> '!' ValueArgumentList {
// pin = 2
implements = [ "org.move.lang.core.psi.ext.MvCallable" ]
}
AssertBangExpr ::= <<assertIdent>> '!' ValueArgumentList
//AssertBangExpr ::= MacroIdent ValueArgumentList { pin = 1 }
//MacroIdent ::= IDENTIFIER '!'

Expand Down Expand Up @@ -1378,7 +1412,7 @@ private SpecStmt ::= UseStmt
| LetMslStmt
| SpecExprStmt

fake SchemaFieldStmt ::= local? BindingPat TypeAnnotation ';'
fake SchemaFieldStmt ::= local? PatBinding TypeAnnotation ';'
{
implements = [
"org.move.lang.core.psi.MvTypeAnnotationOwner"
Expand All @@ -1388,11 +1422,11 @@ fake SchemaFieldStmt ::= local? BindingPat TypeAnnotation ';'

private SchemaFieldStmtImpl ::= SchemaFieldStmt_local | SchemaFieldStmt_simple

SchemaFieldStmt_simple ::= BindingPat TypeAnnotation ';' {
SchemaFieldStmt_simple ::= PatBinding TypeAnnotation ';' {
pin = 2
elementType = SchemaFieldStmt
}
SchemaFieldStmt_local ::= local BindingPat TypeAnnotation ';' {
SchemaFieldStmt_local ::= local PatBinding TypeAnnotation ';' {
pin = 1
elementType = SchemaFieldStmt
}
Expand Down Expand Up @@ -1652,11 +1686,11 @@ QuantBinding ::= RangeQuantBinding | TypeQuantBinding
implements = [ "org.move.lang.core.psi.MslOnlyElement" ]
}

RangeQuantBinding ::= BindingPat in Expr {
RangeQuantBinding ::= PatBinding in Expr {
pin = 2
extends = QuantBinding
}
TypeQuantBinding ::= BindingPat ':' Type {
TypeQuantBinding ::= PatBinding ':' Type {
pin = 2
extends = QuantBinding
}
Expand Down
4 changes: 1 addition & 3 deletions src/main/kotlin/org/move/cli/MoveProject.kt
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ data class MoveProject(
val dirScope = GlobalSearchScopes.directoryScope(project, folder, true)
searchScope = searchScope.uniteWith(dirScope)
}
if (isUnitTestMode
&& searchScope == GlobalSearchScope.EMPTY_SCOPE
) {
if (isUnitTestMode && searchScope == GlobalSearchScope.EMPTY_SCOPE) {
// add current file to the search scope for the tests
val currentFile =
FileEditorManager.getInstance(project).selectedTextEditor?.virtualFile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ abstract class CommandConfigurationHandler {
transaction.typeParams[name] = value
}

val parameterBindings = getFunctionParameters(function).map { it.bindingPat }
val parameterBindings = getFunctionParameters(function).map { it.patBinding }
val inference = function.inference(false)
for ((binding, valueWithType) in parameterBindings.zip(callArgs.args)) {
val name = binding.name
val value = valueWithType.split(':')[1]
val ty = inference.getPatType(binding)
val ty = inference.getBindingType(binding)
transaction.valueParams[name] = FunctionCallParam(value, FunctionCallParam.tyTypeName(ty))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ class FunctionCallConfigurationEditor<T: FunctionCallConfigurationBase>(
functionParametersPanel.updateFromFunctionCall(FunctionCall.empty())

// refill completion variants
val completionVariants = commandHandler.getFunctionCompletionVariants(moveProject)
this.functionItemField.setVariants(completionVariants)
// val completionVariants = commandHandler.getFunctionCompletionVariants(moveProject)
// this.functionItemField.setVariants(completionVariants)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ class FunctionParametersPanel(
val outerPanel = this
return panel {
val typeParameters = function?.typeParameters.orEmpty()
val parameters = function
?.let { commandHandler.getFunctionParameters(function).map { it.bindingPat } }
val parameterBindings = function
?.let { commandHandler.getFunctionParameters(function).map { it.patBinding } }
.orEmpty()

if (typeParameters.isNotEmpty()) {
Expand Down Expand Up @@ -141,12 +141,12 @@ class FunctionParametersPanel(
}
}
}
if (parameters.isNotEmpty()) {
if (parameterBindings.isNotEmpty()) {
val msl = false
val inference = function!!.inference(msl)
for (parameter in parameters) {
val paramName = parameter.name
val paramTy = inference.getPatType(parameter)
for (parameterBinding in parameterBindings) {
val paramName = parameterBinding.name
val paramTy = inference.getBindingType(parameterBinding)
val paramTyName = FunctionCallParam.tyTypeName(paramTy)
row(paramName) {
comment(": $paramTyName")
Expand Down
Loading
Loading