|
8 | 8 | import os
|
9 | 9 | import pathlib
|
10 | 10 | import re
|
| 11 | +import requests |
11 | 12 | import tempfile
|
12 | 13 | import time
|
13 | 14 | import urllib.error
|
14 |
| -import urllib.request |
15 | 15 | from urllib.parse import urlparse
|
16 | 16 | from dataclasses import dataclass
|
17 | 17 | from typing import TYPE_CHECKING, Optional
|
@@ -45,53 +45,70 @@ def get_device(device_str=None) -> torch.device:
|
45 | 45 |
|
46 | 46 |
|
47 | 47 | def get_or_download_file(
|
48 |
| - path_or_url, destination_dir=None, prefix=None, suffix=None |
| 48 | + path, destination_dir=None, prefix=None, suffix=None |
49 | 49 | ) -> pathlib.Path:
|
50 | 50 | """
|
51 |
| - Fetch a file from a URL or local path. If the path is a URL, download the file. |
52 |
| - If the URL has already been downloaded, return the existing local path. |
53 |
| - If the path is a local path, return the path. |
54 |
| -
|
55 |
| - >>> filepath = get_or_download_file("https://example.uk/images/31-20230919033000-snapshot.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=451d406b7eb1113e1bb05c083ce51481%2F20240429%2F") |
56 |
| - >>> filepath.name |
57 |
| - '31-20230919033000-snapshot.jpg' |
58 |
| - >>> filepath = get_or_download_file("/home/user/images/31-20230919033000-snapshot.jpg") |
59 |
| - >>> filepath.name |
60 |
| - '31-20230919033000-snapshot.jpg' |
| 51 | + Get or download a file from a given path or URL using the requests library. |
| 52 | +
|
| 53 | + Args: |
| 54 | + path (str): URL or local path to the file |
| 55 | + destination_dir (str, optional): Directory to save the downloaded file |
| 56 | + prefix (str, optional): Prefix to add to the destination directory |
| 57 | + suffix (str, optional): Suffix to add to the filename |
| 58 | +
|
| 59 | + Returns: |
| 60 | + pathlib.Path: Path to the local file |
| 61 | +
|
| 62 | + >>> filename = get_or_download_file("https://example.com/file with spaces.zip") |
61 | 63 | """
|
62 |
| - if not path_or_url: |
63 |
| - raise Exception("Specify a URL or path to fetch file from.") |
| 64 | + if not path: |
| 65 | + raise ValueError("Specify a URL or path to fetch file from.") |
| 66 | + |
| 67 | + # If path is a local path instead of a URL, just return it |
| 68 | + if os.path.exists(path): |
| 69 | + return pathlib.Path(path) |
64 | 70 |
|
65 |
| - # If path is a local path instead of a URL then urlretrieve will just return that path |
66 | 71 | destination_dir = destination_dir or os.environ.get("LOCAL_WEIGHTS_PATH")
|
67 |
| - fname = pathlib.Path(urlparse(path_or_url).path).name |
68 |
| - if destination_dir: |
69 |
| - destination_dir = pathlib.Path(destination_dir) |
70 |
| - if prefix: |
71 |
| - destination_dir = destination_dir / prefix |
72 |
| - if not destination_dir.exists(): |
73 |
| - logger.info(f"Creating local directory {str(destination_dir)}") |
74 |
| - destination_dir.mkdir(parents=True, exist_ok=True) |
75 |
| - local_filepath = pathlib.Path(destination_dir) / fname |
76 |
| - if suffix: |
77 |
| - local_filepath = local_filepath.with_suffix(suffix) |
78 |
| - else: |
79 |
| - raise Exception( |
| 72 | + |
| 73 | + if not destination_dir: |
| 74 | + raise ValueError( |
80 | 75 | "No destination directory specified by LOCAL_WEIGHTS_PATH or app settings."
|
81 | 76 | )
|
82 | 77 |
|
83 |
| - if local_filepath and local_filepath.exists(): |
| 78 | + destination_dir = pathlib.Path(destination_dir) |
| 79 | + if prefix: |
| 80 | + destination_dir = destination_dir / prefix |
| 81 | + if not destination_dir.exists(): |
| 82 | + logger.info(f"Creating local directory {str(destination_dir)}") |
| 83 | + destination_dir.mkdir(parents=True, exist_ok=True) |
| 84 | + |
| 85 | + # Extract filename from URL |
| 86 | + fname = path.split("/")[-1] |
| 87 | + local_filepath = destination_dir / fname |
| 88 | + |
| 89 | + if suffix: |
| 90 | + local_filepath = local_filepath.with_suffix(suffix) |
| 91 | + |
| 92 | + if local_filepath.exists(): |
84 | 93 | logger.info(f"Using existing {local_filepath}")
|
85 | 94 | return local_filepath
|
86 | 95 |
|
87 |
| - else: |
88 |
| - logger.info(f"Downloading {path_or_url} to {local_filepath}") |
89 |
| - resulting_filepath, headers = urllib.request.urlretrieve( |
90 |
| - url=path_or_url, filename=local_filepath |
91 |
| - ) |
92 |
| - resulting_filepath = pathlib.Path(resulting_filepath) |
93 |
| - logger.info(f"Downloaded to {resulting_filepath}") |
94 |
| - return resulting_filepath |
| 96 | + logger.info(f"Downloading {path} to {local_filepath}") |
| 97 | + |
| 98 | + try: |
| 99 | + response = requests.get(path, stream=True) |
| 100 | + response.raise_for_status() # Raises an HTTPError for bad responses |
| 101 | + |
| 102 | + with open(local_filepath, 'wb') as file: |
| 103 | + for chunk in response.iter_content(chunk_size=8192): |
| 104 | + file.write(chunk) |
| 105 | + |
| 106 | + logger.info(f"Downloaded to {local_filepath}") |
| 107 | + return local_filepath |
| 108 | + |
| 109 | + except requests.exceptions.RequestException as e: |
| 110 | + logger.error(f"Error downloading file: {e}") |
| 111 | + raise |
95 | 112 |
|
96 | 113 |
|
97 | 114 | def decode_base64_string(string) -> io.BytesIO:
|
|
0 commit comments