diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index aceb30d10..4e91340a6 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -146,7 +146,9 @@ class JITSharpEdgeError(RuntimeError): def _general_jit_sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS: sharp_edges: SHARP_EDGES_OPTIONS = get_jit_ctx().sharp_edges - s: str = f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!" + s: str = ( + f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!" + ) if sharp_edges is SHARP_EDGES_OPTIONS.ERROR: return do_raise(JITSharpEdgeError(s)) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 23249058e..dd0955aa1 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3156,18 +3156,15 @@ def amin(a, /, dim=None, keepdim: bool = False): # NOTE: Using name `torch_max` to avoid conflict with Python's `max` @overload -def torch_max(a: TensorLike, /) -> TensorLike: - ... +def torch_max(a: TensorLike, /) -> TensorLike: ... @overload -def torch_max(a: TensorLike, /, dim: NumberLike, keepdim: bool = False) -> tuple[TensorLike, TensorLike]: - ... +def torch_max(a: TensorLike, /, dim: NumberLike, keepdim: bool = False) -> tuple[TensorLike, TensorLike]: ... @overload -def torch_max(a: TensorLike, b: TensorLike, /) -> TensorLike: - ... +def torch_max(a: TensorLike, b: TensorLike, /) -> TensorLike: ... @torchsymbol(torch.max, is_method=True, method_name="max", id="torch.max")