Skip to content

Commit

Permalink
allow absolute path for the external script (#2565)
Browse files Browse the repository at this point in the history
* add check not allow absolute path for the external script in job creation API.

* Changed the job creation API to allow absolution path ext_scripts, but not copy the script.

* Added absolute path ext_script support for job creation.

* removed the no use import.

* Added handle for the line continous support, added unit test.

* Removed the requirement for absolute external script to have the package path.

* enhance the handling of multi-lines import. added more cases of unit test.

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <yuantingh@nvidia.com>
  • Loading branch information
yhwen and YuanTingHsieh authored Jun 6, 2024
1 parent f8f121f commit e37ff67
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 9 deletions.
2 changes: 1 addition & 1 deletion nvflare/job_config/base_app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def add_ext_script(self, ext_script: str):
if not isinstance(ext_script, str):
raise RuntimeError(f"ext_script must be type of str, but got {ext_script.__class__}")

if not os.path.exists(ext_script):
if not (os.path.isabs(ext_script) or os.path.exists(ext_script)):
raise RuntimeError(f"Could not locate external script: {ext_script}")

if not ext_script.endswith(".py"):
Expand Down
47 changes: 39 additions & 8 deletions nvflare/job_config/fed_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import os
import shutil
import sys
from enum import Enum
from tempfile import TemporaryDirectory
from typing import Dict
Expand Down Expand Up @@ -155,9 +156,22 @@ def _get_server_app(self, config_dir, custom_dir, fed_app):

def _copy_ext_scripts(self, custom_dir, ext_scripts):
for script in ext_scripts:
dest_file = os.path.join(custom_dir, script)
module = "".join(script.rsplit(".py", 1)).replace(os.sep, ".")
self._copy_source_file(custom_dir, module, script, dest_file)
if os.path.exists(script):
if os.path.isabs(script):
relative_script = self._get_relative_script(script)
else:
relative_script = script
dest_file = os.path.join(custom_dir, relative_script)
module = "".join(relative_script.rsplit(".py", 1)).replace(os.sep, ".")
self._copy_source_file(custom_dir, module, script, dest_file)

def _get_relative_script(self, script):
package_path = ""
for path in sys.path:
if script.startswith(path):
if len(path) > len(package_path):
package_path = path
return script[len(package_path) + 1 :]

def _get_class_path(self, obj, custom_dir):
module = obj.__module__
Expand Down Expand Up @@ -294,15 +308,32 @@ def _get_filters(self, filters, custom_dir):
return r

def locate_imports(self, sf, dest_file):
"""Locate all the import statements from the python script, including the imports across multiple lines,
using the the line break continuing.
Args:
sf: source file
dest_file: copy to destination file
Returns:
yield all the imports within the source file
"""
os.makedirs(os.path.dirname(dest_file), exist_ok=True)
with open(dest_file, "w") as df:
trimmed = ""
for line in sf:
df.write(line)
trimmed = line.strip()
if trimmed.startswith("from ") and ("import " in trimmed):
yield trimmed
elif trimmed.startswith("import "):
yield trimmed
trimmed += line.strip()
if trimmed.endswith("\\"):
trimmed = trimmed[0:-1]
trimmed = trimmed.strip() + " "
else:
if trimmed.startswith("from ") and ("import " in trimmed):
yield trimmed
elif trimmed.startswith("import "):
yield trimmed
trimmed = ""

def _get_deploy_map(self):
deploy_map = {}
Expand Down
51 changes: 51 additions & 0 deletions tests/unit_test/data/job_config/sample_code.data
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List

from\
nvflare.fuel.f3.drivers.base_driver \
import \
BaseDriver

from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo \

from nvflare.fuel.f3.drivers.driver_params import DriverCap


class WarpDriver(BaseDriver):
"""A dummy driver to test custom driver loading"""

def __init__(self):
super().__init__()

@staticmethod
def supported_transports() -> List[str]:
return ["warp"]

@staticmethod
def capabilities() -> Dict[str, Any]:
return {DriverCap.SEND_HEARTBEAT.value: True, DriverCap.SUPPORT_SSL.value: False}

def listen(self, connector: ConnectorInfo):
self.connector = connector

def connect(self, connector: ConnectorInfo):
self.connector = connector

def shutdown(self):
self.close_all()

@staticmethod
def get_urls(scheme: str, resources: dict) -> (str, str):
return "warp:enterprise"
41 changes: 41 additions & 0 deletions tests/unit_test/job_config/base_app_config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile

import pytest

from nvflare.job_config.base_app_config import BaseAppConfig


class TestBaseAppConfig:
def setup_method(self, method):
self.app_config = BaseAppConfig()

def test_add_relative_script(self):
cwd = os.getcwd()
with tempfile.NamedTemporaryFile(dir=cwd, suffix=".py") as temp_file:
script = os.path.basename(temp_file.name)
self.app_config.add_ext_script(script)
assert script in self.app_config.ext_scripts

def test_add_ext_script(self):
script = "/scripts/sample.py"
self.app_config.add_ext_script(script)
assert script in self.app_config.ext_scripts

def test_add_ext_script_error(self):
script = "scripts/sample.py"
with pytest.raises(Exception):
self.app_config.add_ext_script(script)
34 changes: 34 additions & 0 deletions tests/unit_test/job_config/fed_job_config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile

from nvflare.job_config.fed_job_config import FedJobConfig


class TestFedJobConfig:
def test_locate_imports(self):
job_config = FedJobConfig(job_name="job_name", min_clients=1)
cwd = os.path.dirname(__file__)
source_file = os.path.join(cwd, "../data/job_config/sample_code.data")
expected = [
"from typing import Any, Dict, List",
"from nvflare.fuel.f3.drivers.base_driver import BaseDriver",
"from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo ",
"from nvflare.fuel.f3.drivers.driver_params import DriverCap",
]
with open(source_file, "r") as sf:
with tempfile.NamedTemporaryFile(dir=cwd, suffix=".py") as dest_file:
imports = list(job_config.locate_imports(sf, dest_file=dest_file.name))
assert imports == expected

0 comments on commit e37ff67

Please sign in to comment.