Skip to content

Commit

Permalink
Add comments and test
Browse files Browse the repository at this point in the history
  • Loading branch information
slawomir-gorawski-reef committed Jan 27, 2025
1 parent 4828214 commit 1e38772
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
5 changes: 4 additions & 1 deletion compute_horde/compute_horde/base/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ def __str__(self):
class HuggingfaceVolume(pydantic.BaseModel):
volume_type: Literal[VolumeType.huggingface_volume] = VolumeType.huggingface_volume
repo_id: str
# Set to "dataset" or "space" for a dataset or space, None or "model" for a model.
repo_type: str | None = None
revision: str | None = None # Git revision id: branch name / tag / commit hash
# Git revision id: branch name / tag / commit hash
revision: str | None = None
relative_path: str | None = None
# If provided, only files matching at least one pattern are downloaded.
allow_patterns: str | list[str] | None = None

def is_safe(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@

payload = "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(32))


def mock_download(local_dir, **kwargs):
with open(local_dir / "payload.txt", "w") as file:
file.write(payload)


in_memory_output = io.BytesIO()
zipf = zipfile.ZipFile(in_memory_output, "w")
zipf.writestr("payload.txt", payload)
Expand Down Expand Up @@ -170,10 +176,6 @@ def test_huggingface_volume():
repo_id = "huggingface/model"
revision = "main"

def mock_download(local_dir, **kwargs):
with open(local_dir / "payload.txt", "w") as file:
file.write(payload)

with patch(
"compute_horde_executor.executor.management.commands.run_executor.snapshot_download",
mock_download,
Expand Down Expand Up @@ -231,6 +233,87 @@ def mock_download(local_dir, **kwargs):
]


def test_huggingface_volume_dataset():
# Arrange
repo_id = "huggingface/dataset"
revision = "main"
repo_type = "dataset"
file_patterns = [
"default/train/001/01JJK16EFPA7HWY3Z7MWZ4A6N9.parquet",
"default/train/003/01JJK12V8K1A65RD75NSWRGECK.parquet",
"default/train/003/01JJKJ49N4NBSS3YJG35XQ9XPB.parquet",
"default/train/004/01JJKB151AFCB1TGJDCXCTBZPW.parquet",
"default/train/004/01JJKCK5DPEH61SBY8MC2NXTRM.parquet",
"default/train/004/01JJKNQADWRJYBPKKZGJHSKRSC.parquet",
"default/train/005/01JJKQ9QGPBV5KW18ZSM3VBTSX.parquet",
"default/train/008/01JJK3G51DHYK1N0JHPJQS3GFR.parquet",
]

with patch(
"compute_horde_executor.executor.management.commands.run_executor.snapshot_download",
side_effect=mock_download,
) as mock_snapshot_download:
command = CommandTested(
iter(
[
json.dumps(
{
"message_type": "V0PrepareJobRequest",
"base_docker_image_name": "backenddevelopersltd/compute-horde-job-echo:v0-latest",
"timeout_seconds": None,
"volume_type": "huggingface_volume",
"job_uuid": job_uuid,
}
),
json.dumps(
{
"message_type": "V0RunJobRequest",
"docker_image_name": "backenddevelopersltd/compute-horde-job-echo:v0-latest",
"docker_run_cmd": [],
"docker_run_options_preset": "none",
"volume": {
"volume_type": "huggingface_volume",
"repo_id": repo_id,
"repo_type": repo_type,
"revision": revision,
"allow_patterns": file_patterns,
},
"job_uuid": job_uuid,
}
),
]
)
)

# Act
command.handle()

# Assert
assert [json.loads(msg) for msg in command.miner_client_for_tests.transport.sent_messages] == [
{
"message_type": "V0ReadyRequest",
"job_uuid": job_uuid,
},
{
"message_type": "V0MachineSpecsRequest",
"specs": mock.ANY,
"job_uuid": job_uuid,
},
{
"message_type": "V0FinishedRequest",
"docker_process_stdout": payload,
"docker_process_stderr": mock.ANY,
"job_uuid": job_uuid,
},
]

_, kwargs = mock_snapshot_download.call_args
assert kwargs["repo_id"] == repo_id
assert kwargs["revision"] == revision
assert kwargs["repo_type"] == repo_type
assert kwargs["allow_patterns"] == file_patterns


def test_zip_url_volume(httpx_mock: HTTPXMock):
zip_url = "https://localhost/payload.txt"
httpx_mock.add_response(url=zip_url, content=zip_contents)
Expand Down

0 comments on commit 1e38772

Please sign in to comment.