diff --git a/hubconf.py b/hubconf.py index 6c498f3..a7cbade 100644 --- a/hubconf.py +++ b/hubconf.py @@ -4,13 +4,13 @@ ## Users can get the diverse models of pytorch_gan_zoo by calling hub_model = hub.load( - '??/pytorch_gan_zoo:master', + 'facebookresearch/pytorch_gan_zoo:master', $MODEL_NAME, # config = None, useGPU = True, pretrained=False) # (Not pretrained models online yet) -Available model'names are [DCGAN, PGAN]. +Available model'names are [DCGAN, PGAN, StyleGAN]. The config option should be a dictionnary defining the training parameters of the model. See ??/pytorch_gan_zoo/models/trainer/standard_configurations to see all possible options @@ -99,6 +99,28 @@ def PGAN(pretrained=False, *args, **kwargs): return model +def StyleGAN(pretrained=False, *args, **kwargs): + """ + NVIDIA StyleGAN + pretrained (bool): load a 1024x1024 model trained on FlickrHQ + """ + from models.styleGAN import StyleGAN + if 'config' not in kwargs or kwargs['config'] is None: + kwargs['config'] = {} + + model = StyleGAN(useGPU=kwargs.get('useGPU', True), + storeAVG=True, + **kwargs['config']) + + checkpoint = 'https://dl.fbaipublicfiles.com/gan_zoo/StyleGAN/FFHQ_styleGAN-7cbdec00.pth' + if pretrained: + print("Loading default model : Flickr-HQ") + state_dict = model_zoo.load_url(checkpoint, + map_location='cpu') + model.load_state_dict(state_dict) + return model + + def DCGAN(pretrained=False, *args, **kwargs): """ DCGAN basic model