Skip to content

Commit

Permalink
Pull all types in corpus tests (#12919)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaReiser authored Aug 17, 2024
1 parent 25f5ae4 commit dd0a7ec
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 46 deletions.
30 changes: 20 additions & 10 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use ruff_index::IndexVec;
use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor};
use ruff_python_ast::AnyParameterRef;

use crate::ast_node_ref::AstNodeRef;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
Expand Down Expand Up @@ -309,6 +310,23 @@ impl<'db> SemanticIndexBuilder<'db> {
}
}

fn declare_parameter(&mut self, parameter: AnyParameterRef) {
let symbol =
self.add_or_update_symbol(parameter.name().id().clone(), SymbolFlags::IS_DEFINED);

let definition = self.add_definition(symbol, parameter);

if let AnyParameterRef::NonVariadic(with_default) = parameter {
// Insert a mapping from the parameter to the same definition.
// This ensures that calling `HasTy::ty` on the inner parameter returns
// a valid type (and doesn't panic)
self.definitions_by_node.insert(
DefinitionNodeRef::from(AnyParameterRef::Variadic(&with_default.parameter)).key(),
definition,
);
}
}

pub(super) fn build(mut self) -> SemanticIndex<'db> {
let module = self.module;
self.visit_body(module.suite());
Expand Down Expand Up @@ -399,11 +417,7 @@ where

// Add symbols and definitions for the parameters to the function scope.
for parameter in &*function_def.parameters {
let symbol = builder.add_or_update_symbol(
parameter.name().id().clone(),
SymbolFlags::IS_DEFINED,
);
builder.add_definition(symbol, parameter);
builder.declare_parameter(parameter);
}

builder.visit_body(&function_def.body);
Expand Down Expand Up @@ -618,11 +632,7 @@ where
// Add symbols and definitions for the parameters to the lambda scope.
if let Some(parameters) = &lambda.parameters {
for parameter in &**parameters {
let symbol = self.add_or_update_symbol(
parameter.name().id().clone(),
SymbolFlags::IS_DEFINED,
);
self.add_definition(symbol, parameter);
self.declare_parameter(parameter);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ impl AssignmentDefinitionKind {
pub(crate) fn assignment(&self) -> &ast::StmtAssign {
self.assignment.node()
}

pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
Expand Down
39 changes: 17 additions & 22 deletions crates/red_knot_python_semantic/src/semantic_model.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use ruff_db::files::{File, FilePath};
use ruff_db::source::line_index;
use ruff_python_ast as ast;
use ruff_python_ast::{Expr, ExpressionRef, StmtClassDef};
use ruff_python_ast::{Expr, ExpressionRef};
use ruff_source_file::LineIndex;

use crate::module_name::ModuleName;
Expand Down Expand Up @@ -147,29 +147,24 @@ impl HasTy for ast::Expr {
}
}

impl HasTy for ast::StmtFunctionDef {
fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let index = semantic_index(model.db, model.file);
let definition = index.definition(self);
definition_ty(model.db, definition)
}
}

impl HasTy for StmtClassDef {
fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let index = semantic_index(model.db, model.file);
let definition = index.definition(self);
definition_ty(model.db, definition)
}
macro_rules! impl_definition_has_ty {
($ty: ty) => {
impl HasTy for $ty {
#[inline]
fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let index = semantic_index(model.db, model.file);
let definition = index.definition(self);
definition_ty(model.db, definition)
}
}
};
}

impl HasTy for ast::Alias {
fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let index = semantic_index(model.db, model.file);
let definition = index.definition(self);
definition_ty(model.db, definition)
}
}
impl_definition_has_ty!(ast::StmtFunctionDef);
impl_definition_has_ty!(ast::StmtClassDef);
impl_definition_has_ty!(ast::Alias);
impl_definition_has_ty!(ast::Parameter);
impl_definition_has_ty!(ast::ParameterWithDefault);

#[cfg(test)]
mod tests {
Expand Down
27 changes: 22 additions & 5 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use salsa::plumbing::AsId;
use ruff_db::files::File;
use ruff_db::parsed::parsed_module;
use ruff_python_ast as ast;
use ruff_python_ast::{ExprContext, TypeParams};
use ruff_python_ast::{Expr, ExprContext};

use crate::builtins::builtins_scope;
use crate::module_name::ModuleName;
Expand Down Expand Up @@ -294,7 +294,11 @@ impl<'db> TypeInferenceBuilder<'db> {
);
}
DefinitionKind::Assignment(assignment) => {
self.infer_assignment_definition(assignment.assignment(), definition);
self.infer_assignment_definition(
assignment.target(),
assignment.assignment(),
definition,
);
}
DefinitionKind::AnnotatedAssignment(annotated_assignment) => {
self.infer_annotated_assignment_definition(annotated_assignment.node(), definition);
Expand Down Expand Up @@ -706,6 +710,7 @@ impl<'db> TypeInferenceBuilder<'db> {

fn infer_assignment_definition(
&mut self,
target: &ast::ExprName,
assignment: &ast::StmtAssign,
definition: Definition<'db>,
) {
Expand All @@ -715,6 +720,9 @@ impl<'db> TypeInferenceBuilder<'db> {
let value_ty = self
.types
.expression_ty(assignment.value.scoped_ast_id(self.db, self.scope));
self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), value_ty);
self.types.definitions.insert(definition, value_ty);
}

Expand Down Expand Up @@ -999,6 +1007,9 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::Expr::NumberLiteral(literal) => self.infer_number_literal_expression(literal),
ast::Expr::BooleanLiteral(literal) => self.infer_boolean_literal_expression(literal),
ast::Expr::StringLiteral(literal) => self.infer_string_literal_expression(literal),
ast::Expr::BytesLiteral(bytes_literal) => {
self.infer_bytes_literal_expression(bytes_literal)
}
ast::Expr::FString(fstring) => self.infer_fstring_expression(fstring),
ast::Expr::EllipsisLiteral(literal) => self.infer_ellipsis_literal_expression(literal),
ast::Expr::Tuple(tuple) => self.infer_tuple_expression(tuple),
Expand All @@ -1025,8 +1036,7 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression),
ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from),
ast::Expr::Await(await_expression) => self.infer_await_expression(await_expression),

_ => todo!("expression type resolution for {:?}", expression),
Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"),
};

let expr_id = expression.scoped_ast_id(self.db, self.scope);
Expand Down Expand Up @@ -1063,6 +1073,12 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::Unknown
}

#[allow(clippy::unused_self)]
fn infer_bytes_literal_expression(&mut self, _literal: &ast::ExprBytesLiteral) -> Type<'db> {
// TODO
Type::Unknown
}

fn infer_fstring_expression(&mut self, fstring: &ast::ExprFString) -> Type<'db> {
let ast::ExprFString { range: _, value } = fstring;

Expand Down Expand Up @@ -1630,7 +1646,7 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::Unknown
}

fn infer_type_parameters(&mut self, type_parameters: &TypeParams) {
fn infer_type_parameters(&mut self, type_parameters: &ast::TypeParams) {
let ast::TypeParams {
range: _,
type_params,
Expand Down Expand Up @@ -1677,6 +1693,7 @@ impl<'db> TypeInferenceBuilder<'db> {
#[cfg(test)]
mod tests {
use anyhow::Context;

use ruff_db::files::{system_path_to_file, File};
use ruff_db::parsed::parsed_module;
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
Expand Down
106 changes: 97 additions & 9 deletions crates/red_knot_workspace/tests/check.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use red_knot_python_semantic::{ProgramSettings, PythonVersion, SearchPathSettings};
use red_knot_python_semantic::{
HasTy, ProgramSettings, PythonVersion, SearchPathSettings, SemanticModel,
};
use red_knot_workspace::db::RootDatabase;
use red_knot_workspace::lint::lint_semantic;
use red_knot_workspace::workspace::WorkspaceMetadata;
use ruff_db::files::system_path_to_file;
use ruff_db::system::{OsSystem, SystemPathBuf};
use ruff_db::files::{system_path_to_file, File};
use ruff_db::parsed::parsed_module;
use ruff_db::system::{OsSystem, SystemPath, SystemPathBuf};
use ruff_python_ast::visitor::source_order;
use ruff_python_ast::visitor::source_order::SourceOrderVisitor;
use ruff_python_ast::{Alias, Expr, Parameter, ParameterWithDefault, Stmt};
use std::fs;
use std::path::PathBuf;

Expand All @@ -28,17 +33,100 @@ fn setup_db(workspace_root: SystemPathBuf) -> anyhow::Result<RootDatabase> {
#[allow(clippy::print_stdout)]
fn corpus_no_panic() -> anyhow::Result<()> {
let corpus = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("resources/test/corpus");
let system_corpus =
SystemPathBuf::from_path_buf(corpus.clone()).expect("corpus path to be UTF8");
let db = setup_db(system_corpus.clone())?;
let system_corpus = SystemPath::from_std_path(&corpus).expect("corpus path to be UTF8");
let db = setup_db(system_corpus.to_path_buf())?;

for path in fs::read_dir(&corpus).expect("corpus to be a directory") {
let path = path.expect("path to not be an error").path();
println!("checking {path:?}");
let path = SystemPathBuf::from_path_buf(path.clone()).expect("path to be UTF-8");
// this test is only asserting that we can run the lint without a panic
// this test is only asserting that we can pull every expression type without a panic
// (and some non-expressions that clearly define a single type)
let file = system_path_to_file(&db, path).expect("file to exist");
lint_semantic(&db, file);

pull_types(&db, file);
}
Ok(())
}

fn pull_types(db: &RootDatabase, file: File) {
let mut visitor = PullTypesVisitor::new(db, file);

let ast = parsed_module(db, file);

visitor.visit_body(ast.suite());
}

struct PullTypesVisitor<'db> {
model: SemanticModel<'db>,
}

impl<'db> PullTypesVisitor<'db> {
fn new(db: &'db RootDatabase, file: File) -> Self {
Self {
model: SemanticModel::new(db, file),
}
}
}

impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> {
fn visit_stmt(&mut self, stmt: &Stmt) {
match stmt {
Stmt::FunctionDef(function) => {
let _ty = function.ty(&self.model);
}
Stmt::ClassDef(class) => {
let _ty = class.ty(&self.model);
}
Stmt::AnnAssign(_)
| Stmt::Return(_)
| Stmt::Delete(_)
| Stmt::Assign(_)
| Stmt::AugAssign(_)
| Stmt::TypeAlias(_)
| Stmt::For(_)
| Stmt::While(_)
| Stmt::If(_)
| Stmt::With(_)
| Stmt::Match(_)
| Stmt::Raise(_)
| Stmt::Try(_)
| Stmt::Assert(_)
| Stmt::Import(_)
| Stmt::ImportFrom(_)
| Stmt::Global(_)
| Stmt::Nonlocal(_)
| Stmt::Expr(_)
| Stmt::Pass(_)
| Stmt::Break(_)
| Stmt::Continue(_)
| Stmt::IpyEscapeCommand(_) => {}
}

source_order::walk_stmt(self, stmt);
}

fn visit_expr(&mut self, expr: &Expr) {
let _ty = expr.ty(&self.model);

source_order::walk_expr(self, expr);
}

fn visit_parameter(&mut self, parameter: &Parameter) {
let _ty = parameter.ty(&self.model);

source_order::walk_parameter(self, parameter);
}

fn visit_parameter_with_default(&mut self, parameter_with_default: &ParameterWithDefault) {
let _ty = parameter_with_default.ty(&self.model);

source_order::walk_parameter_with_default(self, parameter_with_default);
}

fn visit_alias(&mut self, alias: &Alias) {
let _ty = alias.ty(&self.model);

source_order::walk_alias(self, alias);
}
}

0 comments on commit dd0a7ec

Please sign in to comment.