Skip to content

Commit

Permalink
Infer non-narrowing predicates when the contextual signature is a typ…
Browse files Browse the repository at this point in the history
…e predicate
  • Loading branch information
Andarist committed Dec 20, 2024
1 parent 56a0825 commit bb1dd39
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 26 deletions.
47 changes: 21 additions & 26 deletions src/compiler/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38778,7 +38778,7 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
}
}

function getTypePredicateFromBody(func: FunctionLikeDeclaration): TypePredicate | undefined {
function getTypePredicateFromBody(func: FunctionLikeDeclaration, contextualTypePredicate?: IdentifierTypePredicate): TypePredicate | undefined {
switch (func.kind) {
case SyntaxKind.Constructor:
case SyntaxKind.GetAccessor:
Expand All @@ -38800,41 +38800,35 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
});
if (bailedEarly || !singleReturn || functionHasImplicitReturn(func)) return undefined;
}
return checkIfExpressionRefinesAnyParameter(func, singleReturn);
}

function checkIfExpressionRefinesAnyParameter(func: FunctionLikeDeclaration, expr: Expression): TypePredicate | undefined {
expr = skipParentheses(expr, /*excludeJSDocTypeAssertions*/ true);
const expr = skipParentheses(singleReturn, /*excludeJSDocTypeAssertions*/ true);
const returnType = checkExpressionCached(expr);
if (!(returnType.flags & TypeFlags.Boolean)) return undefined;

return forEach(func.parameters, (param, i) => {
const initType = getTypeOfSymbol(param.symbol);
if (!initType || initType.flags & TypeFlags.Boolean || !isIdentifier(param.name) || isSymbolAssigned(param.symbol) || isRestParameter(param)) {
// Refining "x: boolean" to "x is true" or "x is false" isn't useful.
return;
}
const trueType = checkIfExpressionRefinesParameter(func, expr, param, initType);
if (trueType) {
return createTypePredicate(TypePredicateKind.Identifier, unescapeLeadingUnderscores(param.name.escapedText), i, trueType);
}
});
return contextualTypePredicate ?
getTypePredicateIfRefinesParameterAtIndex(func, expr, contextualTypePredicate, contextualTypePredicate.parameterIndex) :
forEach(func.parameters, (_, i) => getTypePredicateIfRefinesParameterAtIndex(func, expr, contextualTypePredicate, i));
}

function checkIfExpressionRefinesParameter(func: FunctionLikeDeclaration, expr: Expression, param: ParameterDeclaration, initType: Type): Type | undefined {
function getTypePredicateIfRefinesParameterAtIndex(func: FunctionLikeDeclaration, expr: Expression, contextualTypePredicate: IdentifierTypePredicate | undefined, parameterIndex: number): TypePredicate | undefined {
const param = func.parameters[parameterIndex];
const initType = getTypeOfSymbol(param.symbol);
if (!initType || initType.flags & TypeFlags.Boolean || !isIdentifier(param.name) || isSymbolAssigned(param.symbol) || isRestParameter(param)) {
// Refining "x: boolean" to "x is true" or "x is false" isn't useful.
return;
}
const antecedent = canHaveFlowNode(expr) && expr.flowNode ||
expr.parent.kind === SyntaxKind.ReturnStatement && (expr.parent as ReturnStatement).flowNode ||
createFlowNode(FlowFlags.Start, /*node*/ undefined, /*antecedent*/ undefined);
const trueCondition = createFlowNode(FlowFlags.TrueCondition, expr, antecedent);

const trueType = getFlowTypeOfReference(param.name, initType, initType, func, trueCondition);
if (trueType === initType) return undefined;

if (!contextualTypePredicate && trueType === initType) {
return undefined;
}
// "x is T" means that x is T if and only if it returns true. If it returns false then x is not T.
// This means that if the function is called with an argument of type trueType, there can't be anything left in the `else` branch. It must reduce to `never`.
const falseCondition = createFlowNode(FlowFlags.FalseCondition, expr, antecedent);
const falseSubtype = getFlowTypeOfReference(param.name, initType, trueType, func, falseCondition);
return falseSubtype.flags & TypeFlags.Never ? trueType : undefined;
return falseSubtype.flags & TypeFlags.Never ? createTypePredicate(TypePredicateKind.Identifier, unescapeLeadingUnderscores(param.name.escapedText), parameterIndex, trueType) : undefined;
}

/**
Expand Down Expand Up @@ -38978,10 +38972,11 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
inferFromAnnotatedParameters(signature, contextualSignature, inferenceContext!);
}
}
if (contextualSignature && !getReturnTypeFromAnnotation(node) && !signature.resolvedReturnType) {
const returnType = getReturnTypeFromBody(node, checkMode);
if (!signature.resolvedReturnType) {
signature.resolvedReturnType = returnType;
if (contextualSignature && !getReturnTypeFromAnnotation(node)) {
const returnType = signature.resolvedReturnType ?? getReturnTypeFromBody(node, checkMode);
signature.resolvedReturnType ??= returnType;
if (signature.resolvedReturnType.flags && TypeFlags.BooleanLike && contextualSignature.resolvedTypePredicate && contextualSignature.resolvedTypePredicate !== noTypePredicate && contextualSignature.resolvedTypePredicate.kind === TypePredicateKind.Identifier) {
signature.resolvedTypePredicate ??= getTypePredicateFromBody(node, contextualSignature.resolvedTypePredicate) ?? noTypePredicate;
}
}
checkSignatureDeclaration(node);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
inferContextualTypePredicates1.ts(13,26): error TS2345: Argument of type '(item: Foo | Bar) => false' is not assignable to parameter of type '(a: Foo | Bar) => a is Foo | Bar'.
Signature '(item: Foo | Bar): false' must be a type predicate.
inferContextualTypePredicates1.ts(14,26): error TS2345: Argument of type '(item: Foo | Bar) => true' is not assignable to parameter of type '(a: Foo | Bar) => a is Foo | Bar'.
Signature '(item: Foo | Bar): true' must be a type predicate.
inferContextualTypePredicates1.ts(17,7): error TS2322: Type '(a: string | null, b: string | null) => boolean' is not assignable to type '(a: string | null, b: string | null) => b is string'.
Signature '(a: string | null, b: string | null): boolean' must be a type predicate.


==== inferContextualTypePredicates1.ts (3 errors) ====
type Foo = { type: "foo"; foo: number };
type Bar = { type: "bar"; bar: string };

declare function skipIf<A, B extends A>(
as: A[],
predicate: (a: A) => a is B,
): Exclude<A, B>[];

declare const items: (Foo | Bar)[];

const r1 = skipIf(items, (item) => item.type === "foo"); // ok
const r2 = skipIf(items, (item) => item.type === "foo" || item.type === "bar"); // ok
const r3 = skipIf(items, (item) => false); // error
~~~~~~~~~~~~~~~
!!! error TS2345: Argument of type '(item: Foo | Bar) => false' is not assignable to parameter of type '(a: Foo | Bar) => a is Foo | Bar'.
!!! error TS2345: Signature '(item: Foo | Bar): false' must be a type predicate.
const r4 = skipIf(items, (item) => true); // error
~~~~~~~~~~~~~~
!!! error TS2345: Argument of type '(item: Foo | Bar) => true' is not assignable to parameter of type '(a: Foo | Bar) => a is Foo | Bar'.
!!! error TS2345: Signature '(item: Foo | Bar): true' must be a type predicate.

const pred1: (a: string | null, b: string | null) => b is string = (a, b) => typeof b === 'string'; // ok
const pred2: (a: string | null, b: string | null) => b is string = (a, b) => typeof a === 'string'; // error
~~~~~
!!! error TS2322: Type '(a: string | null, b: string | null) => boolean' is not assignable to type '(a: string | null, b: string | null) => b is string'.
!!! error TS2322: Signature '(a: string | null, b: string | null): boolean' must be a type predicate.

91 changes: 91 additions & 0 deletions tests/baselines/reference/inferContextualTypePredicates1.symbols
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//// [tests/cases/compiler/inferContextualTypePredicates1.ts] ////

=== inferContextualTypePredicates1.ts ===
type Foo = { type: "foo"; foo: number };
>Foo : Symbol(Foo, Decl(inferContextualTypePredicates1.ts, 0, 0))
>type : Symbol(type, Decl(inferContextualTypePredicates1.ts, 0, 12))
>foo : Symbol(foo, Decl(inferContextualTypePredicates1.ts, 0, 25))

type Bar = { type: "bar"; bar: string };
>Bar : Symbol(Bar, Decl(inferContextualTypePredicates1.ts, 0, 40))
>type : Symbol(type, Decl(inferContextualTypePredicates1.ts, 1, 12))
>bar : Symbol(bar, Decl(inferContextualTypePredicates1.ts, 1, 25))

declare function skipIf<A, B extends A>(
>skipIf : Symbol(skipIf, Decl(inferContextualTypePredicates1.ts, 1, 40))
>A : Symbol(A, Decl(inferContextualTypePredicates1.ts, 3, 24))
>B : Symbol(B, Decl(inferContextualTypePredicates1.ts, 3, 26))
>A : Symbol(A, Decl(inferContextualTypePredicates1.ts, 3, 24))

as: A[],
>as : Symbol(as, Decl(inferContextualTypePredicates1.ts, 3, 40))
>A : Symbol(A, Decl(inferContextualTypePredicates1.ts, 3, 24))

predicate: (a: A) => a is B,
>predicate : Symbol(predicate, Decl(inferContextualTypePredicates1.ts, 4, 10))
>a : Symbol(a, Decl(inferContextualTypePredicates1.ts, 5, 14))
>A : Symbol(A, Decl(inferContextualTypePredicates1.ts, 3, 24))
>a : Symbol(a, Decl(inferContextualTypePredicates1.ts, 5, 14))
>B : Symbol(B, Decl(inferContextualTypePredicates1.ts, 3, 26))

): Exclude<A, B>[];
>Exclude : Symbol(Exclude, Decl(lib.es5.d.ts, --, --))
>A : Symbol(A, Decl(inferContextualTypePredicates1.ts, 3, 24))
>B : Symbol(B, Decl(inferContextualTypePredicates1.ts, 3, 26))

declare const items: (Foo | Bar)[];
>items : Symbol(items, Decl(inferContextualTypePredicates1.ts, 8, 13))
>Foo : Symbol(Foo, Decl(inferContextualTypePredicates1.ts, 0, 0))
>Bar : Symbol(Bar, Decl(inferContextualTypePredicates1.ts, 0, 40))

const r1 = skipIf(items, (item) => item.type === "foo"); // ok
>r1 : Symbol(r1, Decl(inferContextualTypePredicates1.ts, 10, 5))
>skipIf : Symbol(skipIf, Decl(inferContextualTypePredicates1.ts, 1, 40))
>items : Symbol(items, Decl(inferContextualTypePredicates1.ts, 8, 13))
>item : Symbol(item, Decl(inferContextualTypePredicates1.ts, 10, 26))
>item.type : Symbol(type, Decl(inferContextualTypePredicates1.ts, 0, 12), Decl(inferContextualTypePredicates1.ts, 1, 12))
>item : Symbol(item, Decl(inferContextualTypePredicates1.ts, 10, 26))
>type : Symbol(type, Decl(inferContextualTypePredicates1.ts, 0, 12), Decl(inferContextualTypePredicates1.ts, 1, 12))

const r2 = skipIf(items, (item) => item.type === "foo" || item.type === "bar"); // ok
>r2 : Symbol(r2, Decl(inferContextualTypePredicates1.ts, 11, 5))
>skipIf : Symbol(skipIf, Decl(inferContextualTypePredicates1.ts, 1, 40))
>items : Symbol(items, Decl(inferContextualTypePredicates1.ts, 8, 13))
>item : Symbol(item, Decl(inferContextualTypePredicates1.ts, 11, 26))
>item.type : Symbol(type, Decl(inferContextualTypePredicates1.ts, 0, 12), Decl(inferContextualTypePredicates1.ts, 1, 12))
>item : Symbol(item, Decl(inferContextualTypePredicates1.ts, 11, 26))
>type : Symbol(type, Decl(inferContextualTypePredicates1.ts, 0, 12), Decl(inferContextualTypePredicates1.ts, 1, 12))
>item.type : Symbol(type, Decl(inferContextualTypePredicates1.ts, 1, 12))
>item : Symbol(item, Decl(inferContextualTypePredicates1.ts, 11, 26))
>type : Symbol(type, Decl(inferContextualTypePredicates1.ts, 1, 12))

const r3 = skipIf(items, (item) => false); // error
>r3 : Symbol(r3, Decl(inferContextualTypePredicates1.ts, 12, 5))
>skipIf : Symbol(skipIf, Decl(inferContextualTypePredicates1.ts, 1, 40))
>items : Symbol(items, Decl(inferContextualTypePredicates1.ts, 8, 13))
>item : Symbol(item, Decl(inferContextualTypePredicates1.ts, 12, 26))

const r4 = skipIf(items, (item) => true); // error
>r4 : Symbol(r4, Decl(inferContextualTypePredicates1.ts, 13, 5))
>skipIf : Symbol(skipIf, Decl(inferContextualTypePredicates1.ts, 1, 40))
>items : Symbol(items, Decl(inferContextualTypePredicates1.ts, 8, 13))
>item : Symbol(item, Decl(inferContextualTypePredicates1.ts, 13, 26))

const pred1: (a: string | null, b: string | null) => b is string = (a, b) => typeof b === 'string'; // ok
>pred1 : Symbol(pred1, Decl(inferContextualTypePredicates1.ts, 15, 5))
>a : Symbol(a, Decl(inferContextualTypePredicates1.ts, 15, 14))
>b : Symbol(b, Decl(inferContextualTypePredicates1.ts, 15, 31))
>b : Symbol(b, Decl(inferContextualTypePredicates1.ts, 15, 31))
>a : Symbol(a, Decl(inferContextualTypePredicates1.ts, 15, 68))
>b : Symbol(b, Decl(inferContextualTypePredicates1.ts, 15, 70))
>b : Symbol(b, Decl(inferContextualTypePredicates1.ts, 15, 70))

const pred2: (a: string | null, b: string | null) => b is string = (a, b) => typeof a === 'string'; // error
>pred2 : Symbol(pred2, Decl(inferContextualTypePredicates1.ts, 16, 5))
>a : Symbol(a, Decl(inferContextualTypePredicates1.ts, 16, 14))
>b : Symbol(b, Decl(inferContextualTypePredicates1.ts, 16, 31))
>b : Symbol(b, Decl(inferContextualTypePredicates1.ts, 16, 31))
>a : Symbol(a, Decl(inferContextualTypePredicates1.ts, 16, 68))
>b : Symbol(b, Decl(inferContextualTypePredicates1.ts, 16, 70))
>a : Symbol(a, Decl(inferContextualTypePredicates1.ts, 16, 68))

175 changes: 175 additions & 0 deletions tests/baselines/reference/inferContextualTypePredicates1.types
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
//// [tests/cases/compiler/inferContextualTypePredicates1.ts] ////

=== inferContextualTypePredicates1.ts ===
type Foo = { type: "foo"; foo: number };
>Foo : Foo
> : ^^^
>type : "foo"
> : ^^^^^
>foo : number
> : ^^^^^^

type Bar = { type: "bar"; bar: string };
>Bar : Bar
> : ^^^
>type : "bar"
> : ^^^^^
>bar : string
> : ^^^^^^

declare function skipIf<A, B extends A>(
>skipIf : <A, B extends A>(as: A[], predicate: (a: A) => a is B) => Exclude<A, B>[]
> : ^ ^^ ^^^^^^^^^ ^^ ^^ ^^ ^^ ^^^^^

as: A[],
>as : A[]
> : ^^^

predicate: (a: A) => a is B,
>predicate : (a: A) => a is B
> : ^ ^^ ^^^^^
>a : A
> : ^

): Exclude<A, B>[];

declare const items: (Foo | Bar)[];
>items : (Foo | Bar)[]
> : ^^^^^^^^^^^^^

const r1 = skipIf(items, (item) => item.type === "foo"); // ok
>r1 : Bar[]
> : ^^^^^
>skipIf(items, (item) => item.type === "foo") : Bar[]
> : ^^^^^
>skipIf : <A, B extends A>(as: A[], predicate: (a: A) => a is B) => Exclude<A, B>[]
> : ^ ^^ ^^^^^^^^^ ^^ ^^ ^^ ^^ ^^^^^
>items : (Foo | Bar)[]
> : ^^^^^^^^^^^^^
>(item) => item.type === "foo" : (item: Foo | Bar) => item is Foo
> : ^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
>item : Foo | Bar
> : ^^^^^^^^^
>item.type === "foo" : boolean
> : ^^^^^^^
>item.type : "foo" | "bar"
> : ^^^^^^^^^^^^^
>item : Foo | Bar
> : ^^^^^^^^^
>type : "foo" | "bar"
> : ^^^^^^^^^^^^^
>"foo" : "foo"
> : ^^^^^

const r2 = skipIf(items, (item) => item.type === "foo" || item.type === "bar"); // ok
>r2 : never[]
> : ^^^^^^^
>skipIf(items, (item) => item.type === "foo" || item.type === "bar") : never[]
> : ^^^^^^^
>skipIf : <A, B extends A>(as: A[], predicate: (a: A) => a is B) => Exclude<A, B>[]
> : ^ ^^ ^^^^^^^^^ ^^ ^^ ^^ ^^ ^^^^^
>items : (Foo | Bar)[]
> : ^^^^^^^^^^^^^
>(item) => item.type === "foo" || item.type === "bar" : (item: Foo | Bar) => item is Foo | Bar
> : ^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
>item : Foo | Bar
> : ^^^^^^^^^
>item.type === "foo" || item.type === "bar" : boolean
> : ^^^^^^^
>item.type === "foo" : boolean
> : ^^^^^^^
>item.type : "foo" | "bar"
> : ^^^^^^^^^^^^^
>item : Foo | Bar
> : ^^^^^^^^^
>type : "foo" | "bar"
> : ^^^^^^^^^^^^^
>"foo" : "foo"
> : ^^^^^
>item.type === "bar" : boolean
> : ^^^^^^^
>item.type : "bar"
> : ^^^^^
>item : Bar
> : ^^^
>type : "bar"
> : ^^^^^
>"bar" : "bar"
> : ^^^^^

const r3 = skipIf(items, (item) => false); // error
>r3 : never[]
> : ^^^^^^^
>skipIf(items, (item) => false) : never[]
> : ^^^^^^^
>skipIf : <A, B extends A>(as: A[], predicate: (a: A) => a is B) => Exclude<A, B>[]
> : ^ ^^ ^^^^^^^^^ ^^ ^^ ^^ ^^ ^^^^^
>items : (Foo | Bar)[]
> : ^^^^^^^^^^^^^
>(item) => false : (item: Foo | Bar) => false
> : ^ ^^^^^^^^^^^^^^^^^^^^^
>item : Foo | Bar
> : ^^^^^^^^^
>false : false
> : ^^^^^

const r4 = skipIf(items, (item) => true); // error
>r4 : never[]
> : ^^^^^^^
>skipIf(items, (item) => true) : never[]
> : ^^^^^^^
>skipIf : <A, B extends A>(as: A[], predicate: (a: A) => a is B) => Exclude<A, B>[]
> : ^ ^^ ^^^^^^^^^ ^^ ^^ ^^ ^^ ^^^^^
>items : (Foo | Bar)[]
> : ^^^^^^^^^^^^^
>(item) => true : (item: Foo | Bar) => true
> : ^ ^^^^^^^^^^^^^^^^^^^^
>item : Foo | Bar
> : ^^^^^^^^^
>true : true
> : ^^^^

const pred1: (a: string | null, b: string | null) => b is string = (a, b) => typeof b === 'string'; // ok
>pred1 : (a: string | null, b: string | null) => b is string
> : ^ ^^ ^^ ^^ ^^^^^
>a : string | null
> : ^^^^^^^^^^^^^
>b : string | null
> : ^^^^^^^^^^^^^
>(a, b) => typeof b === 'string' : (a: string | null, b: string | null) => b is string
> : ^ ^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
>a : string | null
> : ^^^^^^^^^^^^^
>b : string | null
> : ^^^^^^^^^^^^^
>typeof b === 'string' : boolean
> : ^^^^^^^
>typeof b : "string" | "number" | "bigint" | "boolean" | "symbol" | "undefined" | "object" | "function"
> : ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
>b : string | null
> : ^^^^^^^^^^^^^
>'string' : "string"
> : ^^^^^^^^

const pred2: (a: string | null, b: string | null) => b is string = (a, b) => typeof a === 'string'; // error
>pred2 : (a: string | null, b: string | null) => b is string
> : ^ ^^ ^^ ^^ ^^^^^
>a : string | null
> : ^^^^^^^^^^^^^
>b : string | null
> : ^^^^^^^^^^^^^
>(a, b) => typeof a === 'string' : (a: string | null, b: string | null) => boolean
> : ^ ^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
>a : string | null
> : ^^^^^^^^^^^^^
>b : string | null
> : ^^^^^^^^^^^^^
>typeof a === 'string' : boolean
> : ^^^^^^^
>typeof a : "string" | "number" | "bigint" | "boolean" | "symbol" | "undefined" | "object" | "function"
> : ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
>a : string | null
> : ^^^^^^^^^^^^^
>'string' : "string"
> : ^^^^^^^^

Loading

0 comments on commit bb1dd39

Please sign in to comment.