Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] infer basic (name-based) annotation expressions #13130

Merged
merged 14 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ pub struct FunctionType<'db> {
/// name of the function at definition
pub name: ast::name::Name,

definition: Definition<'db>,

/// types of all decorators on this function
decorators: Vec<Type<'db>>,
}
Expand All @@ -344,6 +346,19 @@ impl<'db> FunctionType<'db> {
pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool {
self.decorators(db).contains(&decorator)
}

/// annotated return type for this function, if any
pub fn returns(&self, db: &'db dyn Db) -> Option<Type<'db>> {
let definition = self.definition(db);
let DefinitionKind::Function(function_stmt_node) = definition.node(db) else {
panic!("Function type definition must have `DefinitionKind::Function`")
};

function_stmt_node
.returns
.as_ref()
.map(|returns| definition_expression_ty(db, definition, returns.as_ref()))
}
}

#[salsa::interned]
Expand Down
232 changes: 224 additions & 8 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ fn infer_definition_types_cycle_recovery<'db>(
_cycle: &salsa::Cycle,
input: Definition<'db>,
) -> TypeInference<'db> {
tracing::trace!("infer_definition_types_cycle_recovery");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should change this to an actual sentence rather than the method name. Tracing messages are user facing (not trace but we might decide to make them user facing in the future).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is a partially-broken temporary crutch just to get us through until we have fixpoint iteration and full deferred resolution of annotations in the necessary places; it should go away. So I'm not too worried about this particular trace message sticking around for a long time. But this is a good point.

let mut inference = TypeInference::default();
inference.definitions.insert(input, Type::Unknown);
inference
Expand Down Expand Up @@ -420,9 +421,7 @@ impl<'db> TypeInferenceBuilder<'db> {

fn infer_region_deferred(&mut self, definition: Definition<'db>) {
match definition.node(self.db) {
DefinitionKind::Function(_function) => {
// TODO self.infer_function_deferred(function.node());
}
DefinitionKind::Function(function) => self.infer_function_deferred(function.node()),
DefinitionKind::Class(class) => self.infer_class_deferred(class.node()),
DefinitionKind::AnnotatedAssignment(_annotated_assignment) => {
// TODO self.infer_annotated_assignment_deferred(annotated_assignment.node());
Expand Down Expand Up @@ -460,7 +459,12 @@ impl<'db> TypeInferenceBuilder<'db> {
let Some(type_params) = function.type_params.as_deref() else {
panic!("function type params scope without type params");
};
self.infer_optional_expression(function.returns.as_deref());

// TODO: this should also be applied to parameter annotations.
if !self.is_stub() {
self.infer_optional_expression(function.returns.as_deref());
}

self.infer_type_parameters(type_params);
self.infer_parameters(&function.parameters);
}
Expand Down Expand Up @@ -549,14 +553,23 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_expression(default);
}

// If there are type params, parameters and returns are evaluated in that scope.
// If there are type params, parameters and returns are evaluated in that scope, that is, in
// `infer_function_type_params`, rather than here.
if type_params.is_none() {
self.infer_parameters(parameters);
self.infer_optional_expression(returns.as_deref());

// TODO: this should also be applied to parameter annotations.
if !self.is_stub() {
self.infer_optional_annotation_expression(returns.as_deref());
}
}

let function_ty =
Type::Function(FunctionType::new(self.db, name.id.clone(), decorator_tys));
let function_ty = Type::Function(FunctionType::new(
self.db,
name.id.clone(),
definition,
decorator_tys,
));

self.types.definitions.insert(definition, function_ty);
}
Expand Down Expand Up @@ -670,6 +683,13 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}

fn infer_function_deferred(&mut self, function: &ast::StmtFunctionDef) {
if self.is_stub() {
self.types.has_deferred = true;
self.infer_optional_annotation_expression(function.returns.as_deref());
}
}

fn infer_class_deferred(&mut self, class: &ast::StmtClassDef) {
if self.is_stub() {
self.types.has_deferred = true;
Expand Down Expand Up @@ -1297,6 +1317,13 @@ impl<'db> TypeInferenceBuilder<'db> {
expression.map(|expr| self.infer_expression(expr))
}

fn infer_optional_annotation_expression(
&mut self,
expr: Option<&ast::Expr>,
) -> Option<Type<'db>> {
expr.map(|expr| self.infer_annotation_expression(expr))
}

fn infer_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
let ty = match expression {
ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::None,
Expand Down Expand Up @@ -2059,6 +2086,174 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}

/// Annotation expressions.
impl<'db> TypeInferenceBuilder<'db> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for placing each method in its own impl block?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was not "each method" but "each set of related methods", though at the moment it is just one method for annotation expressions and one for type expressions; that will change. It was just a way to more clearly separate inference of value expressions from inference of annotation expressions and type expressions.

This was Chris' idea and it seemed fine to me, but I'm also fine with just using comment headers to achieve the same separation.

fn infer_annotation_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
// https://typing.readthedocs.io/en/latest/spec/annotations.html#grammar-token-expression-grammar-annotation_expression
match expression {
// TODO: parse the expression and check whether it is a string annotation, since they
// can be annotation expressions distinct from type expressions.
// https://typing.readthedocs.io/en/latest/spec/annotations.html#string-annotations
ast::Expr::StringLiteral(_literal) => Type::Unknown,

// Annotation expressions also get special handling for `*args` and `**kwargs`.
ast::Expr::Starred(starred) => self.infer_starred_expression(starred),

// All other annotation expressions are (possibly) valid type expressions, so handle
// them there instead.
type_expr => self.infer_type_expression(type_expr),
}
}
}

/// Type expressions
impl<'db> TypeInferenceBuilder<'db> {
fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
// https://typing.readthedocs.io/en/latest/spec/annotations.html#grammar-token-expression-grammar-type_expression
// TODO: this does not include any of the special forms, and is only a
// stub of the forms other than a standalone name in scope.

let ty = match expression {
ast::Expr::Name(name) => {
debug_assert!(
name.ctx.is_load(),
"name in a type expression is always 'load' but got: '{:?}'",
name.ctx
);

let instance = self.infer_name_expression(name).instance();
instance
carljm marked this conversation as resolved.
Show resolved Hide resolved
}

ast::Expr::NoneLiteral(_literal) => Type::None,

// TODO: parse the expression and check whether it is a string annotation.
// https://typing.readthedocs.io/en/latest/spec/annotations.html#string-annotations
ast::Expr::StringLiteral(_literal) => Type::Unknown,

// TODO: an Ellipsis literal *on its own* does not have any meaning in annotation
// expressions, but is meaningful in the context of a number of special forms.
ast::Expr::EllipsisLiteral(_literal) => Type::Unknown,

// Other literals do not have meaningful values in the annotation expression context.
// However, we will we want to handle these differently when working with special forms,
// since (e.g.) `123` is not valid in an annotation expression but `Literal[123]` is.
ast::Expr::BytesLiteral(_literal) => Type::Unknown,
ast::Expr::NumberLiteral(_literal) => Type::Unknown,
ast::Expr::BooleanLiteral(_literal) => Type::Unknown,

// Forms which are invalid in the context of annotation expressions: we infer their
// nested expressions as normal expressions, but the type of the top-level expression is
// always `Type::Unknown` in these cases.
ast::Expr::BoolOp(bool_op) => {
self.infer_boolean_expression(bool_op);
Type::Unknown
}
ast::Expr::Named(named) => {
self.infer_named_expression(named);
Type::Unknown
}
ast::Expr::BinOp(binary) => {
self.infer_binary_expression(binary);
Type::Unknown
}
ast::Expr::UnaryOp(unary) => {
self.infer_unary_expression(unary);
Type::Unknown
}
ast::Expr::Lambda(lambda_expression) => {
self.infer_lambda_expression(lambda_expression);
Type::Unknown
}
ast::Expr::If(if_expression) => {
self.infer_if_expression(if_expression);
Type::Unknown
}
ast::Expr::Dict(dict) => {
self.infer_dict_expression(dict);
Type::Unknown
}
ast::Expr::Set(set) => {
self.infer_set_expression(set);
Type::Unknown
}
ast::Expr::ListComp(listcomp) => {
self.infer_list_comprehension_expression(listcomp);
Type::Unknown
}
ast::Expr::SetComp(setcomp) => {
self.infer_set_comprehension_expression(setcomp);
Type::Unknown
}
ast::Expr::DictComp(dictcomp) => {
self.infer_dict_comprehension_expression(dictcomp);
Type::Unknown
}
ast::Expr::Generator(generator) => {
self.infer_generator_expression(generator);
Type::Unknown
}
ast::Expr::Await(await_expression) => {
self.infer_await_expression(await_expression);
Type::Unknown
}
ast::Expr::Yield(yield_expression) => {
self.infer_yield_expression(yield_expression);
Type::Unknown
}
ast::Expr::YieldFrom(yield_from) => {
self.infer_yield_from_expression(yield_from);
Type::Unknown
}
ast::Expr::Compare(compare) => {
self.infer_compare_expression(compare);
Type::Unknown
}
ast::Expr::Call(call_expr) => {
self.infer_call_expression(call_expr);
Type::Unknown
}
ast::Expr::FString(fstring) => {
self.infer_fstring_expression(fstring);
Type::Unknown
}
//
ast::Expr::Attribute(attribute) => {
self.infer_attribute_expression(attribute);
Type::Unknown
}
// TODO: this may be a place we need to revisit with special forms.
ast::Expr::Subscript(subscript) => {
self.infer_subscript_expression(subscript);
Type::Unknown
}
ast::Expr::Starred(starred) => {
self.infer_starred_expression(starred);
Type::Unknown
}
ast::Expr::List(list) => {
self.infer_list_expression(list);
Type::Unknown
}
ast::Expr::Tuple(tuple) => {
self.infer_tuple_expression(tuple);
Type::Unknown
}
ast::Expr::Slice(slice) => {
self.infer_slice_expression(slice);
Type::Unknown
}

ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"),
};

let expr_id = expression.scoped_ast_id(self.db, self.scope);
self.types.expressions.insert(expr_id, ty);

ty
}
}

fn format_import_from_module(level: u32, module: Option<&str>) -> String {
format!(
"{}{}",
Expand Down Expand Up @@ -2593,6 +2788,27 @@ mod tests {
Ok(())
}

#[test]
fn function_return_type() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_file("src/a.py", "def example() -> int: return 42")?;

let mod_file = system_path_to_file(&db, "src/a.py").unwrap();
let ty = global_symbol_ty_by_name(&db, mod_file, "example");
let Type::Function(function) = ty else {
panic!("example is not a function");
};

let returns = function
.returns(&db)
.expect("There is a return type on the function");

assert_eq!(returns.display(&db).to_string(), "int");

Ok(())
}

#[test]
fn resolve_union() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down
Loading