diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index c89e8f0092..06fee8a548 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -687,13 +687,34 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) -> Type { // Based on https://typing.readthedocs.io/en/latest/spec/constructors.html. let vs = if let Some(hint) = hint { - let vs = self - .solver() - .freshen_class_targs(cls.targs_mut(), self.uniques); - - self.is_subset_eq(&self.heap.mk_class_type(cls.clone()), hint.ty()); - self.solver().generalize_class_targs(cls.targs_mut()); - vs + // Constructor hints may be unions that contain non-instance branches + // (for example `T | Box[T]`). Constraining against the full union can + // bind unrelated type variables and over-specialize this constructor. + // Only pre-specialize from concrete instance branches of the same class. + // If multiple concrete branches of the same class are present (for + // example `Box[int] | Box[str]`), skip pre-specialization entirely so + // constructor inference does not depend on union member ordering. + let mut matching_class_hints = + hint.ty().clone().into_unions().into_iter().filter(|ty| { + matches!(ty, Type::ClassType(_)) + && ty.qname() == Some(cls.qname()) + && !ty.contains_type_variable() + && !ty.may_contain_quantified_var() + }); + let class_hint = match (matching_class_hints.next(), matching_class_hints.next()) { + (Some(class_hint), None) => Some(class_hint), + _ => None, + }; + if let Some(class_hint) = class_hint { + let vs = self + .solver() + .freshen_class_targs(cls.targs_mut(), self.uniques); + self.is_subset_eq(&self.heap.mk_class_type(cls.clone()), &class_hint); + self.solver().generalize_class_targs(cls.targs_mut()); + vs + } else { + QuantifiedHandle::empty() + } } else { QuantifiedHandle::empty() }; diff --git a/pyrefly/lib/test/constructors.rs b/pyrefly/lib/test/constructors.rs index 1d3fcbaf92..eaa74a1559 100644 --- a/pyrefly/lib/test/constructors.rs +++ b/pyrefly/lib/test/constructors.rs @@ -793,6 +793,51 @@ B([A("oops")]) # E: `str` is not assignable to upper bound `A | int` of type va "#, ); +testcase!( + test_init_overload_inline_constructor_with_union_hint, + r#" +from collections.abc import Iterable, Mapping +from typing import Generic, Never, SupportsInt, TypeVar, overload + +TCo = TypeVar("TCo", covariant=True) +T = TypeVar("T") + +class Box(Generic[TCo]): + @overload + def __init__(self: "Box[Never]", val: Mapping[Never, SupportsInt], /) -> None: ... + @overload + def __init__(self: "Box[T]", val: Mapping[T, SupportsInt], /) -> None: ... + def __init__(self, val: object, /) -> None: + pass + +def process(items: Iterable[tuple[T | Box[T], int]]) -> Box[T]: ... + +process(((Box({1: 1}), 1),)) + "#, +); + +testcase!( + test_init_overload_inline_constructor_with_multiple_concrete_union_hints, + r#" +from collections.abc import Iterable +from typing import Generic, TypeVar, overload + +T = TypeVar("T") + +class Box(Generic[T]): + @overload + def __init__(self: "Box[int]", val: int, /) -> None: ... + @overload + def __init__(self: "Box[str]", val: str, /) -> None: ... + def __init__(self, val: int | str, /) -> None: + pass + +def process(items: Iterable[tuple[Box[int] | Box[str], int]]) -> None: ... + +process(((Box("x"), 1),)) + "#, +); + testcase!( test_init_overload_with_self, r#"