|
20 | 20 |
|
21 | 21 | payload = "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(32))
|
22 | 22 |
|
| 23 | + |
| 24 | +def mock_download(local_dir, **kwargs): |
| 25 | + with open(local_dir / "payload.txt", "w") as file: |
| 26 | + file.write(payload) |
| 27 | + |
| 28 | + |
23 | 29 | in_memory_output = io.BytesIO()
|
24 | 30 | zipf = zipfile.ZipFile(in_memory_output, "w")
|
25 | 31 | zipf.writestr("payload.txt", payload)
|
@@ -170,10 +176,6 @@ def test_huggingface_volume():
|
170 | 176 | repo_id = "huggingface/model"
|
171 | 177 | revision = "main"
|
172 | 178 |
|
173 |
| - def mock_download(local_dir, **kwargs): |
174 |
| - with open(local_dir / "payload.txt", "w") as file: |
175 |
| - file.write(payload) |
176 |
| - |
177 | 179 | with patch(
|
178 | 180 | "compute_horde_executor.executor.management.commands.run_executor.snapshot_download",
|
179 | 181 | mock_download,
|
@@ -231,6 +233,87 @@ def mock_download(local_dir, **kwargs):
|
231 | 233 | ]
|
232 | 234 |
|
233 | 235 |
|
| 236 | +def test_huggingface_volume_dataset(): |
| 237 | + # Arrange |
| 238 | + repo_id = "huggingface/model" |
| 239 | + revision = "main" |
| 240 | + repo_type = "dataset" |
| 241 | + file_patterns = [ |
| 242 | + "default/train/001/01JJK16EFPA7HWY3Z7MWZ4A6N9.parquet", |
| 243 | + "default/train/003/01JJK12V8K1A65RD75NSWRGECK.parquet", |
| 244 | + "default/train/003/01JJKJ49N4NBSS3YJG35XQ9XPB.parquet", |
| 245 | + "default/train/004/01JJKB151AFCB1TGJDCXCTBZPW.parquet", |
| 246 | + "default/train/004/01JJKCK5DPEH61SBY8MC2NXTRM.parquet", |
| 247 | + "default/train/004/01JJKNQADWRJYBPKKZGJHSKRSC.parquet", |
| 248 | + "default/train/005/01JJKQ9QGPBV5KW18ZSM3VBTSX.parquet", |
| 249 | + "default/train/008/01JJK3G51DHYK1N0JHPJQS3GFR.parquet", |
| 250 | + ] |
| 251 | + |
| 252 | + with patch( |
| 253 | + "compute_horde_executor.executor.management.commands.run_executor.snapshot_download", |
| 254 | + side_effect=mock_download, |
| 255 | + ) as mock_snapshot_download: |
| 256 | + command = CommandTested( |
| 257 | + iter( |
| 258 | + [ |
| 259 | + json.dumps( |
| 260 | + { |
| 261 | + "message_type": "V0PrepareJobRequest", |
| 262 | + "base_docker_image_name": "backenddevelopersltd/compute-horde-job-echo:v0-latest", |
| 263 | + "timeout_seconds": None, |
| 264 | + "volume_type": "huggingface_volume", |
| 265 | + "job_uuid": job_uuid, |
| 266 | + } |
| 267 | + ), |
| 268 | + json.dumps( |
| 269 | + { |
| 270 | + "message_type": "V0RunJobRequest", |
| 271 | + "docker_image_name": "backenddevelopersltd/compute-horde-job-echo:v0-latest", |
| 272 | + "docker_run_cmd": [], |
| 273 | + "docker_run_options_preset": "none", |
| 274 | + "volume": { |
| 275 | + "volume_type": "huggingface_volume", |
| 276 | + "repo_id": repo_id, |
| 277 | + "repo_type": repo_type, |
| 278 | + "revision": revision, |
| 279 | + "allow_patterns": file_patterns, |
| 280 | + }, |
| 281 | + "job_uuid": job_uuid, |
| 282 | + } |
| 283 | + ), |
| 284 | + ] |
| 285 | + ) |
| 286 | + ) |
| 287 | + |
| 288 | + # Act |
| 289 | + command.handle() |
| 290 | + |
| 291 | + # Assert |
| 292 | + assert [json.loads(msg) for msg in command.miner_client_for_tests.transport.sent_messages] == [ |
| 293 | + { |
| 294 | + "message_type": "V0ReadyRequest", |
| 295 | + "job_uuid": job_uuid, |
| 296 | + }, |
| 297 | + { |
| 298 | + "message_type": "V0MachineSpecsRequest", |
| 299 | + "specs": mock.ANY, |
| 300 | + "job_uuid": job_uuid, |
| 301 | + }, |
| 302 | + { |
| 303 | + "message_type": "V0FinishedRequest", |
| 304 | + "docker_process_stdout": payload, |
| 305 | + "docker_process_stderr": mock.ANY, |
| 306 | + "job_uuid": job_uuid, |
| 307 | + }, |
| 308 | + ] |
| 309 | + |
| 310 | + _, kwargs = mock_snapshot_download.call_args |
| 311 | + assert kwargs["repo_id"] == repo_id |
| 312 | + assert kwargs["revision"] == revision |
| 313 | + assert kwargs["repo_type"] == repo_type |
| 314 | + assert kwargs["allow_patterns"] == file_patterns |
| 315 | + |
| 316 | + |
234 | 317 | def test_zip_url_volume(httpx_mock: HTTPXMock):
|
235 | 318 | zip_url = "https://localhost/payload.txt"
|
236 | 319 | httpx_mock.add_response(url=zip_url, content=zip_contents)
|
|
0 commit comments