From b4aca780a3aec8f7af98bcaea58b63ef963928c3 Mon Sep 17 00:00:00 2001 From: Felix Mosheev <9304194+felixmosh@users.noreply.github.com> Date: Thu, 1 May 2025 19:58:36 +0300 Subject: [PATCH] fix: parse Content-Disposition properly, closes #414 --- runpod/serverless/utils/rp_download.py | 43 +++++++----- .../test_utils/test_download.py | 66 ++++++++++++------- 2 files changed, 66 insertions(+), 43 deletions(-) diff --git a/runpod/serverless/utils/rp_download.py b/runpod/serverless/utils/rp_download.py index baa0d13c..3904d65e 100644 --- a/runpod/serverless/utils/rp_download.py +++ b/runpod/serverless/utils/rp_download.py @@ -11,8 +11,7 @@ import uuid import zipfile from concurrent.futures import ThreadPoolExecutor -from email import message_from_string -from typing import List, Union +from typing import List, Union, Dict from urllib.parse import urlparse import backoff @@ -35,28 +34,36 @@ def calculate_chunk_size(file_size: int) -> int: return 1024 * 1024 * 10 # 10 MB +def extract_disposition_params(content_disposition: str) -> Dict[str, str]: + parts = (p.strip() for p in content_disposition.split(";")) + + params = { + key.strip().lower(): value.strip().strip('"') + for part in parts + if "=" in part + for key, value in [part.split("=", 1)] + } + + return params + + def download_files_from_urls(job_id: str, urls: Union[str, List[str]]) -> List[str]: """ Accepts a single URL or a list of URLs and downloads the files. Returns the list of downloaded file absolute paths. Saves the files in a directory called "downloaded_files" in the job directory. """ - download_directory = os.path.abspath( - os.path.join("jobs", job_id, "downloaded_files") - ) + download_directory = os.path.abspath(os.path.join("jobs", job_id, "downloaded_files")) os.makedirs(download_directory, exist_ok=True) @backoff.on_exception(backoff.expo, RequestException, max_tries=3) def download_file(url: str, path_to_save: str) -> str: - with SyncClientSession().get( - url, headers=HEADERS, stream=True, timeout=5 - ) as response: + with SyncClientSession().get(url, headers=HEADERS, stream=True, timeout=5) as response: response.raise_for_status() content_disposition = response.headers.get("Content-Disposition") file_extension = "" if content_disposition: - msg = message_from_string(f"Content-Disposition: {content_disposition}") - params = dict(msg.items()) + params = extract_disposition_params(content_disposition) file_extension = os.path.splitext(params.get("filename", ""))[1] # If no extension could be determined from 'Content-Disposition', get it from the URL @@ -113,15 +120,15 @@ def file(file_url: str) -> dict: download_response = SyncClientSession().get(file_url, headers=HEADERS, timeout=30) - original_file_name = [] - if "Content-Disposition" in download_response.headers.keys(): - original_file_name = re.findall( - "filename=(.+)", download_response.headers["Content-Disposition"] - ) + content_disposition = download_response.headers.get("Content-Disposition") - if len(original_file_name) > 0: - original_file_name = original_file_name[0] - else: + original_file_name = "" + if content_disposition: + params = extract_disposition_params(content_disposition) + + original_file_name = params.get("filename", "") + + if not original_file_name: download_path = urlparse(file_url).path original_file_name = os.path.basename(download_path) diff --git a/tests/test_serverless/test_utils/test_download.py b/tests/test_serverless/test_utils/test_download.py index 6678f452..10c5bcd0 100644 --- a/tests/test_serverless/test_utils/test_download.py +++ b/tests/test_serverless/test_utils/test_download.py @@ -1,4 +1,4 @@ -""" Tests for runpod | serverless | modules | download.py """ +"""Tests for runpod | serverless | modules | download.py""" # pylint: disable=R0903,W0613 @@ -17,6 +17,7 @@ URL_LIST = [ "https://example.com/picture.jpg", "https://example.com/picture.jpg?X-Amz-Signature=123", + "https://example.com/file_without_extension", ] JOB_ID = "job_123" @@ -75,9 +76,7 @@ def test_calculate_chunk_size(self): self.assertEqual(calculate_chunk_size(1024), 1024) self.assertEqual(calculate_chunk_size(1024 * 1024), 1024) self.assertEqual(calculate_chunk_size(1024 * 1024 * 1024), 1024 * 1024) - self.assertEqual( - calculate_chunk_size(1024 * 1024 * 1024 * 10), 1024 * 1024 * 10 - ) + self.assertEqual(calculate_chunk_size(1024 * 1024 * 1024 * 10), 1024 * 1024 * 10) @patch("os.makedirs", return_value=None) @patch("runpod.http_client.SyncClientSession.get", side_effect=mock_requests_get) @@ -86,29 +85,26 @@ def test_download_files_from_urls(self, mock_open_file, mock_get, mock_makedirs) """ Tests download_files_from_urls """ + urls = ["https://example.com/picture.jpg", "https://example.com/file_without_extension"] downloaded_files = download_files_from_urls( JOB_ID, - [ - "https://example.com/picture.jpg", - ], + urls, ) - self.assertEqual(len(downloaded_files), 1) + self.assertEqual(len(downloaded_files), len(urls)) - # Check that the url was called with SyncClientSession.get - self.assertIn("https://example.com/picture.jpg", mock_get.call_args_list[0][0]) + for index, url in enumerate(urls): + # Check that the url was called with SyncClientSession.get + self.assertIn(url, mock_get.call_args_list[index][0]) - # Check that the file has the correct extension - self.assertTrue(downloaded_files[0].endswith(".jpg")) + # Check that the file has the correct extension + self.assertTrue(downloaded_files[index].endswith(".jpg")) - mock_open_file.assert_called_once_with(downloaded_files[0], "wb") - mock_makedirs.assert_called_once_with( - os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True - ) + mock_open_file.assert_any_call(downloaded_files[index], "wb") - string_download_file = download_files_from_urls( - JOB_ID, "https://example.com/picture.jpg" - ) + mock_makedirs.assert_called_once_with(os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True) + + string_download_file = download_files_from_urls(JOB_ID, "https://example.com/picture.jpg") self.assertTrue(string_download_file[0].endswith(".jpg")) # Check if None is returned when url is None @@ -124,9 +120,7 @@ def test_download_files_from_urls(self, mock_open_file, mock_get, mock_makedirs) @patch("os.makedirs", return_value=None) @patch("runpod.http_client.SyncClientSession.get", side_effect=mock_requests_get) @patch("builtins.open", new_callable=mock_open) - def test_download_files_from_urls_signed( - self, mock_open_file, mock_get, mock_makedirs - ): + def test_download_files_from_urls_signed(self, mock_open_file, mock_get, mock_makedirs): """ Tests download_files_from_urls with signed urls """ @@ -147,9 +141,7 @@ def test_download_files_from_urls_signed( self.assertTrue(downloaded_files[0].endswith(".jpg")) mock_open_file.assert_called_once_with(downloaded_files[0], "wb") - mock_makedirs.assert_called_once_with( - os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True - ) + mock_makedirs.assert_called_once_with(os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True) class FileDownloaderTestCase(unittest.TestCase): @@ -179,6 +171,30 @@ def test_download_file(self, mock_file, mock_get): # Check that the file was written correctly mock_file().write.assert_called_once_with(b"file content") + @patch("runpod.serverless.utils.rp_download.SyncClientSession.get") + @patch("builtins.open", new_callable=mock_open) + def test_download_file(self, mock_file, mock_get): + """ + Tests download_file using filename from Content-Disposition + """ + # Mock the response from SyncClientSession.get + mock_response = MagicMock() + mock_response.content = b"file content" + mock_response.headers = {"Content-Disposition": 'inline; filename="test_file.txt"'} + mock_get.return_value = mock_response + + # Call the function with a test URL + result = file("http://test.com/file_without_extension") + + # Check the result + self.assertEqual(result["type"], "txt") + self.assertEqual(result["original_name"], "test_file.txt") + self.assertTrue(result["file_path"].endswith(".txt")) + self.assertIsNone(result["extracted_path"]) + + # Check that the file was written correctly + mock_file().write.assert_called_once_with(b"file content") + @patch("runpod.serverless.utils.rp_download.SyncClientSession.get") @patch("builtins.open", new_callable=mock_open) @patch("runpod.serverless.utils.rp_download.zipfile.ZipFile")