diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index c9b698a13ba04..75bb4468b1139 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -336,6 +336,8 @@ pub struct FunctionType<'db> { /// name of the function at definition pub name: ast::name::Name, + definition: Definition<'db>, + /// types of all decorators on this function decorators: Vec>, } @@ -344,6 +346,19 @@ impl<'db> FunctionType<'db> { pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool { self.decorators(db).contains(&decorator) } + + /// annotated return type for this function, if any + pub fn returns(&self, db: &'db dyn Db) -> Option> { + let definition = self.definition(db); + let DefinitionKind::Function(function_stmt_node) = definition.node(db) else { + panic!("Function type definition must have `DefinitionKind::Function`") + }; + + function_stmt_node + .returns + .as_ref() + .map(|returns| definition_expression_ty(db, definition, returns.as_ref())) + } } #[salsa::interned] diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index c0e2df24a5334..0ce3b7d253e28 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -77,6 +77,7 @@ fn infer_definition_types_cycle_recovery<'db>( _cycle: &salsa::Cycle, input: Definition<'db>, ) -> TypeInference<'db> { + tracing::trace!("infer_definition_types_cycle_recovery"); let mut inference = TypeInference::default(); inference.definitions.insert(input, Type::Unknown); inference @@ -420,9 +421,7 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_region_deferred(&mut self, definition: Definition<'db>) { match definition.node(self.db) { - DefinitionKind::Function(_function) => { - // TODO self.infer_function_deferred(function.node()); - } + DefinitionKind::Function(function) => self.infer_function_deferred(function.node()), DefinitionKind::Class(class) => self.infer_class_deferred(class.node()), DefinitionKind::AnnotatedAssignment(_annotated_assignment) => { // TODO self.infer_annotated_assignment_deferred(annotated_assignment.node()); @@ -460,7 +459,12 @@ impl<'db> TypeInferenceBuilder<'db> { let Some(type_params) = function.type_params.as_deref() else { panic!("function type params scope without type params"); }; - self.infer_optional_expression(function.returns.as_deref()); + + // TODO: this should also be applied to parameter annotations. + if !self.is_stub() { + self.infer_optional_expression(function.returns.as_deref()); + } + self.infer_type_parameters(type_params); self.infer_parameters(&function.parameters); } @@ -549,14 +553,23 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(default); } - // If there are type params, parameters and returns are evaluated in that scope. + // If there are type params, parameters and returns are evaluated in that scope, that is, in + // `infer_function_type_params`, rather than here. if type_params.is_none() { self.infer_parameters(parameters); - self.infer_optional_expression(returns.as_deref()); + + // TODO: this should also be applied to parameter annotations. + if !self.is_stub() { + self.infer_optional_annotation_expression(returns.as_deref()); + } } - let function_ty = - Type::Function(FunctionType::new(self.db, name.id.clone(), decorator_tys)); + let function_ty = Type::Function(FunctionType::new( + self.db, + name.id.clone(), + definition, + decorator_tys, + )); self.types.definitions.insert(definition, function_ty); } @@ -670,6 +683,13 @@ impl<'db> TypeInferenceBuilder<'db> { } } + fn infer_function_deferred(&mut self, function: &ast::StmtFunctionDef) { + if self.is_stub() { + self.types.has_deferred = true; + self.infer_optional_annotation_expression(function.returns.as_deref()); + } + } + fn infer_class_deferred(&mut self, class: &ast::StmtClassDef) { if self.is_stub() { self.types.has_deferred = true; @@ -1297,6 +1317,13 @@ impl<'db> TypeInferenceBuilder<'db> { expression.map(|expr| self.infer_expression(expr)) } + fn infer_optional_annotation_expression( + &mut self, + expr: Option<&ast::Expr>, + ) -> Option> { + expr.map(|expr| self.infer_annotation_expression(expr)) + } + fn infer_expression(&mut self, expression: &ast::Expr) -> Type<'db> { let ty = match expression { ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::None, @@ -2059,6 +2086,173 @@ impl<'db> TypeInferenceBuilder<'db> { } } +/// Annotation expressions. +impl<'db> TypeInferenceBuilder<'db> { + fn infer_annotation_expression(&mut self, expression: &ast::Expr) -> Type<'db> { + // https://typing.readthedocs.io/en/latest/spec/annotations.html#grammar-token-expression-grammar-annotation_expression + match expression { + // TODO: parse the expression and check whether it is a string annotation, since they + // can be annotation expressions distinct from type expressions. + // https://typing.readthedocs.io/en/latest/spec/annotations.html#string-annotations + ast::Expr::StringLiteral(_literal) => Type::Unknown, + + // Annotation expressions also get special handling for `*args` and `**kwargs`. + ast::Expr::Starred(starred) => self.infer_starred_expression(starred), + + // All other annotation expressions are (possibly) valid type expressions, so handle + // them there instead. + type_expr => self.infer_type_expression(type_expr), + } + } +} + +/// Type expressions +impl<'db> TypeInferenceBuilder<'db> { + fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> { + // https://typing.readthedocs.io/en/latest/spec/annotations.html#grammar-token-expression-grammar-type_expression + // TODO: this does not include any of the special forms, and is only a + // stub of the forms other than a standalone name in scope. + + let ty = match expression { + ast::Expr::Name(name) => { + debug_assert!( + name.ctx.is_load(), + "name in a type expression is always 'load' but got: '{:?}'", + name.ctx + ); + + self.infer_name_expression(name).instance() + } + + ast::Expr::NoneLiteral(_literal) => Type::None, + + // TODO: parse the expression and check whether it is a string annotation. + // https://typing.readthedocs.io/en/latest/spec/annotations.html#string-annotations + ast::Expr::StringLiteral(_literal) => Type::Unknown, + + // TODO: an Ellipsis literal *on its own* does not have any meaning in annotation + // expressions, but is meaningful in the context of a number of special forms. + ast::Expr::EllipsisLiteral(_literal) => Type::Unknown, + + // Other literals do not have meaningful values in the annotation expression context. + // However, we will we want to handle these differently when working with special forms, + // since (e.g.) `123` is not valid in an annotation expression but `Literal[123]` is. + ast::Expr::BytesLiteral(_literal) => Type::Unknown, + ast::Expr::NumberLiteral(_literal) => Type::Unknown, + ast::Expr::BooleanLiteral(_literal) => Type::Unknown, + + // Forms which are invalid in the context of annotation expressions: we infer their + // nested expressions as normal expressions, but the type of the top-level expression is + // always `Type::Unknown` in these cases. + ast::Expr::BoolOp(bool_op) => { + self.infer_boolean_expression(bool_op); + Type::Unknown + } + ast::Expr::Named(named) => { + self.infer_named_expression(named); + Type::Unknown + } + ast::Expr::BinOp(binary) => { + self.infer_binary_expression(binary); + Type::Unknown + } + ast::Expr::UnaryOp(unary) => { + self.infer_unary_expression(unary); + Type::Unknown + } + ast::Expr::Lambda(lambda_expression) => { + self.infer_lambda_expression(lambda_expression); + Type::Unknown + } + ast::Expr::If(if_expression) => { + self.infer_if_expression(if_expression); + Type::Unknown + } + ast::Expr::Dict(dict) => { + self.infer_dict_expression(dict); + Type::Unknown + } + ast::Expr::Set(set) => { + self.infer_set_expression(set); + Type::Unknown + } + ast::Expr::ListComp(listcomp) => { + self.infer_list_comprehension_expression(listcomp); + Type::Unknown + } + ast::Expr::SetComp(setcomp) => { + self.infer_set_comprehension_expression(setcomp); + Type::Unknown + } + ast::Expr::DictComp(dictcomp) => { + self.infer_dict_comprehension_expression(dictcomp); + Type::Unknown + } + ast::Expr::Generator(generator) => { + self.infer_generator_expression(generator); + Type::Unknown + } + ast::Expr::Await(await_expression) => { + self.infer_await_expression(await_expression); + Type::Unknown + } + ast::Expr::Yield(yield_expression) => { + self.infer_yield_expression(yield_expression); + Type::Unknown + } + ast::Expr::YieldFrom(yield_from) => { + self.infer_yield_from_expression(yield_from); + Type::Unknown + } + ast::Expr::Compare(compare) => { + self.infer_compare_expression(compare); + Type::Unknown + } + ast::Expr::Call(call_expr) => { + self.infer_call_expression(call_expr); + Type::Unknown + } + ast::Expr::FString(fstring) => { + self.infer_fstring_expression(fstring); + Type::Unknown + } + // + ast::Expr::Attribute(attribute) => { + self.infer_attribute_expression(attribute); + Type::Unknown + } + // TODO: this may be a place we need to revisit with special forms. + ast::Expr::Subscript(subscript) => { + self.infer_subscript_expression(subscript); + Type::Unknown + } + ast::Expr::Starred(starred) => { + self.infer_starred_expression(starred); + Type::Unknown + } + ast::Expr::List(list) => { + self.infer_list_expression(list); + Type::Unknown + } + ast::Expr::Tuple(tuple) => { + self.infer_tuple_expression(tuple); + Type::Unknown + } + ast::Expr::Slice(slice) => { + self.infer_slice_expression(slice); + Type::Unknown + } + + ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), + }; + + let expr_id = expression.scoped_ast_id(self.db, self.scope); + self.types.expressions.insert(expr_id, ty); + + ty + } +} + fn format_import_from_module(level: u32, module: Option<&str>) -> String { format!( "{}{}", @@ -2593,6 +2787,27 @@ mod tests { Ok(()) } + #[test] + fn function_return_type() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_file("src/a.py", "def example() -> int: return 42")?; + + let mod_file = system_path_to_file(&db, "src/a.py").unwrap(); + let ty = global_symbol_ty_by_name(&db, mod_file, "example"); + let Type::Function(function) = ty else { + panic!("example is not a function"); + }; + + let returns = function + .returns(&db) + .expect("There is a return type on the function"); + + assert_eq!(returns.display(&db).to_string(), "int"); + + Ok(()) + } + #[test] fn resolve_union() -> anyhow::Result<()> { let mut db = setup_db();