Skip to content

Commit

Permalink
Improved parentheses handling
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwannheden committed Oct 17, 2024
1 parent c1409c7 commit dc5791f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
23 changes: 16 additions & 7 deletions rewrite/rewrite/python/_parser_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ParserVisitor(ast.NodeVisitor):

_source: str
_cursor: int
_parentheses_stack: List[Tuple[Callable[[J, Space], j.Parentheses], int]]
_parentheses_stack: List[Tuple[Callable[[J, Space], j.Parentheses], int, ast.AST]]

@property
def _source_after_cursor(self) -> str:
Expand Down Expand Up @@ -189,7 +189,7 @@ def visit_ClassDef(self, node):
name = self.__convert_name(node.name)
save_cursor = self._cursor
interfaces_prefix = self.__whitespace()
if self._source[self._cursor] == '(' and node.bases:
if node.bases and self.__cursor_at('('):
self.__skip('(')
interfaces = JContainer(
interfaces_prefix,
Expand All @@ -198,7 +198,7 @@ def visit_ClassDef(self, node):
enumerate(node.bases)],
Markers.EMPTY
)
elif self._source[self._cursor] == '(':
elif self.__cursor_at('('):
self.__skip('(')
interfaces = JContainer(
interfaces_prefix,
Expand Down Expand Up @@ -1628,15 +1628,15 @@ def visit_Subscript(self, node):
def visit_Tuple(self, node):
prefix = self.__whitespace()

if self._source[self._cursor] == '(' and node.elts:
if self.__cursor_at('(') and node.elts:
save_cursor = self._cursor
elements = JContainer(
Space.EMPTY,
[self.__pad_list_element(self.__convert(e), last=i == len(node.elts) - 1) for i, e in enumerate(node.elts)],
Markers.EMPTY
)

if self._cursor < len(self._source) and self._source[self._cursor] == ')':
if self.__cursor_at(')'):
# we need to backtrack as the parentheses belonged to a nested element
elements = None
self._cursor = save_cursor
Expand All @@ -1646,7 +1646,7 @@ def visit_Tuple(self, node):
elements = None

if elements is None:
omit_parens = self._source[self._cursor] != '('
omit_parens = not self.__cursor_at('(')
if not omit_parens:
self._cursor += 1
elements = JContainer(
Expand Down Expand Up @@ -1755,7 +1755,7 @@ def __convert_internal(self, node, recursion) -> Optional[J]:
prefix,
Markers.EMPTY,
self.__pad_right(e.with_prefix(expr_prefix), r)
), self._cursor))
), self._cursor, node))
# handle nested parens
result = recursion(node)
else:
Expand All @@ -1773,9 +1773,15 @@ def __convert_internal(self, node, recursion) -> Optional[J]:
self._cursor += 1
result = self._parentheses_stack.pop()[0](result, suffix)
else:
if len(self._parentheses_stack) > 0:
while len(self._parentheses_stack) > 0 and self._parentheses_stack[-1][2] == node:
self._parentheses_stack.pop()
self._cursor = save_cursor_2
return result
else:
if not self.__cursor_at('(') and len(self._parentheses_stack) > 0 and self._parentheses_stack[-1][1] == self._cursor:
self._parentheses_stack.pop()
self._cursor -= 1
return self.visit(cast(ast.AST, node))
else:
return None
Expand Down Expand Up @@ -2144,3 +2150,6 @@ def __map_fstring(self, node: ast.JoinedStr, prefix: Space, tok: TokenInfo, toke
delimiter,
parts
), tok)

def __cursor_at(self, s: str):
return self._cursor < len(self._source) and (len(s) == 1 and self._source[self._cursor] == s or self._source.startswith(s, self._cursor))
26 changes: 26 additions & 0 deletions rewrite/tests/python/all/if_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,32 @@ def foo(b):
)


def test_if_with_tuple_1():
# language=python
rewrite_run(
python(
"""\
import sys
if (sys.version_info[0], sys.version_info[1]) < (3, 8):
pass
"""
)
)


def test_if_with_tuple_2():
# language=python
rewrite_run(
python(
"""\
import sys
if ((sys.version_info[0], sys.version_info[1]) < (3, 8)):
pass
"""
)
)


def test_else_multiple():
# language=python
rewrite_run(
Expand Down

0 comments on commit dc5791f

Please sign in to comment.