Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generating invalid trailing commas in import statements #869

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions libcst/_nodes/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,18 +1172,24 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "ImportAlias":
comma=visit_sentinel(self, "comma", self.comma, visitor),
)

def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
def _codegen_impl(
self,
state: CodegenState,
default_comma: bool = False,
comma_is_valid: bool = True,
) -> None:
with state.record_syntactic_position(self):
self.name._codegen(state)
asname = self.asname
if asname is not None:
asname._codegen(state)

comma = self.comma
if comma is MaybeSentinel.DEFAULT and default_comma:
state.add_token(", ")
elif isinstance(comma, Comma):
comma._codegen(state)
if comma_is_valid:
comma = self.comma
if default_comma and comma is MaybeSentinel.DEFAULT:
state.add_token(", ")
elif isinstance(comma, Comma):
comma._codegen(state)

def _name(self, node: CSTNode) -> str:
# Unrolled version of get_full_name_for_node to avoid circular imports.
Expand Down Expand Up @@ -1400,9 +1406,17 @@ def _codegen_impl(
if lpar is not None:
lpar._codegen(state)
if isinstance(names, Sequence):
lastname = len(names) - 1
has_parens = self.rpar is not None
last_i = len(names) - 1
for i, name in enumerate(names):
name._codegen(state, default_comma=(i != lastname))
is_last = i == last_i
name._codegen(
state,
# Unless we're wrappend in parens we can't output a trailing
# comma because that would be invalid code
comma_is_valid=has_parens or not is_last,
default_comma=not is_last,
)
if isinstance(names, ImportStar):
names._codegen(state)
rpar = self.rpar
Expand Down
18 changes: 16 additions & 2 deletions libcst/_nodes/tests/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ class ImportFromCreateTest(CSTNodeTest):
),
"code": "from foo import bar, baz",
},
# Trailing comma
# Trailing comma is stripped if no parens (to avoid generating invalid code)
{
"node": cst.ImportFrom(
module=cst.Name("foo"),
Expand All @@ -412,9 +412,23 @@ class ImportFromCreateTest(CSTNodeTest):
cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()),
),
),
"code": "from foo import bar,baz,",
"code": "from foo import bar,baz",
"expected_position": CodeRange((1, 0), (1, 23)),
},
# Trailing comma is preserved if there are parens
{
"node": cst.ImportFrom(
module=cst.Name("foo"),
names=(
cst.ImportAlias(cst.Name("bar"), comma=cst.Comma()),
cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()),
),
lpar=cst.LeftParen(),
rpar=cst.RightParen(),
),
"code": "from foo import (bar,baz,)",
"expected_position": CodeRange((1, 0), (1, 26)),
},
# Star import statement
{
"node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()),
Expand Down