From 0b4a60a460db254f94e26fc540479b59c4d6520f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 6 Dec 2024 16:02:54 -0500 Subject: [PATCH] Use `sys.base_prefix` and `sys.prefix` to filter out modules in fast registry (#2985) Signed-off-by: Thomas J. Fan --- flytekit/tools/ignore.py | 2 +- flytekit/tools/script_mode.py | 60 +++++++--------- tests/flytekit/unit/tools/test_script_mode.py | 70 ++++++++++++++++++- 3 files changed, 95 insertions(+), 37 deletions(-) diff --git a/flytekit/tools/ignore.py b/flytekit/tools/ignore.py index e41daf0904..e2aefef596 100644 --- a/flytekit/tools/ignore.py +++ b/flytekit/tools/ignore.py @@ -48,7 +48,7 @@ def _list_ignored(self) -> Dict: out = subprocess.run(["git", "ls-files", "-io", "--exclude-standard"], cwd=self.root, capture_output=True) if out.returncode == 0: return dict.fromkeys(out.stdout.decode("utf-8").split("\n")[:-1]) - logger.warning(f"Could not determine ignored files due to:\n{out.stderr}\nNot applying any filters") + logger.info(f"Could not determine ignored files due to:\n{out.stderr}\nNot applying any filters") return {} logger.info("No git executable found, not applying any filters") return {} diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index 5dd54e90ab..6580fa6462 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -188,23 +188,39 @@ def list_all_files(source_path: str, deref_symlinks, ignore_group: Optional[Igno return all_files +def _file_is_in_directory(file: str, directory: str) -> bool: + """Return True if file is in directory and in its children.""" + try: + return os.path.commonpath([file, directory]) == directory + except ValueError as e: + # ValueError is raised by windows if the paths are not from the same drive + logger.debug(f"{file} and {directory} are not in the same drive: {str(e)}") + return False + + def list_imported_modules_as_files(source_path: str, modules: List[ModuleType]) -> List[str]: """Copies modules into destination that are in modules. The module files are copied only if: 1. Not a site-packages. These are installed packages and not user files. - 2. Not in the bin. These are also installed and not user files. + 2. Not in the sys.base_prefix or sys.prefix. These are also installed and not user files. 3. Does not share a common path with the source_path. """ # source path is the folder holding the main script. # but in register/package case, there are multiple folders. # identify a common root amongst the packages listed? - site_packages = site.getsitepackages() - site_packages_set = set(site_packages) - bin_directory = os.path.dirname(sys.executable) files = [] flytekit_root = os.path.dirname(flytekit.__file__) + # These directories contain installed packages or modules from the Python standard library. + # If a module is from these directories, then they are not user files. + invalid_directories = [ + flytekit_root, + sys.prefix, + sys.base_prefix, + site.getusersitepackages(), + ] + site.getsitepackages() + for mod in modules: try: mod_file = mod.__file__ @@ -214,37 +230,11 @@ def list_imported_modules_as_files(source_path: str, modules: List[ModuleType]) if mod_file is None: continue - # Check to see if mod_file is in site_packages or bin_directory, which are - # installed packages & libraries that are not user files. This happens when - # there is a virtualenv like `.venv` in the working directory. - try: - # Do not upload code if it is from the flytekit library - if os.path.commonpath([flytekit_root, mod_file]) == flytekit_root: - continue - - if os.path.commonpath(site_packages + [mod_file]) in site_packages_set: - # Do not upload files from site-packages - continue - - if os.path.commonpath([bin_directory, mod_file]) == bin_directory: - # Do not upload from the bin directory - continue - - except ValueError: - # ValueError is raised by windows if the paths are not from the same drive - # If the files are not in the same drive, then mod_file is not - # in the site-packages or bin directory. - pass + if any(_file_is_in_directory(mod_file, directory) for directory in invalid_directories): + continue - try: - common_path = os.path.commonpath([mod_file, source_path]) - if common_path != source_path: - # Do not upload files that do not share a common directory with the source - continue - except ValueError: - # ValueError is raised by windows if the paths are not from the same drive - # If they are not in the same directory, then they do not share a common path, - # so we do not upload the file. + if not _file_is_in_directory(mod_file, source_path): + # Only upload files where the module file in the source directory continue files.append(mod_file) @@ -256,7 +246,7 @@ def add_imported_modules_from_source(source_path: str, destination: str, modules """Copies modules into destination that are in modules. The module files are copied only if: 1. Not a site-packages. These are installed packages and not user files. - 2. Not in the bin. These are also installed and not user files. + 2. Not in the sys.base_prefix or sys.prefix. These are also installed and not user files. 3. Does not share a common path with the source_path. """ # source path is the folder holding the main script. diff --git a/tests/flytekit/unit/tools/test_script_mode.py b/tests/flytekit/unit/tools/test_script_mode.py index e664368b25..6d14cac107 100644 --- a/tests/flytekit/unit/tools/test_script_mode.py +++ b/tests/flytekit/unit/tools/test_script_mode.py @@ -4,11 +4,18 @@ import sys import tempfile from pathlib import Path +from types import ModuleType +from unittest.mock import patch import pytest -from flytekit.tools.script_mode import compress_scripts, hash_file, add_imported_modules_from_source, get_all_modules, list_all_files +import flytekit from flytekit.core.tracker import import_module_from_file +from flytekit.tools.script_mode import compress_scripts, hash_file, add_imported_modules_from_source, get_all_modules, \ + list_all_files +from flytekit.tools.script_mode import ( + list_imported_modules_as_files, +) MAIN_WORKFLOW = """ from flytekit import task, workflow @@ -243,6 +250,67 @@ def test_get_all_modules(tmp_path): workflow_file.write_text(WORKFLOW_CONTENT) assert n_sys_modules + 1 == len(get_all_modules(os.fspath(source_dir), "my_workflows.main")) + +@patch("flytekit.tools.script_mode.sys") +@patch("site.getsitepackages") +def test_list_imported_modules_as_files(mock_getsitepackage, mock_sys, tmp_path): + + bin_directory = Path(os.path.dirname(sys.executable)) + flytekit_root = Path(os.path.dirname(flytekit.__file__)) + source_path = tmp_path / "project" + + # Site packages should be executed + site_packages = [ + str(source_path / ".venv" / "lib" / "python3.10" / "site-packages"), + str(source_path / ".venv" / "local" / "lib" / "python3.10" / "dist-packages"), + str(source_path / ".venv" / "lib" / "python3" / "dist-packages"), + str(source_path / ".venv" / "lib" / "python3.10" / "dist-packages"), + ] + mock_getsitepackage.return_value = site_packages + + # lib module that should be excluded, even if it is in the same roto as source_path + lib_path = source_path / "micromamba" / "envs" / "my-env" + lib_modules = [ + (ModuleType("lib_module"), str(lib_path / "module.py")) + ] + # mock the sys prefix to be in the source path + mock_sys.prefix = str(lib_path) + + # bin module that should be excluded + bin_modules = [ + (ModuleType("bin_module"), str(bin_directory / "bin" / "module.py")) + ] + # site modules that should be excluded + site_modules = [ + (ModuleType("site_module_1"), str(Path(site_packages[0]) / "package" / "module_1.py")), + (ModuleType("site_module_2"), str(Path(site_packages[1]) / "package" / "module_2.py")), + (ModuleType("site_module_3"), str(Path(site_packages[2]) / "package" / "module_3.py")), + (ModuleType("site_module_4"), str(Path(site_packages[3]) / "package" / "module_4.py")), + ] + + # local modules that should be included + local_modules = [ + (ModuleType("local_module_1"), str(source_path / "package_a" / "module_1.py")), + (ModuleType("local_module_2"), str(source_path / "package_a" / "module_2.py")), + (ModuleType("local_module_3"), str(source_path / "package_b" / "module_3.py")), + (ModuleType("local_module_4"), str(source_path / "package_b" / "module_4.py")), + ] + flyte_modules = [ + (ModuleType("flyte_module"), str(flytekit_root / "package" / "module.py")) + ] + + module_path_pairs = local_modules + flyte_modules + bin_modules + lib_modules + site_modules + + for m, p in module_path_pairs: + m.__file__ = p + + modules = [m for m, _ in module_path_pairs] + + file_list = list_imported_modules_as_files(str(source_path), modules) + + assert sorted(file_list) == sorted([p for _, p in local_modules]) + + @pytest.mark.skipif( sys.platform == "win32", reason="Skip if running on windows since Unix Domain Sockets do not exist in that OS",