Skip to content

Commit

Permalink
[red-knot] type narrowing (#12706)
Browse files Browse the repository at this point in the history
Extend the `UseDefMap` to also track which constraints (provided by e.g.
`if` tests) apply to each visible definition.

Uses a custom `BitSet` and `BitSetArray` to track which constraints
apply to which definitions, while keeping data inline as much as
possible.
  • Loading branch information
carljm authored Aug 16, 2024
1 parent a9847af commit 6359e55
Show file tree
Hide file tree
Showing 12 changed files with 1,054 additions and 231 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/red_knot_python_semantic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ salsa = { workspace = true }
tracing = { workspace = true }
rustc-hash = { workspace = true }
hashbrown = { workspace = true }
smallvec = { workspace = true }
static_assertions = { workspace = true }

[build-dependencies]
path-slash = { workspace = true }
Expand Down
146 changes: 77 additions & 69 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@ use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolTable,
};
use crate::semantic_index::use_def::UseDefMap;
use crate::Db;

pub(crate) use self::use_def::UseDefMap;

pub mod ast_ids;
mod builder;
pub mod definition;
pub mod expression;
pub mod symbol;
mod use_def;

pub(crate) use self::use_def::{DefinitionWithConstraints, DefinitionWithConstraintsIterator};

type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;

/// Returns the semantic index for `file`.
Expand Down Expand Up @@ -310,12 +311,29 @@ mod tests {
use ruff_text_size::{Ranged, TextRange};

use crate::db::tests::TestDb;
use crate::semantic_index::ast_ids::HasScopedUseId;
use crate::semantic_index::definition::DefinitionKind;
use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable};
use crate::semantic_index::ast_ids::{HasScopedUseId, ScopedUseId};
use crate::semantic_index::definition::{Definition, DefinitionKind};
use crate::semantic_index::symbol::{
FileScopeId, Scope, ScopeKind, ScopedSymbolId, SymbolTable,
};
use crate::semantic_index::use_def::UseDefMap;
use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map};
use crate::Db;

impl UseDefMap<'_> {
fn first_public_definition(&self, symbol: ScopedSymbolId) -> Option<Definition<'_>> {
self.public_definitions(symbol)
.next()
.map(|constrained_definition| constrained_definition.definition)
}

fn first_use_definition(&self, use_id: ScopedUseId) -> Option<Definition<'_>> {
self.use_definitions(use_id)
.next()
.map(|constrained_definition| constrained_definition.definition)
}
}

struct TestCase {
db: TestDb,
file: File,
Expand Down Expand Up @@ -374,9 +392,7 @@ mod tests {
let foo = global_table.symbol_id_by_name("foo").unwrap();

let use_def = use_def_map(&db, scope);
let [definition] = use_def.public_definitions(foo) else {
panic!("expected one definition");
};
let definition = use_def.first_public_definition(foo).unwrap();
assert!(matches!(definition.node(&db), DefinitionKind::Import(_)));
}

Expand Down Expand Up @@ -411,13 +427,13 @@ mod tests {
);

let use_def = use_def_map(&db, scope);
let [definition] = use_def.public_definitions(
global_table
.symbol_id_by_name("foo")
.expect("symbol to exist"),
) else {
panic!("expected one definition");
};
let definition = use_def
.first_public_definition(
global_table
.symbol_id_by_name("foo")
.expect("symbol to exist"),
)
.unwrap();
assert!(matches!(
definition.node(&db),
DefinitionKind::ImportFrom(_)
Expand All @@ -438,11 +454,9 @@ mod tests {
"a symbol used but not defined in a scope should have only the used flag"
);
let use_def = use_def_map(&db, scope);
let [definition] =
use_def.public_definitions(global_table.symbol_id_by_name("x").expect("symbol exists"))
else {
panic!("expected one definition");
};
let definition = use_def
.first_public_definition(global_table.symbol_id_by_name("x").expect("symbol exists"))
.unwrap();
assert!(matches!(
definition.node(&db),
DefinitionKind::Assignment(_)
Expand Down Expand Up @@ -477,11 +491,9 @@ y = 2
assert_eq!(names(&class_table), vec!["x"]);

let use_def = index.use_def_map(class_scope_id);
let [definition] =
use_def.public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists"))
else {
panic!("expected one definition");
};
let definition = use_def
.first_public_definition(class_table.symbol_id_by_name("x").expect("symbol exists"))
.unwrap();
assert!(matches!(
definition.node(&db),
DefinitionKind::Assignment(_)
Expand Down Expand Up @@ -515,13 +527,13 @@ y = 2
assert_eq!(names(&function_table), vec!["x"]);

let use_def = index.use_def_map(function_scope_id);
let [definition] = use_def.public_definitions(
function_table
.symbol_id_by_name("x")
.expect("symbol exists"),
) else {
panic!("expected one definition");
};
let definition = use_def
.first_public_definition(
function_table
.symbol_id_by_name("x")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
definition.node(&db),
DefinitionKind::Assignment(_)
Expand Down Expand Up @@ -557,26 +569,26 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):

let use_def = index.use_def_map(function_scope_id);
for name in ["a", "b", "c", "d"] {
let [definition] = use_def.public_definitions(
function_table
.symbol_id_by_name(name)
.expect("symbol exists"),
) else {
panic!("Expected parameter definition for {name}");
};
let definition = use_def
.first_public_definition(
function_table
.symbol_id_by_name(name)
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
definition.node(&db),
DefinitionKind::ParameterWithDefault(_)
));
}
for name in ["args", "kwargs"] {
let [definition] = use_def.public_definitions(
function_table
.symbol_id_by_name(name)
.expect("symbol exists"),
) else {
panic!("Expected parameter definition for {name}");
};
let definition = use_def
.first_public_definition(
function_table
.symbol_id_by_name(name)
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(definition.node(&db), DefinitionKind::Parameter(_)));
}
}
Expand Down Expand Up @@ -605,22 +617,22 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):

let use_def = index.use_def_map(lambda_scope_id);
for name in ["a", "b", "c", "d"] {
let [definition] = use_def
.public_definitions(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
else {
panic!("Expected parameter definition for {name}");
};
let definition = use_def
.first_public_definition(
lambda_table.symbol_id_by_name(name).expect("symbol exists"),
)
.unwrap();
assert!(matches!(
definition.node(&db),
DefinitionKind::ParameterWithDefault(_)
));
}
for name in ["args", "kwargs"] {
let [definition] = use_def
.public_definitions(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
else {
panic!("Expected parameter definition for {name}");
};
let definition = use_def
.first_public_definition(
lambda_table.symbol_id_by_name(name).expect("symbol exists"),
)
.unwrap();
assert!(matches!(definition.node(&db), DefinitionKind::Parameter(_)));
}
}
Expand Down Expand Up @@ -691,9 +703,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let element_use_id =
element.scoped_use_id(&db, comprehension_scope_id.to_scope_id(&db, file));

let [definition] = use_def.use_definitions(element_use_id) else {
panic!("expected one definition")
};
let definition = use_def.first_use_definition(element_use_id).unwrap();
let DefinitionKind::Comprehension(comprehension) = definition.node(&db) else {
panic!("expected generator definition")
};
Expand Down Expand Up @@ -790,13 +800,13 @@ def func():
assert_eq!(names(&func2_table), vec!["y"]);

let use_def = index.use_def_map(FileScopeId::global());
let [definition] = use_def.public_definitions(
global_table
.symbol_id_by_name("func")
.expect("symbol exists"),
) else {
panic!("expected one definition");
};
let definition = use_def
.first_public_definition(
global_table
.symbol_id_by_name("func")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(definition.node(&db), DefinitionKind::Function(_)));
}

Expand Down Expand Up @@ -897,9 +907,7 @@ class C[T]:
};
let x_use_id = x_use_expr_name.scoped_use_id(&db, scope);
let use_def = use_def_map(&db, scope);
let [definition] = use_def.use_definitions(x_use_id) else {
panic!("expected one definition");
};
let definition = use_def.first_use_definition(x_use_id).unwrap();
let DefinitionKind::Assignment(assignment) = definition.node(&db) else {
panic!("should be an assignment definition")
};
Expand Down
23 changes: 16 additions & 7 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl<'db> SemanticIndexBuilder<'db> {
self.current_use_def_map_mut().restore(state);
}

fn flow_merge(&mut self, state: &FlowSnapshot) {
fn flow_merge(&mut self, state: FlowSnapshot) {
self.current_use_def_map_mut().merge(state);
}

Expand Down Expand Up @@ -195,9 +195,16 @@ impl<'db> SemanticIndexBuilder<'db> {
definition
}

fn add_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> {
let expression = self.add_standalone_expression(constraint_node);
self.current_use_def_map_mut().record_constraint(expression);

expression
}

/// Record an expression that needs to be a Salsa ingredient, because we need to infer its type
/// standalone (type narrowing tests, RHS of an assignment.)
fn add_standalone_expression(&mut self, expression_node: &ast::Expr) {
fn add_standalone_expression(&mut self, expression_node: &ast::Expr) -> Expression<'db> {
let expression = Expression::new(
self.db,
self.file,
Expand All @@ -210,6 +217,7 @@ impl<'db> SemanticIndexBuilder<'db> {
);
self.expressions_by_node
.insert(expression_node.into(), expression);
expression
}

fn with_type_params(
Expand Down Expand Up @@ -476,6 +484,7 @@ where
ast::Stmt::If(node) => {
self.visit_expr(&node.test);
let pre_if = self.flow_snapshot();
self.add_constraint(&node.test);
self.visit_body(&node.body);
let mut post_clauses: Vec<FlowSnapshot> = vec![];
for clause in &node.elif_else_clauses {
Expand All @@ -488,7 +497,7 @@ where
self.visit_elif_else_clause(clause);
}
for post_clause_state in post_clauses {
self.flow_merge(&post_clause_state);
self.flow_merge(post_clause_state);
}
let has_else = node
.elif_else_clauses
Expand All @@ -497,7 +506,7 @@ where
if !has_else {
// if there's no else clause, then it's possible we took none of the branches,
// and the pre_if state can reach here
self.flow_merge(&pre_if);
self.flow_merge(pre_if);
}
}
ast::Stmt::While(node) => {
Expand All @@ -515,13 +524,13 @@ where

// We may execute the `else` clause without ever executing the body, so merge in
// the pre-loop state before visiting `else`.
self.flow_merge(&pre_loop);
self.flow_merge(pre_loop);
self.visit_body(&node.orelse);

// Breaking out of a while loop bypasses the `else` clause, so merge in the break
// states after visiting `else`.
for break_state in break_states {
self.flow_merge(&break_state);
self.flow_merge(break_state);
}
}
ast::Stmt::Break(_) => {
Expand Down Expand Up @@ -631,7 +640,7 @@ where
let post_body = self.flow_snapshot();
self.flow_restore(pre_if);
self.visit_expr(orelse);
self.flow_merge(&post_body);
self.flow_merge(post_body);
}
ast::Expr::ListComp(
list_comprehension @ ast::ExprListComp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub(crate) struct Expression<'db> {
/// The expression node.
#[no_eq]
#[return_ref]
pub(crate) node: AstNodeRef<ast::Expr>,
pub(crate) node_ref: AstNodeRef<ast::Expr>,

#[no_eq]
count: countme::Count<Expression<'static>>,
Expand Down
Loading

0 comments on commit 6359e55

Please sign in to comment.