Skip to content

Commit bc5918a

Browse files
authored
Merge branch 'master' into master
2 parents 3a03c4b + a58654e commit bc5918a

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

src/sagemaker/modules/local_core/local_container.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ class _LocalContainer(BaseModel):
108108
container_entrypoint: Optional[List[str]]
109109
container_arguments: Optional[List[str]]
110110

111+
_temporary_folders: List[str] = []
112+
111113
def model_post_init(self, __context: Any):
112114
"""Post init method to perform custom validation and set default values."""
113115
self.hosts = [f"algo-{i}" for i in range(1, self.instance_count + 1)]
@@ -201,6 +203,13 @@ def train(
201203

202204
# Print our Job Complete line
203205
logger.info("Local training job completed, output artifacts saved to %s", artifacts)
206+
207+
shutil.rmtree(os.path.join(self.container_root, "input"))
208+
shutil.rmtree(os.path.join(self.container_root, "shared"))
209+
for host in self.hosts:
210+
shutil.rmtree(os.path.join(self.container_root, host))
211+
for folder in self._temporary_folders:
212+
shutil.rmtree(os.path.join(self.container_root, folder))
204213
return artifacts
205214

206215
def retrieve_artifacts(
@@ -540,6 +549,7 @@ def _get_data_source_local_path(self, data_source: DataSource):
540549
uri = data_source.s3_data_source.s3_uri
541550
parsed_uri = urlparse(uri)
542551
local_dir = TemporaryDirectory(prefix=os.path.join(self.container_root + "/")).name
552+
self._temporary_folders.append(local_dir)
543553
download_folder(parsed_uri.netloc, parsed_uri.path, local_dir, self.sagemaker_session)
544554
return local_dir
545555
else:

tests/integ/sagemaker/modules/train/test_local_model_trainer.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,7 @@ def test_single_container_local_mode_local_data(modules_sagemaker_session):
9292
"compressed_artifacts",
9393
"artifacts",
9494
"model",
95-
"shared",
96-
"input",
9795
"output",
98-
"algo-1",
9996
]
10097

10198
for directory in directories:
@@ -149,14 +146,16 @@ def test_single_container_local_mode_s3_data(modules_sagemaker_session):
149146
assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz"))
150147
finally:
151148
subprocess.run(["docker", "compose", "down", "-v"])
149+
150+
assert not os.path.exists(os.path.join(CWD, "shared"))
151+
assert not os.path.exists(os.path.join(CWD, "input"))
152+
assert not os.path.exists(os.path.join(CWD, "algo-1"))
153+
152154
directories = [
153155
"compressed_artifacts",
154156
"artifacts",
155157
"model",
156-
"shared",
157-
"input",
158158
"output",
159-
"algo-1",
160159
]
161160

162161
for directory in directories:
@@ -204,20 +203,20 @@ def test_multi_container_local_mode(modules_sagemaker_session):
204203

205204
model_trainer.train()
206205
assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz"))
207-
assert os.path.exists(os.path.join(CWD, "algo-1"))
208-
assert os.path.exists(os.path.join(CWD, "algo-2"))
209206

210207
finally:
211208
subprocess.run(["docker", "compose", "down", "-v"])
209+
210+
assert not os.path.exists(os.path.join(CWD, "shared"))
211+
assert not os.path.exists(os.path.join(CWD, "input"))
212+
assert not os.path.exists(os.path.join(CWD, "algo-1"))
213+
assert not os.path.exists(os.path.join(CWD, "algo-2"))
214+
212215
directories = [
213216
"compressed_artifacts",
214217
"artifacts",
215218
"model",
216-
"shared",
217-
"input",
218219
"output",
219-
"algo-1",
220-
"algo-2",
221220
]
222221

223222
for directory in directories:

0 commit comments

Comments
 (0)