diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index f57a1747aae64..37430d95c3aa6 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -7,9 +7,11 @@ use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId}; use crate::semantic_index::{global_scope, symbol_table, use_def_map}; use crate::{Db, FxOrderSet}; +mod builder; mod display; mod infer; +pub(crate) use self::builder::UnionBuilder; pub(crate) use self::infer::{infer_definition_types, infer_scope_types}; /// Infer the public type of a symbol (its type as seen from outside its scope). @@ -91,14 +93,14 @@ pub(crate) fn definitions_ty<'db>( }; if let Some(second) = all_types.next() { - let mut builder = UnionTypeBuilder::new(db); + let mut builder = UnionBuilder::new(db); builder = builder.add(first).add(second); for variant in all_types { builder = builder.add(variant); } - Type::Union(builder.build()) + builder.build() } else { first } @@ -117,7 +119,7 @@ pub enum Type<'db> { /// name does not exist or is not bound to any value (this represents an error, but with some /// leniency options it could be silently resolved to Unknown in some cases) Unbound, - /// the None object (TODO remove this in favor of Instance(types.NoneType) + /// the None object -- TODO remove this in favor of Instance(types.NoneType) None, /// a specific function object Function(FunctionType<'db>), @@ -127,8 +129,11 @@ pub enum Type<'db> { Class(ClassType<'db>), /// the set of Python objects with the given class in their __class__'s method resolution order Instance(ClassType<'db>), + /// the set of objects in any of the types in the union Union(UnionType<'db>), + /// the set of objects in all of the types in the intersection Intersection(IntersectionType<'db>), + /// An integer literal IntLiteral(i64), /// A boolean literal, either `True` or `False`. BooleanLiteral(bool), @@ -159,15 +164,13 @@ impl<'db> Type<'db> { // TODO MRO? get_own_instance_member, get_instance_member todo!("attribute lookup on Instance type") } - Type::Union(union) => Type::Union( - union - .elements(db) - .iter() - .fold(UnionTypeBuilder::new(db), |builder, element_ty| { - builder.add(element_ty.member(db, name)) - }) - .build(), - ), + Type::Union(union) => union + .elements(db) + .iter() + .fold(UnionBuilder::new(db), |builder, element_ty| { + builder.add(element_ty.member(db, name)) + }) + .build(), Type::Intersection(_) => { // TODO perform the get_member on each type in the intersection // TODO return the intersection of those results @@ -251,7 +254,7 @@ impl<'db> ClassType<'db> { #[salsa::interned] pub struct UnionType<'db> { - /// the union type includes values in any of these types + /// The union type includes values in any of these types. elements: FxOrderSet>, } @@ -261,48 +264,15 @@ impl<'db> UnionType<'db> { } } -struct UnionTypeBuilder<'db> { - elements: FxOrderSet>, - db: &'db dyn Db, -} - -impl<'db> UnionTypeBuilder<'db> { - fn new(db: &'db dyn Db) -> Self { - Self { - db, - elements: FxOrderSet::default(), - } - } - - /// Adds a type to this union. - fn add(mut self, ty: Type<'db>) -> Self { - match ty { - Type::Union(union) => { - self.elements.extend(&union.elements(self.db)); - } - _ => { - self.elements.insert(ty); - } - } - - self - } - - fn build(self) -> UnionType<'db> { - UnionType::new(self.db, self.elements) - } -} - -// Negation types aren't expressible in annotations, and are most likely to arise from type -// narrowing along with intersections (e.g. `if not isinstance(...)`), so we represent them -// directly in intersections rather than as a separate type. This sacrifices some efficiency in the -// case where a Not appears outside an intersection (unclear when that could even happen, but we'd -// have to represent it as a single-element intersection if it did) in exchange for better -// efficiency in the within-intersection case. #[salsa::interned] pub struct IntersectionType<'db> { - // the intersection type includes only values in all of these types + /// The intersection type includes only values in all of these types. positive: FxOrderSet>, - // the intersection type does not include any value in any of these types + + /// The intersection type does not include any value in any of these types. + /// + /// Negation types aren't expressible in annotations, and are most likely to arise from type + /// narrowing along with intersections (e.g. `if not isinstance(...)`), so we represent them + /// directly in intersections rather than as a separate type. negative: FxOrderSet>, } diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs new file mode 100644 index 0000000000000..9f8af0f295160 --- /dev/null +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -0,0 +1,429 @@ +//! Smart builders for union and intersection types. +//! +//! Invariants we maintain here: +//! * No single-element union types (should just be the contained type instead.) +//! * No single-positive-element intersection types. Single-negative-element are OK, we don't +//! have a standalone negation type so there's no other representation for this. +//! * The same type should never appear more than once in a union or intersection. (This should +//! be expanded to cover subtyping -- see below -- but for now we only implement it for type +//! identity.) +//! * Disjunctive normal form (DNF): the tree of unions and intersections can never be deeper +//! than a union-of-intersections. Unions cannot contain other unions (the inner union just +//! flattens into the outer one), intersections cannot contain other intersections (also +//! flattens), and intersections cannot contain unions (the intersection distributes over the +//! union, inverting it into a union-of-intersections). +//! +//! The implication of these invariants is that a [`UnionBuilder`] does not necessarily build a +//! [`Type::Union`]. For example, if only one type is added to the [`UnionBuilder`], `build()` will +//! just return that type directly. The same is true for [`IntersectionBuilder`]; for example, if a +//! union type is added to the intersection, it will distribute and [`IntersectionBuilder::build`] +//! may end up returning a [`Type::Union`] of intersections. +//! +//! In the future we should have these additional invariants, but they aren't implemented yet: +//! * No type in a union can be a subtype of any other type in the union (just eliminate the +//! subtype from the union). +//! * No type in an intersection can be a supertype of any other type in the intersection (just +//! eliminate the supertype from the intersection). +//! * An intersection containing two non-overlapping types should simplify to [`Type::Never`]. +use crate::types::{IntersectionType, Type, UnionType}; +use crate::{Db, FxOrderSet}; + +pub(crate) struct UnionBuilder<'db> { + elements: FxOrderSet>, + db: &'db dyn Db, +} + +impl<'db> UnionBuilder<'db> { + pub(crate) fn new(db: &'db dyn Db) -> Self { + Self { + db, + elements: FxOrderSet::default(), + } + } + + /// Adds a type to this union. + pub(crate) fn add(mut self, ty: Type<'db>) -> Self { + match ty { + Type::Union(union) => { + self.elements.extend(&union.elements(self.db)); + } + Type::Never => {} + _ => { + self.elements.insert(ty); + } + } + + self + } + + pub(crate) fn build(self) -> Type<'db> { + match self.elements.len() { + 0 => Type::Never, + 1 => self.elements[0], + _ => Type::Union(UnionType::new(self.db, self.elements)), + } + } +} + +#[allow(unused)] +#[derive(Clone)] +pub(crate) struct IntersectionBuilder<'db> { + // Really this builds a union-of-intersections, because we always keep our set-theoretic types + // in disjunctive normal form (DNF), a union of intersections. In the simplest case there's + // just a single intersection in this vector, and we are building a single intersection type, + // but if a union is added to the intersection, we'll distribute ourselves over that union and + // create a union of intersections. + intersections: Vec>, + db: &'db dyn Db, +} + +impl<'db> IntersectionBuilder<'db> { + #[allow(dead_code)] + fn new(db: &'db dyn Db) -> Self { + Self { + db, + intersections: vec![InnerIntersectionBuilder::new()], + } + } + + fn empty(db: &'db dyn Db) -> Self { + Self { + db, + intersections: vec![], + } + } + + #[allow(dead_code)] + fn add_positive(mut self, ty: Type<'db>) -> Self { + if let Type::Union(union) = ty { + // Distribute ourself over this union: for each union element, clone ourself and + // intersect with that union element, then create a new union-of-intersections with all + // of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2` + // and we add `T3 | T4` to the intersection, we don't get `T1 & T2 & (T3 | T4)` (that's + // not in DNF), we distribute the union and get `(T1 & T3) | (T2 & T3) | (T1 & T4) | + // (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)` + // and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 & + // T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea. + union + .elements(self.db) + .iter() + .map(|elem| self.clone().add_positive(*elem)) + .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { + builder.intersections.extend(sub.intersections); + builder + }) + } else { + // If we are already a union-of-intersections, distribute the new intersected element + // across all of those intersections. + for inner in &mut self.intersections { + inner.add_positive(self.db, ty); + } + self + } + } + + #[allow(dead_code)] + fn add_negative(mut self, ty: Type<'db>) -> Self { + // See comments above in `add_positive`; this is just the negated version. + if let Type::Union(union) = ty { + union + .elements(self.db) + .iter() + .map(|elem| self.clone().add_negative(*elem)) + .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { + builder.intersections.extend(sub.intersections); + builder + }) + } else { + for inner in &mut self.intersections { + inner.add_negative(self.db, ty); + } + self + } + } + + #[allow(dead_code)] + fn build(mut self) -> Type<'db> { + // Avoid allocating the UnionBuilder unnecessarily if we have just one intersection: + if self.intersections.len() == 1 { + self.intersections.pop().unwrap().build(self.db) + } else { + let mut builder = UnionBuilder::new(self.db); + for inner in self.intersections { + builder = builder.add(inner.build(self.db)); + } + builder.build() + } + } +} + +#[allow(unused)] +#[derive(Debug, Clone, Default)] +struct InnerIntersectionBuilder<'db> { + positive: FxOrderSet>, + negative: FxOrderSet>, +} + +impl<'db> InnerIntersectionBuilder<'db> { + fn new() -> Self { + Self::default() + } + + /// Adds a positive type to this intersection. + fn add_positive(&mut self, db: &'db dyn Db, ty: Type<'db>) { + match ty { + Type::Intersection(inter) => { + let pos = inter.positive(db); + let neg = inter.negative(db); + self.positive.extend(pos.difference(&self.negative)); + self.negative.extend(neg.difference(&self.positive)); + self.positive.retain(|elem| !neg.contains(elem)); + self.negative.retain(|elem| !pos.contains(elem)); + } + _ => { + if !self.negative.remove(&ty) { + self.positive.insert(ty); + }; + } + } + } + + /// Adds a negative type to this intersection. + fn add_negative(&mut self, db: &'db dyn Db, ty: Type<'db>) { + // TODO Any/Unknown actually should not self-cancel + match ty { + Type::Intersection(intersection) => { + let pos = intersection.negative(db); + let neg = intersection.positive(db); + self.positive.extend(pos.difference(&self.negative)); + self.negative.extend(neg.difference(&self.positive)); + self.positive.retain(|elem| !neg.contains(elem)); + self.negative.retain(|elem| !pos.contains(elem)); + } + Type::Never => {} + _ => { + if !self.positive.remove(&ty) { + self.negative.insert(ty); + }; + } + } + } + + fn simplify(&mut self) { + // TODO this should be generalized based on subtyping, for now we just handle a few cases + + // Never is a subtype of all types + if self.positive.contains(&Type::Never) { + self.positive.clear(); + self.negative.clear(); + self.positive.insert(Type::Never); + } + } + + fn build(mut self, db: &'db dyn Db) -> Type<'db> { + self.simplify(); + match (self.positive.len(), self.negative.len()) { + (0, 0) => Type::Never, + (1, 0) => self.positive[0], + _ => { + self.positive.shrink_to_fit(); + self.negative.shrink_to_fit(); + Type::Intersection(IntersectionType::new(db, self.positive, self.negative)) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::{IntersectionBuilder, IntersectionType, Type, UnionBuilder, UnionType}; + use crate::db::tests::TestDb; + + fn setup_db() -> TestDb { + TestDb::new() + } + + impl<'db> UnionType<'db> { + fn elements_vec(self, db: &'db TestDb) -> Vec> { + self.elements(db).into_iter().collect() + } + } + + #[test] + fn build_union() { + let db = setup_db(); + let t0 = Type::IntLiteral(0); + let t1 = Type::IntLiteral(1); + let Type::Union(union) = UnionBuilder::new(&db).add(t0).add(t1).build() else { + panic!("expected a union"); + }; + + assert_eq!(union.elements_vec(&db), &[t0, t1]); + } + + #[test] + fn build_union_single() { + let db = setup_db(); + let t0 = Type::IntLiteral(0); + let ty = UnionBuilder::new(&db).add(t0).build(); + + assert_eq!(ty, t0); + } + + #[test] + fn build_union_empty() { + let db = setup_db(); + let ty = UnionBuilder::new(&db).build(); + + assert_eq!(ty, Type::Never); + } + + #[test] + fn build_union_never() { + let db = setup_db(); + let t0 = Type::IntLiteral(0); + let ty = UnionBuilder::new(&db).add(t0).add(Type::Never).build(); + + assert_eq!(ty, t0); + } + + #[test] + fn build_union_flatten() { + let db = setup_db(); + let t0 = Type::IntLiteral(0); + let t1 = Type::IntLiteral(1); + let t2 = Type::IntLiteral(2); + let u1 = UnionBuilder::new(&db).add(t0).add(t1).build(); + let Type::Union(union) = UnionBuilder::new(&db).add(u1).add(t2).build() else { + panic!("expected a union"); + }; + + assert_eq!(union.elements_vec(&db), &[t0, t1, t2]); + } + + impl<'db> IntersectionType<'db> { + fn pos_vec(self, db: &'db TestDb) -> Vec> { + self.positive(db).into_iter().collect() + } + + fn neg_vec(self, db: &'db TestDb) -> Vec> { + self.negative(db).into_iter().collect() + } + } + + #[test] + fn build_intersection() { + let db = setup_db(); + let t0 = Type::IntLiteral(0); + let ta = Type::Any; + let Type::Intersection(inter) = IntersectionBuilder::new(&db) + .add_positive(ta) + .add_negative(t0) + .build() + else { + panic!("expected to be an intersection"); + }; + + assert_eq!(inter.pos_vec(&db), &[ta]); + assert_eq!(inter.neg_vec(&db), &[t0]); + } + + #[test] + fn build_intersection_flatten_positive() { + let db = setup_db(); + let ta = Type::Any; + let t1 = Type::IntLiteral(1); + let t2 = Type::IntLiteral(2); + let i0 = IntersectionBuilder::new(&db) + .add_positive(ta) + .add_negative(t1) + .build(); + let Type::Intersection(inter) = IntersectionBuilder::new(&db) + .add_positive(t2) + .add_positive(i0) + .build() + else { + panic!("expected to be an intersection"); + }; + + assert_eq!(inter.pos_vec(&db), &[t2, ta]); + assert_eq!(inter.neg_vec(&db), &[t1]); + } + + #[test] + fn build_intersection_flatten_negative() { + let db = setup_db(); + let ta = Type::Any; + let t1 = Type::IntLiteral(1); + let t2 = Type::IntLiteral(2); + let i0 = IntersectionBuilder::new(&db) + .add_positive(ta) + .add_negative(t1) + .build(); + let Type::Intersection(inter) = IntersectionBuilder::new(&db) + .add_positive(t2) + .add_negative(i0) + .build() + else { + panic!("expected to be an intersection"); + }; + + assert_eq!(inter.pos_vec(&db), &[t2, t1]); + assert_eq!(inter.neg_vec(&db), &[ta]); + } + + #[test] + fn intersection_distributes_over_union() { + let db = setup_db(); + let t0 = Type::IntLiteral(0); + let t1 = Type::IntLiteral(1); + let ta = Type::Any; + let u0 = UnionBuilder::new(&db).add(t0).add(t1).build(); + + let Type::Union(union) = IntersectionBuilder::new(&db) + .add_positive(ta) + .add_positive(u0) + .build() + else { + panic!("expected a union"); + }; + let [Type::Intersection(i0), Type::Intersection(i1)] = union.elements_vec(&db)[..] else { + panic!("expected a union of two intersections"); + }; + assert_eq!(i0.pos_vec(&db), &[ta, t0]); + assert_eq!(i1.pos_vec(&db), &[ta, t1]); + } + + #[test] + fn build_intersection_self_negation() { + let db = setup_db(); + let ty = IntersectionBuilder::new(&db) + .add_positive(Type::None) + .add_negative(Type::None) + .build(); + + assert_eq!(ty, Type::Never); + } + + #[test] + fn build_intersection_simplify_negative_never() { + let db = setup_db(); + let ty = IntersectionBuilder::new(&db) + .add_positive(Type::None) + .add_negative(Type::Never) + .build(); + + assert_eq!(ty, Type::None); + } + + #[test] + fn build_intersection_simplify_positive_never() { + let db = setup_db(); + let ty = IntersectionBuilder::new(&db) + .add_positive(Type::None) + .add_positive(Type::Never) + .build(); + + assert_eq!(ty, Type::Never); + } +} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 4e14325673afd..644a15ddd0124 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -40,7 +40,7 @@ use crate::semantic_index::symbol::{FileScopeId, NodeWithScopeKind, NodeWithScop use crate::semantic_index::SemanticIndex; use crate::types::{ builtins_symbol_ty_by_name, definitions_ty, global_symbol_ty_by_name, ClassType, FunctionType, - Name, Type, UnionTypeBuilder, + Name, Type, UnionBuilder, }; use crate::Db; @@ -1179,12 +1179,10 @@ impl<'db> TypeInferenceBuilder<'db> { let body_ty = self.infer_expression(body); let orelse_ty = self.infer_expression(orelse); - let union = UnionTypeBuilder::new(self.db) + UnionBuilder::new(self.db) .add(body_ty) .add(orelse_ty) - .build(); - - Type::Union(union) + .build() } fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) {