Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 82 additions & 22 deletions sphinx/domains/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,52 +109,78 @@ def type_to_xref(text: str, env: BuildEnvironment = None) -> addnodes.pending_xr

def _parse_annotation(annotation: str, env: BuildEnvironment = None) -> List[Node]:
"""Parse type annotation."""
def unparse(node: ast.AST) -> List[Node]:

def is_literal_name(node: ast.AST) -> bool:
if isinstance(node, ast.Name):
return node.id == 'Literal'
if isinstance(node, ast.Attribute):
return node.attr == 'Literal'
return False

def literal_text(value: Any) -> str:
if isinstance(value, (str, bytes)):
return repr(value)
return str(value)

def literal_node(text: str) -> nodes.literal:
return nodes.literal(text, text)

def unparse(node: ast.AST, in_literal: bool = False) -> List[Node]:
if isinstance(node, ast.Attribute):
return [nodes.Text("%s.%s" % (unparse(node.value)[0], node.attr))]
base_nodes = unparse(node.value, in_literal=in_literal)
base = ''.join(child.astext() for child in base_nodes)
text = f"{base}.{node.attr}" if base else node.attr
if in_literal:
return [literal_node(text)]
return [nodes.Text(text)]
elif isinstance(node, ast.BinOp):
result: List[Node] = unparse(node.left)
result.extend(unparse(node.op))
result.extend(unparse(node.right))
result: List[Node] = unparse(node.left, in_literal=in_literal)
result.extend(unparse(node.op, in_literal=in_literal))
result.extend(unparse(node.right, in_literal=in_literal))
return result
elif isinstance(node, ast.BitOr):
return [nodes.Text(' '), addnodes.desc_sig_punctuation('', '|'), nodes.Text(' ')]
elif isinstance(node, ast.Constant): # type: ignore
elif isinstance(node, ast.Constant): # type: ignore[attr-defined]
if node.value is Ellipsis:
return [addnodes.desc_sig_punctuation('', "...")]
else:
return [nodes.Text(node.value)]
if in_literal:
return [literal_node(literal_text(node.value))]
return [nodes.Text(str(node.value))]
elif isinstance(node, ast.Expr):
return unparse(node.value)
return unparse(node.value, in_literal=in_literal)
elif isinstance(node, ast.Index):
return unparse(node.value)
return unparse(node.value, in_literal=in_literal)
elif isinstance(node, ast.List):
result = [addnodes.desc_sig_punctuation('', '[')]
if node.elts:
# check if there are elements in node.elts to only pop the
# last element of result if the for-loop was run at least
# once
for elem in node.elts:
result.extend(unparse(elem))
result.extend(unparse(elem, in_literal=in_literal))
result.append(addnodes.desc_sig_punctuation('', ', '))
result.pop()
result.append(addnodes.desc_sig_punctuation('', ']'))
return result
elif isinstance(node, ast.Module):
return sum((unparse(e) for e in node.body), [])
return sum((unparse(e, in_literal=in_literal) for e in node.body), [])
elif isinstance(node, ast.Name):
return [nodes.Text(node.id)]
text = node.id
if in_literal:
return [literal_node(text)]
return [nodes.Text(text)]
elif isinstance(node, ast.Subscript):
result = unparse(node.value)
literal_slice = is_literal_name(node.value)
result = unparse(node.value, in_literal=in_literal)
result.append(addnodes.desc_sig_punctuation('', '['))
result.extend(unparse(node.slice))
result.extend(unparse(node.slice, in_literal=in_literal or literal_slice))
result.append(addnodes.desc_sig_punctuation('', ']'))
return result
elif isinstance(node, ast.Tuple):
if node.elts:
result = []
result: List[Node] = []
for elem in node.elts:
result.extend(unparse(elem))
result.extend(unparse(elem, in_literal=in_literal))
result.append(addnodes.desc_sig_punctuation('', ', '))
result.pop()
else:
Expand All @@ -167,7 +193,17 @@ def unparse(node: ast.AST) -> List[Node]:
if isinstance(node, ast.Ellipsis):
return [addnodes.desc_sig_punctuation('', "...")]
elif isinstance(node, ast.NameConstant):
return [nodes.Text(node.value)]
if in_literal:
return [literal_node(literal_text(node.value))]
return [nodes.Text(str(node.value))]
elif isinstance(node, ast.Num): # type: ignore[attr-defined]
if in_literal:
return [literal_node(literal_text(node.n))]
return [nodes.Text(str(node.n))]
elif isinstance(node, ast.Str): # type: ignore[attr-defined]
if in_literal:
return [literal_node(literal_text(node.s))]
return [nodes.Text(node.s)]

raise SyntaxError # unsupported syntax

Expand Down Expand Up @@ -331,15 +367,39 @@ def make_xrefs(self, rolename: str, domain: str, target: str,
split_contnode = bool(contnode and contnode.astext() == target)

results = []
literal_depth = 0
pending_literal_open = False
for sub_target in filter(None, sub_targets):
token_contnode = contnode
if split_contnode:
contnode = nodes.Text(sub_target)
token_contnode = nodes.Text(sub_target)

if delims_re.match(sub_target):
results.append(contnode or innernode(sub_target, sub_target))
results.append(token_contnode or innernode(sub_target, sub_target))
if literal_depth > 0:
literal_depth += sub_target.count('[')
literal_depth -= sub_target.count(']')
if literal_depth < 0:
literal_depth = 0
elif pending_literal_open and '[' in sub_target:
literal_depth = sub_target.count('[') or 1
pending_literal_open = False
elif pending_literal_open:
pending_literal_open = False
continue

if literal_depth > 0:
results.append(token_contnode or nodes.Text(sub_target))
continue

normalized = sub_target.strip().rsplit('.', 1)[-1]
if normalized == 'Literal':
pending_literal_open = True
else:
results.append(self.make_xref(rolename, domain, sub_target,
innernode, contnode, env, inliner, location))
pending_literal_open = False

results.append(self.make_xref(rolename, domain, sub_target,
innernode, token_contnode, env, inliner, location))

return results

Expand Down
2 changes: 2 additions & 0 deletions tests/roots/test-domain-py-literal/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
exclude_patterns = ['_build']
nitpicky = True
22 changes: 22 additions & 0 deletions tests/roots/test-domain-py-literal/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Literal nitpicky annotations
=============================

.. py:module:: literal_examples

.. py:class:: SomeEnum

.. py:attribute:: VALUE

.. py:function:: f(a: Literal[True]) -> Literal["x"]

.. py:function:: g(a: Literal[True, 1, "x", None])

.. py:function:: h(a: Literal[SomeEnum.VALUE])

.. py:function:: j(a: Union[Literal[True], bool])

.. py:function:: k(a: Annotated[Literal["a"], int])

.. py:function:: df(a)

:type a: Literal["A", "B"]
15 changes: 15 additions & 0 deletions tests/test_domain_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,21 @@ def test_parse_annotation(app):
assert_node(doctree[0], pending_xref, refdomain="py", reftype="obj", reftarget="None")


@pytest.mark.sphinx('dummy', testroot='domain-py-literal')
def test_literal_annotations_without_xrefs(app, warning):
app.build()
assert warning.getvalue() == ''

doctree = app.env.get_doctree('index')

reftargets = {node['reftarget'] for node in doctree.traverse(pending_xref)}
assert {'Literal', 'Union', 'Annotated', 'bool', 'int'} <= reftargets
assert not reftargets.intersection({'True', '1', "'x'", 'None', 'SomeEnum.VALUE', "'a'", "'A'", "'B'"})

literal_texts = {node.astext() for node in doctree.traverse(nodes.literal)}
assert {'True', '1', "'x'", 'None', 'SomeEnum.VALUE', "'a'"} <= literal_texts


def test_pyfunction_signature(app):
text = ".. py:function:: hello(name: str) -> str"
doctree = restructuredtext.parse(app, text)
Expand Down