Skip to content

Commit

Permalink
Use sys.base_prefix and sys.prefix to filter out modules in fast …
Browse files Browse the repository at this point in the history
…registry (#2985)

Signed-off-by: Thomas J. Fan <thomasjpfan@gmail.com>
  • Loading branch information
thomasjpfan authored Dec 6, 2024
1 parent 3a42c8c commit 0b4a60a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 37 deletions.
2 changes: 1 addition & 1 deletion flytekit/tools/ignore.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _list_ignored(self) -> Dict:
out = subprocess.run(["git", "ls-files", "-io", "--exclude-standard"], cwd=self.root, capture_output=True)
if out.returncode == 0:
return dict.fromkeys(out.stdout.decode("utf-8").split("\n")[:-1])
logger.warning(f"Could not determine ignored files due to:\n{out.stderr}\nNot applying any filters")
logger.info(f"Could not determine ignored files due to:\n{out.stderr}\nNot applying any filters")
return {}
logger.info("No git executable found, not applying any filters")
return {}
Expand Down
60 changes: 25 additions & 35 deletions flytekit/tools/script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,23 +188,39 @@ def list_all_files(source_path: str, deref_symlinks, ignore_group: Optional[Igno
return all_files


def _file_is_in_directory(file: str, directory: str) -> bool:
"""Return True if file is in directory and in its children."""
try:
return os.path.commonpath([file, directory]) == directory
except ValueError as e:
# ValueError is raised by windows if the paths are not from the same drive
logger.debug(f"{file} and {directory} are not in the same drive: {str(e)}")
return False


def list_imported_modules_as_files(source_path: str, modules: List[ModuleType]) -> List[str]:
"""Copies modules into destination that are in modules. The module files are copied only if:
1. Not a site-packages. These are installed packages and not user files.
2. Not in the bin. These are also installed and not user files.
2. Not in the sys.base_prefix or sys.prefix. These are also installed and not user files.
3. Does not share a common path with the source_path.
"""
# source path is the folder holding the main script.
# but in register/package case, there are multiple folders.
# identify a common root amongst the packages listed?

site_packages = site.getsitepackages()
site_packages_set = set(site_packages)
bin_directory = os.path.dirname(sys.executable)
files = []
flytekit_root = os.path.dirname(flytekit.__file__)

# These directories contain installed packages or modules from the Python standard library.
# If a module is from these directories, then they are not user files.
invalid_directories = [
flytekit_root,
sys.prefix,
sys.base_prefix,
site.getusersitepackages(),
] + site.getsitepackages()

for mod in modules:
try:
mod_file = mod.__file__
Expand All @@ -214,37 +230,11 @@ def list_imported_modules_as_files(source_path: str, modules: List[ModuleType])
if mod_file is None:
continue

# Check to see if mod_file is in site_packages or bin_directory, which are
# installed packages & libraries that are not user files. This happens when
# there is a virtualenv like `.venv` in the working directory.
try:
# Do not upload code if it is from the flytekit library
if os.path.commonpath([flytekit_root, mod_file]) == flytekit_root:
continue

if os.path.commonpath(site_packages + [mod_file]) in site_packages_set:
# Do not upload files from site-packages
continue

if os.path.commonpath([bin_directory, mod_file]) == bin_directory:
# Do not upload from the bin directory
continue

except ValueError:
# ValueError is raised by windows if the paths are not from the same drive
# If the files are not in the same drive, then mod_file is not
# in the site-packages or bin directory.
pass
if any(_file_is_in_directory(mod_file, directory) for directory in invalid_directories):
continue

try:
common_path = os.path.commonpath([mod_file, source_path])
if common_path != source_path:
# Do not upload files that do not share a common directory with the source
continue
except ValueError:
# ValueError is raised by windows if the paths are not from the same drive
# If they are not in the same directory, then they do not share a common path,
# so we do not upload the file.
if not _file_is_in_directory(mod_file, source_path):
# Only upload files where the module file in the source directory
continue

files.append(mod_file)
Expand All @@ -256,7 +246,7 @@ def add_imported_modules_from_source(source_path: str, destination: str, modules
"""Copies modules into destination that are in modules. The module files are copied only if:
1. Not a site-packages. These are installed packages and not user files.
2. Not in the bin. These are also installed and not user files.
2. Not in the sys.base_prefix or sys.prefix. These are also installed and not user files.
3. Does not share a common path with the source_path.
"""
# source path is the folder holding the main script.
Expand Down
70 changes: 69 additions & 1 deletion tests/flytekit/unit/tools/test_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
import sys
import tempfile
from pathlib import Path
from types import ModuleType
from unittest.mock import patch

import pytest

from flytekit.tools.script_mode import compress_scripts, hash_file, add_imported_modules_from_source, get_all_modules, list_all_files
import flytekit
from flytekit.core.tracker import import_module_from_file
from flytekit.tools.script_mode import compress_scripts, hash_file, add_imported_modules_from_source, get_all_modules, \
list_all_files
from flytekit.tools.script_mode import (
list_imported_modules_as_files,
)

MAIN_WORKFLOW = """
from flytekit import task, workflow
Expand Down Expand Up @@ -243,6 +250,67 @@ def test_get_all_modules(tmp_path):
workflow_file.write_text(WORKFLOW_CONTENT)
assert n_sys_modules + 1 == len(get_all_modules(os.fspath(source_dir), "my_workflows.main"))


@patch("flytekit.tools.script_mode.sys")
@patch("site.getsitepackages")
def test_list_imported_modules_as_files(mock_getsitepackage, mock_sys, tmp_path):

bin_directory = Path(os.path.dirname(sys.executable))
flytekit_root = Path(os.path.dirname(flytekit.__file__))
source_path = tmp_path / "project"

# Site packages should be executed
site_packages = [
str(source_path / ".venv" / "lib" / "python3.10" / "site-packages"),
str(source_path / ".venv" / "local" / "lib" / "python3.10" / "dist-packages"),
str(source_path / ".venv" / "lib" / "python3" / "dist-packages"),
str(source_path / ".venv" / "lib" / "python3.10" / "dist-packages"),
]
mock_getsitepackage.return_value = site_packages

# lib module that should be excluded, even if it is in the same roto as source_path
lib_path = source_path / "micromamba" / "envs" / "my-env"
lib_modules = [
(ModuleType("lib_module"), str(lib_path / "module.py"))
]
# mock the sys prefix to be in the source path
mock_sys.prefix = str(lib_path)

# bin module that should be excluded
bin_modules = [
(ModuleType("bin_module"), str(bin_directory / "bin" / "module.py"))
]
# site modules that should be excluded
site_modules = [
(ModuleType("site_module_1"), str(Path(site_packages[0]) / "package" / "module_1.py")),
(ModuleType("site_module_2"), str(Path(site_packages[1]) / "package" / "module_2.py")),
(ModuleType("site_module_3"), str(Path(site_packages[2]) / "package" / "module_3.py")),
(ModuleType("site_module_4"), str(Path(site_packages[3]) / "package" / "module_4.py")),
]

# local modules that should be included
local_modules = [
(ModuleType("local_module_1"), str(source_path / "package_a" / "module_1.py")),
(ModuleType("local_module_2"), str(source_path / "package_a" / "module_2.py")),
(ModuleType("local_module_3"), str(source_path / "package_b" / "module_3.py")),
(ModuleType("local_module_4"), str(source_path / "package_b" / "module_4.py")),
]
flyte_modules = [
(ModuleType("flyte_module"), str(flytekit_root / "package" / "module.py"))
]

module_path_pairs = local_modules + flyte_modules + bin_modules + lib_modules + site_modules

for m, p in module_path_pairs:
m.__file__ = p

modules = [m for m, _ in module_path_pairs]

file_list = list_imported_modules_as_files(str(source_path), modules)

assert sorted(file_list) == sorted([p for _, p in local_modules])


@pytest.mark.skipif(
sys.platform == "win32",
reason="Skip if running on windows since Unix Domain Sockets do not exist in that OS",
Expand Down

0 comments on commit 0b4a60a

Please sign in to comment.