Skip to content

Commit

Permalink
only break long lines in unpacking (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Mar 23, 2024
1 parent 8f65c9b commit 456be01
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,24 @@ def unpack_sequence_meta(x: Sequence | CollectionProxy, l: int, /) -> list:
return list(_collectify(y) for y in x)


# TODO Review using multi-line unpacks more cleverly
def _make_parts_into_line_or_lines(parts: list[str], out: list[str] | None = None) -> list[str]:
if out is None:
lines = []
else:
lines = out
line_parts = []
pos = 0
for p in parts:
if pos and pos + len(p) > 80:
lines.append("".join(line_parts) + "\\")
line_parts = []
line_parts.append(p)
pos += len(p)

lines.append("".join(line_parts))
return lines


# TODO Possibly put the length in the code to show the requirement
def unpack_sequence_printer(
bsym: BoundSymbol, out_printables: Any, arg_printables: Sequence[Printable], kwarg_printables: dict[str, Printable]
Expand All @@ -754,12 +771,10 @@ def unpack_sequence_printer(
if len(bsym.output) == 0:
return f"# {call_str} (empty sequence)"

lines = []
for out in out_printables:
line = f"{codeutils.prettyprint(out, literals_as_underscores=True)}, \\"
lines.append(line)
parts = [f"{codeutils.prettyprint(out, literals_as_underscores=True)}, " for out in out_printables]
parts.append(f"= {call_str}")

lines.append(f"= {call_str}")
lines = _make_parts_into_line_or_lines(parts)
return lines


Expand Down Expand Up @@ -812,12 +827,10 @@ def _unpack_tuple_printer(
if len(bsym.output) == 0:
return f"# {call_str} (empty tuple)"

lines = []
for out in out_printables:
line = f"{codeutils.prettyprint(out, literals_as_underscores=True)}, \\"
lines.append(line)
parts = [f"{codeutils.prettyprint(out, literals_as_underscores=True)}, " for out in out_printables]
parts.append(f"= {call_str}")

lines.append(f"= {call_str}")
lines = _make_parts_into_line_or_lines(parts)
return lines


Expand Down Expand Up @@ -865,12 +878,10 @@ def _unpack_list_printer(
if len(bsym.output) == 0:
return f"# {call_str} (empty list)"

lines = []
for out in out_printables:
line = f"{codeutils.prettyprint(out, literals_as_underscores=True)}, \\"
lines.append(line)
parts = [f"{codeutils.prettyprint(out, literals_as_underscores=True)}, " for out in out_printables]
parts.append(f"= {call_str}")

lines.append(f"= {call_str}")
lines = _make_parts_into_line_or_lines(parts)
return lines


Expand Down

0 comments on commit 456be01

Please sign in to comment.