From ef9bd17cf9bbb40e2987173e3d4b28007fd85305 Mon Sep 17 00:00:00 2001 From: Hood Chatham Date: Wed, 21 Aug 2024 13:28:59 +0200 Subject: [PATCH] Add Python formatting with Ruff --- .github/workflows/lint.yml | 3 +++ tools/cross/format.py | 25 ++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5a8e6aa895e8..e22ca9a0c1f0 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -27,6 +27,9 @@ jobs: - name: Install project deps with pnpm run: | pnpm i + - name: Install Ruff + run: | + pip install ruff - name: Lint run: | python3 ./tools/cross/format.py --check diff --git a/tools/cross/format.py b/tools/cross/format.py index 0d26a3d9507f..39869d28384b 100644 --- a/tools/cross/format.py +++ b/tools/cross/format.py @@ -3,6 +3,7 @@ import logging import os import re +import shutil import subprocess from argparse import ArgumentParser, Namespace from typing import List, Optional, Tuple, Callable @@ -11,6 +12,7 @@ CLANG_FORMAT = os.environ.get("CLANG_FORMAT", "clang-format") PRETTIER = os.environ.get("PRETTIER", "node_modules/.bin/prettier") +RUFF = os.environ.get("RUFF", "ruff") def parse_args() -> Namespace: @@ -78,7 +80,7 @@ def filter_files_by_exts( return [ file for file in files - if file.startswith(dir_path + "/") and file.endswith(exts) + if (dir_path == "." or file.startswith(dir_path + "/")) and file.endswith(exts) ] @@ -102,6 +104,26 @@ def prettier(files: List[str], check: bool = False) -> bool: return result.returncode == 0 +def ruff(files: List[str], check: bool = False) -> bool: + if files and not shutil.which(RUFF): + msg = "Cannot find ruff, will not format Python" + if check: + # In ci, fail. + logging.error(msg) + return False + else: + # In a local checkout, let it go. If the user wants Python + # formatting they can install ruff and run again. + logging.warning(msg) + return True + return not check + cmd = [RUFF, "format"] + if check: + cmd.append("--check") + result = subprocess.run(cmd + files) + return result.returncode == 0 + + def git_get_modified_files( target: str, source: Optional[str], staged: bool ) -> List[str]: @@ -147,6 +169,7 @@ class FormatConfig: formatter=prettier, ), FormatConfig(directory="src", extensions=(".json",), formatter=prettier), + FormatConfig(directory=".", extensions=(".py",), formatter=ruff), # TODO: lint bazel files ]