@@ -240,7 +240,7 @@ def test_register_definitions(self):
240
240
catalog = None ,
241
241
database = None ,
242
242
configuration = {},
243
- libraries = [LibrariesGlob (include = "subdir1/*" )],
243
+ libraries = [LibrariesGlob (include = "subdir1/** " )],
244
244
)
245
245
with tempfile .TemporaryDirectory () as temp_dir :
246
246
outer_dir = Path (temp_dir )
@@ -283,7 +283,7 @@ def test_register_definitions_file_raises_error(self):
283
283
catalog = None ,
284
284
database = None ,
285
285
configuration = {},
286
- libraries = [LibrariesGlob (include = "*" )],
286
+ libraries = [LibrariesGlob (include = "./* *" )],
287
287
)
288
288
with tempfile .TemporaryDirectory () as temp_dir :
289
289
outer_dir = Path (temp_dir )
@@ -301,7 +301,7 @@ def test_register_definitions_unsupported_file_extension_matches_glob(self):
301
301
catalog = None ,
302
302
database = None ,
303
303
configuration = {},
304
- libraries = [LibrariesGlob (include = "*" )],
304
+ libraries = [LibrariesGlob (include = "./* *" )],
305
305
)
306
306
with tempfile .TemporaryDirectory () as temp_dir :
307
307
outer_dir = Path (temp_dir )
@@ -451,6 +451,70 @@ def test_parse_table_list_with_spaces(self):
451
451
result = parse_table_list ("table1, table2 , table3" )
452
452
self .assertEqual (result , ["table1" , "table2" , "table3" ])
453
453
454
+ def test_valid_glob_patterns (self ):
455
+ """Test that valid glob patterns are accepted."""
456
+ from pyspark .pipelines .cli import validate_patch_glob_pattern
457
+
458
+ cases = {
459
+ # Simple file paths
460
+ "src/main.py" : "src/main.py" ,
461
+ "data/file.sql" : "data/file.sql" ,
462
+ # Folder paths ending with /** (normalized)
463
+ "src/**" : "src/**/*" ,
464
+ "transformations/**" : "transformations/**/*" ,
465
+ "notebooks/production/**" : "notebooks/production/**/*" ,
466
+ }
467
+
468
+ for pattern , expected in cases .items ():
469
+ with self .subTest (pattern = pattern ):
470
+ self .assertEqual (validate_patch_glob_pattern (pattern ), expected )
471
+
472
+ def test_invalid_glob_patterns (self ):
473
+ """Test that invalid glob patterns are rejected."""
474
+ from pyspark .pipelines .cli import validate_patch_glob_pattern
475
+
476
+ invalid_patterns = [
477
+ "transformations/**/*.py" ,
478
+ "src/**/utils/*.py" ,
479
+ "*/main.py" ,
480
+ "src/*/test/*.py" ,
481
+ "**/*.py" ,
482
+ "data/*/file.sql" ,
483
+ ]
484
+
485
+ for pattern in invalid_patterns :
486
+ with self .subTest (pattern = pattern ):
487
+ with self .assertRaises (PySparkException ) as context :
488
+ validate_patch_glob_pattern (pattern )
489
+ self .assertEqual (
490
+ context .exception .getCondition (), "PIPELINE_SPEC_INVALID_GLOB_PATTERN"
491
+ )
492
+ self .assertEqual (
493
+ context .exception .getMessageParameters (), {"glob_pattern" : pattern }
494
+ )
495
+
496
+ def test_pipeline_spec_with_invalid_glob_pattern (self ):
497
+ """Test that pipeline spec with invalid glob pattern is rejected."""
498
+ with tempfile .NamedTemporaryFile (mode = "w" ) as tmpfile :
499
+ tmpfile .write (
500
+ """
501
+ {
502
+ "name": "test_pipeline",
503
+ "libraries": [
504
+ {"glob": {"include": "transformations/**/*.py"}}
505
+ ]
506
+ }
507
+ """
508
+ )
509
+ tmpfile .flush ()
510
+ with self .assertRaises (PySparkException ) as context :
511
+ load_pipeline_spec (Path (tmpfile .name ))
512
+ self .assertEqual (context .exception .getCondition (), "PIPELINE_SPEC_INVALID_GLOB_PATTERN" )
513
+ self .assertEqual (
514
+ context .exception .getMessageParameters (),
515
+ {"glob_pattern" : "transformations/**/*.py" },
516
+ )
517
+
454
518
455
519
if __name__ == "__main__" :
456
520
try :
0 commit comments