From da65a5ebd49247926a8ea40f82cb520ce78de83a Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sun, 5 Apr 2026 23:59:09 +0900 Subject: [PATCH] fix --- pyrefly/lib/lsp/wasm/completion.rs | 6 + pyrefly/lib/state/lsp/dict_completions.rs | 325 ++++++++++++++++++++-- pyrefly/lib/test/lsp/completion.rs | 62 +++++ 3 files changed, 376 insertions(+), 17 deletions(-) diff --git a/pyrefly/lib/lsp/wasm/completion.rs b/pyrefly/lib/lsp/wasm/completion.rs index e746c730f4..1429b20a18 100644 --- a/pyrefly/lib/lsp/wasm/completion.rs +++ b/pyrefly/lib/lsp/wasm/completion.rs @@ -1105,6 +1105,12 @@ impl Transaction<'_> { &mut result, in_string_literal, ); + self.add_dict_value_literal_completions( + handle, + mod_module.as_ref(), + position, + &mut result, + ); let dict_key_claimed = self.add_dict_key_completions( handle, mod_module.as_ref(), diff --git a/pyrefly/lib/state/lsp/dict_completions.rs b/pyrefly/lib/state/lsp/dict_completions.rs index 8857417758..ff47e9adc9 100644 --- a/pyrefly/lib/state/lsp/dict_completions.rs +++ b/pyrefly/lib/state/lsp/dict_completions.rs @@ -73,6 +73,61 @@ impl DictKeyLiteralContext { } impl<'a> Transaction<'a> { + fn named_target_type(&self, handle: &Handle, expr: &Expr) -> Option { + let Expr::Name(name) = expr else { + return None; + }; + let short_id = ShortIdentifier::expr_name(name); + let bindings = self.get_bindings(handle)?; + let bound_key = Key::BoundName(short_id); + if bindings.is_valid_key(&bound_key) { + return self.get_type(handle, &bound_key); + } + let def_key = Key::Definition(short_id); + if bindings.is_valid_key(&def_key) { + self.get_type(handle, &def_key) + } else { + None + } + } + + fn dict_literal_expected_type( + &self, + handle: &Handle, + module: &ModModule, + dict: &ExprDict, + ) -> Option { + for node in Ast::locate_node(module, dict.range().start()) { + match node { + AnyNodeRef::StmtAnnAssign(assign) + if assign + .value + .as_ref() + .is_some_and(|value| value.range() == dict.range()) => + { + return self.named_target_type(handle, assign.target.as_ref()); + } + AnyNodeRef::StmtAssign(assign) + if assign.value.range() == dict.range() && assign.targets.len() == 1 => + { + return self.named_target_type(handle, &assign.targets[0]); + } + _ => {} + } + } + None + } + + fn dict_literal_contextual_type( + &self, + handle: &Handle, + module: &ModModule, + dict: &ExprDict, + ) -> Option { + self.dict_literal_expected_type(handle, module, dict) + .or_else(|| self.get_type_trace(handle, dict.range())) + } + fn type_contains_typed_dict(ty: &Type) -> bool { match ty { Type::TypedDict(_) | Type::PartialTypedDict(_) => true, @@ -83,6 +138,130 @@ impl<'a> Transaction<'a> { } } + fn typed_dict_members(base_type: Type) -> Vec { + let mut members = Vec::new(); + let mut stack = vec![base_type]; + while let Some(ty) = stack.pop() { + match ty { + Type::TypedDict(_) | Type::PartialTypedDict(_) => members.push(ty), + Type::Union(box Union { + members: union_members, + .. + }) => { + stack.extend(union_members.into_iter()); + } + _ => {} + } + } + members + } + + fn typed_dict_field_type_in_member<'b>( + solver: &crate::alt::answers_solver::AnswersSolver< + crate::state::lsp::TransactionHandle<'b>, + >, + member: &Type, + key: &str, + ) -> Option { + let typed_dict = match member { + Type::TypedDict(td) | Type::PartialTypedDict(td) => td, + _ => return None, + }; + solver + .type_order() + .typed_dict_fields(typed_dict) + .iter() + .find_map(|(name, field)| (name == key).then(|| field.ty.clone())) + } + + fn narrowed_typed_dict_members_for_dict_literal( + &self, + handle: &Handle, + module: &ModModule, + dict: &ExprDict, + skip_key_range: Option, + skip_value_range: Option, + ) -> Option> { + let base_type = self.dict_literal_contextual_type(handle, module, dict)?; + self.ad_hoc_solve(handle, "dict_literal_typed_dict_members", |solver| { + let members = Self::typed_dict_members(base_type); + if members.is_empty() { + return Vec::new(); + } + let narrowed = members + .iter() + .filter(|member| { + dict.items.iter().all(|item| { + let Some(key_expr) = item.key.as_ref() else { + return true; + }; + let value_expr = &item.value; + let Expr::StringLiteral(key_lit) = key_expr else { + return true; + }; + if skip_key_range == Some(key_lit.range()) + || skip_value_range == Some(value_expr.range()) + { + return true; + } + let Some(field_ty) = Self::typed_dict_field_type_in_member( + &solver, + member, + key_lit.value.to_str(), + ) else { + return false; + }; + let Some(value_ty) = self.get_type_trace(handle, value_expr.range()) else { + return true; + }; + solver.is_subset_eq(&value_ty, &field_ty) + }) + }) + .cloned() + .collect::>(); + if narrowed.is_empty() { + members + } else { + narrowed + } + }) + } + + fn typed_dict_field_type_from_members( + &self, + handle: &Handle, + members: Vec, + key: &str, + ) -> Option { + self.ad_hoc_solve(handle, "typed_dict_field_type", |solver| { + let field_types = members + .iter() + .filter_map(|member| Self::typed_dict_field_type_in_member(&solver, member, key)) + .collect::>(); + match field_types.len() { + 0 => None, + 1 => field_types.into_iter().next(), + _ => Some(solver.unions(field_types)), + } + }) + .flatten() + } + + fn dict_literal_present_keys( + dict: &ExprDict, + skip_key_range: Option, + ) -> BTreeMap { + dict.items + .iter() + .filter_map(|item| { + let Expr::StringLiteral(lit) = item.key.as_ref()? else { + return None; + }; + (skip_key_range != Some(lit.range())).then(|| (lit.value.to_string(), ())) + }) + .collect() + } + fn expr_has_typed_dict_type(&self, handle: &Handle, expr: &Expr) -> bool { self.get_type_trace(handle, expr.range()) .map(|ty| Self::type_contains_typed_dict(&ty)) @@ -244,6 +423,58 @@ impl<'a> Transaction<'a> { best.map(|(_, _, dict, literal)| (dict, literal)) } + fn dict_literal_value_string_literal_at( + module: &ModModule, + position: TextSize, + ) -> Option<(ExprDict, ExprStringLiteral, ExprStringLiteral)> { + let nodes = Ast::locate_node(module, position); + let mut best: Option<(u8, TextSize, ExprDict, ExprStringLiteral, ExprStringLiteral)> = None; + for node in nodes { + let AnyNodeRef::ExprDict(dict) = node else { + continue; + }; + let mut best_in_dict: Option<(u8, TextSize, ExprStringLiteral, ExprStringLiteral)> = + None; + for item in &dict.items { + let Some(Expr::StringLiteral(key_lit)) = item.key.as_ref() else { + continue; + }; + let Expr::StringLiteral(value_lit) = &item.value else { + continue; + }; + let (priority, dist) = Self::string_literal_priority(position, value_lit.range()); + let should_update = match &best_in_dict { + Some((best_prio, best_dist, _, _)) => { + priority < *best_prio || (priority == *best_prio && dist < *best_dist) + } + None => true, + }; + if should_update { + best_in_dict = Some((priority, dist, key_lit.clone(), value_lit.clone())); + if priority == 0 && dist == TextSize::from(0) { + break; + } + } + } + let Some((priority, dist, key_lit, value_lit)) = best_in_dict else { + continue; + }; + let should_update = match &best { + Some((best_prio, best_dist, _, _, _)) => { + priority < *best_prio || (priority == *best_prio && dist < *best_dist) + } + None => true, + }; + if should_update { + best = Some((priority, dist, dict.clone(), key_lit, value_lit)); + if priority == 0 && dist == TextSize::from(0) { + break; + } + } + } + best.map(|(_, _, dict, key_lit, value_lit)| (dict, key_lit, value_lit)) + } + fn expression_facets(expr: &Expr) -> Option<(Identifier, Vec)> { let mut facets = Vec::new(); let mut current = expr; @@ -279,25 +510,52 @@ impl<'a> Transaction<'a> { ) -> Option> { self.ad_hoc_solve(handle, "typed_dict_keys", |solver| { let mut map = BTreeMap::new(); - let mut stack = vec![base_type]; - while let Some(ty) = stack.pop() { - match ty { - Type::TypedDict(td) | Type::PartialTypedDict(td) => { - for (name, field) in solver.type_order().typed_dict_fields(&td) { - map.entry(name.to_string()) - .or_insert_with(|| field.ty.clone()); - } - } - Type::Union(box Union { members, .. }) => { - stack.extend(members.into_iter()); - } - _ => {} + for member in Self::typed_dict_members(base_type) { + let typed_dict = match member { + Type::TypedDict(td) | Type::PartialTypedDict(td) => td, + _ => continue, + }; + for (name, field) in solver.type_order().typed_dict_fields(&typed_dict) { + map.entry(name.to_string()) + .or_insert_with(|| field.ty.clone()); } } map }) } + pub(crate) fn add_dict_value_literal_completions( + &self, + handle: &Handle, + module: &ModModule, + position: TextSize, + completions: &mut Vec, + ) { + let Some((dict, key_lit, value_lit)) = + Self::dict_literal_value_string_literal_at(module, position) + else { + return; + }; + if position < value_lit.range().start() || position > value_lit.range().end() { + return; + } + let Some(members) = self.narrowed_typed_dict_members_for_dict_literal( + handle, + module, + &dict, + Some(key_lit.range()), + Some(value_lit.range()), + ) else { + return; + }; + let Some(field_ty) = + self.typed_dict_field_type_from_members(handle, members, key_lit.value.to_str()) + else { + return; + }; + Self::add_literal_completions_from_type(&field_ty, completions, true); + } + /// Adds dict key completions for the given position. Returns `true` if this function /// claimed the position (i.e., we are inside a dict/TypedDict key string literal), in /// which case the caller should skip overload-based literal completions to avoid showing @@ -366,12 +624,45 @@ impl<'a> Transaction<'a> { } } - // For key access we query the container expression; for literals we query the - // literal itself to pick up contextual TypedDict typing from assignments. - if let Some(base_type) = self.get_type_trace(handle, context.base_range()) - && let Some(typed_keys) = self.collect_typed_dict_keys(handle, base_type) + let dict_literal_members = match &context { + DictKeyLiteralContext::DictLiteral { dict, literal } => self + .narrowed_typed_dict_members_for_dict_literal( + handle, + module, + dict, + Some(literal.range()), + None, + ), + DictKeyLiteralContext::KeyAccess { .. } => None, + }; + + // For key access we query the container expression; for literals we recover the + // contextual type because incomplete dict literals may infer as plain `dict[...]`. + if let Some(base_type) = match (&context, dict_literal_members.as_ref()) { + (DictKeyLiteralContext::DictLiteral { .. }, Some(members)) => self + .ad_hoc_solve( + handle, + "dict_literal_typed_dict_union", + |solver| match members.len() { + 0 => None, + 1 => members.first().cloned(), + _ => Some(solver.unions(members.clone())), + }, + ) + .flatten(), + _ => self.get_type_trace(handle, context.base_range()), + } && let Some(typed_keys) = self.collect_typed_dict_keys(handle, base_type) { + let present_keys = match &context { + DictKeyLiteralContext::DictLiteral { dict, literal } => { + Self::dict_literal_present_keys(dict, Some(literal.range())) + } + DictKeyLiteralContext::KeyAccess { .. } => BTreeMap::new(), + }; for (key, ty) in typed_keys { + if present_keys.contains_key(&key) { + continue; + } let entry = suggestions.entry(key).or_insert(None); if entry.is_none() { *entry = Some(ty); diff --git a/pyrefly/lib/test/lsp/completion.rs b/pyrefly/lib/test/lsp/completion.rs index 3e2617ebe2..996d8851cc 100644 --- a/pyrefly/lib/test/lsp/completion.rs +++ b/pyrefly/lib/test/lsp/completion.rs @@ -350,6 +350,68 @@ cfg: Config = {"": 1} assert!(report.contains("- (Field) name: str")); } +#[test] +fn dict_value_completion_from_discriminated_typed_dict_union_literal() { + let code = r#" +from typing import Literal, TypedDict + +class Foo(TypedDict): + kind: Literal["foo"] + foo_value: int + +class Bar(TypedDict): + kind: Literal["bar"] + bar_value: str + +type FooBar = Foo | Bar + +item: FooBar = { + "kind": "|", +# ^ +} +"#; + let report = + get_batched_lsp_operations_report_allow_error(&[("main", code)], get_default_test_report()); + let report = strip_ansi(&report); + assert!( + report.contains("- (Value) 'bar': Literal['bar']"), + "{report}" + ); + assert!( + report.contains("- (Value) 'foo': Literal['foo']"), + "{report}" + ); +} + +#[test] +fn dict_key_completion_from_discriminated_typed_dict_union_literal() { + let code = r#" +from typing import Literal, TypedDict + +class Foo(TypedDict): + kind: Literal["foo"] + foo_value: int + +class Bar(TypedDict): + kind: Literal["bar"] + bar_value: str + +type FooBar = Foo | Bar + +item: FooBar = { + "kind": "foo", + "": 0, +# ^ +} +"#; + let (handles, state) = mk_multi_file_state(&[("main", code)], Require::Exports, false); + let handle = handles.get("main").unwrap(); + let position = extract_cursors_for_test(code)[0]; + let txn = state.transaction(); + let labels = dict_field_labels(&txn, handle, position); + assert_eq!(labels, vec!["foo_value".to_owned()]); +} + #[test] fn dot_complete_with_deprecated_method() { let code = r#"