Skip to content

Commit

Permalink
v0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
w86763777 committed Jul 9, 2021
1 parent 0d34197 commit 7991c96
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 11 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ __pycache__

build
dist
pytorch_gan_metrics.egg-info
pytorch_gan_metrics-0.3.0
pytorch_gan_metrics.egg-info
.tox
2 changes: 1 addition & 1 deletion pytorch_gan_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
get_inception_score_and_fid_from_directory
]

__version__ = '0.2.0'
__version__ = '0.3.0'
7 changes: 5 additions & 2 deletions pytorch_gan_metrics/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
from .inception import InceptionV3


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def get_inception_feature(
images: Union[List[torch.FloatTensor], DataLoader],
dims: List[int],
batch_size: int = 50,
use_torch: bool = False,
verbose: bool = False
verbose: bool = False,
device: torch.device = torch.device('cuda:0'),
) -> Union[torch.FloatTensor, np.ndarray]:
"""Calculate Inception Score and FID.
For each image, only a forward propagation is required to
Expand All @@ -38,6 +39,7 @@ def get_inception_feature(
accelerated by GPU.
verbose: Set verbose to False for disabling progress bar. Otherwise,
the progress bar is showing when calculating activations.
device: the torch device which is used to calculate inception feature
Returns:
inception_score: float tuple, (mean, std)
fid: float
Expand Down Expand Up @@ -146,6 +148,7 @@ def calculate_frechet_inception_distance(
sigma: np.ndarray,
use_torch: bool = False,
eps: float = 1e-6,
device: torch.device = torch.device('cuda:0'),
) -> float:
if use_torch:
m1 = torch.mean(acts, axis=0)
Expand Down
3 changes: 2 additions & 1 deletion pytorch_gan_metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def get_inception_score_and_fid(
enableb, the backend linalg is implemented by torch, the
results are not guaranteed to be consistent with numpy, but
the speed can be accelerated by GPU.
**kwargs: Please refer to `get_inception_feature` for other arguments.
**kwargs: Please refer to `core.get_inception_feature` for other
arguments.
Returns:
inception_score: float tuple, (mean, std)
fid: float
Expand Down
8 changes: 2 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
import setuptools


import pytorch_gan_metrics


def read(rel_path):
base_path = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(base_path, rel_path), 'r') as f:
Expand All @@ -14,7 +11,7 @@ def read(rel_path):
if __name__ == '__main__':
setuptools.setup(
name='pytorch_gan_metrics',
version=pytorch_gan_metrics.__version__,
version='0.3.0',
author='Yi-Lun Wu',
author_email='w86763777@gmail.com',
description=('Package for calculating GAN metrics using Pytorch'),
Expand All @@ -34,12 +31,11 @@ def read(rel_path):
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
],
python_requires='>=3.5',
python_requires='>=3.6',
install_requires=[
'tqdm',
'scipy==1.5.4',
'torch>=1.5.0,<=1.8.1',
'torchvision>=0.6.0,<=0.9.1',
],
setup_requires=['flake8'],
)
6 changes: 6 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[tox]
envlist = py38

[testenv]
commands =
python -m test.test

0 comments on commit 7991c96

Please sign in to comment.