diff --git a/packages/zenscript/src/module.ts b/packages/zenscript/src/module.ts index f9ccc6a1..5610e228 100644 --- a/packages/zenscript/src/module.ts +++ b/packages/zenscript/src/module.ts @@ -14,6 +14,7 @@ import { ZenScriptMemberProvider } from './reference/member-provider' import { ZenScriptNameProvider } from './reference/name-provider' import { ZenScriptScopeComputation } from './reference/scope-computation' import { ZenScriptScopeProvider } from './reference/scope-provider' +import { ZenScriptOverloadResolver } from './typing/overload-resolver' import { ZenScriptTypeComputer } from './typing/type-computer' import { ZenScriptTypeFeatures } from './typing/type-features' import { registerValidationChecks, ZenScriptValidator } from './validation/validator' @@ -38,6 +39,7 @@ export interface ZenScriptAddedServices { typing: { TypeComputer: ZenScriptTypeComputer TypeFeatures: ZenScriptTypeFeatures + OverloadResolver: ZenScriptOverloadResolver } workspace: { PackageManager: ZenScriptPackageManager @@ -90,6 +92,7 @@ export const ZenScriptModule: Module new ZenScriptTypeComputer(services), TypeFeatures: services => new ZenScriptTypeFeatures(services), + OverloadResolver: services => new ZenScriptOverloadResolver(services), }, lsp: { CompletionProvider: services => new ZenScriptCompletionProvider(services), diff --git a/packages/zenscript/src/reference/member-provider.ts b/packages/zenscript/src/reference/member-provider.ts index 5da1e2fa..80fc6cde 100644 --- a/packages/zenscript/src/reference/member-provider.ts +++ b/packages/zenscript/src/reference/member-provider.ts @@ -4,7 +4,7 @@ import type { ZenScriptServices } from '../module' import type { TypeComputer } from '../typing/type-computer' import type { ZenScriptSyntheticAstType } from './synthetic' import { EMPTY_STREAM, stream } from 'langium' -import { isClassDeclaration, isOperatorFunctionDeclaration, isVariableDeclaration } from '../generated/ast' +import { isClassDeclaration, isOperatorFunctionDeclaration, isScript, isVariableDeclaration } from '../generated/ast' import { ClassType, isAnyType, isClassType, isFunctionType, type Type, type ZenScriptType } from '../typing/type-description' import { isStatic, streamClassChain, streamDeclaredMembers } from '../utils/ast' import { defineRules } from '../utils/rule' @@ -80,7 +80,7 @@ export class ZenScriptMemberProvider implements MemberProvider { return EMPTY_STREAM } - if (isSyntheticAstNode(target)) { + if (isSyntheticAstNode(target) || isScript(target)) { return this.streamMembers(target) } @@ -96,6 +96,21 @@ export class ZenScriptMemberProvider implements MemberProvider { return this.streamMembers(type) }, + ParenthesizedExpression: (source) => { + const type = this.typeComputer.inferType(source) + return this.streamMembers(type) + }, + + PrefixExpression: (source) => { + const type = this.typeComputer.inferType(source) + return this.streamMembers(type) + }, + + InfixExpression: (source) => { + const type = this.typeComputer.inferType(source) + return this.streamMembers(type) + }, + IndexingExpression: (source) => { const type = this.typeComputer.inferType(source) return this.streamMembers(type) diff --git a/packages/zenscript/src/reference/name-provider.ts b/packages/zenscript/src/reference/name-provider.ts index 928cfc1e..dc88e7ae 100644 --- a/packages/zenscript/src/reference/name-provider.ts +++ b/packages/zenscript/src/reference/name-provider.ts @@ -46,13 +46,13 @@ export class ZenScriptNameProvider extends DefaultNameProvider { Script: source => source.$document ? getName(source.$document) : undefined, ImportDeclaration: source => source.alias || source.path.at(-1)?.$refText, FunctionDeclaration: source => source.name || 'lambda function', - ConstructorDeclaration: _ => 'zenConstructor', + ConstructorDeclaration: source => source.$container.name, OperatorFunctionDeclaration: source => source.op, }) private readonly nameNodeRules = defineRules({ ImportDeclaration: source => GrammarUtils.findNodeForProperty(source.$cstNode, 'alias'), - ConstructorDeclaration: source => GrammarUtils.findNodeForProperty(source.$cstNode, 'zenConstructor'), + ConstructorDeclaration: source => GrammarUtils.findNodeForKeyword(source.$cstNode, 'zenConstructor'), OperatorFunctionDeclaration: source => GrammarUtils.findNodeForProperty(source.$cstNode, 'op'), }) } diff --git a/packages/zenscript/src/reference/scope-provider.ts b/packages/zenscript/src/reference/scope-provider.ts index 6e0e13ce..577c6fc8 100644 --- a/packages/zenscript/src/reference/scope-provider.ts +++ b/packages/zenscript/src/reference/scope-provider.ts @@ -1,13 +1,14 @@ -import type { AstNode, AstNodeDescription, ReferenceInfo, Scope, ScopeOptions, Stream } from 'langium' +import type { AstNode, AstNodeDescription, ReferenceInfo, Scope, ScopeOptions } from 'langium' import type { ZenScriptAstType } from '../generated/ast' import type { ZenScriptServices } from '../module' +import type { OverloadResolver } from '../typing/overload-resolver' import type { ZenScriptDescriptionIndex } from '../workspace/description-index' import type { PackageManager } from '../workspace/package-manager' import type { DynamicProvider } from './dynamic-provider' import type { MemberProvider } from './member-provider' import { substringBeforeLast } from '@intellizen/shared' import { AstUtils, DefaultScopeProvider, EMPTY_SCOPE, stream, StreamScope } from 'langium' -import { ClassDeclaration, ImportDeclaration, isClassDeclaration, TypeParameter } from '../generated/ast' +import { ClassDeclaration, ImportDeclaration, isCallExpression, isClassDeclaration, isConstructorDeclaration, isScript, TypeParameter } from '../generated/ast' import { getPathAsString } from '../utils/ast' import { defineRules } from '../utils/rule' import { generateStream } from '../utils/stream' @@ -20,6 +21,7 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { private readonly memberProvider: MemberProvider private readonly dynamicProvider: DynamicProvider private readonly descriptionIndex: ZenScriptDescriptionIndex + private readonly overloadResolver: OverloadResolver constructor(services: ZenScriptServices) { super(services) @@ -27,6 +29,7 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { this.memberProvider = services.references.MemberProvider this.dynamicProvider = services.references.DynamicProvider this.descriptionIndex = services.workspace.DescriptionIndex + this.overloadResolver = services.typing.OverloadResolver } override getScope(context: ReferenceInfo): Scope { @@ -68,8 +71,34 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { return this.createScopeForNodes(classes, outside) } - override createScopeForNodes(nodes: Stream, outerScope?: Scope, options?: ScopeOptions): Scope { - return new StreamScope(nodes.map(it => this.descriptionIndex.getDescription(it)), outerScope, options) + private importedScope(source: ReferenceInfo, outside?: Scope) { + const script = AstUtils.findRootNode(source.container) + if (!isScript(script)) { + return EMPTY_SCOPE + } + + const refText = source.reference.$refText + const imports = stream(script.imports) + .flatMap(it => this.descriptionIndex.createImportedDescriptions(it)) + + if (refText === '' || !isCallExpression(source.container.$container) || source.container.$containerProperty !== 'receiver') { + return this.createScope(imports, outside) + } + + // TODO: Workaround for function overloading, may rework after langium supports multi-target references + const maybeCandidates = imports + .filter(it => it.name === refText) + .map(it => it.node) + .nonNullable() + .toArray() + + const overloads = this.overloadResolver.resolveOverloads(source.container.$container, maybeCandidates) + const descriptions = overloads.map(it => this.descriptionIndex.createDynamicDescription(it, refText)) + return this.createScope(descriptions, outside) + } + + override createScopeForNodes(nodes: Iterable, outerScope?: Scope, options?: ScopeOptions): Scope { + return new StreamScope(stream(nodes).map(it => this.descriptionIndex.getDescription(it)), outerScope, options) } private readonly rules = defineRules({ @@ -97,6 +126,7 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { let outer: Scope outer = this.packageScope() outer = this.globalScope(outer) + outer = this.importedScope(source, outer) outer = this.dynamicScope(source.container, outer) const processor = (desc: AstNodeDescription) => { @@ -104,7 +134,19 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { case TypeParameter: return case ImportDeclaration: { - return this.descriptionIndex.createImportedDescription(desc.node as ImportDeclaration) + return + } + case ClassDeclaration: { + const classDecl = desc.node as ClassDeclaration + const callExpr = source.container.$container + if (isCallExpression(callExpr) && source.container.$containerProperty === 'receiver') { + const constructors = classDecl.members.filter(isConstructorDeclaration) + const overloads = this.overloadResolver.resolveOverloads(callExpr, constructors) + if (overloads[0]) { + return this.descriptionIndex.getDescription(overloads[0]) + } + } + return desc } default: return desc @@ -116,19 +158,28 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { MemberAccess: (source) => { const outer = this.dynamicScope(source.container) const members = this.memberProvider.streamMembers(source.container.receiver) - return this.createScopeForNodes(members, outer) + + if (source.reference.$refText && isCallExpression(source.container.$container) && source.container.$containerProperty === 'receiver') { + const maybeCandidates = members.filter(it => this.nameProvider.getName(it) === source.reference.$refText).toArray() + const overloads = this.overloadResolver.resolveOverloads(source.container.$container, maybeCandidates) + return this.createScopeForNodes(overloads, outer) + } + else { + return this.createScopeForNodes(members, outer) + } }, NamedTypeReference: (source) => { if (!source.index) { - const outer = this.classScope() + let outer = this.packageScope() + outer = this.classScope(outer) const processor = (desc: AstNodeDescription) => { switch (desc.type) { case TypeParameter: case ClassDeclaration: return desc case ImportDeclaration: { - return this.descriptionIndex.createImportedDescription(desc.node as ImportDeclaration) + return this.descriptionIndex.createImportedDescriptions(desc.node as ImportDeclaration)[0] } } } diff --git a/packages/zenscript/src/typing/overload-resolver.ts b/packages/zenscript/src/typing/overload-resolver.ts new file mode 100644 index 00000000..cc6c195c --- /dev/null +++ b/packages/zenscript/src/typing/overload-resolver.ts @@ -0,0 +1,183 @@ +import type { CallableDeclaration, CallExpression, Expression, ValueParameter } from '../generated/ast' +import type { ZenScriptServices } from '../module' +import type { TypeComputer } from './type-computer' +import type { TypeFeatures } from './type-features' +import { type AstNode, MultiMap } from 'langium' +import { isClassDeclaration, isConstructorDeclaration, isFunctionDeclaration } from '../generated/ast' + +export interface OverloadResolver { + resolveOverloads: (callExpr: CallExpression, maybeCandidates: AstNode[]) => AstNode[] +} + +export enum OverloadMatch { + ExactMatch, + VarargMatch, + OptionalArgMatch, + SubtypeMatch, + ImplicitCastMatch, + NotMatch, +} + +function worstMatch(matchSet: Set): OverloadMatch { + return Array.from(matchSet).sort((a, b) => a - b).at(-1) ?? OverloadMatch.NotMatch +} + +export class ZenScriptOverloadResolver implements OverloadResolver { + private readonly typeComputer: TypeComputer + private readonly typeFeatures: TypeFeatures + + constructor(services: ZenScriptServices) { + this.typeComputer = services.typing.TypeComputer + this.typeFeatures = services.typing.TypeFeatures + } + + public resolveOverloads(callExpr: CallExpression, maybeCandidates: AstNode[]): AstNode[] { + if (!maybeCandidates.length) { + return [] + } + + let candidates: CallableDeclaration[] + if (maybeCandidates.find(isClassDeclaration)) { + candidates = maybeCandidates.find(isClassDeclaration)!.members.filter(isConstructorDeclaration) + } + else if (maybeCandidates.find(isFunctionDeclaration)) { + candidates = maybeCandidates.filter(isFunctionDeclaration) + } + else if (maybeCandidates.find(isConstructorDeclaration)) { + candidates = maybeCandidates.filter(isConstructorDeclaration) + } + else { + console.error(`Invalid overload candidates for call expression: ${callExpr.$cstNode?.text}`) + return [] + } + + if (candidates.length === 1) { + return candidates + } + + const groupedCandidates = candidates.reduce>((map, it) => map.add(it.$container, it), new MultiMap()) + for (const container of groupedCandidates.keys()) { + const overloads = this.analyzeOverloads(new Set(groupedCandidates.get(container)), callExpr.arguments) + if (overloads.length) { + return overloads + } + else { + // FIXME: overloading error + // For debugging, consider adding a breakpoint here + console.error(`Could not resolve overloads for call expression: ${callExpr.$cstNode?.text}`) + } + } + + return candidates + } + + private analyzeOverloads(candidates: Set, args: Expression[]): CallableDeclaration[] { + const possibles = candidates.values() + .map(it => ({ candidate: it, match: this.matchSignature(it, args) })) + .filter(it => it.match !== OverloadMatch.NotMatch) + .toArray() + .sort((a, b) => a.match - b.match) + const groupedPossibles = Object.groupBy(possibles, it => it.match) + const bestMatches = Object.values(groupedPossibles).at(0) ?? [] + + if (bestMatches.length > 1) { + this.logAmbiguousOverloads(possibles, args) + } + + return bestMatches.map(it => it.candidate) + } + + private logAmbiguousOverloads(possibles: { candidate: CallableDeclaration, match: OverloadMatch }[], args: Expression[]) { + const first = possibles[0].candidate + const name = isConstructorDeclaration(first) ? first.$container.name : first.name + const types = args.map(it => this.typeComputer.inferType(it)?.toString()).join(', ') + console.warn(`ambiguous overload for ${name}(${types})`) + for (const { candidate, match } of possibles) { + const params = candidate.parameters + .map((it) => { + const str = this.typeComputer.inferType(it)?.toString() ?? 'undefined' + if (it.varargs) { + return `...${str}` + } + else if (it.defaultValue) { + return `${str}?` + } + else { + return str + } + }).join(', ') + console.warn(`----- ${OverloadMatch[match]} ${name}(${params})`) + } + } + + private createParamToArgsMap(params: ValueParameter[], args: Expression[]): MultiMap { + const map = new MultiMap() + for (let a = 0, p = 0, arg = args[a], param = params[p]; a < args.length && p < params.length;) { + if (arg) { + map.add(param, arg) + arg = args[++a] + } + if (!param.varargs) { + param = params[++p] + } + } + return map + } + + private matchSignature(callable: CallableDeclaration, args: Expression[]): OverloadMatch { + const params = [...callable.parameters] + const map = this.createParamToArgsMap(params, args) + + const matchSet = new Set([OverloadMatch.ExactMatch]) + if (args.length > map.size) { + matchSet.add(OverloadMatch.NotMatch) + } + else { + for (const param of params) { + const arg = map.get(param).at(0) + // special checking + if (param.varargs) { + matchSet.add(OverloadMatch.VarargMatch) + if (!arg) { + continue + } + } + else if (param.defaultValue) { + matchSet.add(OverloadMatch.OptionalArgMatch) + if (!arg) { + continue + } + } + else { + if (!arg) { + matchSet.add(OverloadMatch.NotMatch) + break + } + } + + // type checking + const paramType = this.typeComputer.inferType(param) + const argType = this.typeComputer.inferType(arg) + if (!paramType || !argType) { + matchSet.add(OverloadMatch.ImplicitCastMatch) + continue + } + + if (this.typeFeatures.areTypesEqual(paramType, argType)) { + matchSet.add(OverloadMatch.ExactMatch) + } + else if (this.typeFeatures.isSubType(argType, paramType)) { + matchSet.add(OverloadMatch.SubtypeMatch) + } + else if (this.typeFeatures.isConvertible(argType, paramType)) { + matchSet.add(OverloadMatch.ImplicitCastMatch) + } + else { + matchSet.add(OverloadMatch.NotMatch) + break + } + } + } + return worstMatch(matchSet) + } +} diff --git a/packages/zenscript/src/typing/type-computer.ts b/packages/zenscript/src/typing/type-computer.ts index 8d5325c6..d2673517 100644 --- a/packages/zenscript/src/typing/type-computer.ts +++ b/packages/zenscript/src/typing/type-computer.ts @@ -6,7 +6,7 @@ import type { BracketManager } from '../workspace/bracket-manager' import type { PackageManager } from '../workspace/package-manager' import type { BuiltinTypes, Type, TypeParameterSubstitutions } from './type-description' import { type AstNode, stream } from 'langium' -import { isAssignment, isCallExpression, isClassDeclaration, isExpression, isFunctionDeclaration, isFunctionExpression, isIndexingExpression, isOperatorFunctionDeclaration, isTypeParameter, isVariableDeclaration } from '../generated/ast' +import { isAssignment, isCallExpression, isClassDeclaration, isConstructorDeclaration, isExpression, isFunctionDeclaration, isFunctionExpression, isIndexingExpression, isMemberAccess, isOperatorFunctionDeclaration, isReferenceExpression, isTypeParameter, isVariableDeclaration } from '../generated/ast' import { defineRules } from '../utils/rule' import { ClassType, CompoundType, FunctionType, IntersectionType, isAnyType, isClassType, isFunctionType, TypeVariable, UnionType } from './type-description' @@ -182,10 +182,13 @@ export class ZenScriptTypeComputer implements TypeComputer { else if (isCallExpression(funcExpr.$container)) { const callArgIndex = funcExpr.$containerIndex! const receiverType = this.inferType(funcExpr.$container.receiver) - expectingType = isFunctionType(receiverType) ? receiverType.paramTypes.at(callArgIndex) : undefined + expectingType = isFunctionType(receiverType) ? receiverType.paramTypes.at(callArgIndex) : receiverType } - if (isFunctionType(expectingType)) { + if (isAnyType(expectingType)) { + return expectingType + } + else if (isFunctionType(expectingType)) { return expectingType.paramTypes.at(index) } else if (isClassType(expectingType)) { @@ -357,6 +360,12 @@ export class ZenScriptTypeComputer implements TypeComputer { MemberAccess: (source) => { const receiverType = this.inferType(source.receiver) + // Recursive Guard + const _ref = (source.target as any)._ref + if (typeof _ref === 'symbol' && _ref.description === 'ref_resolving') { + return this.classTypeOf('any') + } + const targetContainer = source.target.ref?.$container if (isOperatorFunctionDeclaration(targetContainer) && targetContainer.op === '.') { let dynamicTargetType = this.inferType(targetContainer.returnTypeRef) @@ -391,6 +400,15 @@ export class ZenScriptTypeComputer implements TypeComputer { }, CallExpression: (source) => { + if (isReferenceExpression(source.receiver) || isMemberAccess(source.receiver)) { + const receiverRef = source.receiver.target.ref + if (!receiverRef) { + return + } + if (isConstructorDeclaration(receiverRef)) { + return new ClassType(receiverRef.$container, new Map()) + } + } const receiverType = this.inferType(source.receiver) if (isFunctionType(receiverType)) { return receiverType.returnType diff --git a/packages/zenscript/src/typing/type-features.ts b/packages/zenscript/src/typing/type-features.ts index 1d94a21c..4955faba 100644 --- a/packages/zenscript/src/typing/type-features.ts +++ b/packages/zenscript/src/typing/type-features.ts @@ -2,7 +2,7 @@ import type { ZenScriptServices } from '../module' import type { MemberProvider } from '../reference/member-provider' import type { TypeComputer } from './type-computer' import type { Type, ZenScriptType } from './type-description' -import { isOperatorFunctionDeclaration } from '../generated/ast' +import { isFunctionDeclaration, isOperatorFunctionDeclaration } from '../generated/ast' import { streamClassChain } from '../utils/ast' import { defineRules } from '../utils/rule' import { isAnyType, isClassType, isCompoundType, isFunctionType, isIntersectionType, isTypeVariable, isUnionType } from './type-description' @@ -38,7 +38,11 @@ export class ZenScriptTypeFeatures implements TypeFeatures { this.memberProvider = services.references.MemberProvider } - isAssignable(target: Type, source: Type): boolean { + isAssignable(target: Type | undefined, source: Type | undefined): boolean { + if (target === undefined || source === undefined) { + return false + } + // 1. are both types equal? if (this.areTypesEqual(source, target)) { return true @@ -57,7 +61,11 @@ export class ZenScriptTypeFeatures implements TypeFeatures { return false } - areTypesEqual(first: Type, second: Type): boolean { + areTypesEqual(first: Type | undefined, second: Type | undefined): boolean { + if (first === undefined || second === undefined) { + return false + } + if (first === second) { return true } @@ -77,7 +85,21 @@ export class ZenScriptTypeFeatures implements TypeFeatures { private readonly typeEqualityRules = defineRules({ ClassType: (self, other) => { - return isClassType(other) && self.declaration === other.declaration + if (!isClassType(other)) { + return false + } + + if (self.declaration !== other.declaration) { + return false + } + + if (self.declaration.typeParameters.length !== other.declaration.typeParameters.length) { + return false + } + + const selfSubstitutions = self.declaration.typeParameters.map(it => self.substitutions.get(it)).filter(it => !!it) + const otherSubstitutions = other.declaration.typeParameters.map(it => other.substitutions.get(it)).filter(it => !!it) + return selfSubstitutions.every((type, index) => this.areTypesEqual(type, otherSubstitutions[index])) }, FunctionType: (self, other) => { @@ -109,7 +131,10 @@ export class ZenScriptTypeFeatures implements TypeFeatures { }, }) - isConvertible(from: Type, to: Type): boolean { + isConvertible(from: Type | undefined, to: Type | undefined): boolean { + if (from === undefined || to === undefined) { + return false + } return this.typeConversionRules(from.$type)?.call(this, from, to) ?? false } @@ -124,7 +149,33 @@ export class ZenScriptTypeFeatures implements TypeFeatures { .filter(it => it.op === 'as') .map(it => this.typeComputer.inferType(it.returnTypeRef)) .nonNullable() - .some(it => this.isAssignable(to, it)) + .some(it => this.isSubType(to, it)) + }, + + FunctionType: (from, to) => { + if (isAnyType(to)) { + return true + } + + let toFuncType: Type | undefined + if (isFunctionType(to)) { + toFuncType = to + } + else if (isClassType(to)) { + const lambdaDecl = this.memberProvider.streamMembers(to) + .filter(isFunctionDeclaration) + .filter(it => it.prefix === 'lambda') + .head() + toFuncType = this.typeComputer.inferType(lambdaDecl) + } + + if (!isFunctionType(toFuncType)) { + return false + } + + return from.paramTypes.length === toFuncType.paramTypes.length + && this.isConvertible(from.returnType, toFuncType.returnType) + && from.paramTypes.every((param, index) => this.isConvertible(param, toFuncType.paramTypes[index])) }, CompoundType: (from, to) => { @@ -132,7 +183,11 @@ export class ZenScriptTypeFeatures implements TypeFeatures { }, }) - isSubType(subType: Type, superType: Type): boolean { + isSubType(subType: Type | undefined, superType: Type | undefined): boolean { + if (subType === undefined || superType === undefined) { + return false + } + // ask the subtype if (this.subTypeRules(subType.$type)?.call(this, subType, superType)) { return true @@ -160,5 +215,9 @@ export class ZenScriptTypeFeatures implements TypeFeatures { IntersectionType: (superType, subType) => { return superType.types.some(it => this.isSubType(subType, it)) }, + + CompoundType: (superType, subType) => { + return superType.types.some(it => this.isSubType(subType, it)) + }, }) } diff --git a/packages/zenscript/src/utils/ast.ts b/packages/zenscript/src/utils/ast.ts index 2bb2ddf1..fa792628 100644 --- a/packages/zenscript/src/utils/ast.ts +++ b/packages/zenscript/src/utils/ast.ts @@ -65,9 +65,9 @@ export function toAstNode(item: AstNode | AstNodeDescription): AstNode | undefin } export function streamClassChain(classDecl: ClassDeclaration): Stream { - const visited = new Set() - return stream(function* () { + const generator = function* () { const deque = [classDecl] + const visited = new Set() while (deque.length) { const head = deque.shift()! if (!visited.has(head)) { @@ -79,7 +79,13 @@ export function streamClassChain(classDecl: ClassDeclaration): Stream deque.push(it)) } } - }()) + } + + return stream({ + [Symbol.iterator]() { + return generator()[Symbol.iterator]() + }, + }) } export function streamDeclaredMembers(classDecl: ClassDeclaration): Stream { diff --git a/packages/zenscript/src/workspace/description-index.ts b/packages/zenscript/src/workspace/description-index.ts index b1914c7e..5cf88445 100644 --- a/packages/zenscript/src/workspace/description-index.ts +++ b/packages/zenscript/src/workspace/description-index.ts @@ -1,14 +1,14 @@ -import type { AstNode, AstNodeDescription, NameProvider } from 'langium' -import type { ClassDeclaration, ImportDeclaration } from '../generated/ast' import type { ZenScriptServices } from '../module' import type { DescriptionCreator } from './description-creator' -import { getDocumentUri } from '../utils/ast' +import { type AstNode, type AstNodeDescription, type NameProvider, stream } from 'langium' +import { type ClassDeclaration, type ImportDeclaration, isClassDeclaration, isFunctionDeclaration } from '../generated/ast' +import { getDocumentUri, isStatic } from '../utils/ast' export interface DescriptionIndex { getDescription: (astNode: AstNode) => AstNodeDescription getThisDescription: (classDecl: ClassDeclaration) => AstNodeDescription createDynamicDescription: (astNode: AstNode, name: string) => AstNodeDescription - createImportedDescription: (importDecl: ImportDeclaration) => AstNodeDescription + createImportedDescriptions: (importDecl: ImportDeclaration) => AstNodeDescription[] } export class ZenScriptDescriptionIndex implements DescriptionIndex { @@ -44,28 +44,47 @@ export class ZenScriptDescriptionIndex implements DescriptionIndex { } createDynamicDescription(astNode: AstNode, name: string): AstNodeDescription { - const originalUri = this.astDescriptions.get(astNode)?.documentUri + const existing = this.astDescriptions.get(astNode) + if (existing?.name === name) { + return existing + } + const originalUri = existing?.documentUri return this.creator.createDescriptionWithUri(astNode, originalUri, name) } - createImportedDescription(importDecl: ImportDeclaration): AstNodeDescription { + createImportedDescriptions(importDecl: ImportDeclaration): AstNodeDescription[] { const targetRef = importDecl.path.at(-1) if (!targetRef) { - return this.getDescription(importDecl) + return [this.getDescription(importDecl)] } const target = targetRef.ref if (!target) { - return this.getDescription(importDecl) + return [this.getDescription(importDecl)] + } + + // TODO: Workaround for function overloading, may rework after langium supports multi-target references + if (isFunctionDeclaration(target)) { + const classDecl = importDecl.path.at(-2)?.ref + if (!isClassDeclaration(classDecl)) { + return [] + } + + return stream(classDecl.members) + .filter(isFunctionDeclaration) + .filter(isStatic) + .filter(it => it.name === target.name) + .map(it => this.createDynamicDescription(it, it.name)) + .toArray() } const targetDescription = targetRef.$nodeDescription if (!importDecl.alias && targetDescription) { - return targetDescription + return [targetDescription] } const targetUri = targetDescription?.documentUri const alias = this.nameProvider.getName(importDecl) - return this.creator.createDescriptionWithUri(target, targetUri, alias) + return [this.creator.createDescriptionWithUri(target, targetUri, alias)] } } diff --git a/packages/zenscript/src/zenscript.langium b/packages/zenscript/src/zenscript.langium index 2c092705..7d234316 100644 --- a/packages/zenscript/src/zenscript.langium +++ b/packages/zenscript/src/zenscript.langium @@ -28,7 +28,7 @@ interface ClassDeclaration extends Declaration { members: ClassMemberDeclaration[]; } -type NamedElement = Script | ClassDeclaration | FunctionDeclaration | ExpandFunctionDeclaration | FieldDeclaration | ValueParameter| VariableDeclaration | LoopParameter | MapEntry | ImportDeclaration; +type NamedElement = Script | ClassDeclaration | FunctionDeclaration | ExpandFunctionDeclaration | FieldDeclaration | ValueParameter| VariableDeclaration | LoopParameter | MapEntry | ImportDeclaration | ConstructorDeclaration; type ClassMemberDeclaration = FunctionDeclaration | FieldDeclaration | ConstructorDeclaration | OperatorFunctionDeclaration; diff --git a/packages/zenscript/test/typing/overloading/intellizen.json b/packages/zenscript/test/typing/overloading/intellizen.json new file mode 100644 index 00000000..1e293b03 --- /dev/null +++ b/packages/zenscript/test/typing/overloading/intellizen.json @@ -0,0 +1,5 @@ +{ + "srcRoots": [ + "./scripts" + ] +} diff --git a/packages/zenscript/test/typing/overloading/overloading.test.ts b/packages/zenscript/test/typing/overloading/overloading.test.ts new file mode 100644 index 00000000..8387a77e --- /dev/null +++ b/packages/zenscript/test/typing/overloading/overloading.test.ts @@ -0,0 +1,180 @@ +import type { AstNode } from 'langium' +import type { MemberAccess, ReferenceExpression, Statement } from '../../../src/generated/ast' +import path from 'node:path' +import { AstUtils } from 'langium' +import { assert, describe, expect, it, suite } from 'vitest' +import { isCallExpression, isConstructorDeclaration, isFunctionDeclaration } from '../../../src/generated/ast' +import { assertNoErrors, createTestServices, getDocument } from '../../utils' + +const services = await createTestServices(__dirname) + +function findOverloadForCall(call: Statement): AstNode { + const callExpr = AstUtils.streamAst(call).find(isCallExpression) + expect(callExpr).toBeDefined() + expect(callExpr?.receiver?.$type).toMatch(/ReferenceExpression|MemberAccess/) + const receiver = callExpr!.receiver as ReferenceExpression | MemberAccess + const target = receiver.target.ref + expect(target).toBeDefined() + return target! +} + +describe('check overload', async () => { + const document_overload_zs = await getDocument(services, path.resolve(__dirname, 'scripts', 'overload.zs')) + const script_overload_zs = document_overload_zs.parseResult.value + + it('syntax', () => { + assertNoErrors(document_overload_zs) + expect(script_overload_zs.statements[0].$type).toBe('VariableDeclaration') + }) + + suite('overload_basic', () => { + it('normal', () => { + const foo = findOverloadForCall(script_overload_zs.statements[1]) + assert(isFunctionDeclaration(foo)) + expect(foo.name).toBe('foo') + expect(foo.parameters.length).toBe(0) + + const foo1 = findOverloadForCall(script_overload_zs.statements[2]) + assert(isFunctionDeclaration(foo1)) + expect(foo1.name).toBe('foo') + expect(foo1.parameters.length).toBe(1) + + const foo2 = findOverloadForCall(script_overload_zs.statements[3]) + assert(isFunctionDeclaration(foo2)) + expect(foo2.name).toBe('foo') + expect(foo2.parameters.length).toBe(2) + expect(foo2.parameters[1].typeRef?.$cstNode?.text).toBe('int') + + const foo3 = findOverloadForCall(script_overload_zs.statements[4]) + assert(isFunctionDeclaration(foo3)) + expect(foo3.name).toBe('foo') + expect(foo3.parameters.length).toBe(2) + expect(foo3.parameters[1].typeRef?.$cstNode?.text).toBe('double') + }) + + it('varargs', () => { + const varargs_low_priorty = findOverloadForCall(script_overload_zs.statements[5]) + assert(isFunctionDeclaration(varargs_low_priorty)) + expect(varargs_low_priorty.name).toBe('varargs') + expect(varargs_low_priorty.parameters.length).toBe(1) + + const varargs_full = findOverloadForCall(script_overload_zs.statements[6]) + assert(isFunctionDeclaration(varargs_full)) + expect(varargs_full.name).toBe('varargs') + expect(varargs_full.parameters.length).toBe(2) + expect(varargs_full.parameters[1].varargs).toBe(true) + + const varargs_more = findOverloadForCall(script_overload_zs.statements[7]) + assert(isFunctionDeclaration(varargs_more)) + expect(varargs_more.name).toBe('varargs') + expect(varargs_more.parameters.length).toBe(2) + expect(varargs_more.parameters[1].varargs).toBe(true) + + const varargs_less = findOverloadForCall(script_overload_zs.statements[8]) + assert(isFunctionDeclaration(varargs_less)) + expect(varargs_less.name).toBe('varargs_miss') + expect(varargs_less.parameters.length).toBe(1) + expect(varargs_less.parameters[0].varargs).toBe(true) + }) + + it('optional', () => { + const optional_eq = findOverloadForCall(script_overload_zs.statements[9]) + assert(isFunctionDeclaration(optional_eq)) + expect(optional_eq.name).toBe('optional') + expect(optional_eq.parameters.length).toBe(2) + expect(optional_eq.parameters[1].defaultValue).toBeDefined() + + const optional_low_priorty = findOverloadForCall(script_overload_zs.statements[10]) + assert(isFunctionDeclaration(optional_low_priorty)) + expect(optional_low_priorty.name).toBe('optional') + expect(optional_low_priorty.parameters.length).toBe(1) + + const optional_less = findOverloadForCall(script_overload_zs.statements[11]) + assert(isFunctionDeclaration(optional_less)) + expect(optional_less.name).toBe('optional_miss') + expect(optional_less.parameters.length).toBe(1) + expect(optional_less.parameters[0].defaultValue).toBeDefined() + + const optional_convert_1 = findOverloadForCall(script_overload_zs.statements[12]) + assert(isFunctionDeclaration(optional_convert_1)) + expect(optional_convert_1.name).toBe('optional_convert') + expect(optional_convert_1.parameters.length).toBe(2) + expect(optional_convert_1.parameters[1].defaultValue).toBeDefined() + + const optional_convert_2 = findOverloadForCall(script_overload_zs.statements[13]) + assert(isFunctionDeclaration(optional_convert_2)) + expect(optional_convert_2.name).toBe('optional_convert') + expect(optional_convert_2.parameters.length).toBe(2) + expect(optional_convert_2.parameters[1].defaultValue).toBeDefined() + + const optional_vs_varargs = findOverloadForCall(script_overload_zs.statements[14]) + assert(isFunctionDeclaration(optional_vs_varargs)) + expect(optional_vs_varargs.name).toBe('varargs_vs_optional') + expect(optional_vs_varargs.parameters.length).toBe(1) + expect(optional_vs_varargs.parameters[0].varargs).toBeTruthy() + expect(optional_vs_varargs.parameters[0].defaultValue).toBeFalsy() + }) + }) +}) + +describe('check static overload', async () => { + const document_overload_static_zs = await getDocument(services, path.resolve(__dirname, 'scripts', 'overload_static.zs')) + const script_overload_zs = document_overload_static_zs.parseResult.value + + it('syntax', () => { + assertNoErrors(document_overload_static_zs) + }) + + it('import overload', () => { + const foo1 = findOverloadForCall(script_overload_zs.statements[0]) + assert(isFunctionDeclaration(foo1)) + expect(foo1.name).toBe('foo') + expect(foo1.parameters.length).toBe(0) + + const foo2 = findOverloadForCall(script_overload_zs.statements[1]) + assert(isFunctionDeclaration(foo2)) + expect(foo2.name).toBe('foo') + expect(foo2.parameters.length).toBe(1) + }) + + it('member access', () => { + const foo1 = findOverloadForCall(script_overload_zs.statements[2]) + assert(isFunctionDeclaration(foo1)) + expect(foo1.name).toBe('foo') + expect(foo1.parameters.length).toBe(0) + + const foo2 = findOverloadForCall(script_overload_zs.statements[3]) + assert(isFunctionDeclaration(foo2)) + expect(foo2.name).toBe('foo') + expect(foo2.parameters.length).toBe(1) + }) +}) + +describe('check ctor overload', async () => { + const document_overload_ctor_zs = await getDocument(services, path.resolve(__dirname, 'scripts', 'overload_ctor.zs')) + const script_overload_zs = document_overload_ctor_zs.parseResult.value + + it('syntax', () => { + assertNoErrors(document_overload_ctor_zs) + }) + + it('ctor overload import', () => { + const foo1 = findOverloadForCall(script_overload_zs.statements[0]) + assert(isConstructorDeclaration(foo1)) + expect(foo1.parameters.length).toBe(0) + + const foo2 = findOverloadForCall(script_overload_zs.statements[1]) + assert(isConstructorDeclaration(foo2)) + expect(foo2.parameters.length).toBe(1) + }) + + it('ctor overload member access', () => { + const foo1 = findOverloadForCall(script_overload_zs.statements[2]) + assert(isConstructorDeclaration(foo1)) + expect(foo1.parameters.length).toBe(0) + + const foo2 = findOverloadForCall(script_overload_zs.statements[3]) + assert(isConstructorDeclaration(foo2)) + expect(foo2.parameters.length).toBe(1) + }) +}) diff --git a/packages/zenscript/test/typing/overloading/scripts/overload.zs b/packages/zenscript/test/typing/overloading/scripts/overload.zs new file mode 100644 index 00000000..96389cec --- /dev/null +++ b/packages/zenscript/test/typing/overloading/scripts/overload.zs @@ -0,0 +1,25 @@ + +val obj as intellizen.test.Overload; + +obj.foo(); +obj.foo(1); + +obj.foo(1, 1); +obj.foo(1, 1.0); + +obj.varargs(1); +obj.varargs(1, 2); +obj.varargs(1, 2, 3); + +obj.varargs_miss(); + +obj.optional(1, 2); +obj.optional(1); + +obj.optional_miss(); + +obj.optional_convert(1); +obj.optional_convert(1, 2.0); + + +obj.varargs_vs_optional(1); diff --git a/packages/zenscript/test/typing/overloading/scripts/overload_ctor.zs b/packages/zenscript/test/typing/overloading/scripts/overload_ctor.zs new file mode 100644 index 00000000..916ef237 --- /dev/null +++ b/packages/zenscript/test/typing/overloading/scripts/overload_ctor.zs @@ -0,0 +1,7 @@ +import intellizen.test.Overload + +val obj1 = Overload(); +val obj2 = Overload(1); + +val obj3 = intellizen.test.Overload(); +val obj4 = intellizen.test.Overload(1); diff --git a/packages/zenscript/test/typing/overloading/scripts/overload_static.zs b/packages/zenscript/test/typing/overloading/scripts/overload_static.zs new file mode 100644 index 00000000..a17fc004 --- /dev/null +++ b/packages/zenscript/test/typing/overloading/scripts/overload_static.zs @@ -0,0 +1,7 @@ +import intellizen.test.StaticOverload.foo; + +foo(); +foo(1); + +intellizen.test.StaticOverload.foo(); +intellizen.test.StaticOverload.foo(1); diff --git a/packages/zenscript/test/typing/overloading/scripts/overload_type.dzs b/packages/zenscript/test/typing/overloading/scripts/overload_type.dzs new file mode 100644 index 00000000..a9f7db1a --- /dev/null +++ b/packages/zenscript/test/typing/overloading/scripts/overload_type.dzs @@ -0,0 +1,30 @@ +package intellizen.test; + +zenClass Overload { + function foo(); + function foo(intVal as int); + + function foo(intVal as int, doubleVal as double); + function foo(intVal as int, intVal as int); + + function varargs(); + function varargs(intVal as int, ...rest as int); + function varargs(intVal as int); + + function varargs_miss(...rest as int); + function varargs_miss(...rest as int); + + function optional(intVal as int, optionalInt as int = 1); + function optional(intVal as int); + function optional_miss(optionalInt as int = 1); + + function optional_convert(intVal as int, optionalInt as int = 1); + function optional_convert(doubleVal as double); + + function varargs_vs_optional(optionalInt as int = 1); + function varargs_vs_optional(...intVarargs as int); + + + zenConstructor (); + zenConstructor (intVal as int); +} diff --git a/packages/zenscript/test/typing/overloading/scripts/static_functions.dzs b/packages/zenscript/test/typing/overloading/scripts/static_functions.dzs new file mode 100644 index 00000000..a0b071d2 --- /dev/null +++ b/packages/zenscript/test/typing/overloading/scripts/static_functions.dzs @@ -0,0 +1,6 @@ +package intellizen.test; + +zenClass StaticOverload { + static function foo(); + static function foo(intVal as int); +} diff --git a/packages/zenscript/test/utils.ts b/packages/zenscript/test/utils.ts index f83d2dc2..15d8123f 100644 --- a/packages/zenscript/test/utils.ts +++ b/packages/zenscript/test/utils.ts @@ -29,6 +29,7 @@ export async function getDocument(services: ZenScriptServices, docPath: string) } export async function assertNoErrors(model: LangiumDocument