-
Notifications
You must be signed in to change notification settings - Fork 0
/
kaggle_downloader.py
147 lines (116 loc) · 4.98 KB
/
kaggle_downloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
from pathlib import Path
import yaml
import logging
from typing import Optional
import kaggle
from zipfile import ZipFile
class KaggleDataDownloader:
"""Download and manage datasets from Kaggle."""
def __init__(self, config_path: str = "config/config.yaml") -> None:
"""
Initialize downloader with configuration.
Args:
config_path: Path to configuration file
"""
self.logger = logging.getLogger(self.__class__.__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
))
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)
self._check_kaggle_credentials()
self.download_path = Path(self.config['kaggle']['download_path'])
self.download_path.mkdir(parents=True, exist_ok=True)
def _check_kaggle_credentials(self) -> None:
"""
Verify Kaggle API credentials are properly set up.
You need to place kaggle.json in ~/.kaggle/ or set KAGGLE_USERNAME and
KAGGLE_KEY environment variables.
"""
if not os.path.exists(os.path.expanduser('~/.kaggle/kaggle.json')):
if not (os.getenv('KAGGLE_USERNAME') and os.getenv('KAGGLE_KEY')):
raise ValueError(
"Kaggle API credentials not found. Please follow these steps:\n"
"1. Go to https://kaggle.com/account\n"
"2. Click on 'Create API Token'\n"
"3. Move the downloaded kaggle.json to ~/.kaggle/\n"
" OR set KAGGLE_USERNAME and KAGGLE_KEY environment variables"
)
def download_malimg(self) -> Path:
"""
Download Malimg dataset from Kaggle.
Returns:
Path: Path to downloaded dataset
"""
dataset_name = self.config['kaggle']['dataset_name']
try:
if self._check_dataset_exists():
self.logger.info("Dataset already downloaded")
return self.download_path
self.logger.info(f"Downloading dataset: {dataset_name}")
kaggle.api.dataset_download_files(
dataset_name,
path = self.download_path,
unzip = True
)
self.logger.info(f"Dataset downloaded successfully to {self.download_path}")
self._validate_download()
return self.download_path
except Exception as e:
self.logger.info(f"Error downloading dataset: {str(e)}")
raise
def _check_dataset_exists(self) -> bool:
"""
Check if dataset is already downloaded.
Returns:
bool: True if dataset exists
"""
expected_files = ['malimg_dataset.zip', 'malimg_paper.pdf']
return all((self.download_path / f).exists() for f in expected_files)
def _validate_download(self) -> None:
"""
Validate downloaded dataset structure and content.
Expected structure:
data/raw/malimg/
└── malimg_dataset/
├── train/
├── val/
└── test/
Raises:
ValueError: If dataset structure is invalid
"""
self.logger.info("Starting dataset validation...")
dataset_folder = next(self.download_path.glob("malimg_dataset"), None)
if not dataset_folder:
raise ValueError("Dataset folder 'malimg_dataset' not found")
splits = ['train', 'val', 'test']
for split in splits:
split_path = dataset_folder / split
if not split_path.exists():
raise ValueError(f"Missing {split} directory in dataset")
if not any(split_path.rglob('*.png')):
raise ValueError(f"No image files found in {split} directory")
classes = [d.name for d in split_path.iterdir() if d.is_dir()]
if len(classes) != self.config['dataset']['num_classes']:
raise ValueError(
f"Expected {self.config['dataset']['num_classes']} classes in {split} split, "
f"but found {len(classes)}"
)
self.logger.info("Dataset validation successful")
self.logger.info(f"Found all required splits: {splits}")
self.logger.info(f"Number of classes per split: {self.config['dataset']['num_classes']}")
def setup_dataset():
"""
Utility function to download and setup the dataset.
Returns:
Path: Path to dataset
"""
downloader = KaggleDataDownloader()
return downloader.download_malimg()
if __name__ == "__main__":
logging.basicConfig(level = logging.INFO)
setup_dataset()