From d37e96a241b5bedee2bebbe70898bf89abbd8491 Mon Sep 17 00:00:00 2001 From: Yi-Lun Wu Date: Wed, 9 Nov 2022 23:42:35 +0800 Subject: [PATCH] 1. use tqdm.auto instead of tqdm 2. use param name for calc_fid_stats --- README.md | 4 +++- pytorch_gan_metrics/calc_fid_stats.py | 4 ++-- pytorch_gan_metrics/core.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 14d1fdd..718694f 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,9 @@ The results are slightly different from official implementations due to the fram - [Download](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC?usp=sharing) precalculated statistics or - Calculate statistics for your custom dataset using command line tool ```bash - python -m pytorch_gan_metrics.calc_fid_stats path/to/images path/to/statistics.npz + python -m pytorch_gan_metrics.calc_fid_stats \ + --path path/to/images \ + --stats path/to/statistics.npz ``` See [calc_fid_stats.py](./pytorch_gan_metrics/calc_fid_stats.py) for details. diff --git a/pytorch_gan_metrics/calc_fid_stats.py b/pytorch_gan_metrics/calc_fid_stats.py index d1ab65d..5876d54 100644 --- a/pytorch_gan_metrics/calc_fid_stats.py +++ b/pytorch_gan_metrics/calc_fid_stats.py @@ -10,9 +10,9 @@ parser = argparse.ArgumentParser( "A handy cli tool to calculate FID statistics.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("path", type=str, + parser.add_argument("--path", type=str, required=True, help='path to image directory (include subfolders)') - parser.add_argument("output", type=str, + parser.add_argument("--output", type=str, required=True, help="output path") parser.add_argument("--batch_size", type=int, default=50, help="batch size") diff --git a/pytorch_gan_metrics/core.py b/pytorch_gan_metrics/core.py index abb953a..8fcbb64 100644 --- a/pytorch_gan_metrics/core.py +++ b/pytorch_gan_metrics/core.py @@ -5,7 +5,7 @@ import numpy as np import torch from scipy import linalg -from tqdm import tqdm +from tqdm.auto import tqdm from torch.utils.data import DataLoader from .inception import InceptionV3