Skip to content

Commit 1cf2a1c

Browse files
authored
fix: handle bad urls with the requests library
1 parent c843686 commit 1cf2a1c

File tree

1 file changed

+54
-37
lines changed

1 file changed

+54
-37
lines changed

trapdata/ml/utils.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import os
99
import pathlib
1010
import re
11+
import requests
1112
import tempfile
1213
import time
1314
import urllib.error
14-
import urllib.request
1515
from urllib.parse import urlparse
1616
from dataclasses import dataclass
1717
from typing import TYPE_CHECKING, Optional
@@ -45,53 +45,70 @@ def get_device(device_str=None) -> torch.device:
4545

4646

4747
def get_or_download_file(
48-
path_or_url, destination_dir=None, prefix=None, suffix=None
48+
path, destination_dir=None, prefix=None, suffix=None
4949
) -> pathlib.Path:
5050
"""
51-
Fetch a file from a URL or local path. If the path is a URL, download the file.
52-
If the URL has already been downloaded, return the existing local path.
53-
If the path is a local path, return the path.
54-
55-
>>> filepath = get_or_download_file("https://example.uk/images/31-20230919033000-snapshot.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=451d406b7eb1113e1bb05c083ce51481%2F20240429%2F")
56-
>>> filepath.name
57-
'31-20230919033000-snapshot.jpg'
58-
>>> filepath = get_or_download_file("/home/user/images/31-20230919033000-snapshot.jpg")
59-
>>> filepath.name
60-
'31-20230919033000-snapshot.jpg'
51+
Get or download a file from a given path or URL using the requests library.
52+
53+
Args:
54+
path (str): URL or local path to the file
55+
destination_dir (str, optional): Directory to save the downloaded file
56+
prefix (str, optional): Prefix to add to the destination directory
57+
suffix (str, optional): Suffix to add to the filename
58+
59+
Returns:
60+
pathlib.Path: Path to the local file
61+
62+
>>> filename = get_or_download_file("https://example.com/file with spaces.zip")
6163
"""
62-
if not path_or_url:
63-
raise Exception("Specify a URL or path to fetch file from.")
64+
if not path:
65+
raise ValueError("Specify a URL or path to fetch file from.")
66+
67+
# If path is a local path instead of a URL, just return it
68+
if os.path.exists(path):
69+
return pathlib.Path(path)
6470

65-
# If path is a local path instead of a URL then urlretrieve will just return that path
6671
destination_dir = destination_dir or os.environ.get("LOCAL_WEIGHTS_PATH")
67-
fname = pathlib.Path(urlparse(path_or_url).path).name
68-
if destination_dir:
69-
destination_dir = pathlib.Path(destination_dir)
70-
if prefix:
71-
destination_dir = destination_dir / prefix
72-
if not destination_dir.exists():
73-
logger.info(f"Creating local directory {str(destination_dir)}")
74-
destination_dir.mkdir(parents=True, exist_ok=True)
75-
local_filepath = pathlib.Path(destination_dir) / fname
76-
if suffix:
77-
local_filepath = local_filepath.with_suffix(suffix)
78-
else:
79-
raise Exception(
72+
73+
if not destination_dir:
74+
raise ValueError(
8075
"No destination directory specified by LOCAL_WEIGHTS_PATH or app settings."
8176
)
8277

83-
if local_filepath and local_filepath.exists():
78+
destination_dir = pathlib.Path(destination_dir)
79+
if prefix:
80+
destination_dir = destination_dir / prefix
81+
if not destination_dir.exists():
82+
logger.info(f"Creating local directory {str(destination_dir)}")
83+
destination_dir.mkdir(parents=True, exist_ok=True)
84+
85+
# Extract filename from URL
86+
fname = path.split("/")[-1]
87+
local_filepath = destination_dir / fname
88+
89+
if suffix:
90+
local_filepath = local_filepath.with_suffix(suffix)
91+
92+
if local_filepath.exists():
8493
logger.info(f"Using existing {local_filepath}")
8594
return local_filepath
8695

87-
else:
88-
logger.info(f"Downloading {path_or_url} to {local_filepath}")
89-
resulting_filepath, headers = urllib.request.urlretrieve(
90-
url=path_or_url, filename=local_filepath
91-
)
92-
resulting_filepath = pathlib.Path(resulting_filepath)
93-
logger.info(f"Downloaded to {resulting_filepath}")
94-
return resulting_filepath
96+
logger.info(f"Downloading {path} to {local_filepath}")
97+
98+
try:
99+
response = requests.get(path, stream=True)
100+
response.raise_for_status() # Raises an HTTPError for bad responses
101+
102+
with open(local_filepath, 'wb') as file:
103+
for chunk in response.iter_content(chunk_size=8192):
104+
file.write(chunk)
105+
106+
logger.info(f"Downloaded to {local_filepath}")
107+
return local_filepath
108+
109+
except requests.exceptions.RequestException as e:
110+
logger.error(f"Error downloading file: {e}")
111+
raise
95112

96113

97114
def decode_base64_string(string) -> io.BytesIO:

0 commit comments

Comments
 (0)