Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions runpod/serverless/utils/rp_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import uuid
import zipfile
from concurrent.futures import ThreadPoolExecutor
from email import message_from_string
from typing import List, Union
from typing import List, Union, Dict
from urllib.parse import urlparse

import backoff
Expand All @@ -35,28 +34,36 @@ def calculate_chunk_size(file_size: int) -> int:
return 1024 * 1024 * 10 # 10 MB


def extract_disposition_params(content_disposition: str) -> Dict[str, str]:
parts = (p.strip() for p in content_disposition.split(";"))

params = {
key.strip().lower(): value.strip().strip('"')
for part in parts
if "=" in part
for key, value in [part.split("=", 1)]
}

return params


def download_files_from_urls(job_id: str, urls: Union[str, List[str]]) -> List[str]:
"""
Accepts a single URL or a list of URLs and downloads the files.
Returns the list of downloaded file absolute paths.
Saves the files in a directory called "downloaded_files" in the job directory.
"""
download_directory = os.path.abspath(
os.path.join("jobs", job_id, "downloaded_files")
)
download_directory = os.path.abspath(os.path.join("jobs", job_id, "downloaded_files"))
os.makedirs(download_directory, exist_ok=True)

@backoff.on_exception(backoff.expo, RequestException, max_tries=3)
def download_file(url: str, path_to_save: str) -> str:
with SyncClientSession().get(
url, headers=HEADERS, stream=True, timeout=5
) as response:
with SyncClientSession().get(url, headers=HEADERS, stream=True, timeout=5) as response:
response.raise_for_status()
content_disposition = response.headers.get("Content-Disposition")
file_extension = ""
if content_disposition:
msg = message_from_string(f"Content-Disposition: {content_disposition}")
params = dict(msg.items())
params = extract_disposition_params(content_disposition)
file_extension = os.path.splitext(params.get("filename", ""))[1]

# If no extension could be determined from 'Content-Disposition', get it from the URL
Expand Down Expand Up @@ -113,15 +120,15 @@ def file(file_url: str) -> dict:

download_response = SyncClientSession().get(file_url, headers=HEADERS, timeout=30)

original_file_name = []
if "Content-Disposition" in download_response.headers.keys():
original_file_name = re.findall(
"filename=(.+)", download_response.headers["Content-Disposition"]
)
content_disposition = download_response.headers.get("Content-Disposition")

if len(original_file_name) > 0:
original_file_name = original_file_name[0]
else:
original_file_name = ""
if content_disposition:
params = extract_disposition_params(content_disposition)

original_file_name = params.get("filename", "")

if not original_file_name:
download_path = urlparse(file_url).path
original_file_name = os.path.basename(download_path)

Expand Down
66 changes: 41 additions & 25 deletions tests/test_serverless/test_utils/test_download.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Tests for runpod | serverless | modules | download.py """
"""Tests for runpod | serverless | modules | download.py"""

# pylint: disable=R0903,W0613

Expand All @@ -17,6 +17,7 @@
URL_LIST = [
"https://example.com/picture.jpg",
"https://example.com/picture.jpg?X-Amz-Signature=123",
"https://example.com/file_without_extension",
]

JOB_ID = "job_123"
Expand Down Expand Up @@ -75,9 +76,7 @@ def test_calculate_chunk_size(self):
self.assertEqual(calculate_chunk_size(1024), 1024)
self.assertEqual(calculate_chunk_size(1024 * 1024), 1024)
self.assertEqual(calculate_chunk_size(1024 * 1024 * 1024), 1024 * 1024)
self.assertEqual(
calculate_chunk_size(1024 * 1024 * 1024 * 10), 1024 * 1024 * 10
)
self.assertEqual(calculate_chunk_size(1024 * 1024 * 1024 * 10), 1024 * 1024 * 10)

@patch("os.makedirs", return_value=None)
@patch("runpod.http_client.SyncClientSession.get", side_effect=mock_requests_get)
Expand All @@ -86,29 +85,26 @@ def test_download_files_from_urls(self, mock_open_file, mock_get, mock_makedirs)
"""
Tests download_files_from_urls
"""
urls = ["https://example.com/picture.jpg", "https://example.com/file_without_extension"]
downloaded_files = download_files_from_urls(
JOB_ID,
[
"https://example.com/picture.jpg",
],
urls,
)

self.assertEqual(len(downloaded_files), 1)
self.assertEqual(len(downloaded_files), len(urls))

# Check that the url was called with SyncClientSession.get
self.assertIn("https://example.com/picture.jpg", mock_get.call_args_list[0][0])
for index, url in enumerate(urls):
# Check that the url was called with SyncClientSession.get
self.assertIn(url, mock_get.call_args_list[index][0])

# Check that the file has the correct extension
self.assertTrue(downloaded_files[0].endswith(".jpg"))
# Check that the file has the correct extension
self.assertTrue(downloaded_files[index].endswith(".jpg"))

mock_open_file.assert_called_once_with(downloaded_files[0], "wb")
mock_makedirs.assert_called_once_with(
os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True
)
mock_open_file.assert_any_call(downloaded_files[index], "wb")

string_download_file = download_files_from_urls(
JOB_ID, "https://example.com/picture.jpg"
)
mock_makedirs.assert_called_once_with(os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True)

string_download_file = download_files_from_urls(JOB_ID, "https://example.com/picture.jpg")
self.assertTrue(string_download_file[0].endswith(".jpg"))

# Check if None is returned when url is None
Expand All @@ -124,9 +120,7 @@ def test_download_files_from_urls(self, mock_open_file, mock_get, mock_makedirs)
@patch("os.makedirs", return_value=None)
@patch("runpod.http_client.SyncClientSession.get", side_effect=mock_requests_get)
@patch("builtins.open", new_callable=mock_open)
def test_download_files_from_urls_signed(
self, mock_open_file, mock_get, mock_makedirs
):
def test_download_files_from_urls_signed(self, mock_open_file, mock_get, mock_makedirs):
"""
Tests download_files_from_urls with signed urls
"""
Expand All @@ -147,9 +141,7 @@ def test_download_files_from_urls_signed(
self.assertTrue(downloaded_files[0].endswith(".jpg"))

mock_open_file.assert_called_once_with(downloaded_files[0], "wb")
mock_makedirs.assert_called_once_with(
os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True
)
mock_makedirs.assert_called_once_with(os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True)


class FileDownloaderTestCase(unittest.TestCase):
Expand Down Expand Up @@ -179,6 +171,30 @@ def test_download_file(self, mock_file, mock_get):
# Check that the file was written correctly
mock_file().write.assert_called_once_with(b"file content")

@patch("runpod.serverless.utils.rp_download.SyncClientSession.get")
@patch("builtins.open", new_callable=mock_open)
def test_download_file(self, mock_file, mock_get):
"""
Tests download_file using filename from Content-Disposition
"""
# Mock the response from SyncClientSession.get
mock_response = MagicMock()
mock_response.content = b"file content"
mock_response.headers = {"Content-Disposition": 'inline; filename="test_file.txt"'}
mock_get.return_value = mock_response

# Call the function with a test URL
result = file("http://test.com/file_without_extension")

# Check the result
self.assertEqual(result["type"], "txt")
self.assertEqual(result["original_name"], "test_file.txt")
self.assertTrue(result["file_path"].endswith(".txt"))
self.assertIsNone(result["extracted_path"])

# Check that the file was written correctly
mock_file().write.assert_called_once_with(b"file content")

@patch("runpod.serverless.utils.rp_download.SyncClientSession.get")
@patch("builtins.open", new_callable=mock_open)
@patch("runpod.serverless.utils.rp_download.zipfile.ZipFile")
Expand Down
Loading