Skip to content

Commit

Permalink
[red-knot] Add symbols for for loop variables (#13075)
Browse files Browse the repository at this point in the history
## Summary

This PR adds symbols introduced by `for` loops to red-knot:
- `x` in `for x in range(10): pass`
- `x` and `y` in `for x, y in d.items(): pass`
- `a`, `b`, `c` and `d` in `for [((a,), b), (c, d)] in foo: pass`

## Test Plan

Several tests added, and the assertion in the benchmarks has been
updated.

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
  • Loading branch information
AlexWaygood and MichaReiser authored Aug 23, 2024
1 parent 99df859 commit d19fd1b
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 18 deletions.
52 changes: 52 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1095,4 +1095,56 @@ match subject:
vec!["subject", "a", "b", "c", "d", "f", "e", "h", "g", "Foo", "i", "j", "k", "l"]
);
}

#[test]
fn for_loops_single_assignment() {
let TestCase { db, file } = test_case("for x in a: pass");
let scope = global_scope(&db, file);
let global_table = symbol_table(&db, scope);

assert_eq!(&names(&global_table), &["a", "x"]);

let use_def = use_def_map(&db, scope);
let definition = use_def
.first_public_definition(global_table.symbol_id_by_name("x").unwrap())
.unwrap();

assert!(matches!(definition.node(&db), DefinitionKind::For(_)));
}

#[test]
fn for_loops_simple_unpacking() {
let TestCase { db, file } = test_case("for (x, y) in a: pass");
let scope = global_scope(&db, file);
let global_table = symbol_table(&db, scope);

assert_eq!(&names(&global_table), &["a", "x", "y"]);

let use_def = use_def_map(&db, scope);
let x_definition = use_def
.first_public_definition(global_table.symbol_id_by_name("x").unwrap())
.unwrap();
let y_definition = use_def
.first_public_definition(global_table.symbol_id_by_name("y").unwrap())
.unwrap();

assert!(matches!(x_definition.node(&db), DefinitionKind::For(_)));
assert!(matches!(y_definition.node(&db), DefinitionKind::For(_)));
}

#[test]
fn for_loops_complex_unpacking() {
let TestCase { db, file } = test_case("for [((a,) b), (c, d)] in e: pass");
let scope = global_scope(&db, file);
let global_table = symbol_table(&db, scope);

assert_eq!(&names(&global_table), &["e", "a", "b", "c", "d"]);

let use_def = use_def_map(&db, scope);
let definition = use_def
.first_public_definition(global_table.symbol_id_by_name("a").unwrap())
.unwrap();

assert!(matches!(definition.node(&db), DefinitionKind::For(_)));
}
}
39 changes: 38 additions & 1 deletion crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::AstIdsBuilder;
use crate::semantic_index::definition::{
AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, Definition, DefinitionNodeKey,
DefinitionNodeRef, ImportFromDefinitionNodeRef,
DefinitionNodeRef, ForStmtDefinitionNodeRef, ImportFromDefinitionNodeRef,
};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{
Expand Down Expand Up @@ -578,6 +578,27 @@ where
ast::Stmt::Break(_) => {
self.loop_break_states.push(self.flow_snapshot());
}

ast::Stmt::For(
for_stmt @ ast::StmtFor {
range: _,
is_async: _,
target,
iter,
body,
orelse,
},
) => {
// TODO add control flow similar to `ast::Stmt::While` above
self.add_standalone_expression(iter);
self.visit_expr(iter);
debug_assert!(self.current_assignment.is_none());
self.current_assignment = Some(for_stmt.into());
self.visit_expr(target);
self.current_assignment = None;
self.visit_body(body);
self.visit_body(orelse);
}
_ => {
walk_stmt(self, stmt);
}
Expand Down Expand Up @@ -624,6 +645,15 @@ where
Some(CurrentAssignment::AugAssign(aug_assign)) => {
self.add_definition(symbol, aug_assign);
}
Some(CurrentAssignment::For(node)) => {
self.add_definition(
symbol,
ForStmtDefinitionNodeRef {
iterable: &node.iter,
target: name_node,
},
);
}
Some(CurrentAssignment::Named(named)) => {
// TODO(dhruvmanila): If the current scope is a comprehension, then the
// named expression is implicitly nonlocal. This is yet to be
Expand Down Expand Up @@ -796,6 +826,7 @@ enum CurrentAssignment<'a> {
Assign(&'a ast::StmtAssign),
AnnAssign(&'a ast::StmtAnnAssign),
AugAssign(&'a ast::StmtAugAssign),
For(&'a ast::StmtFor),
Named(&'a ast::ExprNamed),
Comprehension {
node: &'a ast::Comprehension,
Expand All @@ -822,6 +853,12 @@ impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> {
}
}

impl<'a> From<&'a ast::StmtFor> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtFor) -> Self {
Self::For(value)
}
}

impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
fn from(value: &'a ast::ExprNamed) -> Self {
Self::Named(value)
Expand Down
46 changes: 46 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ impl<'db> Definition<'db> {
pub(crate) enum DefinitionNodeRef<'a> {
Import(&'a ast::Alias),
ImportFrom(ImportFromDefinitionNodeRef<'a>),
For(ForStmtDefinitionNodeRef<'a>),
Function(&'a ast::StmtFunctionDef),
Class(&'a ast::StmtClassDef),
NamedExpression(&'a ast::ExprNamed),
Expand Down Expand Up @@ -92,6 +93,12 @@ impl<'a> From<ImportFromDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
}
}

impl<'a> From<ForStmtDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(value: ForStmtDefinitionNodeRef<'a>) -> Self {
Self::For(value)
}
}

impl<'a> From<AssignmentDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: AssignmentDefinitionNodeRef<'a>) -> Self {
Self::Assignment(node_ref)
Expand Down Expand Up @@ -134,6 +141,12 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> {
pub(crate) target: &'a ast::ExprName,
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct ForStmtDefinitionNodeRef<'a> {
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::ExprName,
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::Comprehension,
Expand Down Expand Up @@ -174,6 +187,12 @@ impl DefinitionNodeRef<'_> {
DefinitionNodeRef::AugmentedAssignment(augmented_assignment) => {
DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment))
}
DefinitionNodeRef::For(ForStmtDefinitionNodeRef { iterable, target }) => {
DefinitionKind::For(ForStmtDefinitionKind {
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
})
}
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => {
DefinitionKind::Comprehension(ComprehensionDefinitionKind {
node: AstNodeRef::new(parsed, node),
Expand Down Expand Up @@ -212,6 +231,10 @@ impl DefinitionNodeRef<'_> {
}) => target.into(),
Self::AnnotatedAssignment(node) => node.into(),
Self::AugmentedAssignment(node) => node.into(),
Self::For(ForStmtDefinitionNodeRef {
iterable: _,
target,
}) => target.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(),
Self::Parameter(node) => match node {
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
Expand All @@ -232,6 +255,7 @@ pub enum DefinitionKind {
Assignment(AssignmentDefinitionKind),
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
For(ForStmtDefinitionKind),
Comprehension(ComprehensionDefinitionKind),
Parameter(AstNodeRef<ast::Parameter>),
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
Expand Down Expand Up @@ -302,6 +326,22 @@ impl WithItemDefinitionKind {
}
}

#[derive(Clone, Debug)]
pub struct ForStmtDefinitionKind {
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
}

impl ForStmtDefinitionKind {
pub(crate) fn iterable(&self) -> &ast::Expr {
self.iterable.node()
}

pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub(crate) struct DefinitionNodeKey(NodeKey);

Expand Down Expand Up @@ -347,6 +387,12 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey {
}
}

impl From<&ast::StmtFor> for DefinitionNodeKey {
fn from(value: &ast::StmtFor) -> Self {
Self(NodeKey::from_node(value))
}
}

impl From<&ast::Comprehension> for DefinitionNodeKey {
fn from(node: &ast::Comprehension) -> Self {
Self(NodeKey::from_node(node))
Expand Down
47 changes: 45 additions & 2 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ pub(crate) struct TypeInference<'db> {
}

impl<'db> TypeInference<'db> {
#[allow(unused)]
pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> {
self.expressions[&expression]
}
Expand Down Expand Up @@ -317,6 +316,13 @@ impl<'db> TypeInferenceBuilder<'db> {
DefinitionKind::AugmentedAssignment(augmented_assignment) => {
self.infer_augment_assignment_definition(augmented_assignment.node(), definition);
}
DefinitionKind::For(for_statement_definition) => {
self.infer_for_statement_definition(
for_statement_definition.target(),
for_statement_definition.iterable(),
definition,
);
}
DefinitionKind::NamedExpression(named_expression) => {
self.infer_named_expression_definition(named_expression.node(), definition);
}
Expand Down Expand Up @@ -865,11 +871,48 @@ impl<'db> TypeInferenceBuilder<'db> {
} = for_statement;

self.infer_expression(iter);
self.infer_expression(target);
// TODO more complex assignment targets
if let ast::Expr::Name(name) = &**target {
self.infer_definition(name);
} else {
self.infer_expression(target);
}
self.infer_body(body);
self.infer_body(orelse);
}

fn infer_for_statement_definition(
&mut self,
target: &ast::ExprName,
iterable: &ast::Expr,
definition: Definition<'db>,
) {
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);
self.extend(result);
let iterable_ty = self
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));

// TODO(Alex): only a valid iterable if the *type* of `iterable_ty` has an `__iter__`
// member (dunders are never looked up on an instance)
let _dunder_iter_ty = iterable_ty.member(self.db, &ast::name::Name::from("__iter__"));

// TODO(Alex):
// - infer the return type of the `__iter__` method, which gives us the iterator
// - lookup the `__next__` method on the iterator
// - infer the return type of the iterator's `__next__` method,
// which gives us the type of the variable being bound here
// (...or the type of the object being unpacked into multiple definitions, if it's something like
// `for k, v in d.items(): ...`)
let loop_var_value_ty = Type::Unknown;

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), loop_var_value_ty);
self.types.definitions.insert(definition, loop_var_value_ty);
}

fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) {
let ast::StmtWhile {
range: _,
Expand Down
15 changes: 0 additions & 15 deletions crates/ruff_benchmark/benches/red_knot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,8 @@ static EXPECTED_DIAGNOSTICS: &[&str] = &[
"Use double quotes for strings",
"Use double quotes for strings",
"Use double quotes for strings",
"/src/tomllib/_parser.py:153:22: Name 'key' used when not defined.",
"/src/tomllib/_parser.py:153:27: Name 'flag' used when not defined.",
"/src/tomllib/_parser.py:159:16: Name 'k' used when not defined.",
"/src/tomllib/_parser.py:161:25: Name 'k' used when not defined.",
"/src/tomllib/_parser.py:168:16: Name 'k' used when not defined.",
"/src/tomllib/_parser.py:169:22: Name 'k' used when not defined.",
"/src/tomllib/_parser.py:170:25: Name 'k' used when not defined.",
"/src/tomllib/_parser.py:180:16: Name 'k' used when not defined.",
"/src/tomllib/_parser.py:182:31: Name 'k' used when not defined.",
"/src/tomllib/_parser.py:206:16: Name 'k' used when not defined.",
"/src/tomllib/_parser.py:207:22: Name 'k' used when not defined.",
"/src/tomllib/_parser.py:208:25: Name 'k' used when not defined.",
"/src/tomllib/_parser.py:330:32: Name 'header' used when not defined.",
"/src/tomllib/_parser.py:330:41: Name 'key' used when not defined.",
"/src/tomllib/_parser.py:333:26: Name 'cont_key' used when not defined.",
"/src/tomllib/_parser.py:334:71: Name 'cont_key' used when not defined.",
"/src/tomllib/_parser.py:337:31: Name 'cont_key' used when not defined.",
"/src/tomllib/_parser.py:628:75: Name 'e' used when not defined.",
"/src/tomllib/_parser.py:686:23: Name 'parse_float' used when not defined.",
];
Expand Down

0 comments on commit d19fd1b

Please sign in to comment.