Skip to content

Commit

Permalink
[red-knot] Emit a diagnostic if the value of a starred expression or …
Browse files Browse the repository at this point in the history
…a `yield from` expression is not iterable (#13240)
  • Loading branch information
AlexWaygood authored Sep 4, 2024
1 parent 46a4573 commit 0512428
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 25 deletions.
118 changes: 108 additions & 10 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use infer::TypeInferenceBuilder;
use ruff_db::files::File;
use ruff_python_ast as ast;

Expand Down Expand Up @@ -400,28 +401,42 @@ impl<'db> Type<'db> {
/// for y in x:
/// pass
/// ```
///
/// Returns `None` if `self` represents a type that is not iterable.
fn iterate(&self, db: &'db dyn Db) -> Option<Type<'db>> {
fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
// `self` represents the type of the iterable;
// `__iter__` and `__next__` are both looked up on the class of the iterable:
let type_of_class = self.to_meta_type(db);
let iterable_meta_type = self.to_meta_type(db);

let dunder_iter_method = type_of_class.member(db, "__iter__");
let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let iterator_ty = dunder_iter_method.call(db)?;
let Some(iterator_ty) = dunder_iter_method.call(db) else {
return IterationOutcome::NotIterable {
not_iterable_ty: *self,
};
};

let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method.call(db);
return dunder_next_method
.call(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
});
}

// Although it's not considered great practice,
// classes that define `__getitem__` are also iterable,
// even if they do not define `__iter__`.
//
// TODO this is only valid if the `__getitem__` method is annotated as
// TODO(Alex) this is only valid if the `__getitem__` method is annotated as
// accepting `int` or `SupportsIndex`
let dunder_get_item_method = type_of_class.member(db, "__getitem__");
dunder_get_item_method.call(db)
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");

dunder_get_item_method
.call(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
})
}

#[must_use]
Expand Down Expand Up @@ -463,6 +478,28 @@ impl<'db> Type<'db> {
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IterationOutcome<'db> {
Iterable { element_ty: Type<'db> },
NotIterable { not_iterable_ty: Type<'db> },
}

impl<'db> IterationOutcome<'db> {
fn unwrap_with_diagnostic(
self,
iterable_node: ast::AnyNodeRef,
inference_builder: &mut TypeInferenceBuilder<'db>,
) -> Type<'db> {
match self {
Self::Iterable { element_ty } => element_ty,
Self::NotIterable { not_iterable_ty } => {
inference_builder.not_iterable_diagnostic(iterable_node, not_iterable_ty);
Type::Unknown
}
}
}
}

#[salsa::interned]
pub struct FunctionType<'db> {
/// name of the function at definition
Expand Down Expand Up @@ -789,4 +826,65 @@ mod tests {
&["Object of type 'NotIterable' is not iterable"],
);
}

#[test]
fn starred_expressions_must_be_iterable() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
class NotIterable: pass
class Iterator:
def __next__(self) -> int:
return 42
class Iterable:
def __iter__(self) -> Iterator:
x = [*NotIterable()]
y = [*Iterable()]
",
)
.unwrap();

let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}

#[test]
fn yield_from_expression_must_be_iterable() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
class NotIterable: pass
class Iterator:
def __next__(self) -> int:
return 42
class Iterable:
def __iter__(self) -> Iterator:
def generator_function():
yield from Iterable()
yield from NotIterable()
",
)
.unwrap();

let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}
}
40 changes: 25 additions & 15 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ impl<'db> TypeInference<'db> {
/// Similarly, when we encounter a standalone-inferable expression (right-hand side of an
/// assignment, type narrowing guard), we use the [`infer_expression_types()`] query to ensure we
/// don't infer its types more than once.
struct TypeInferenceBuilder<'db> {
pub(super) struct TypeInferenceBuilder<'db> {
db: &'db dyn Db,
index: &'db SemanticIndex<'db>,
region: InferenceRegion<'db>,
Expand Down Expand Up @@ -1029,6 +1029,18 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_body(orelse);
}

/// Emit a diagnostic declaring that the object represented by `node` is not iterable
pub(super) fn not_iterable_diagnostic(&mut self, node: AnyNodeRef, not_iterable_ty: Type<'db>) {
self.add_diagnostic(
node,
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
not_iterable_ty.display(self.db)
),
);
}

fn infer_for_statement_definition(
&mut self,
target: &ast::ExprName,
Expand All @@ -1042,17 +1054,9 @@ impl<'db> TypeInferenceBuilder<'db> {
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));

let loop_var_value_ty = iterable_ty.iterate(self.db).unwrap_or_else(|| {
self.add_diagnostic(
iterable.into(),
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
iterable_ty.display(self.db)
),
);
Type::Unknown
});
let loop_var_value_ty = iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(iterable.into(), self);

self.types
.expressions
Expand Down Expand Up @@ -1812,7 +1816,10 @@ impl<'db> TypeInferenceBuilder<'db> {
ctx: _,
} = starred;

self.infer_expression(value);
let iterable_ty = self.infer_expression(value);
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(value.as_ref().into(), self);

// TODO
Type::Unknown
Expand All @@ -1830,9 +1837,12 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_yield_from_expression(&mut self, yield_from: &ast::ExprYieldFrom) -> Type<'db> {
let ast::ExprYieldFrom { range: _, value } = yield_from;

self.infer_expression(value);
let iterable_ty = self.infer_expression(value);
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(value.as_ref().into(), self);

// TODO get type from awaitable
// TODO get type from `ReturnType` of generator
Type::Unknown
}

Expand Down

0 comments on commit 0512428

Please sign in to comment.