Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,11 @@
"No pipeline.yaml or pipeline.yml file provided in arguments or found in directory `<dir_path>` or readable ancestor directories."
]
},
"PIPELINE_SPEC_INVALID_GLOB_PATTERN": {
"message": [
"Invalid glob pattern `<glob_pattern>` in libraries. Only file paths, or folder paths ending with /** are allowed."
]
},
"PIPELINE_SPEC_MISSING_REQUIRED_FIELD": {
"message": [
"Pipeline spec missing required field `<field_name>`."
Expand Down
43 changes: 42 additions & 1 deletion python/pyspark/pipelines/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""
from contextlib import contextmanager
import argparse
import glob
import importlib.util
import os
import yaml
Expand Down Expand Up @@ -58,6 +59,32 @@ class LibrariesGlob:
include: str


def validate_patch_glob_pattern(glob_pattern: str) -> str:
"""Validates that a glob pattern is allowed.

Only allows:
- File paths (paths without wildcards except for the filename)
- Folder paths ending with /** (recursive directory patterns)

Disallows complex glob patterns like transformations/**/*.py
"""
# Check if it's a simple file path (no wildcards at all)
if not glob.has_magic(glob_pattern):
return glob_pattern

# Check if it's a folder path ending with /**
if glob_pattern.endswith("/**"):
prefix = glob_pattern[:-3]
if not glob.has_magic(prefix):
# append "/*" to match everything under the directory recursively
return glob_pattern + "/*"

raise PySparkException(
errorClass="PIPELINE_SPEC_INVALID_GLOB_PATTERN",
messageParameters={"glob_pattern": glob_pattern},
)


@dataclass(frozen=True)
class PipelineSpec:
"""Spec for a pipeline.
Expand All @@ -75,6 +102,16 @@ class PipelineSpec:
configuration: Mapping[str, str]
libraries: Sequence[LibrariesGlob]

def __post_init__(self) -> None:
"""Validate libraries automatically after instantiation."""
validated = [
LibrariesGlob(validate_patch_glob_pattern(lib.include)) for lib in self.libraries
]

# If normalization changed anything, patch into frozen dataclass
if tuple(validated) != tuple(self.libraries):
object.__setattr__(self, "libraries", tuple(validated))


def find_pipeline_spec(current_dir: Path) -> Path:
"""Looks in the current directory and its ancestors for a pipeline spec file."""
Expand Down Expand Up @@ -180,7 +217,11 @@ def register_definitions(
log_with_curr_timestamp(f"Loading definitions. Root directory: '{path}'.")
for libraries_glob in spec.libraries:
glob_expression = libraries_glob.include
matching_files = [p for p in path.glob(glob_expression) if p.is_file()]
matching_files = [
p
for p in path.glob(glob_expression)
if p.is_file() and "__pycache__" not in p.parts # ignore generated python cache
]
log_with_curr_timestamp(
f"Found {len(matching_files)} files matching glob '{glob_expression}'"
)
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/pipelines/init_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
name: {{ name }}
libraries:
- glob:
include: transformations/**/*.py
- glob:
include: transformations/**/*.sql
include: transformations/**
"""

PYTHON_EXAMPLE = """from pyspark import pipelines as dp
Expand Down
70 changes: 67 additions & 3 deletions python/pyspark/pipelines/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_register_definitions(self):
catalog=None,
database=None,
configuration={},
libraries=[LibrariesGlob(include="subdir1/*")],
libraries=[LibrariesGlob(include="subdir1/**")],
)
with tempfile.TemporaryDirectory() as temp_dir:
outer_dir = Path(temp_dir)
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_register_definitions_file_raises_error(self):
catalog=None,
database=None,
configuration={},
libraries=[LibrariesGlob(include="*")],
libraries=[LibrariesGlob(include="./**")],
)
with tempfile.TemporaryDirectory() as temp_dir:
outer_dir = Path(temp_dir)
Expand All @@ -301,7 +301,7 @@ def test_register_definitions_unsupported_file_extension_matches_glob(self):
catalog=None,
database=None,
configuration={},
libraries=[LibrariesGlob(include="*")],
libraries=[LibrariesGlob(include="./**")],
)
with tempfile.TemporaryDirectory() as temp_dir:
outer_dir = Path(temp_dir)
Expand Down Expand Up @@ -451,6 +451,70 @@ def test_parse_table_list_with_spaces(self):
result = parse_table_list("table1, table2 , table3")
self.assertEqual(result, ["table1", "table2", "table3"])

def test_valid_glob_patterns(self):
"""Test that valid glob patterns are accepted."""
from pyspark.pipelines.cli import validate_patch_glob_pattern

cases = {
# Simple file paths
"src/main.py": "src/main.py",
"data/file.sql": "data/file.sql",
# Folder paths ending with /** (normalized)
"src/**": "src/**/*",
"transformations/**": "transformations/**/*",
"notebooks/production/**": "notebooks/production/**/*",
}

for pattern, expected in cases.items():
with self.subTest(pattern=pattern):
self.assertEqual(validate_patch_glob_pattern(pattern), expected)

def test_invalid_glob_patterns(self):
"""Test that invalid glob patterns are rejected."""
from pyspark.pipelines.cli import validate_patch_glob_pattern

invalid_patterns = [
"transformations/**/*.py",
"src/**/utils/*.py",
"*/main.py",
"src/*/test/*.py",
"**/*.py",
"data/*/file.sql",
]

for pattern in invalid_patterns:
with self.subTest(pattern=pattern):
with self.assertRaises(PySparkException) as context:
validate_patch_glob_pattern(pattern)
self.assertEqual(
context.exception.getCondition(), "PIPELINE_SPEC_INVALID_GLOB_PATTERN"
)
self.assertEqual(
context.exception.getMessageParameters(), {"glob_pattern": pattern}
)

def test_pipeline_spec_with_invalid_glob_pattern(self):
"""Test that pipeline spec with invalid glob pattern is rejected."""
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
tmpfile.write(
"""
{
"name": "test_pipeline",
"libraries": [
{"glob": {"include": "transformations/**/*.py"}}
]
}
"""
)
tmpfile.flush()
with self.assertRaises(PySparkException) as context:
load_pipeline_spec(Path(tmpfile.name))
self.assertEqual(context.exception.getCondition(), "PIPELINE_SPEC_INVALID_GLOB_PATTERN")
self.assertEqual(
context.exception.getMessageParameters(),
{"glob_pattern": "transformations/**/*.py"},
)


if __name__ == "__main__":
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ class EndToEndAPISuite extends PipelineTest with APITest with SparkConnectServer
// Create each source file in the temporary directory
sources.foreach { file =>
val filePath = Paths.get(file.name)
val fileName = filePath.getFileName.toString
val tempFilePath = projectDir.resolve(fileName)
val tempFilePath = projectDir.resolve(filePath)

// Create any necessary parent directories
val parentDir = tempFilePath.getParent
if (parentDir != null) {
Files.createDirectories(parentDir)
}

// Create the file with the specified contents
Files.write(tempFilePath, file.contents.getBytes("UTF-8"))
Expand Down
Loading