Skip to content

Commit

Permalink
[pre-commit.ci] pre-commit suggestions (#76)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pre-commit-ci[bot] authored Mar 26, 2024
1 parent 6d6a880 commit 79af2e3
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 26 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand All @@ -24,22 +24,22 @@ repos:
- id: detect-private-key

- repo: https://github.com/asottile/pyupgrade
rev: v3.11.1
rev: v3.15.2
hooks:
- id: pyupgrade
args: ["--py310-plus"]
name: Upgrade code
exclude: "examples|thunder/tests/test_interpreter.py|thunder/tests/test_jit_general.py"

- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.2.6
hooks:
- id: codespell
additional_dependencies: [tomli]
#args: ["--write-changes"] # uncomment if you want to get automatic fixing

- repo: https://github.com/psf/black
rev: 23.9.1
rev: 24.3.0
hooks:
- id: black
name: Black code
Expand All @@ -61,7 +61,7 @@ repos:
- id: sphinx-lint

- repo: https://github.com/asottile/yesqa
rev: v1.4.0
rev: v1.5.0
hooks:
- id: yesqa

Expand Down
16 changes: 10 additions & 6 deletions thunder/benchmarks/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,11 @@ def parse_args() -> argparse.Namespace:
ResultFormatter(
model_name=args.model,
base_name="torch_fsdp",
suffix=str(sharding_strategy).lower() + "-bucketing_" + "block"
if auto_wrap_policy is not None
else "none",
suffix=(
str(sharding_strategy).lower() + "-bucketing_" + "block"
if auto_wrap_policy is not None
else "none"
),
dtype=args.dtype,
world_size=world_size,
total_callable_construction_time=total_cct,
Expand Down Expand Up @@ -352,9 +354,11 @@ def parse_args() -> argparse.Namespace:
ResultFormatter(
model_name=args.model,
base_name="torch_compile_fsdp",
suffix=str(sharding_strategy).lower() + "-bucketing_" + "block"
if auto_wrap_policy is not None
else "none",
suffix=(
str(sharding_strategy).lower() + "-bucketing_" + "block"
if auto_wrap_policy is not None
else "none"
),
dtype=args.dtype,
world_size=world_size,
total_callable_construction_time=total_cct,
Expand Down
8 changes: 6 additions & 2 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@ class ThunderSharpEdgeError(RuntimeError):
def _sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS:
sharp_edges: SHARP_EDGES_OPTIONS = get_minimal_ctx().sharp_edges

s: str = f"{desc} is a sharp edge that cannot be translated to a thunder program unless using interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON."
s: str = (
f"{desc} is a sharp edge that cannot be translated to a thunder program unless using interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON."
)

if sharp_edges is SHARP_EDGES_OPTIONS.ERROR:
return do_raise(ThunderSharpEdgeError(s))
Expand Down Expand Up @@ -469,7 +471,9 @@ class JITSharpEdgeError(RuntimeError):
def _general_jit_sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS:
sharp_edges: SHARP_EDGES_OPTIONS = get_minimal_ctx().sharp_edges

s: str = f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!"
s: str = (
f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!"
)

if sharp_edges is SHARP_EDGES_OPTIONS.ERROR:
return do_raise(JITSharpEdgeError(s))
Expand Down
12 changes: 4 additions & 8 deletions thunder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,16 +743,13 @@ class FrozenDict(_UserDictT[T, T1], Mapping[T, T1]):
"""

@overload
def __init__(self, data: Mapping[T, T1]) -> None:
...
def __init__(self, data: Mapping[T, T1]) -> None: ...

@overload
def __init__(self, data: Iterable[T, T1]) -> None:
...
def __init__(self, data: Iterable[T, T1]) -> None: ...

@overload
def __init__(self, *args: Any, **kwargs: Any) -> None:
...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -834,8 +831,7 @@ def _safe_zip_gen(*args):


@overload
def safe_zip(x: Iterable[T], y: Iterable[T1], /) -> Iterable[tuple[T, T1]]:
...
def safe_zip(x: Iterable[T], y: Iterable[T1], /) -> Iterable[tuple[T, T1]]: ...


def safe_zip(*args):
Expand Down
10 changes: 5 additions & 5 deletions thunder/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def prefer_comm_over_other_over_wait_over_allgather(eligible_nodes: list[Node])
# nodes over "wait_prim_impl", pick "all_gather_prim_impl" last.
def key(node: Node) -> int:
match node.bsym.sym.id:
case (wait_prim_impl.id | unpack_for_fsdp_prim_impl.id):
case wait_prim_impl.id | unpack_for_fsdp_prim_impl.id:
return len(order_in_trace)
case (reduce_scatter_prim_impl.id | all_reduce_prim_impl.id):
case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id:
# Prefer larger communication ops over smaller ones
return -node.bsym.args[0].numel
case (all_gather_prim_impl.id):
case all_gather_prim_impl.id:
return len(order_in_trace) + order_in_trace[node.bsym]
case _:
# Prefer nodes that are earlier in the trace
Expand Down Expand Up @@ -141,9 +141,9 @@ def prefer_comm_over_other_over_wait(eligible_nodes: list[Node]) -> int:
# nodes over "wait_prim_impl"
def key(node: Node) -> int:
match node.bsym.sym.id:
case (wait_prim_impl.id):
case wait_prim_impl.id:
return len(order_in_trace)
case (reduce_scatter_prim_impl.id | all_reduce_prim_impl.id | all_gather_prim_impl.id):
case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id | all_gather_prim_impl.id:
# Prefer larger communication ops over smaller ones
return -node.bsym.args[0].numel
case _:
Expand Down
1 change: 1 addition & 0 deletions thunder/tests/litgpt_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py"""

import torch
import torch.nn as nn

Expand Down

0 comments on commit 79af2e3

Please sign in to comment.