Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyrefly/lib/lsp/wasm/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,12 @@ impl Transaction<'_> {
&mut result,
in_string_literal,
);
self.add_dict_value_literal_completions(
handle,
mod_module.as_ref(),
position,
&mut result,
);
let dict_key_claimed = self.add_dict_key_completions(
handle,
mod_module.as_ref(),
Expand Down
325 changes: 308 additions & 17 deletions pyrefly/lib/state/lsp/dict_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,61 @@ impl DictKeyLiteralContext {
}

impl<'a> Transaction<'a> {
fn named_target_type(&self, handle: &Handle, expr: &Expr) -> Option<Type> {
let Expr::Name(name) = expr else {
return None;
};
let short_id = ShortIdentifier::expr_name(name);
let bindings = self.get_bindings(handle)?;
let bound_key = Key::BoundName(short_id);
if bindings.is_valid_key(&bound_key) {
return self.get_type(handle, &bound_key);
}
let def_key = Key::Definition(short_id);
if bindings.is_valid_key(&def_key) {
self.get_type(handle, &def_key)
} else {
None
}
}

fn dict_literal_expected_type(
&self,
handle: &Handle,
module: &ModModule,
dict: &ExprDict,
) -> Option<Type> {
for node in Ast::locate_node(module, dict.range().start()) {
match node {
AnyNodeRef::StmtAnnAssign(assign)
if assign
.value
.as_ref()
.is_some_and(|value| value.range() == dict.range()) =>
{
return self.named_target_type(handle, assign.target.as_ref());
}
AnyNodeRef::StmtAssign(assign)
if assign.value.range() == dict.range() && assign.targets.len() == 1 =>
{
return self.named_target_type(handle, &assign.targets[0]);
}
_ => {}
}
}
None
}

fn dict_literal_contextual_type(
&self,
handle: &Handle,
module: &ModModule,
dict: &ExprDict,
) -> Option<Type> {
self.dict_literal_expected_type(handle, module, dict)
.or_else(|| self.get_type_trace(handle, dict.range()))
}

fn type_contains_typed_dict(ty: &Type) -> bool {
match ty {
Type::TypedDict(_) | Type::PartialTypedDict(_) => true,
Expand All @@ -83,6 +138,130 @@ impl<'a> Transaction<'a> {
}
}

fn typed_dict_members(base_type: Type) -> Vec<Type> {
let mut members = Vec::new();
let mut stack = vec![base_type];
while let Some(ty) = stack.pop() {
match ty {
Type::TypedDict(_) | Type::PartialTypedDict(_) => members.push(ty),
Type::Union(box Union {
members: union_members,
..
}) => {
stack.extend(union_members.into_iter());
}
_ => {}
}
}
members
}

fn typed_dict_field_type_in_member<'b>(
solver: &crate::alt::answers_solver::AnswersSolver<
crate::state::lsp::TransactionHandle<'b>,
>,
member: &Type,
key: &str,
) -> Option<Type> {
let typed_dict = match member {
Type::TypedDict(td) | Type::PartialTypedDict(td) => td,
_ => return None,
};
solver
.type_order()
.typed_dict_fields(typed_dict)
.iter()
.find_map(|(name, field)| (name == key).then(|| field.ty.clone()))
}
Comment on lines +159 to +175
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typed_dict_field_type_in_member linearly scans typed_dict_fields(...) for every (member, item) pair during narrowing. Since completions run on essentially every keystroke, this can become noticeably expensive for large TypedDicts or large unions. Consider building a per-member field-name -> type map once (or otherwise avoiding repeated linear scans) inside the narrowing / key-collection paths.

Copilot uses AI. Check for mistakes.

fn narrowed_typed_dict_members_for_dict_literal(
&self,
handle: &Handle,
module: &ModModule,
dict: &ExprDict,
skip_key_range: Option<TextRange>,
skip_value_range: Option<TextRange>,
) -> Option<Vec<Type>> {
let base_type = self.dict_literal_contextual_type(handle, module, dict)?;
self.ad_hoc_solve(handle, "dict_literal_typed_dict_members", |solver| {
let members = Self::typed_dict_members(base_type);
if members.is_empty() {
return Vec::new();
}
let narrowed = members
.iter()
.filter(|member| {
dict.items.iter().all(|item| {
let Some(key_expr) = item.key.as_ref() else {
return true;
};
let value_expr = &item.value;
let Expr::StringLiteral(key_lit) = key_expr else {
return true;
};
if skip_key_range == Some(key_lit.range())
|| skip_value_range == Some(value_expr.range())
{
return true;
}
let Some(field_ty) = Self::typed_dict_field_type_in_member(
&solver,
member,
key_lit.value.to_str(),
) else {
return false;
};
let Some(value_ty) = self.get_type_trace(handle, value_expr.range()) else {
return true;
};
solver.is_subset_eq(&value_ty, &field_ty)
})
})
.cloned()
.collect::<Vec<_>>();
if narrowed.is_empty() {
members
} else {
narrowed
}
})
}

fn typed_dict_field_type_from_members(
&self,
handle: &Handle,
members: Vec<Type>,
key: &str,
) -> Option<Type> {
self.ad_hoc_solve(handle, "typed_dict_field_type", |solver| {
let field_types = members
.iter()
.filter_map(|member| Self::typed_dict_field_type_in_member(&solver, member, key))
.collect::<Vec<_>>();
match field_types.len() {
0 => None,
1 => field_types.into_iter().next(),
_ => Some(solver.unions(field_types)),
}
})
.flatten()
}

fn dict_literal_present_keys(
dict: &ExprDict,
skip_key_range: Option<TextRange>,
) -> BTreeMap<String, ()> {
dict.items
.iter()
.filter_map(|item| {
let Expr::StringLiteral(lit) = item.key.as_ref()? else {
return None;
};
(skip_key_range != Some(lit.range())).then(|| (lit.value.to_string(), ()))
})
.collect()
}
Comment on lines +250 to +263
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict_literal_present_keys uses a BTreeMap<String, ()> purely for membership checks. A BTreeSet<String> (or similar set type) would better express intent and avoid storing unused unit values.

Copilot uses AI. Check for mistakes.

fn expr_has_typed_dict_type(&self, handle: &Handle, expr: &Expr) -> bool {
self.get_type_trace(handle, expr.range())
.map(|ty| Self::type_contains_typed_dict(&ty))
Expand Down Expand Up @@ -244,6 +423,58 @@ impl<'a> Transaction<'a> {
best.map(|(_, _, dict, literal)| (dict, literal))
}

fn dict_literal_value_string_literal_at(
module: &ModModule,
position: TextSize,
) -> Option<(ExprDict, ExprStringLiteral, ExprStringLiteral)> {
let nodes = Ast::locate_node(module, position);
let mut best: Option<(u8, TextSize, ExprDict, ExprStringLiteral, ExprStringLiteral)> = None;
for node in nodes {
let AnyNodeRef::ExprDict(dict) = node else {
continue;
};
let mut best_in_dict: Option<(u8, TextSize, ExprStringLiteral, ExprStringLiteral)> =
None;
for item in &dict.items {
let Some(Expr::StringLiteral(key_lit)) = item.key.as_ref() else {
continue;
};
let Expr::StringLiteral(value_lit) = &item.value else {
continue;
};
let (priority, dist) = Self::string_literal_priority(position, value_lit.range());
let should_update = match &best_in_dict {
Some((best_prio, best_dist, _, _)) => {
priority < *best_prio || (priority == *best_prio && dist < *best_dist)
}
None => true,
};
if should_update {
best_in_dict = Some((priority, dist, key_lit.clone(), value_lit.clone()));
if priority == 0 && dist == TextSize::from(0) {
break;
}
}
}
let Some((priority, dist, key_lit, value_lit)) = best_in_dict else {
continue;
};
let should_update = match &best {
Some((best_prio, best_dist, _, _, _)) => {
priority < *best_prio || (priority == *best_prio && dist < *best_dist)
}
None => true,
};
if should_update {
best = Some((priority, dist, dict.clone(), key_lit, value_lit));
if priority == 0 && dist == TextSize::from(0) {
break;
}
}
}
best.map(|(_, _, dict, key_lit, value_lit)| (dict, key_lit, value_lit))
}

fn expression_facets(expr: &Expr) -> Option<(Identifier, Vec<FacetKind>)> {
let mut facets = Vec::new();
let mut current = expr;
Expand Down Expand Up @@ -279,25 +510,52 @@ impl<'a> Transaction<'a> {
) -> Option<BTreeMap<String, Type>> {
self.ad_hoc_solve(handle, "typed_dict_keys", |solver| {
let mut map = BTreeMap::new();
let mut stack = vec![base_type];
while let Some(ty) = stack.pop() {
match ty {
Type::TypedDict(td) | Type::PartialTypedDict(td) => {
for (name, field) in solver.type_order().typed_dict_fields(&td) {
map.entry(name.to_string())
.or_insert_with(|| field.ty.clone());
}
}
Type::Union(box Union { members, .. }) => {
stack.extend(members.into_iter());
}
_ => {}
for member in Self::typed_dict_members(base_type) {
let typed_dict = match member {
Type::TypedDict(td) | Type::PartialTypedDict(td) => td,
_ => continue,
};
for (name, field) in solver.type_order().typed_dict_fields(&typed_dict) {
map.entry(name.to_string())
.or_insert_with(|| field.ty.clone());
}
}
map
})
}

pub(crate) fn add_dict_value_literal_completions(
&self,
handle: &Handle,
module: &ModModule,
position: TextSize,
completions: &mut Vec<RankedCompletion>,
) {
let Some((dict, key_lit, value_lit)) =
Self::dict_literal_value_string_literal_at(module, position)
else {
return;
};
if position < value_lit.range().start() || position > value_lit.range().end() {
return;
}
let Some(members) = self.narrowed_typed_dict_members_for_dict_literal(
handle,
module,
&dict,
Some(key_lit.range()),
Some(value_lit.range()),
) else {
return;
};
let Some(field_ty) =
self.typed_dict_field_type_from_members(handle, members, key_lit.value.to_str())
else {
return;
};
Self::add_literal_completions_from_type(&field_ty, completions, true);
}

/// Adds dict key completions for the given position. Returns `true` if this function
/// claimed the position (i.e., we are inside a dict/TypedDict key string literal), in
/// which case the caller should skip overload-based literal completions to avoid showing
Expand Down Expand Up @@ -366,12 +624,45 @@ impl<'a> Transaction<'a> {
}
}

// For key access we query the container expression; for literals we query the
// literal itself to pick up contextual TypedDict typing from assignments.
if let Some(base_type) = self.get_type_trace(handle, context.base_range())
&& let Some(typed_keys) = self.collect_typed_dict_keys(handle, base_type)
let dict_literal_members = match &context {
DictKeyLiteralContext::DictLiteral { dict, literal } => self
.narrowed_typed_dict_members_for_dict_literal(
handle,
module,
dict,
Some(literal.range()),
None,
),
DictKeyLiteralContext::KeyAccess { .. } => None,
};

// For key access we query the container expression; for literals we recover the
// contextual type because incomplete dict literals may infer as plain `dict[...]`.
if let Some(base_type) = match (&context, dict_literal_members.as_ref()) {
(DictKeyLiteralContext::DictLiteral { .. }, Some(members)) => self
.ad_hoc_solve(
handle,
"dict_literal_typed_dict_union",
|solver| match members.len() {
0 => None,
1 => members.first().cloned(),
_ => Some(solver.unions(members.clone())),
},
)
.flatten(),
_ => self.get_type_trace(handle, context.base_range()),
} && let Some(typed_keys) = self.collect_typed_dict_keys(handle, base_type)
{
let present_keys = match &context {
DictKeyLiteralContext::DictLiteral { dict, literal } => {
Self::dict_literal_present_keys(dict, Some(literal.range()))
}
DictKeyLiteralContext::KeyAccess { .. } => BTreeMap::new(),
};
for (key, ty) in typed_keys {
if present_keys.contains_key(&key) {
continue;
}
let entry = suggestions.entry(key).or_insert(None);
if entry.is_none() {
*entry = Some(ty);
Expand Down
Loading
Loading