From d19fd1b91c0465637e94aaba13dbca06fbd43cca Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Fri, 23 Aug 2024 23:40:27 +0100 Subject: [PATCH] [red-knot] Add symbols for `for` loop variables (#13075) ## Summary This PR adds symbols introduced by `for` loops to red-knot: - `x` in `for x in range(10): pass` - `x` and `y` in `for x, y in d.items(): pass` - `a`, `b`, `c` and `d` in `for [((a,), b), (c, d)] in foo: pass` ## Test Plan Several tests added, and the assertion in the benchmarks has been updated. --------- Co-authored-by: Micha Reiser --- .../src/semantic_index.rs | 52 +++++++++++++++++++ .../src/semantic_index/builder.rs | 39 +++++++++++++- .../src/semantic_index/definition.rs | 46 ++++++++++++++++ .../src/types/infer.rs | 47 ++++++++++++++++- crates/ruff_benchmark/benches/red_knot.rs | 15 ------ 5 files changed, 181 insertions(+), 18 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index d7603ddef387a..7c35f3c3fdfa8 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -1095,4 +1095,56 @@ match subject: vec!["subject", "a", "b", "c", "d", "f", "e", "h", "g", "Foo", "i", "j", "k", "l"] ); } + + #[test] + fn for_loops_single_assignment() { + let TestCase { db, file } = test_case("for x in a: pass"); + let scope = global_scope(&db, file); + let global_table = symbol_table(&db, scope); + + assert_eq!(&names(&global_table), &["a", "x"]); + + let use_def = use_def_map(&db, scope); + let definition = use_def + .first_public_definition(global_table.symbol_id_by_name("x").unwrap()) + .unwrap(); + + assert!(matches!(definition.node(&db), DefinitionKind::For(_))); + } + + #[test] + fn for_loops_simple_unpacking() { + let TestCase { db, file } = test_case("for (x, y) in a: pass"); + let scope = global_scope(&db, file); + let global_table = symbol_table(&db, scope); + + assert_eq!(&names(&global_table), &["a", "x", "y"]); + + let use_def = use_def_map(&db, scope); + let x_definition = use_def + .first_public_definition(global_table.symbol_id_by_name("x").unwrap()) + .unwrap(); + let y_definition = use_def + .first_public_definition(global_table.symbol_id_by_name("y").unwrap()) + .unwrap(); + + assert!(matches!(x_definition.node(&db), DefinitionKind::For(_))); + assert!(matches!(y_definition.node(&db), DefinitionKind::For(_))); + } + + #[test] + fn for_loops_complex_unpacking() { + let TestCase { db, file } = test_case("for [((a,) b), (c, d)] in e: pass"); + let scope = global_scope(&db, file); + let global_table = symbol_table(&db, scope); + + assert_eq!(&names(&global_table), &["e", "a", "b", "c", "d"]); + + let use_def = use_def_map(&db, scope); + let definition = use_def + .first_public_definition(global_table.symbol_id_by_name("a").unwrap()) + .unwrap(); + + assert!(matches!(definition.node(&db), DefinitionKind::For(_))); + } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 5db6aca1d8dfa..9805d25019007 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -15,7 +15,7 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::ast_ids::AstIdsBuilder; use crate::semantic_index::definition::{ AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, Definition, DefinitionNodeKey, - DefinitionNodeRef, ImportFromDefinitionNodeRef, + DefinitionNodeRef, ForStmtDefinitionNodeRef, ImportFromDefinitionNodeRef, }; use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ @@ -578,6 +578,27 @@ where ast::Stmt::Break(_) => { self.loop_break_states.push(self.flow_snapshot()); } + + ast::Stmt::For( + for_stmt @ ast::StmtFor { + range: _, + is_async: _, + target, + iter, + body, + orelse, + }, + ) => { + // TODO add control flow similar to `ast::Stmt::While` above + self.add_standalone_expression(iter); + self.visit_expr(iter); + debug_assert!(self.current_assignment.is_none()); + self.current_assignment = Some(for_stmt.into()); + self.visit_expr(target); + self.current_assignment = None; + self.visit_body(body); + self.visit_body(orelse); + } _ => { walk_stmt(self, stmt); } @@ -624,6 +645,15 @@ where Some(CurrentAssignment::AugAssign(aug_assign)) => { self.add_definition(symbol, aug_assign); } + Some(CurrentAssignment::For(node)) => { + self.add_definition( + symbol, + ForStmtDefinitionNodeRef { + iterable: &node.iter, + target: name_node, + }, + ); + } Some(CurrentAssignment::Named(named)) => { // TODO(dhruvmanila): If the current scope is a comprehension, then the // named expression is implicitly nonlocal. This is yet to be @@ -796,6 +826,7 @@ enum CurrentAssignment<'a> { Assign(&'a ast::StmtAssign), AnnAssign(&'a ast::StmtAnnAssign), AugAssign(&'a ast::StmtAugAssign), + For(&'a ast::StmtFor), Named(&'a ast::ExprNamed), Comprehension { node: &'a ast::Comprehension, @@ -822,6 +853,12 @@ impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> { } } +impl<'a> From<&'a ast::StmtFor> for CurrentAssignment<'a> { + fn from(value: &'a ast::StmtFor) -> Self { + Self::For(value) + } +} + impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> { fn from(value: &'a ast::ExprNamed) -> Self { Self::Named(value) diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 68c56f763fb0c..24b4b8e23f108 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -39,6 +39,7 @@ impl<'db> Definition<'db> { pub(crate) enum DefinitionNodeRef<'a> { Import(&'a ast::Alias), ImportFrom(ImportFromDefinitionNodeRef<'a>), + For(ForStmtDefinitionNodeRef<'a>), Function(&'a ast::StmtFunctionDef), Class(&'a ast::StmtClassDef), NamedExpression(&'a ast::ExprNamed), @@ -92,6 +93,12 @@ impl<'a> From> for DefinitionNodeRef<'a> { } } +impl<'a> From> for DefinitionNodeRef<'a> { + fn from(value: ForStmtDefinitionNodeRef<'a>) -> Self { + Self::For(value) + } +} + impl<'a> From> for DefinitionNodeRef<'a> { fn from(node_ref: AssignmentDefinitionNodeRef<'a>) -> Self { Self::Assignment(node_ref) @@ -134,6 +141,12 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> { pub(crate) target: &'a ast::ExprName, } +#[derive(Copy, Clone, Debug)] +pub(crate) struct ForStmtDefinitionNodeRef<'a> { + pub(crate) iterable: &'a ast::Expr, + pub(crate) target: &'a ast::ExprName, +} + #[derive(Copy, Clone, Debug)] pub(crate) struct ComprehensionDefinitionNodeRef<'a> { pub(crate) node: &'a ast::Comprehension, @@ -174,6 +187,12 @@ impl DefinitionNodeRef<'_> { DefinitionNodeRef::AugmentedAssignment(augmented_assignment) => { DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment)) } + DefinitionNodeRef::For(ForStmtDefinitionNodeRef { iterable, target }) => { + DefinitionKind::For(ForStmtDefinitionKind { + iterable: AstNodeRef::new(parsed.clone(), iterable), + target: AstNodeRef::new(parsed, target), + }) + } DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => { DefinitionKind::Comprehension(ComprehensionDefinitionKind { node: AstNodeRef::new(parsed, node), @@ -212,6 +231,10 @@ impl DefinitionNodeRef<'_> { }) => target.into(), Self::AnnotatedAssignment(node) => node.into(), Self::AugmentedAssignment(node) => node.into(), + Self::For(ForStmtDefinitionNodeRef { + iterable: _, + target, + }) => target.into(), Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(), Self::Parameter(node) => match node { ast::AnyParameterRef::Variadic(parameter) => parameter.into(), @@ -232,6 +255,7 @@ pub enum DefinitionKind { Assignment(AssignmentDefinitionKind), AnnotatedAssignment(AstNodeRef), AugmentedAssignment(AstNodeRef), + For(ForStmtDefinitionKind), Comprehension(ComprehensionDefinitionKind), Parameter(AstNodeRef), ParameterWithDefault(AstNodeRef), @@ -302,6 +326,22 @@ impl WithItemDefinitionKind { } } +#[derive(Clone, Debug)] +pub struct ForStmtDefinitionKind { + iterable: AstNodeRef, + target: AstNodeRef, +} + +impl ForStmtDefinitionKind { + pub(crate) fn iterable(&self) -> &ast::Expr { + self.iterable.node() + } + + pub(crate) fn target(&self) -> &ast::ExprName { + self.target.node() + } +} + #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub(crate) struct DefinitionNodeKey(NodeKey); @@ -347,6 +387,12 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey { } } +impl From<&ast::StmtFor> for DefinitionNodeKey { + fn from(value: &ast::StmtFor) -> Self { + Self(NodeKey::from_node(value)) + } +} + impl From<&ast::Comprehension> for DefinitionNodeKey { fn from(node: &ast::Comprehension) -> Self { Self(NodeKey::from_node(node)) diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 45803a61ce8da..80603b7cf7a5c 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -138,7 +138,6 @@ pub(crate) struct TypeInference<'db> { } impl<'db> TypeInference<'db> { - #[allow(unused)] pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> { self.expressions[&expression] } @@ -317,6 +316,13 @@ impl<'db> TypeInferenceBuilder<'db> { DefinitionKind::AugmentedAssignment(augmented_assignment) => { self.infer_augment_assignment_definition(augmented_assignment.node(), definition); } + DefinitionKind::For(for_statement_definition) => { + self.infer_for_statement_definition( + for_statement_definition.target(), + for_statement_definition.iterable(), + definition, + ); + } DefinitionKind::NamedExpression(named_expression) => { self.infer_named_expression_definition(named_expression.node(), definition); } @@ -865,11 +871,48 @@ impl<'db> TypeInferenceBuilder<'db> { } = for_statement; self.infer_expression(iter); - self.infer_expression(target); + // TODO more complex assignment targets + if let ast::Expr::Name(name) = &**target { + self.infer_definition(name); + } else { + self.infer_expression(target); + } self.infer_body(body); self.infer_body(orelse); } + fn infer_for_statement_definition( + &mut self, + target: &ast::ExprName, + iterable: &ast::Expr, + definition: Definition<'db>, + ) { + let expression = self.index.expression(iterable); + let result = infer_expression_types(self.db, expression); + self.extend(result); + let iterable_ty = self + .types + .expression_ty(iterable.scoped_ast_id(self.db, self.scope)); + + // TODO(Alex): only a valid iterable if the *type* of `iterable_ty` has an `__iter__` + // member (dunders are never looked up on an instance) + let _dunder_iter_ty = iterable_ty.member(self.db, &ast::name::Name::from("__iter__")); + + // TODO(Alex): + // - infer the return type of the `__iter__` method, which gives us the iterator + // - lookup the `__next__` method on the iterator + // - infer the return type of the iterator's `__next__` method, + // which gives us the type of the variable being bound here + // (...or the type of the object being unpacked into multiple definitions, if it's something like + // `for k, v in d.items(): ...`) + let loop_var_value_ty = Type::Unknown; + + self.types + .expressions + .insert(target.scoped_ast_id(self.db, self.scope), loop_var_value_ty); + self.types.definitions.insert(definition, loop_var_value_ty); + } + fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) { let ast::StmtWhile { range: _, diff --git a/crates/ruff_benchmark/benches/red_knot.rs b/crates/ruff_benchmark/benches/red_knot.rs index 312b2b2310313..fa927cdc97be0 100644 --- a/crates/ruff_benchmark/benches/red_knot.rs +++ b/crates/ruff_benchmark/benches/red_knot.rs @@ -40,23 +40,8 @@ static EXPECTED_DIAGNOSTICS: &[&str] = &[ "Use double quotes for strings", "Use double quotes for strings", "Use double quotes for strings", - "/src/tomllib/_parser.py:153:22: Name 'key' used when not defined.", - "/src/tomllib/_parser.py:153:27: Name 'flag' used when not defined.", - "/src/tomllib/_parser.py:159:16: Name 'k' used when not defined.", - "/src/tomllib/_parser.py:161:25: Name 'k' used when not defined.", - "/src/tomllib/_parser.py:168:16: Name 'k' used when not defined.", - "/src/tomllib/_parser.py:169:22: Name 'k' used when not defined.", - "/src/tomllib/_parser.py:170:25: Name 'k' used when not defined.", - "/src/tomllib/_parser.py:180:16: Name 'k' used when not defined.", - "/src/tomllib/_parser.py:182:31: Name 'k' used when not defined.", - "/src/tomllib/_parser.py:206:16: Name 'k' used when not defined.", - "/src/tomllib/_parser.py:207:22: Name 'k' used when not defined.", - "/src/tomllib/_parser.py:208:25: Name 'k' used when not defined.", "/src/tomllib/_parser.py:330:32: Name 'header' used when not defined.", "/src/tomllib/_parser.py:330:41: Name 'key' used when not defined.", - "/src/tomllib/_parser.py:333:26: Name 'cont_key' used when not defined.", - "/src/tomllib/_parser.py:334:71: Name 'cont_key' used when not defined.", - "/src/tomllib/_parser.py:337:31: Name 'cont_key' used when not defined.", "/src/tomllib/_parser.py:628:75: Name 'e' used when not defined.", "/src/tomllib/_parser.py:686:23: Name 'parse_float' used when not defined.", ];