Skip to content

Commit 6af15f3

Browse files
Add comments and test
1 parent 4828214 commit 6af15f3

File tree

2 files changed

+93
-6
lines changed

2 files changed

+93
-6
lines changed

compute_horde/compute_horde/base/volume.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@ def __str__(self):
2323
class HuggingfaceVolume(pydantic.BaseModel):
2424
volume_type: Literal[VolumeType.huggingface_volume] = VolumeType.huggingface_volume
2525
repo_id: str
26-
repo_type: str | None = None
26+
repo_type: str | None = (
27+
None # Set to "dataset" or "space" for a dataset or space, None or "model" if uploading to a model.
28+
)
2729
revision: str | None = None # Git revision id: branch name / tag / commit hash
2830
relative_path: str | None = None
29-
allow_patterns: str | list[str] | None = None
31+
allow_patterns: str | list[str] | None = (
32+
None # If provided, only files matching at least one pattern are downloaded.
33+
)
3034

3135
def is_safe(self) -> bool:
3236
return True

executor/app/src/compute_horde_executor/executor/tests/integration/test_main_loop.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020

2121
payload = "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(32))
2222

23+
24+
def mock_download(local_dir, **kwargs):
25+
with open(local_dir / "payload.txt", "w") as file:
26+
file.write(payload)
27+
28+
2329
in_memory_output = io.BytesIO()
2430
zipf = zipfile.ZipFile(in_memory_output, "w")
2531
zipf.writestr("payload.txt", payload)
@@ -170,10 +176,6 @@ def test_huggingface_volume():
170176
repo_id = "huggingface/model"
171177
revision = "main"
172178

173-
def mock_download(local_dir, **kwargs):
174-
with open(local_dir / "payload.txt", "w") as file:
175-
file.write(payload)
176-
177179
with patch(
178180
"compute_horde_executor.executor.management.commands.run_executor.snapshot_download",
179181
mock_download,
@@ -231,6 +233,87 @@ def mock_download(local_dir, **kwargs):
231233
]
232234

233235

236+
def test_huggingface_volume_dataset():
237+
# Arrange
238+
repo_id = "huggingface/model"
239+
revision = "main"
240+
repo_type = "dataset"
241+
file_patterns = [
242+
"default/train/001/01JJK16EFPA7HWY3Z7MWZ4A6N9.parquet",
243+
"default/train/003/01JJK12V8K1A65RD75NSWRGECK.parquet",
244+
"default/train/003/01JJKJ49N4NBSS3YJG35XQ9XPB.parquet",
245+
"default/train/004/01JJKB151AFCB1TGJDCXCTBZPW.parquet",
246+
"default/train/004/01JJKCK5DPEH61SBY8MC2NXTRM.parquet",
247+
"default/train/004/01JJKNQADWRJYBPKKZGJHSKRSC.parquet",
248+
"default/train/005/01JJKQ9QGPBV5KW18ZSM3VBTSX.parquet",
249+
"default/train/008/01JJK3G51DHYK1N0JHPJQS3GFR.parquet",
250+
]
251+
252+
with patch(
253+
"compute_horde_executor.executor.management.commands.run_executor.snapshot_download",
254+
side_effect=mock_download,
255+
) as mock_snapshot_download:
256+
command = CommandTested(
257+
iter(
258+
[
259+
json.dumps(
260+
{
261+
"message_type": "V0PrepareJobRequest",
262+
"base_docker_image_name": "backenddevelopersltd/compute-horde-job-echo:v0-latest",
263+
"timeout_seconds": None,
264+
"volume_type": "huggingface_volume",
265+
"job_uuid": job_uuid,
266+
}
267+
),
268+
json.dumps(
269+
{
270+
"message_type": "V0RunJobRequest",
271+
"docker_image_name": "backenddevelopersltd/compute-horde-job-echo:v0-latest",
272+
"docker_run_cmd": [],
273+
"docker_run_options_preset": "none",
274+
"volume": {
275+
"volume_type": "huggingface_volume",
276+
"repo_id": repo_id,
277+
"repo_type": repo_type,
278+
"revision": revision,
279+
"allow_patterns": file_patterns,
280+
},
281+
"job_uuid": job_uuid,
282+
}
283+
),
284+
]
285+
)
286+
)
287+
288+
# Act
289+
command.handle()
290+
291+
# Assert
292+
assert [json.loads(msg) for msg in command.miner_client_for_tests.transport.sent_messages] == [
293+
{
294+
"message_type": "V0ReadyRequest",
295+
"job_uuid": job_uuid,
296+
},
297+
{
298+
"message_type": "V0MachineSpecsRequest",
299+
"specs": mock.ANY,
300+
"job_uuid": job_uuid,
301+
},
302+
{
303+
"message_type": "V0FinishedRequest",
304+
"docker_process_stdout": payload,
305+
"docker_process_stderr": mock.ANY,
306+
"job_uuid": job_uuid,
307+
},
308+
]
309+
310+
_, kwargs = mock_snapshot_download.call_args
311+
assert kwargs["repo_id"] == repo_id
312+
assert kwargs["revision"] == revision
313+
assert kwargs["repo_type"] == repo_type
314+
assert kwargs["allow_patterns"] == file_patterns
315+
316+
234317
def test_zip_url_volume(httpx_mock: HTTPXMock):
235318
zip_url = "https://localhost/payload.txt"
236319
httpx_mock.add_response(url=zip_url, content=zip_contents)

0 commit comments

Comments
 (0)