Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion dev/lint-python
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# define test binaries + versions
FLAKE8_BUILD="flake8"
MINIMUM_FLAKE8="3.9.0"
RUFF_BUILD="ruff"
MINIMUM_RUFF="0.14.0"
MINIMUM_MYPY="1.8.0"
MYPY_BUILD="mypy"
PYTEST_BUILD="pytest"
Expand Down Expand Up @@ -52,6 +54,9 @@ while (( "$#" )); do
--flake8)
FLAKE8_TEST=true
;;
--ruff)
RUFF_TEST=true
;;
--mypy)
MYPY_TEST=true
;;
Expand All @@ -69,7 +74,7 @@ while (( "$#" )); do
shift
done

if [[ -z "$COMPILE_TEST$BLACK_TEST$PYSPARK_CUSTOM_ERRORS_CHECK_TEST$FLAKE8_TEST$MYPY_TEST$MYPY_EXAMPLES_TEST$MYPY_DATA_TEST" ]]; then
if [[ -z "$COMPILE_TEST$BLACK_TEST$PYSPARK_CUSTOM_ERRORS_CHECK_TEST$FLAKE8_TEST$RUFF_TEST$MYPY_TEST$MYPY_EXAMPLES_TEST$MYPY_DATA_TEST" ]]; then
COMPILE_TEST=true
BLACK_TEST=true
PYSPARK_CUSTOM_ERRORS_CHECK_TEST=true
Expand Down Expand Up @@ -270,6 +275,45 @@ flake8 checks failed."
fi
}

function ruff_test {
local RUFF_VERSION=
local EXPECTED_RUFF=
local RUFF_REPORT=
local RUFF_STATUS=

if ! hash "$RUFF_BUILD" 2> /dev/null; then
echo "The ruff command was not found. Skipping for now."
return
fi

_RUFF_VERSION=($($RUFF_BUILD --version))
RUFF_VERSION="${_RUFF_VERSION[1]}"
EXPECTED_RUFF="$(satisfies_min_version $RUFF_VERSION $MINIMUM_RUFF)"

if [[ "$EXPECTED_RUFF" == "False" ]]; then
echo "\
The minimum ruff version needs to be $MINIMUM_RUFF. Your current version is $RUFF_VERSION

ruff checks failed."
exit 1
fi

echo "starting $RUFF_BUILD test..."
RUFF_REPORT=$( ($RUFF_BUILD check --config dev/pyproject.toml) 2>&1)
RUFF_STATUS=$?

if [ "$RUFF_STATUS" -ne 0 ]; then
echo "ruff checks failed:"
echo "$RUFF_REPORT"
echo "$RUFF_STATUS"
exit "$RUFF_STATUS"
else
echo "ruff checks passed."
echo
fi

}

function black_test {
local BLACK_REPORT=
local BLACK_STATUS=
Expand Down Expand Up @@ -335,6 +379,9 @@ fi
if [[ "$FLAKE8_TEST" == "true" ]]; then
flake8_test
fi
if [[ "$RUFF_TEST" == "true" ]]; then
ruff_test
fi
if [[ "$MYPY_TEST" == "true" ]] || [[ "$MYPY_EXAMPLES_TEST" == "true" ]] || [[ "$MYPY_DATA_TEST" == "true" ]]; then
mypy_test
fi
Expand Down
52 changes: 52 additions & 0 deletions dev/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,58 @@ testpaths = [
"pyspark/ml/typing",
]

[tool.ruff]
exclude = [
"*/target/*",
"**/*.ipynb",
"docs/.local_ruby_bundle/",
"*python/pyspark/cloudpickle/*.py",
"*python/pyspark/ml/deepspeed/tests/*.py",
"*python/docs/build/*",
"*python/docs/source/conf.py",
"*python/.eggs/*",
"dist/*",
".git/*",
"*python/pyspark/sql/pandas/functions.pyi",
"*python/pyspark/sql/column.pyi",
"*python/pyspark/worker.pyi",
"*python/pyspark/java_gateway.pyi",
"*python/pyspark/sql/connect/proto/*",
"*python/pyspark/sql/streaming/proto/*",
"*venv*/*",
]

[tool.ruff.lint]
ignore = [
"E203", # Skip as black formatter adds a whitespace around ':'.
"E402", # Module top level import is disabled for optional import check, etc.
# TODO
"E721", # Use isinstance for type comparison, too many for now.
"E741", # Ambiguous variables like l, I or O.
]

[tool.ruff.lint.per-file-ignores]
# E501 is ignored as shared.py is auto-generated.
"python/pyspark/ml/param/shared.py" = ["E501"]
# E501 is ignored as we should keep the json string format in error_classes.py.
"python/pyspark/errors/error_classes.py" = ["E501"]
# Examples contain some unused variables.
"examples/src/main/python/sql/datasource.py" = ["F841"]
# Exclude * imports in test files
"python/pyspark/errors/tests/*.py" = ["F403"]
"python/pyspark/logger/tests/*.py" = ["F403"]
"python/pyspark/logger/tests/connect/*.py" = ["F403"]
"python/pyspark/ml/tests/*.py" = ["F403"]
"python/pyspark/mllib/tests/*.py" = ["F403"]
"python/pyspark/pandas/tests/*.py" = ["F401", "F403"]
"python/pyspark/pandas/tests/connect/*.py" = ["F401", "F403"]
"python/pyspark/resource/tests/*.py" = ["F403"]
"python/pyspark/sql/tests/*.py" = ["F403"]
"python/pyspark/streaming/tests/*.py" = ["F403"]
"python/pyspark/tests/*.py" = ["F403"]
"python/pyspark/testing/*.py" = ["F401"]
"python/pyspark/testing/tests/*.py" = ["F403"]

[tool.black]
# When changing the version, we have to update
# GitHub workflow version and dev/reformat-python
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def setParams(

def isLargerBetter(self) -> bool:
"""Override this function to make it run on connect"""
return not self.getMetricName() in [
return self.getMetricName() not in [
"weightedFalsePositiveRate",
"falsePositiveRateByLabel",
"logLoss",
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/sql/connect/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
check_dependencies(__name__)

import sys
from typing import Dict, Optional, TYPE_CHECKING, List, Callable
from typing import Dict, Optional, TYPE_CHECKING, Callable

from pyspark.sql.connect import proto
from pyspark.sql.connect.column import Column
Expand Down Expand Up @@ -73,9 +73,9 @@ def __init__(

self._callback = callback if callback is not None else lambda _: None
self._schema_evolution_enabled = False
self._matched_actions = list() # type: List[proto.MergeAction]
self._not_matched_actions = list() # type: List[proto.MergeAction]
self._not_matched_by_source_actions = list() # type: List[proto.MergeAction]
self._matched_actions: list[proto.MergeAction] = list()
self._not_matched_actions: list[proto.MergeAction] = list()
self._not_matched_by_source_actions: list[proto.MergeAction] = list()

def whenMatched(self, condition: Optional[Column] = None) -> "MergeIntoWriter.WhenMatched":
return self.WhenMatched(self, condition)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/arrow/test_arrow_udf_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def multiply_pandas(a: pd.Series, b: pd.Series) -> pd.Series:


if __name__ == "__main__":
from pyspark.sql.tests.arrow.test_arrow_udf_typehints import * # noqa: #401
from pyspark.sql.tests.arrow.test_arrow_udf_typehints import * # noqa: #F401

try:
import xmlrunner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def multiply_arrow(a: pa.Array, b: pa.Array) -> pa.Array:


if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_udf_typehints import * # noqa: #401
from pyspark.sql.tests.pandas.test_pandas_udf_typehints import * # noqa: #F401

try:
import xmlrunner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def func(col: "Union[pd.Series, pd.DataFrame]", *, col2: "pd.DataFrame") -> "pd.


if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations import * # noqa: #401
from pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations import * # noqa: #F401

try:
import xmlrunner
Expand Down
2 changes: 0 additions & 2 deletions python/pyspark/taskcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,6 @@ def resources(self) -> Dict[str, "ResourceInformation"]:
dict
a dictionary of a string resource name, and :class:`ResourceInformation`.
"""
from pyspark.resource import ResourceInformation

return cast(Dict[str, "ResourceInformation"], self._resources)


Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ class PyLocalIterable:
def __init__(self, _sock_info: "JavaArray", _serializer: "Serializer"):
port: int
auth_secret: str
jsocket_auth_server: "JavaObject"
self.jsocket_auth_server: "JavaObject"
port, auth_secret, self.jsocket_auth_server = _sock_info
self._sockfile, self._sock = _create_local_socket((port, auth_secret))
self._serializer = _serializer
Expand Down