Skip to content

Commit 8d8ffe1

Browse files
committed
simplify pattern glob matching
1 parent 8c422f9 commit 8d8ffe1

File tree

4 files changed

+115
-7
lines changed

4 files changed

+115
-7
lines changed

python/pyspark/errors/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,11 @@
896896
"No pipeline.yaml or pipeline.yml file provided in arguments or found in directory `<dir_path>` or readable ancestor directories."
897897
]
898898
},
899+
"PIPELINE_SPEC_INVALID_GLOB_PATTERN": {
900+
"message": [
901+
"Invalid glob pattern `<glob_pattern>` in libraries. Only file paths, or folder paths ending with /** are allowed."
902+
]
903+
},
899904
"PIPELINE_SPEC_MISSING_REQUIRED_FIELD": {
900905
"message": [
901906
"Pipeline spec missing required field `<field_name>`."

python/pyspark/pipelines/cli.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"""
2424
from contextlib import contextmanager
2525
import argparse
26+
import glob
2627
import importlib.util
2728
import os
2829
import yaml
@@ -58,6 +59,32 @@ class LibrariesGlob:
5859
include: str
5960

6061

62+
def validate_patch_glob_pattern(glob_pattern: str) -> str:
63+
"""Validates that a glob pattern is allowed.
64+
65+
Only allows:
66+
- File paths (paths without wildcards except for the filename)
67+
- Folder paths ending with /** (recursive directory patterns)
68+
69+
Disallows complex glob patterns like transformations/**/*.py
70+
"""
71+
# Check if it's a simple file path (no wildcards at all)
72+
if not glob.has_magic(glob_pattern):
73+
return glob_pattern
74+
75+
# Check if it's a folder path ending with /**
76+
if glob_pattern.endswith("/**"):
77+
prefix = glob_pattern[:-3]
78+
if not glob.has_magic(prefix):
79+
# append "/*" to match everything under the directory recursively
80+
return glob_pattern + "/*"
81+
82+
raise PySparkException(
83+
errorClass="PIPELINE_SPEC_INVALID_GLOB_PATTERN",
84+
messageParameters={"glob_pattern": glob_pattern},
85+
)
86+
87+
6188
@dataclass(frozen=True)
6289
class PipelineSpec:
6390
"""Spec for a pipeline.
@@ -75,6 +102,16 @@ class PipelineSpec:
75102
configuration: Mapping[str, str]
76103
libraries: Sequence[LibrariesGlob]
77104

105+
def __post_init__(self) -> None:
106+
"""Validate libraries automatically after instantiation."""
107+
validated = [
108+
LibrariesGlob(validate_patch_glob_pattern(lib.include)) for lib in self.libraries
109+
]
110+
111+
# If normalization changed anything, patch into frozen dataclass
112+
if tuple(validated) != tuple(self.libraries):
113+
object.__setattr__(self, "libraries", tuple(validated))
114+
78115

79116
def find_pipeline_spec(current_dir: Path) -> Path:
80117
"""Looks in the current directory and its ancestors for a pipeline spec file."""
@@ -180,7 +217,11 @@ def register_definitions(
180217
log_with_curr_timestamp(f"Loading definitions. Root directory: '{path}'.")
181218
for libraries_glob in spec.libraries:
182219
glob_expression = libraries_glob.include
183-
matching_files = [p for p in path.glob(glob_expression) if p.is_file()]
220+
matching_files = [
221+
p
222+
for p in path.glob(glob_expression)
223+
if p.is_file() and "__pycache__" not in p.parts # ignore generated python cache
224+
]
184225
log_with_curr_timestamp(
185226
f"Found {len(matching_files)} files matching glob '{glob_expression}'"
186227
)

python/pyspark/pipelines/init_cli.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
name: {{ name }}
2222
libraries:
2323
- glob:
24-
include: transformations/**/*.py
25-
- glob:
26-
include: transformations/**/*.sql
24+
include: transformations/**
2725
"""
2826

2927
PYTHON_EXAMPLE = """from pyspark import pipelines as dp

python/pyspark/pipelines/tests/test_cli.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def test_register_definitions(self):
240240
catalog=None,
241241
database=None,
242242
configuration={},
243-
libraries=[LibrariesGlob(include="subdir1/*")],
243+
libraries=[LibrariesGlob(include="subdir1/**")],
244244
)
245245
with tempfile.TemporaryDirectory() as temp_dir:
246246
outer_dir = Path(temp_dir)
@@ -283,7 +283,7 @@ def test_register_definitions_file_raises_error(self):
283283
catalog=None,
284284
database=None,
285285
configuration={},
286-
libraries=[LibrariesGlob(include="*")],
286+
libraries=[LibrariesGlob(include="./**")],
287287
)
288288
with tempfile.TemporaryDirectory() as temp_dir:
289289
outer_dir = Path(temp_dir)
@@ -301,7 +301,7 @@ def test_register_definitions_unsupported_file_extension_matches_glob(self):
301301
catalog=None,
302302
database=None,
303303
configuration={},
304-
libraries=[LibrariesGlob(include="*")],
304+
libraries=[LibrariesGlob(include="./**")],
305305
)
306306
with tempfile.TemporaryDirectory() as temp_dir:
307307
outer_dir = Path(temp_dir)
@@ -451,6 +451,70 @@ def test_parse_table_list_with_spaces(self):
451451
result = parse_table_list("table1, table2 , table3")
452452
self.assertEqual(result, ["table1", "table2", "table3"])
453453

454+
def test_valid_glob_patterns(self):
455+
"""Test that valid glob patterns are accepted."""
456+
from pyspark.pipelines.cli import validate_patch_glob_pattern
457+
458+
cases = {
459+
# Simple file paths
460+
"src/main.py": "src/main.py",
461+
"data/file.sql": "data/file.sql",
462+
# Folder paths ending with /** (normalized)
463+
"src/**": "src/**/*",
464+
"transformations/**": "transformations/**/*",
465+
"notebooks/production/**": "notebooks/production/**/*",
466+
}
467+
468+
for pattern, expected in cases.items():
469+
with self.subTest(pattern=pattern):
470+
self.assertEqual(validate_patch_glob_pattern(pattern), expected)
471+
472+
def test_invalid_glob_patterns(self):
473+
"""Test that invalid glob patterns are rejected."""
474+
from pyspark.pipelines.cli import validate_patch_glob_pattern
475+
476+
invalid_patterns = [
477+
"transformations/**/*.py",
478+
"src/**/utils/*.py",
479+
"*/main.py",
480+
"src/*/test/*.py",
481+
"**/*.py",
482+
"data/*/file.sql",
483+
]
484+
485+
for pattern in invalid_patterns:
486+
with self.subTest(pattern=pattern):
487+
with self.assertRaises(PySparkException) as context:
488+
validate_patch_glob_pattern(pattern)
489+
self.assertEqual(
490+
context.exception.getCondition(), "PIPELINE_SPEC_INVALID_GLOB_PATTERN"
491+
)
492+
self.assertEqual(
493+
context.exception.getMessageParameters(), {"glob_pattern": pattern}
494+
)
495+
496+
def test_pipeline_spec_with_invalid_glob_pattern(self):
497+
"""Test that pipeline spec with invalid glob pattern is rejected."""
498+
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
499+
tmpfile.write(
500+
"""
501+
{
502+
"name": "test_pipeline",
503+
"libraries": [
504+
{"glob": {"include": "transformations/**/*.py"}}
505+
]
506+
}
507+
"""
508+
)
509+
tmpfile.flush()
510+
with self.assertRaises(PySparkException) as context:
511+
load_pipeline_spec(Path(tmpfile.name))
512+
self.assertEqual(context.exception.getCondition(), "PIPELINE_SPEC_INVALID_GLOB_PATTERN")
513+
self.assertEqual(
514+
context.exception.getMessageParameters(),
515+
{"glob_pattern": "transformations/**/*.py"},
516+
)
517+
454518

455519
if __name__ == "__main__":
456520
try:

0 commit comments

Comments
 (0)