Skip to content

Commit 79af2e3

Browse files
[pre-commit.ci] pre-commit suggestions (#76)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6d6a880 commit 79af2e3

File tree

6 files changed

+31
-26
lines changed

6 files changed

+31
-26
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ci:
88

99
repos:
1010
- repo: https://github.com/pre-commit/pre-commit-hooks
11-
rev: v4.4.0
11+
rev: v4.5.0
1212
hooks:
1313
- id: end-of-file-fixer
1414
- id: trailing-whitespace
@@ -24,22 +24,22 @@ repos:
2424
- id: detect-private-key
2525

2626
- repo: https://github.com/asottile/pyupgrade
27-
rev: v3.11.1
27+
rev: v3.15.2
2828
hooks:
2929
- id: pyupgrade
3030
args: ["--py310-plus"]
3131
name: Upgrade code
3232
exclude: "examples|thunder/tests/test_interpreter.py|thunder/tests/test_jit_general.py"
3333

3434
- repo: https://github.com/codespell-project/codespell
35-
rev: v2.2.5
35+
rev: v2.2.6
3636
hooks:
3737
- id: codespell
3838
additional_dependencies: [tomli]
3939
#args: ["--write-changes"] # uncomment if you want to get automatic fixing
4040

4141
- repo: https://github.com/psf/black
42-
rev: 23.9.1
42+
rev: 24.3.0
4343
hooks:
4444
- id: black
4545
name: Black code
@@ -61,7 +61,7 @@ repos:
6161
- id: sphinx-lint
6262

6363
- repo: https://github.com/asottile/yesqa
64-
rev: v1.4.0
64+
rev: v1.5.0
6565
hooks:
6666
- id: yesqa
6767

thunder/benchmarks/distributed.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,11 @@ def parse_args() -> argparse.Namespace:
322322
ResultFormatter(
323323
model_name=args.model,
324324
base_name="torch_fsdp",
325-
suffix=str(sharding_strategy).lower() + "-bucketing_" + "block"
326-
if auto_wrap_policy is not None
327-
else "none",
325+
suffix=(
326+
str(sharding_strategy).lower() + "-bucketing_" + "block"
327+
if auto_wrap_policy is not None
328+
else "none"
329+
),
328330
dtype=args.dtype,
329331
world_size=world_size,
330332
total_callable_construction_time=total_cct,
@@ -352,9 +354,11 @@ def parse_args() -> argparse.Namespace:
352354
ResultFormatter(
353355
model_name=args.model,
354356
base_name="torch_compile_fsdp",
355-
suffix=str(sharding_strategy).lower() + "-bucketing_" + "block"
356-
if auto_wrap_policy is not None
357-
else "none",
357+
suffix=(
358+
str(sharding_strategy).lower() + "-bucketing_" + "block"
359+
if auto_wrap_policy is not None
360+
else "none"
361+
),
358362
dtype=args.dtype,
359363
world_size=world_size,
360364
total_callable_construction_time=total_cct,

thunder/core/jit_ext.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,9 @@ class ThunderSharpEdgeError(RuntimeError):
377377
def _sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS:
378378
sharp_edges: SHARP_EDGES_OPTIONS = get_minimal_ctx().sharp_edges
379379

380-
s: str = f"{desc} is a sharp edge that cannot be translated to a thunder program unless using interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON."
380+
s: str = (
381+
f"{desc} is a sharp edge that cannot be translated to a thunder program unless using interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON."
382+
)
381383

382384
if sharp_edges is SHARP_EDGES_OPTIONS.ERROR:
383385
return do_raise(ThunderSharpEdgeError(s))
@@ -469,7 +471,9 @@ class JITSharpEdgeError(RuntimeError):
469471
def _general_jit_sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS:
470472
sharp_edges: SHARP_EDGES_OPTIONS = get_minimal_ctx().sharp_edges
471473

472-
s: str = f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!"
474+
s: str = (
475+
f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!"
476+
)
473477

474478
if sharp_edges is SHARP_EDGES_OPTIONS.ERROR:
475479
return do_raise(JITSharpEdgeError(s))

thunder/core/utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -743,16 +743,13 @@ class FrozenDict(_UserDictT[T, T1], Mapping[T, T1]):
743743
"""
744744

745745
@overload
746-
def __init__(self, data: Mapping[T, T1]) -> None:
747-
...
746+
def __init__(self, data: Mapping[T, T1]) -> None: ...
748747

749748
@overload
750-
def __init__(self, data: Iterable[T, T1]) -> None:
751-
...
749+
def __init__(self, data: Iterable[T, T1]) -> None: ...
752750

753751
@overload
754-
def __init__(self, *args: Any, **kwargs: Any) -> None:
755-
...
752+
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
756753

757754
def __init__(self, *args, **kwargs) -> None:
758755
super().__init__(*args, **kwargs)
@@ -834,8 +831,7 @@ def _safe_zip_gen(*args):
834831

835832

836833
@overload
837-
def safe_zip(x: Iterable[T], y: Iterable[T1], /) -> Iterable[tuple[T, T1]]:
838-
...
834+
def safe_zip(x: Iterable[T], y: Iterable[T1], /) -> Iterable[tuple[T, T1]]: ...
839835

840836

841837
def safe_zip(*args):

thunder/distributed/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def prefer_comm_over_other_over_wait_over_allgather(eligible_nodes: list[Node])
8484
# nodes over "wait_prim_impl", pick "all_gather_prim_impl" last.
8585
def key(node: Node) -> int:
8686
match node.bsym.sym.id:
87-
case (wait_prim_impl.id | unpack_for_fsdp_prim_impl.id):
87+
case wait_prim_impl.id | unpack_for_fsdp_prim_impl.id:
8888
return len(order_in_trace)
89-
case (reduce_scatter_prim_impl.id | all_reduce_prim_impl.id):
89+
case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id:
9090
# Prefer larger communication ops over smaller ones
9191
return -node.bsym.args[0].numel
92-
case (all_gather_prim_impl.id):
92+
case all_gather_prim_impl.id:
9393
return len(order_in_trace) + order_in_trace[node.bsym]
9494
case _:
9595
# Prefer nodes that are earlier in the trace
@@ -141,9 +141,9 @@ def prefer_comm_over_other_over_wait(eligible_nodes: list[Node]) -> int:
141141
# nodes over "wait_prim_impl"
142142
def key(node: Node) -> int:
143143
match node.bsym.sym.id:
144-
case (wait_prim_impl.id):
144+
case wait_prim_impl.id:
145145
return len(order_in_trace)
146-
case (reduce_scatter_prim_impl.id | all_reduce_prim_impl.id | all_gather_prim_impl.id):
146+
case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id | all_gather_prim_impl.id:
147147
# Prefer larger communication ops over smaller ones
148148
return -node.bsym.args[0].numel
149149
case _:

thunder/tests/litgpt_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py"""
2+
23
import torch
34
import torch.nn as nn
45

0 commit comments

Comments
 (0)