From 0362ccffb4f18c1bbec767d66fe648bc1eda03ae Mon Sep 17 00:00:00 2001 From: Byeongman Lee Date: Wed, 20 Dec 2023 11:34:23 +0900 Subject: [PATCH] #62 Add GitHub folder downloader (#63) --- README.md | 9 +++++++ requirements.txt | 3 +-- tools/github_download.py | 57 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tools/github_download.py diff --git a/README.md b/README.md index b5a4d4ac..aa10c6cb 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,15 @@ pip install -e . ## Quick Start +### Download config folder from netspresso-trainer + +If you want to train the trainer as a yaml file, download the config folder and use it. + +```bash +python tools/github_download.py --repo Nota-NetsPresso/netspresso-trainer --path config +``` + + ### Login To use the PyNetsPresso, please enter the email and password registered in NetsPresso. diff --git a/requirements.txt b/requirements.txt index c7a82607..232c3d97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,5 @@ requests==2.30.0 email-validator==2.0.0 pytz==2023.3 typing_extensions==4.5.0 -torch>=1.13.0 -torchvision>=0.14.0 netspresso_trainer==0.0.10 +PyGithub>=2.1.1 diff --git a/tools/github_download.py b/tools/github_download.py new file mode 100644 index 00000000..98800749 --- /dev/null +++ b/tools/github_download.py @@ -0,0 +1,57 @@ +from argparse import ArgumentParser +from pathlib import Path + +import requests +from github import ContentFile, Github, Repository +from loguru import logger + + +class GithubDownloader: + def __init__(self, repo: Repository) -> None: + self.github_client = Github() + self.repo = self.github_client.get_repo(repo) + + def download(self, content: ContentFile, out: str): + r = requests.get(content.download_url) + output_path = Path(out) / content.path + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "wb") as f: + logger.info(f"Downloading {content.path} to {output_path}") + f.write(r.content) + + def download_folder(self, folder: str, out: str, recursive: bool = True): + contents = self.repo.get_contents(folder) + for content in contents: + if content.download_url is None: + if recursive: + self.download_folder(content.path, out, recursive) + continue + self.download(content, out) + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--repo", help="The repo where the file or folder is stored") + parser.add_argument("--path", help="The folder or file you want to download") + parser.add_argument( + "-o", + "--out", + default="./", + required=False, + help="Path to folder you want to download " + "to. Default is current folder + " + "'downloads'", + ) + parser.add_argument( + "-f", + "--file", + action="store_true", + help="Set flag to download a single file, instead of a " "folder.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + github_downloader = GithubDownloader(repo=args.repo) + github_downloader.download_folder(args.path, args.out)